| import argparse |
| import json |
| from pathlib import Path |
| from safetensors import safe_open |
|
|
|
|
| def check_model_shape(model_path: str): |
| """Inspects a model's config and weights to determine its MLP structure.""" |
| model_path = Path(model_path) |
| config_path = model_path / "config.json" |
| weights_path = model_path / "model.safetensors" |
|
|
| if not config_path.exists(): |
| print(f"Error: config.json not found in {model_path}") |
| return |
|
|
| if not weights_path.exists(): |
| print(f"Error: model.safetensors not found in {model_path}") |
| return |
|
|
| print(f"--- Checking model shape in {model_path} ---") |
|
|
| |
| with open(config_path, "r") as f: |
| config = json.load(f) |
|
|
| has_dual_mlp_config = config.get("intermediate_size_mlp", 0) > 0 |
| print(f"Config has 'intermediate_size_mlp': {has_dual_mlp_config}") |
|
|
| |
| has_dual_mlp_weights = False |
| try: |
| with safe_open(weights_path, framework="mlx") as f: |
| weight_keys = f.keys() |
| |
| |
| for key in weight_keys: |
| if ( |
| "mlp" in key |
| and "gate_proj" not in key |
| and "up_proj" not in key |
| and "down_proj" not in key |
| ): |
| print(f"Found potential dual-branch weight: {key}") |
| has_dual_mlp_weights = True |
| break |
| except Exception as e: |
| print(f"Could not read weights from model.safetensors: {e}") |
| return |
|
|
| print(f"Found potential dual-branch MLP weights: {has_dual_mlp_weights}") |
|
|
| |
| print("\n--- Conclusion ---") |
| if has_dual_mlp_config and has_dual_mlp_weights: |
| print("✅ The model appears to be a DUAL-BRANCH MLP variant.") |
| elif has_dual_mlp_config and not has_dual_mlp_weights: |
| print( |
| "⚠️ The model configuration suggests a dual-branch MLP, but no corresponding weights were found." |
| ) |
| print(" It will likely run as a SINGLE-BRANCH model.") |
| else: |
| print("✅ The model appears to be a SINGLE-BRANCH MLP variant.") |
| print("--------------------\n") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Check the MLP shape of a model variant." |
| ) |
| parser.add_argument( |
| "model_path", |
| type=str, |
| nargs="?", |
| default=".", |
| help="Path to the model directory to check.", |
| ) |
| args = parser.parse_args() |
|
|
| check_model_shape(args.model_path) |
|
|