Fix _init_weights and RotaryEmbedding for transformers v5.x compatibility

#10
by apsys - opened

Fix _init_weights and RotaryEmbedding initialization (for transformers 5.x)

_init_weights was using .data.normal_() directly on tensors, which bypasses the _is_hf_initialized guard in transformers v5.x. Since v5.x loads on meta device first then calls initialize_weights() post-checkpoint, this was silently re-randomizing every Linear and Embedding after from_pretrained. Model loads fine, outputs garbage. Switched to torch.nn.init.normal_() / zeros_() so the guard works.

Also, RotaryEmbedding.__init__ KeyErrors on "default" rope type - ROPE_INIT_FUNCTIONS just doesn't have that key, and Ring-mini-2.0 has rope_scaling=None so it always hits this path. Handled default inline. While at it, forced float32 for the inv_freq computation because rope_theta=600k overflows bf16 trivially.

apsys changed pull request status to open
inclusionAI org

@apsys Thanks for your attention and for sharing the code. 🤝
I noticed that the partial_rotary_factor parameter doesn’t seem to be handled—was this intentionally omitted?

# code in transformers v4.56
def _compute_default_rope_parameters(
...
    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
    head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
    dim = int(head_dim * partial_rotary_factor)
...

If you have any before/after comparison results for the change, it would be great if you could share them as well. Thanks again.

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment