File size: 5,250 Bytes
2997d61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
784595b
2997d61
 
 
784595b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2997d61
784595b
 
 
 
 
 
 
 
 
 
 
 
 
2997d61
 
 
 
 
784595b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import pkgutil
import re
from transformers import AutoConfig, AutoModelForCausalLM
import yaml

from stripedhyena.utils import dotdict
from stripedhyena.model import StripedHyena
from stripedhyena.tokenizer import CharLevelTokenizer


MODEL_NAMES = [
    'evo-1.5-8k-base',
    'evo-1-8k-base',
    'evo-1-131k-base',
    'evo-1-8k-crispr',
    'evo-1-8k-transposon',
]

class Evo:
    def __init__(self, model_name: str = MODEL_NAMES[1], device: str = None):
        """
        Loads an Evo model checkpoint given a model name.
        If the checkpoint does not exist, we automatically download it from HuggingFace.
        """
        self.device = device

        # Check model name.

        if model_name not in MODEL_NAMES:
            raise ValueError(
                f'Invalid model name {model_name}. Should be one of: '
                f'{", ".join(MODEL_NAMES)}.'
            )

        # Assign config path.

        if model_name == 'evo-1-8k-base' or \
           model_name == 'evo-1-8k-crispr' or \
           model_name == 'evo-1-8k-transposon' or \
           model_name == 'evo-1.5-8k-base':
            config_path = 'configs/evo-1-8k-base_inference.yml'
        elif model_name == 'evo-1-131k-base':
            config_path = 'configs/evo-1-131k-base_inference.yml'
        else:
            raise ValueError(
                f'Invalid model name {model_name}. Should be one of: '
                f'{", ".join(MODEL_NAMES)}.'
            )

        # Load model.

        self.model = load_checkpoint(
            model_name=model_name,
            config_path=config_path,
            device=self.device
        )

        # Load tokenizer.

        self.tokenizer = CharLevelTokenizer(512)

        
HF_MODEL_NAME_MAP = {
    'evo-1.5-8k-base': 'evo-design/evo-1.5-8k-base',
    'evo-1-8k-base': 'togethercomputer/evo-1-8k-base',
    'evo-1-131k-base': 'togethercomputer/evo-1-131k-base',
    'evo-1-8k-crispr': 'LongSafari/evo-1-8k-crispr',
    'evo-1-8k-transposon': 'LongSafari/evo-1-8k-transposon',
}

def load_checkpoint(
    model_name: str = MODEL_NAMES[1],
    config_path: str = 'evo/configs/evo-1-131k-base_inference.yml',
    device: str = None,
    *args, **kwargs
):
    """
    Load checkpoint from HuggingFace and place it into SH model.
    """

    # Map model name to HuggingFace model name.
    hf_model_name = HF_MODEL_NAME_MAP[model_name]

    # Load SH config first (local)
    config = yaml.safe_load(pkgutil.get_data(__name__, config_path))
    global_config = dotdict(config, Loader=yaml.FullLoader)

    try:
        # Try to load from HuggingFace Hub
        model_config = AutoConfig.from_pretrained(
            hf_model_name,
            trust_remote_code=True,
            revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main',
        )
        model_config.use_cache = True

        # Load pretrained model from HuggingFace
        hf_model = AutoModelForCausalLM.from_pretrained(
            hf_model_name,
            config=model_config,
            trust_remote_code=True,
            revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main',
        )

        # Extract state dict from HuggingFace model
        state_dict = hf_model.backbone.state_dict()
        del hf_model
        del model_config

        # Load into StripedHyena model with our config
        model = StripedHyena(global_config)
        model.load_state_dict(state_dict, strict=True)
        
        # Fix the tokenizer import issue by copying files to HF cache
        _fix_hf_tokenizer_cache(hf_model_name)
        
    except Exception as e:
        # If HuggingFace download fails, initialize from scratch
        print(f"Warning: Could not load pretrained weights from HuggingFace: {e}")
        print("Initializing model with random weights...")
        model = StripedHyena(global_config)
    
    model.to_bfloat16_except_poles_residues()
    if device is not None:
        model = model.to(device)

    return model


def _fix_hf_tokenizer_cache(hf_model_name):
    """Copy tokenizer files to HuggingFace cache after download."""
    import shutil
    from pathlib import Path
    
    try:
        hf_cache = Path.home() / ".cache" / "huggingface" / "modules" / "transformers_modules"
        # Get our local files
        import stripedhyena
        stripedhyena_path = Path(stripedhyena.__file__).parent
        local_tokenizer = stripedhyena_path / "tokenizer.py"
        local_utils = stripedhyena_path / "utils.py"
        
        if not local_tokenizer.exists():
            return
        
        # Find the model cache directory
        model_short_name = hf_model_name.split("/")[-1]  # e.g., "evo-1-8k-base"
        model_cache = hf_cache / hf_model_name
        
        if model_cache.exists():
            # Copy to all version subdirectories
            for version_dir in model_cache.iterdir():
                if version_dir.is_dir():
                    shutil.copy2(local_tokenizer, version_dir / "tokenizer.py")
                    shutil.copy2(local_utils, version_dir / "utils.py")
                    print(f"✓ Fixed tokenizer cache for {model_short_name}")
    except Exception as e:
        print(f"Warning: Could not fix HF cache: {e}")