Evo-App / evo /models.py
sochasticbackup's picture
added model support and caching
784595b
import pkgutil
import re
from transformers import AutoConfig, AutoModelForCausalLM
import yaml
from stripedhyena.utils import dotdict
from stripedhyena.model import StripedHyena
from stripedhyena.tokenizer import CharLevelTokenizer
MODEL_NAMES = [
'evo-1.5-8k-base',
'evo-1-8k-base',
'evo-1-131k-base',
'evo-1-8k-crispr',
'evo-1-8k-transposon',
]
class Evo:
def __init__(self, model_name: str = MODEL_NAMES[1], device: str = None):
"""
Loads an Evo model checkpoint given a model name.
If the checkpoint does not exist, we automatically download it from HuggingFace.
"""
self.device = device
# Check model name.
if model_name not in MODEL_NAMES:
raise ValueError(
f'Invalid model name {model_name}. Should be one of: '
f'{", ".join(MODEL_NAMES)}.'
)
# Assign config path.
if model_name == 'evo-1-8k-base' or \
model_name == 'evo-1-8k-crispr' or \
model_name == 'evo-1-8k-transposon' or \
model_name == 'evo-1.5-8k-base':
config_path = 'configs/evo-1-8k-base_inference.yml'
elif model_name == 'evo-1-131k-base':
config_path = 'configs/evo-1-131k-base_inference.yml'
else:
raise ValueError(
f'Invalid model name {model_name}. Should be one of: '
f'{", ".join(MODEL_NAMES)}.'
)
# Load model.
self.model = load_checkpoint(
model_name=model_name,
config_path=config_path,
device=self.device
)
# Load tokenizer.
self.tokenizer = CharLevelTokenizer(512)
HF_MODEL_NAME_MAP = {
'evo-1.5-8k-base': 'evo-design/evo-1.5-8k-base',
'evo-1-8k-base': 'togethercomputer/evo-1-8k-base',
'evo-1-131k-base': 'togethercomputer/evo-1-131k-base',
'evo-1-8k-crispr': 'LongSafari/evo-1-8k-crispr',
'evo-1-8k-transposon': 'LongSafari/evo-1-8k-transposon',
}
def load_checkpoint(
model_name: str = MODEL_NAMES[1],
config_path: str = 'evo/configs/evo-1-131k-base_inference.yml',
device: str = None,
*args, **kwargs
):
"""
Load checkpoint from HuggingFace and place it into SH model.
"""
# Map model name to HuggingFace model name.
hf_model_name = HF_MODEL_NAME_MAP[model_name]
# Load SH config first (local)
config = yaml.safe_load(pkgutil.get_data(__name__, config_path))
global_config = dotdict(config, Loader=yaml.FullLoader)
try:
# Try to load from HuggingFace Hub
model_config = AutoConfig.from_pretrained(
hf_model_name,
trust_remote_code=True,
revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main',
)
model_config.use_cache = True
# Load pretrained model from HuggingFace
hf_model = AutoModelForCausalLM.from_pretrained(
hf_model_name,
config=model_config,
trust_remote_code=True,
revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main',
)
# Extract state dict from HuggingFace model
state_dict = hf_model.backbone.state_dict()
del hf_model
del model_config
# Load into StripedHyena model with our config
model = StripedHyena(global_config)
model.load_state_dict(state_dict, strict=True)
# Fix the tokenizer import issue by copying files to HF cache
_fix_hf_tokenizer_cache(hf_model_name)
except Exception as e:
# If HuggingFace download fails, initialize from scratch
print(f"Warning: Could not load pretrained weights from HuggingFace: {e}")
print("Initializing model with random weights...")
model = StripedHyena(global_config)
model.to_bfloat16_except_poles_residues()
if device is not None:
model = model.to(device)
return model
def _fix_hf_tokenizer_cache(hf_model_name):
"""Copy tokenizer files to HuggingFace cache after download."""
import shutil
from pathlib import Path
try:
hf_cache = Path.home() / ".cache" / "huggingface" / "modules" / "transformers_modules"
# Get our local files
import stripedhyena
stripedhyena_path = Path(stripedhyena.__file__).parent
local_tokenizer = stripedhyena_path / "tokenizer.py"
local_utils = stripedhyena_path / "utils.py"
if not local_tokenizer.exists():
return
# Find the model cache directory
model_short_name = hf_model_name.split("/")[-1] # e.g., "evo-1-8k-base"
model_cache = hf_cache / hf_model_name
if model_cache.exists():
# Copy to all version subdirectories
for version_dir in model_cache.iterdir():
if version_dir.is_dir():
shutil.copy2(local_tokenizer, version_dir / "tokenizer.py")
shutil.copy2(local_utils, version_dir / "utils.py")
print(f"✓ Fixed tokenizer cache for {model_short_name}")
except Exception as e:
print(f"Warning: Could not fix HF cache: {e}")