|
|
import pkgutil |
|
|
import re |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
import yaml |
|
|
|
|
|
from stripedhyena.utils import dotdict |
|
|
from stripedhyena.model import StripedHyena |
|
|
from stripedhyena.tokenizer import CharLevelTokenizer |
|
|
|
|
|
|
|
|
MODEL_NAMES = [ |
|
|
'evo-1.5-8k-base', |
|
|
'evo-1-8k-base', |
|
|
'evo-1-131k-base', |
|
|
'evo-1-8k-crispr', |
|
|
'evo-1-8k-transposon', |
|
|
] |
|
|
|
|
|
class Evo: |
|
|
def __init__(self, model_name: str = MODEL_NAMES[1], device: str = None): |
|
|
""" |
|
|
Loads an Evo model checkpoint given a model name. |
|
|
If the checkpoint does not exist, we automatically download it from HuggingFace. |
|
|
""" |
|
|
self.device = device |
|
|
|
|
|
|
|
|
|
|
|
if model_name not in MODEL_NAMES: |
|
|
raise ValueError( |
|
|
f'Invalid model name {model_name}. Should be one of: ' |
|
|
f'{", ".join(MODEL_NAMES)}.' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if model_name == 'evo-1-8k-base' or \ |
|
|
model_name == 'evo-1-8k-crispr' or \ |
|
|
model_name == 'evo-1-8k-transposon' or \ |
|
|
model_name == 'evo-1.5-8k-base': |
|
|
config_path = 'configs/evo-1-8k-base_inference.yml' |
|
|
elif model_name == 'evo-1-131k-base': |
|
|
config_path = 'configs/evo-1-131k-base_inference.yml' |
|
|
else: |
|
|
raise ValueError( |
|
|
f'Invalid model name {model_name}. Should be one of: ' |
|
|
f'{", ".join(MODEL_NAMES)}.' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.model = load_checkpoint( |
|
|
model_name=model_name, |
|
|
config_path=config_path, |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.tokenizer = CharLevelTokenizer(512) |
|
|
|
|
|
|
|
|
HF_MODEL_NAME_MAP = { |
|
|
'evo-1.5-8k-base': 'evo-design/evo-1.5-8k-base', |
|
|
'evo-1-8k-base': 'togethercomputer/evo-1-8k-base', |
|
|
'evo-1-131k-base': 'togethercomputer/evo-1-131k-base', |
|
|
'evo-1-8k-crispr': 'LongSafari/evo-1-8k-crispr', |
|
|
'evo-1-8k-transposon': 'LongSafari/evo-1-8k-transposon', |
|
|
} |
|
|
|
|
|
def load_checkpoint( |
|
|
model_name: str = MODEL_NAMES[1], |
|
|
config_path: str = 'evo/configs/evo-1-131k-base_inference.yml', |
|
|
device: str = None, |
|
|
*args, **kwargs |
|
|
): |
|
|
""" |
|
|
Load checkpoint from HuggingFace and place it into SH model. |
|
|
""" |
|
|
|
|
|
|
|
|
hf_model_name = HF_MODEL_NAME_MAP[model_name] |
|
|
|
|
|
|
|
|
config = yaml.safe_load(pkgutil.get_data(__name__, config_path)) |
|
|
global_config = dotdict(config, Loader=yaml.FullLoader) |
|
|
|
|
|
try: |
|
|
|
|
|
model_config = AutoConfig.from_pretrained( |
|
|
hf_model_name, |
|
|
trust_remote_code=True, |
|
|
revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main', |
|
|
) |
|
|
model_config.use_cache = True |
|
|
|
|
|
|
|
|
hf_model = AutoModelForCausalLM.from_pretrained( |
|
|
hf_model_name, |
|
|
config=model_config, |
|
|
trust_remote_code=True, |
|
|
revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main', |
|
|
) |
|
|
|
|
|
|
|
|
state_dict = hf_model.backbone.state_dict() |
|
|
del hf_model |
|
|
del model_config |
|
|
|
|
|
|
|
|
model = StripedHyena(global_config) |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
|
|
|
_fix_hf_tokenizer_cache(hf_model_name) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Warning: Could not load pretrained weights from HuggingFace: {e}") |
|
|
print("Initializing model with random weights...") |
|
|
model = StripedHyena(global_config) |
|
|
|
|
|
model.to_bfloat16_except_poles_residues() |
|
|
if device is not None: |
|
|
model = model.to(device) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def _fix_hf_tokenizer_cache(hf_model_name): |
|
|
"""Copy tokenizer files to HuggingFace cache after download.""" |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
|
|
|
try: |
|
|
hf_cache = Path.home() / ".cache" / "huggingface" / "modules" / "transformers_modules" |
|
|
|
|
|
import stripedhyena |
|
|
stripedhyena_path = Path(stripedhyena.__file__).parent |
|
|
local_tokenizer = stripedhyena_path / "tokenizer.py" |
|
|
local_utils = stripedhyena_path / "utils.py" |
|
|
|
|
|
if not local_tokenizer.exists(): |
|
|
return |
|
|
|
|
|
|
|
|
model_short_name = hf_model_name.split("/")[-1] |
|
|
model_cache = hf_cache / hf_model_name |
|
|
|
|
|
if model_cache.exists(): |
|
|
|
|
|
for version_dir in model_cache.iterdir(): |
|
|
if version_dir.is_dir(): |
|
|
shutil.copy2(local_tokenizer, version_dir / "tokenizer.py") |
|
|
shutil.copy2(local_utils, version_dir / "utils.py") |
|
|
print(f"✓ Fixed tokenizer cache for {model_short_name}") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not fix HF cache: {e}") |
|
|
|