| | import safetensors.torch
|
| | from safetensors import safe_open
|
| | import torch
|
| |
|
| | def patch_final_layer_adaLN(state_dict, prefix="lora_unet_final_layer", verbose=True):
|
| | """
|
| | Add dummy adaLN weights if missing, using final_layer_linear shapes as reference.
|
| | Args:
|
| | state_dict (dict): keys -> tensors
|
| | prefix (str): base name for final_layer keys
|
| | verbose (bool): print debug info
|
| | Returns:
|
| | dict: patched state_dict
|
| | """
|
| | final_layer_linear_down = None
|
| | final_layer_linear_up = None
|
| |
|
| | adaLN_down_key = f"{prefix}_adaLN_modulation_1.lora_down.weight"
|
| | adaLN_up_key = f"{prefix}_adaLN_modulation_1.lora_up.weight"
|
| | linear_down_key = f"{prefix}_linear.lora_down.weight"
|
| | linear_up_key = f"{prefix}_linear.lora_up.weight"
|
| |
|
| | if verbose:
|
| | print(f"\nπ Checking for final_layer keys with prefix: '{prefix}'")
|
| | print(f" Linear down: {linear_down_key}")
|
| | print(f" Linear up: {linear_up_key}")
|
| |
|
| | if linear_down_key in state_dict:
|
| | final_layer_linear_down = state_dict[linear_down_key]
|
| | if linear_up_key in state_dict:
|
| | final_layer_linear_up = state_dict[linear_up_key]
|
| |
|
| | has_adaLN = adaLN_down_key in state_dict and adaLN_up_key in state_dict
|
| | has_linear = final_layer_linear_down is not None and final_layer_linear_up is not None
|
| |
|
| | if verbose:
|
| | print(f" β
Has final_layer.linear: {has_linear}")
|
| | print(f" β
Has final_layer.adaLN_modulation_1: {has_adaLN}")
|
| |
|
| | if has_linear and not has_adaLN:
|
| | dummy_down = torch.zeros_like(final_layer_linear_down)
|
| | dummy_up = torch.zeros_like(final_layer_linear_up)
|
| | state_dict[adaLN_down_key] = dummy_down
|
| | state_dict[adaLN_up_key] = dummy_up
|
| |
|
| | if verbose:
|
| | print(f"β
Added dummy adaLN weights:")
|
| | print(f" {adaLN_down_key} (shape: {dummy_down.shape})")
|
| | print(f" {adaLN_up_key} (shape: {dummy_up.shape})")
|
| | else:
|
| | if verbose:
|
| | print("β
No patch needed β adaLN weights already present or no final_layer.linear found.")
|
| |
|
| | return state_dict
|
| |
|
| |
|
| | def main():
|
| | print("π Universal final_layer.adaLN LoRA patcher (.safetensors)")
|
| | input_path = input("Enter path to input LoRA .safetensors file: ").strip()
|
| | output_path = input("Enter path to save patched LoRA .safetensors file: ").strip()
|
| |
|
| |
|
| | state_dict = {}
|
| | with safe_open(input_path, framework="pt", device="cpu") as f:
|
| | for k in f.keys():
|
| | state_dict[k] = f.get_tensor(k)
|
| |
|
| | print(f"\nβ
Loaded {len(state_dict)} tensors from: {input_path}")
|
| |
|
| |
|
| | final_keys = [k for k in state_dict if "final_layer" in k]
|
| | if final_keys:
|
| | print("\nπ Found these final_layer-related keys:")
|
| | for k in final_keys:
|
| | print(f" {k}")
|
| | else:
|
| | print("\nβ οΈ No keys with 'final_layer' found β will try patch anyway.")
|
| |
|
| |
|
| | prefixes = [
|
| | "lora_unet_final_layer",
|
| | "final_layer",
|
| | "base_model.model.final_layer"
|
| | ]
|
| | patched = False
|
| |
|
| | for prefix in prefixes:
|
| | before = len(state_dict)
|
| | state_dict = patch_final_layer_adaLN(state_dict, prefix=prefix)
|
| | after = len(state_dict)
|
| | if after > before:
|
| | patched = True
|
| | break
|
| |
|
| | if not patched:
|
| | print("\nβΉοΈ No patch applied β either adaLN already exists or no final_layer.linear found.")
|
| |
|
| |
|
| | safetensors.torch.save_file(state_dict, output_path)
|
| | print(f"\nβ
Patched file saved to: {output_path}")
|
| | print(f" Total tensors now: {len(state_dict)}")
|
| |
|
| |
|
| | print("\nπ Verifying patched keys:")
|
| | with safe_open(output_path, framework="pt", device="cpu") as f:
|
| | keys = list(f.keys())
|
| | for k in keys:
|
| | if "final_layer" in k:
|
| | print(f" {k}")
|
| |
|
| | has_adaLN_after = any("adaLN_modulation_1" in k for k in keys)
|
| | print(f"β
Contains adaLN after patch: {has_adaLN_after}")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|