| | |
| | |
| | |
| |
|
| | import argparse |
| | import math |
| | import os |
| | import torch |
| | from safetensors.torch import load_file, save_file, safe_open |
| | from tqdm import tqdm |
| | from library import train_util, model_util |
| | import numpy as np |
| | from library.utils import setup_logging |
| | setup_logging() |
| | import logging |
| | logger = logging.getLogger(__name__) |
| |
|
| | def load_state_dict(file_name): |
| | if model_util.is_safetensors(file_name): |
| | sd = load_file(file_name) |
| | with safe_open(file_name, framework="pt") as f: |
| | metadata = f.metadata() |
| | else: |
| | sd = torch.load(file_name, map_location="cpu") |
| | metadata = None |
| |
|
| | return sd, metadata |
| |
|
| |
|
| | def save_to_file(file_name, model, metadata): |
| | if model_util.is_safetensors(file_name): |
| | save_file(model, file_name, metadata) |
| | else: |
| | torch.save(model, file_name) |
| |
|
| |
|
| | def split_lora_model(lora_sd, unit): |
| | max_rank = 0 |
| |
|
| | |
| | for key, value in lora_sd.items(): |
| | if "lora_down" in key: |
| | rank = value.size()[0] |
| | if rank > max_rank: |
| | max_rank = rank |
| | logger.info(f"Max rank: {max_rank}") |
| |
|
| | rank = unit |
| | split_models = [] |
| | new_alpha = None |
| | while rank < max_rank: |
| | logger.info(f"Splitting rank {rank}") |
| | new_sd = {} |
| | for key, value in lora_sd.items(): |
| | if "lora_down" in key: |
| | new_sd[key] = value[:rank].contiguous() |
| | elif "lora_up" in key: |
| | new_sd[key] = value[:, :rank].contiguous() |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | |
| | new_sd[key] = value |
| |
|
| | split_models.append((new_sd, rank, new_alpha)) |
| | rank += unit |
| |
|
| | return max_rank, split_models |
| |
|
| |
|
| | def split(args): |
| | logger.info("loading Model...") |
| | lora_sd, metadata = load_state_dict(args.model) |
| |
|
| | logger.info("Splitting Model...") |
| | original_rank, split_models = split_lora_model(lora_sd, args.unit) |
| |
|
| | comment = metadata.get("ss_training_comment", "") |
| | for state_dict, new_rank, new_alpha in split_models: |
| | |
| | if metadata is None: |
| | new_metadata = {} |
| | else: |
| | new_metadata = metadata.copy() |
| |
|
| | new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" |
| | new_metadata["ss_network_dim"] = str(new_rank) |
| | |
| |
|
| | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) |
| | metadata["sshs_model_hash"] = model_hash |
| | metadata["sshs_legacy_hash"] = legacy_hash |
| |
|
| | filename, ext = os.path.splitext(args.save_to) |
| | model_file_name = filename + f"-{new_rank:04d}{ext}" |
| |
|
| | logger.info(f"saving model to: {model_file_name}") |
| | save_to_file(model_file_name, state_dict, new_metadata) |
| |
|
| |
|
| | def setup_parser() -> argparse.ArgumentParser: |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ") |
| | parser.add_argument( |
| | "--save_to", |
| | type=str, |
| | default=None, |
| | help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors", |
| | ) |
| | parser.add_argument( |
| | "--model", |
| | type=str, |
| | default=None, |
| | help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors", |
| | ) |
| |
|
| | return parser |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = setup_parser() |
| |
|
| | args = parser.parse_args() |
| | split(args) |
| |
|