Spaces:
Running on Zero
Running on Zero
| """ | |
| ACE-Step 1.5 LoRA Training Engine | |
| Handles dataset building, VAE encoding, and flow-matching LoRA training | |
| of the DiT decoder. Designed to work with the existing AceStepHandler. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import math | |
| import time | |
| import random | |
| import hashlib | |
| import argparse | |
| import tempfile | |
| from pathlib import Path | |
| from dataclasses import dataclass, field, asdict | |
| from typing import Optional, List, Dict, Any, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import soundfile as sf | |
| import numpy as np | |
| from loguru import logger | |
| from tqdm import tqdm | |
| # --------------------------------------------------------------------------- | |
| # Dataset helpers | |
| # --------------------------------------------------------------------------- | |
| AUDIO_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aac"} | |
| class TrackEntry: | |
| """One audio file + its metadata.""" | |
| audio_path: str | |
| caption: str = "" | |
| lyrics: str = "" | |
| bpm: Optional[int] = None | |
| keyscale: str = "" | |
| timesignature: str = "4/4" | |
| vocal_language: str = "en" | |
| duration: Optional[float] = None # seconds (measured at scan time) | |
| def _load_track_entry(audio_path: Path) -> TrackEntry: | |
| """Load one track + optional sidecar metadata.""" | |
| sidecar = audio_path.with_suffix(".json") | |
| meta: Dict[str, Any] = {} | |
| if sidecar.exists(): | |
| try: | |
| meta = json.loads(sidecar.read_text(encoding="utf-8")) | |
| except Exception as exc: | |
| logger.warning(f"Bad sidecar {sidecar}: {exc}") | |
| try: | |
| info = torchaudio.info(str(audio_path)) | |
| duration = info.num_frames / info.sample_rate | |
| except Exception: | |
| duration = meta.get("duration") | |
| return TrackEntry( | |
| audio_path=str(audio_path), | |
| caption=meta.get("caption", ""), | |
| lyrics=meta.get("lyrics", ""), | |
| bpm=meta.get("bpm"), | |
| keyscale=meta.get("keyscale", ""), | |
| timesignature=meta.get("timesignature", "4/4"), | |
| vocal_language=meta.get("vocal_language", "en"), | |
| duration=duration, | |
| ) | |
| def scan_dataset_folder(folder: str) -> List[TrackEntry]: | |
| """Scan *folder* for audio files and optional JSON sidecars. | |
| For every ``track.wav`` found, if ``track.json`` exists next to it the | |
| metadata fields are loaded from the sidecar. Missing sidecars are fine – | |
| the entry will have empty metadata that can be filled later. | |
| """ | |
| folder = Path(folder) | |
| if not folder.is_dir(): | |
| raise FileNotFoundError(f"Dataset folder not found: {folder}") | |
| entries: List[TrackEntry] = [] | |
| for audio_path in sorted(folder.rglob("*")): | |
| if audio_path.suffix.lower() not in AUDIO_EXTENSIONS: | |
| continue | |
| entries.append(_load_track_entry(audio_path)) | |
| logger.info(f"Scanned {len(entries)} audio files in {folder}") | |
| return entries | |
| def scan_uploaded_files(file_paths: List[str]) -> List[TrackEntry]: | |
| """Build entries from dropped/uploaded files. | |
| Supports uploading audio files together with optional ``.json`` sidecars. | |
| Sidecars are matched by basename stem (``song.mp3`` <-> ``song.json``). | |
| """ | |
| meta_by_stem: Dict[str, Dict[str, Any]] = {} | |
| for path in file_paths: | |
| p = Path(path) | |
| if not p.exists() or p.suffix.lower() != ".json": | |
| continue | |
| try: | |
| meta_by_stem[p.stem] = json.loads(p.read_text(encoding="utf-8")) | |
| except Exception as exc: | |
| logger.warning(f"Bad uploaded sidecar {p}: {exc}") | |
| entries: List[TrackEntry] = [] | |
| for path in file_paths: | |
| p = Path(path) | |
| if not p.exists() or p.suffix.lower() not in AUDIO_EXTENSIONS: | |
| continue | |
| uploaded_meta = meta_by_stem.get(p.stem) | |
| if uploaded_meta is None: | |
| entries.append(_load_track_entry(p)) | |
| continue | |
| try: | |
| info = torchaudio.info(str(p)) | |
| duration = info.num_frames / info.sample_rate | |
| except Exception: | |
| duration = uploaded_meta.get("duration") | |
| bpm_val = uploaded_meta.get("bpm") | |
| if isinstance(bpm_val, str) and bpm_val.strip(): | |
| try: | |
| bpm_val = int(float(bpm_val)) | |
| except Exception: | |
| bpm_val = None | |
| entries.append( | |
| TrackEntry( | |
| audio_path=str(p), | |
| caption=uploaded_meta.get("caption", "") or "", | |
| lyrics=uploaded_meta.get("lyrics", "") or "", | |
| bpm=bpm_val if isinstance(bpm_val, int) else None, | |
| keyscale=uploaded_meta.get("keyscale", "") or "", | |
| timesignature=uploaded_meta.get("timesignature", "4/4") or "4/4", | |
| vocal_language=uploaded_meta.get("vocal_language", uploaded_meta.get("language", "en")) or "en", | |
| duration=duration, | |
| ) | |
| ) | |
| logger.info( | |
| "Loaded {} uploaded audio files ({} uploaded sidecars detected)".format( | |
| len(entries), len(meta_by_stem) | |
| ) | |
| ) | |
| return entries | |
| # --------------------------------------------------------------------------- | |
| # Training hyper-parameters | |
| # --------------------------------------------------------------------------- | |
| class LoRATrainConfig: | |
| """All tuneable knobs for a LoRA run.""" | |
| # LoRA architecture | |
| lora_rank: int = 64 | |
| lora_alpha: int = 64 | |
| lora_dropout: float = 0.1 | |
| lora_target_modules: List[str] = field( | |
| default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] | |
| ) | |
| # Optimiser | |
| learning_rate: float = 1e-4 | |
| weight_decay: float = 0.01 | |
| optimizer: str = "adamw_8bit" # "adamw" | "adamw_8bit" | |
| max_grad_norm: float = 1.0 | |
| # Schedule | |
| warmup_ratio: float = 0.03 | |
| scheduler: str = "constant_with_warmup" | |
| # Training loop | |
| num_epochs: int = 50 | |
| batch_size: int = 1 | |
| gradient_accumulation_steps: int = 1 | |
| save_every_n_epochs: int = 10 | |
| log_every_n_steps: int = 5 | |
| # Flow matching | |
| shift: float = 3.0 # timestep shift factor | |
| # Audio pre-processing | |
| max_duration_sec: float = 240.0 # clamp audio to this length | |
| sample_rate: int = 48000 | |
| # Paths | |
| output_dir: str = "lora_output" | |
| resume_from: Optional[str] = None | |
| # Device | |
| device: str = "auto" | |
| dtype: str = "bf16" # "bf16" | "fp16" | "fp32" | |
| mixed_precision: bool = True | |
| # --------------------------------------------------------------------------- | |
| # Core trainer | |
| # --------------------------------------------------------------------------- | |
| class LoRATrainer: | |
| """Thin training loop that wraps the existing AceStepHandler.""" | |
| def __init__(self, handler, config: LoRATrainConfig): | |
| """ | |
| Args: | |
| handler: Initialised ``AceStepHandler`` (model, vae, text_encoder loaded). | |
| config: Training hyper-parameters. | |
| """ | |
| self.handler = handler | |
| self.cfg = config | |
| self.device = handler.device | |
| self.dtype = handler.dtype | |
| # Will be set during prepare() | |
| self.peft_model = None | |
| self.optimizer = None | |
| self.scheduler = None | |
| self.global_step = 0 | |
| self.current_epoch = 0 | |
| # Loss history for UI | |
| self.loss_history: List[Dict[str, Any]] = [] | |
| self._stop_requested = False | |
| # ------------------------------------------------------------------ | |
| # Setup | |
| # ------------------------------------------------------------------ | |
| def _resolve_lora_target_modules(model, requested_targets: Optional[List[str]]) -> List[str]: | |
| """Resolve LoRA target module suffixes against the actual decoder module names.""" | |
| linear_module_names = [ | |
| name for name, module in model.named_modules() if isinstance(module, torch.nn.Linear) | |
| ] | |
| def _exists_as_suffix(target: str) -> bool: | |
| return any(name.endswith(target) for name in linear_module_names) | |
| requested_targets = requested_targets or [] | |
| resolved = [target for target in requested_targets if _exists_as_suffix(target)] | |
| if resolved: | |
| return resolved | |
| fallback_groups = [ | |
| ["q_proj", "k_proj", "v_proj", "o_proj"], | |
| ["to_q", "to_k", "to_v", "to_out.0"], | |
| ["query", "key", "value", "out_proj"], | |
| ["wq", "wk", "wv", "wo"], | |
| ["qkv", "proj_out"], | |
| ] | |
| for group in fallback_groups: | |
| group_resolved = [target for target in group if _exists_as_suffix(target)] | |
| if len(group_resolved) >= 2: | |
| return group_resolved | |
| sample = ", ".join(linear_module_names[:30]) | |
| raise ValueError( | |
| "Could not find LoRA target modules in decoder. " | |
| f"Requested={requested_targets}. " | |
| f"Sample linear modules: {sample}" | |
| ) | |
| def prepare(self): | |
| """Attach LoRA adapters to the decoder and build the optimiser.""" | |
| import copy | |
| from peft import LoraConfig, PeftModel, TaskType, get_peft_model | |
| # Keep a backup of the plain base decoder so load/unload logic remains valid. | |
| if self.handler._base_decoder is None: | |
| self.handler._base_decoder = copy.deepcopy(self.handler.model.decoder) | |
| else: | |
| self.handler.model.decoder = copy.deepcopy(self.handler._base_decoder) | |
| self.handler.model.decoder = self.handler.model.decoder.to(self.device).to(self.dtype) | |
| self.handler.model.decoder.eval() | |
| resume_adapter = None | |
| if self.cfg.resume_from: | |
| adapter_cfg = os.path.join(self.cfg.resume_from, "adapter_config.json") | |
| if os.path.isfile(adapter_cfg): | |
| resume_adapter = self.cfg.resume_from | |
| if resume_adapter: | |
| logger.info(f"Loading existing LoRA adapter for resume: {resume_adapter}") | |
| self.peft_model = PeftModel.from_pretrained( | |
| self.handler.model.decoder, | |
| resume_adapter, | |
| is_trainable=True, | |
| ) | |
| else: | |
| resolved_targets = self._resolve_lora_target_modules( | |
| self.handler.model.decoder, | |
| self.cfg.lora_target_modules, | |
| ) | |
| logger.info(f"Using LoRA target modules: {resolved_targets}") | |
| peft_cfg = LoraConfig( | |
| r=self.cfg.lora_rank, | |
| lora_alpha=self.cfg.lora_alpha, | |
| lora_dropout=self.cfg.lora_dropout, | |
| target_modules=resolved_targets, | |
| bias="none", | |
| task_type=TaskType.FEATURE_EXTRACTION, | |
| ) | |
| self.peft_model = get_peft_model(self.handler.model.decoder, peft_cfg) | |
| self.peft_model.print_trainable_parameters() | |
| self.handler.model.decoder = self.peft_model | |
| self.handler.model.decoder.to(self.device).to(self.dtype) | |
| self.handler.model.decoder.train() | |
| self.handler.lora_loaded = True | |
| self.handler.use_lora = True | |
| # Build optimiser (only LoRA params are trainable) | |
| trainable_params = [p for p in self.peft_model.parameters() if p.requires_grad] | |
| if self.cfg.optimizer == "adamw_8bit": | |
| try: | |
| import bitsandbytes as bnb | |
| self.optimizer = bnb.optim.AdamW8bit( | |
| trainable_params, | |
| lr=self.cfg.learning_rate, | |
| weight_decay=self.cfg.weight_decay, | |
| ) | |
| except ImportError: | |
| logger.warning("bitsandbytes not found – falling back to standard AdamW") | |
| self.optimizer = torch.optim.AdamW( | |
| trainable_params, | |
| lr=self.cfg.learning_rate, | |
| weight_decay=self.cfg.weight_decay, | |
| ) | |
| else: | |
| self.optimizer = torch.optim.AdamW( | |
| trainable_params, | |
| lr=self.cfg.learning_rate, | |
| weight_decay=self.cfg.weight_decay, | |
| ) | |
| # Resume checkpoint state (after model/adapter restore). | |
| if self.cfg.resume_from and os.path.isfile( | |
| os.path.join(self.cfg.resume_from, "training_state.pt") | |
| ): | |
| state = torch.load( | |
| os.path.join(self.cfg.resume_from, "training_state.pt"), | |
| weights_only=False, | |
| ) | |
| try: | |
| self.optimizer.load_state_dict(state["optimizer"]) | |
| except Exception as exc: | |
| logger.warning(f"Could not restore optimizer state, continuing fresh optimizer: {exc}") | |
| self.global_step = int(state.get("global_step", 0)) | |
| # Saved epoch is completed epoch index; continue from next epoch. | |
| self.current_epoch = int(state.get("epoch", -1)) + 1 | |
| loss_path = os.path.join(self.cfg.resume_from, "loss_history.json") | |
| if os.path.isfile(loss_path): | |
| try: | |
| with open(loss_path, "r", encoding="utf-8") as f: | |
| self.loss_history = json.load(f) | |
| except Exception: | |
| pass | |
| logger.info( | |
| f"Resumed from {self.cfg.resume_from} " | |
| f"(epoch {self.current_epoch}, step {self.global_step})" | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Data loading | |
| # ------------------------------------------------------------------ | |
| def _coerce_audio_tensor(audio: Any) -> torch.Tensor: | |
| """Coerce decoded audio into torch.Tensor with shape [C, T].""" | |
| if isinstance(audio, list): | |
| audio = np.asarray(audio, dtype=np.float32) | |
| if isinstance(audio, np.ndarray): | |
| audio = torch.from_numpy(audio) | |
| if not torch.is_tensor(audio): | |
| raise TypeError(f"Unsupported audio type: {type(audio)}") | |
| # Ensure floating point for downstream resample/vae encode. | |
| if not torch.is_floating_point(audio): | |
| audio = audio.float() | |
| # Normalize dimensions to [C, T]. | |
| if audio.dim() == 1: | |
| audio = audio.unsqueeze(0) | |
| elif audio.dim() == 2: | |
| # Accept either [T, C] or [C, T]; transpose only when clearly [T, C]. | |
| if audio.shape[0] > audio.shape[1] and audio.shape[1] <= 8: | |
| audio = audio.transpose(0, 1) | |
| elif audio.dim() == 3: | |
| # If batched, take first item. | |
| audio = audio[0] | |
| else: | |
| raise ValueError(f"Unexpected audio dims: {tuple(audio.shape)}") | |
| return audio.contiguous() | |
| def _load_audio(self, path: str) -> torch.Tensor: | |
| """Load audio, resample to 48 kHz stereo, clamp to max_duration.""" | |
| try: | |
| wav, sr = torchaudio.load(path) | |
| except Exception as torchaudio_exc: | |
| # torchaudio on some Space images requires torchcodec for decode. | |
| # Fallback to soundfile so training can proceed without torchcodec. | |
| try: | |
| audio_np, sr = sf.read(path, dtype="float32", always_2d=True) | |
| wav = torch.from_numpy(audio_np.T) | |
| except Exception as sf_exc: | |
| raise RuntimeError( | |
| f"Failed to decode audio '{path}' with torchaudio ({torchaudio_exc}) " | |
| f"and soundfile ({sf_exc})." | |
| ) from sf_exc | |
| wav = self._coerce_audio_tensor(wav) | |
| # Resample if needed | |
| if sr != self.cfg.sample_rate: | |
| wav = torchaudio.functional.resample(wav, sr, self.cfg.sample_rate) | |
| # Convert mono → stereo | |
| if wav.shape[0] == 1: | |
| wav = wav.repeat(2, 1) | |
| elif wav.shape[0] > 2: | |
| wav = wav[:2] | |
| # Clamp length | |
| max_samples = int(self.cfg.max_duration_sec * self.cfg.sample_rate) | |
| if wav.shape[1] > max_samples: | |
| wav = wav[:, :max_samples] | |
| return wav # [2, T] | |
| def _encode_audio(self, wav: torch.Tensor) -> torch.Tensor: | |
| """Encode raw waveform → VAE latent on device.""" | |
| with torch.no_grad(): | |
| latent = self.handler._encode_audio_to_latents(wav) | |
| if latent.dim() == 2: | |
| latent = latent.unsqueeze(0) | |
| latent = latent.to(self.dtype) | |
| return latent | |
| def _build_text_embeddings(self, caption: str, lyrics: str): | |
| """Compute text & lyric embeddings using the text encoder.""" | |
| tokenizer = self.handler.text_tokenizer | |
| text_encoder = self.handler.text_encoder | |
| # Caption embedding | |
| text_tokens = tokenizer( | |
| caption or "", | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| text_hidden = text_encoder( | |
| input_ids=text_tokens["input_ids"] | |
| ).last_hidden_state.to(self.dtype) | |
| text_mask = text_tokens["attention_mask"].to(self.dtype) | |
| # Lyric embedding (token-level via embed_tokens) | |
| lyric_tokens = tokenizer( | |
| lyrics or "", | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| lyric_hidden = text_encoder.embed_tokens( | |
| lyric_tokens["input_ids"] | |
| ).to(self.dtype) | |
| lyric_mask = lyric_tokens["attention_mask"].to(self.dtype) | |
| return text_hidden, text_mask, lyric_hidden, lyric_mask | |
| # ------------------------------------------------------------------ | |
| # Flow matching loss | |
| # ------------------------------------------------------------------ | |
| def _flow_matching_loss( | |
| self, | |
| x1: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_attention_mask: torch.Tensor, | |
| context_latents: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Compute rectified-flow MSE loss for one sample. | |
| Notation follows ACE-Step convention: | |
| x0 = noise, x1 = clean latent | |
| xt = t * x0 + (1 - t) * x1 | |
| target velocity = x0 - x1 | |
| """ | |
| bsz = x1.shape[0] | |
| # Sample random timestep per element | |
| t = torch.rand(bsz, device=self.device, dtype=self.dtype) | |
| # Apply timestep shift: t_shifted = shift * t / (1 + (shift - 1) * t) | |
| if self.cfg.shift != 1.0: | |
| t = self.cfg.shift * t / (1.0 + (self.cfg.shift - 1.0) * t) | |
| t = t.clamp(1e-5, 1.0 - 1e-5) | |
| # Noise | |
| x0 = torch.randn_like(x1) | |
| # Interpolate | |
| t_expand = t.view(bsz, 1, 1) | |
| xt = t_expand * x0 + (1.0 - t_expand) * x1 | |
| # Target velocity | |
| velocity_target = x0 - x1 | |
| # Attention mask | |
| attention_mask = torch.ones( | |
| bsz, x1.shape[1], device=self.device, dtype=self.dtype | |
| ) | |
| # Forward through decoder | |
| decoder_out = self.handler.model.decoder( | |
| hidden_states=xt, | |
| timestep=t, | |
| timestep_r=t, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| context_latents=context_latents, | |
| use_cache=False, | |
| output_attentions=False, | |
| ) | |
| velocity_pred = decoder_out[0] # first element is the predicted output | |
| loss = F.mse_loss(velocity_pred, velocity_target) | |
| return loss | |
| def _pad_and_stack(tensors: List[torch.Tensor], pad_value: float = 0.0) -> torch.Tensor: | |
| """Pad variable-length tensors on dimension 0 and stack as batch.""" | |
| normalized = [] | |
| for t in tensors: | |
| if t.dim() >= 2 and t.shape[0] == 1: | |
| normalized.append(t.squeeze(0)) | |
| else: | |
| normalized.append(t) | |
| max_len = max(t.shape[0] for t in normalized) | |
| template = normalized[0] | |
| out_shape = (len(normalized), max_len, *template.shape[1:]) | |
| out = template.new_full(out_shape, pad_value) | |
| for i, t in enumerate(normalized): | |
| out[i, : t.shape[0]] = t | |
| return out | |
| # ------------------------------------------------------------------ | |
| # Main training loop | |
| # ------------------------------------------------------------------ | |
| def request_stop(self): | |
| """Ask the training loop to stop after the current step.""" | |
| self._stop_requested = True | |
| def train( | |
| self, | |
| entries: List[TrackEntry], | |
| progress_callback=None, | |
| ) -> str: | |
| """Run the full LoRA training. | |
| Args: | |
| entries: List of scanned TrackEntry objects. | |
| progress_callback: ``fn(step, total_steps, loss, epoch)`` for UI updates. | |
| Returns: | |
| Status message. | |
| """ | |
| self._stop_requested = False | |
| self.loss_history.clear() | |
| os.makedirs(self.cfg.output_dir, exist_ok=True) | |
| if not entries: | |
| return "No training data provided." | |
| num_entries = len(entries) | |
| total_steps = ( | |
| math.ceil(num_entries / self.cfg.batch_size) | |
| * self.cfg.num_epochs | |
| ) | |
| # ---- Pre-encode all audio & text (fits in CPU RAM) ---- | |
| logger.info("Pre-encoding dataset through VAE & text encoder ...") | |
| dataset: List[Dict[str, Any]] = [] | |
| failed_encode: List[str] = [] | |
| # Freeze VAE and text encoder (they are not trained) | |
| self.handler.vae.eval() | |
| self.handler.text_encoder.eval() | |
| # Reuse silence reference latent (same as handler's internal fallback path). | |
| ref_latent = self.handler.silence_latent[:, :750, :].to(self.device).to(self.dtype) | |
| ref_order_mask = torch.zeros(1, device=self.device, dtype=torch.long) | |
| for idx, entry in enumerate(tqdm(entries, desc="Encoding dataset")): | |
| try: | |
| wav = self._load_audio(entry.audio_path) | |
| latent = self._encode_audio(wav) | |
| text_h, text_m, lyric_h, lyric_m = self._build_text_embeddings( | |
| entry.caption, entry.lyrics | |
| ) | |
| # Prepare condition using the model's own prepare_condition | |
| with torch.no_grad(): | |
| enc_hs, enc_mask, ctx_lat = self.handler.model.prepare_condition( | |
| text_hidden_states=text_h, | |
| text_attention_mask=text_m, | |
| lyric_hidden_states=lyric_h, | |
| lyric_attention_mask=lyric_m, | |
| refer_audio_acoustic_hidden_states_packed=ref_latent, | |
| refer_audio_order_mask=ref_order_mask, | |
| hidden_states=latent, | |
| attention_mask=torch.ones( | |
| 1, latent.shape[1], | |
| device=self.device, dtype=self.dtype, | |
| ), | |
| silence_latent=self.handler.silence_latent, | |
| src_latents=latent, | |
| chunk_masks=torch.ones_like(latent), | |
| is_covers=[False], | |
| ) | |
| dataset.append( | |
| { | |
| "latent": latent.cpu(), | |
| "enc_hs": enc_hs.cpu(), | |
| "enc_mask": enc_mask.cpu(), | |
| "ctx_lat": ctx_lat.cpu(), | |
| "name": Path(entry.audio_path).stem, | |
| } | |
| ) | |
| except Exception as exc: | |
| reason = f"{Path(entry.audio_path).name}: {exc}" | |
| failed_encode.append(reason) | |
| logger.warning(f"Skipping {entry.audio_path}: {exc}") | |
| if not dataset: | |
| preview = "\n".join(f"- {msg}" for msg in failed_encode[:8]) or "- (no detailed errors captured)" | |
| return ( | |
| "All tracks failed to encode. Check audio files.\n" | |
| "First errors:\n" | |
| f"{preview}\n" | |
| "Tip: try WAV/FLAC files and dataset folder scan instead of temporary uploads." | |
| ) | |
| logger.info(f"Encoded {len(dataset)}/{num_entries} tracks.") | |
| # ---- Warmup scheduler ---- | |
| total_optim_steps = math.ceil( | |
| total_steps / self.cfg.gradient_accumulation_steps | |
| ) | |
| warmup_steps = int(total_optim_steps * self.cfg.warmup_ratio) | |
| if self.cfg.scheduler in {"constant_with_warmup", "linear", "cosine"}: | |
| try: | |
| from transformers import get_scheduler | |
| self.scheduler = get_scheduler( | |
| name=self.cfg.scheduler, | |
| optimizer=self.optimizer, | |
| num_warmup_steps=warmup_steps, | |
| num_training_steps=total_optim_steps, | |
| ) | |
| except Exception as exc: | |
| logger.warning(f"Could not create scheduler '{self.cfg.scheduler}', disabling scheduler: {exc}") | |
| self.scheduler = None | |
| else: | |
| self.scheduler = None | |
| # ---- Training loop ---- | |
| logger.info( | |
| f"Starting LoRA training: {self.cfg.num_epochs} epochs, " | |
| f"{len(dataset)} samples, {total_optim_steps} optimiser steps" | |
| ) | |
| self.peft_model.train() | |
| accum_loss = 0.0 | |
| step_in_accum = 0 | |
| for epoch in range(self.current_epoch, self.cfg.num_epochs): | |
| if self._stop_requested: | |
| break | |
| self.current_epoch = epoch | |
| indices = list(range(len(dataset))) | |
| random.shuffle(indices) | |
| epoch_loss = 0.0 | |
| epoch_steps = 0 | |
| for i in range(0, len(indices), self.cfg.batch_size): | |
| if self._stop_requested: | |
| break | |
| batch_indices = indices[i : i + self.cfg.batch_size] | |
| batch_items = [dataset[j] for j in batch_indices] | |
| # Move batch to device | |
| latents = self._pad_and_stack([it["latent"] for it in batch_items]).to(self.device, self.dtype) | |
| enc_hs = self._pad_and_stack([it["enc_hs"] for it in batch_items]).to(self.device, self.dtype) | |
| enc_mask = self._pad_and_stack([it["enc_mask"] for it in batch_items], pad_value=0.0).to(self.device) | |
| if enc_mask.dtype != self.dtype: | |
| enc_mask = enc_mask.to(self.dtype) | |
| ctx_lat = self._pad_and_stack([it["ctx_lat"] for it in batch_items]).to(self.device, self.dtype) | |
| # Forward + loss | |
| loss = self._flow_matching_loss(latents, enc_hs, enc_mask, ctx_lat) | |
| loss = loss / self.cfg.gradient_accumulation_steps | |
| loss.backward() | |
| accum_loss += loss.item() | |
| step_in_accum += 1 | |
| if step_in_accum >= self.cfg.gradient_accumulation_steps: | |
| torch.nn.utils.clip_grad_norm_( | |
| self.peft_model.parameters(), self.cfg.max_grad_norm | |
| ) | |
| self.optimizer.step() | |
| if self.scheduler is not None: | |
| self.scheduler.step() | |
| self.optimizer.zero_grad() | |
| self.global_step += 1 | |
| avg_loss = accum_loss | |
| accum_loss = 0.0 | |
| step_in_accum = 0 | |
| self.loss_history.append( | |
| { | |
| "step": self.global_step, | |
| "epoch": epoch, | |
| "loss": avg_loss, | |
| "lr": self.optimizer.param_groups[0]["lr"], | |
| } | |
| ) | |
| if self.global_step % self.cfg.log_every_n_steps == 0: | |
| logger.info( | |
| f"Epoch {epoch+1}/{self.cfg.num_epochs} " | |
| f"Step {self.global_step}/{total_optim_steps} " | |
| f"Loss {avg_loss:.6f} " | |
| f"LR {self.optimizer.param_groups[0]['lr']:.2e}" | |
| ) | |
| if progress_callback: | |
| progress_callback( | |
| self.global_step, total_optim_steps, avg_loss, epoch | |
| ) | |
| epoch_loss += loss.item() * self.cfg.gradient_accumulation_steps | |
| epoch_steps += 1 | |
| # Flush remaining micro-batches when len(dataset) is not divisible by grad accumulation. | |
| if step_in_accum > 0: | |
| torch.nn.utils.clip_grad_norm_(self.peft_model.parameters(), self.cfg.max_grad_norm) | |
| self.optimizer.step() | |
| if self.scheduler is not None: | |
| self.scheduler.step() | |
| self.optimizer.zero_grad() | |
| self.global_step += 1 | |
| avg_loss = accum_loss | |
| accum_loss = 0.0 | |
| step_in_accum = 0 | |
| self.loss_history.append( | |
| { | |
| "step": self.global_step, | |
| "epoch": epoch, | |
| "loss": avg_loss, | |
| "lr": self.optimizer.param_groups[0]["lr"], | |
| } | |
| ) | |
| # End of epoch – checkpoint? | |
| if ( | |
| (epoch + 1) % self.cfg.save_every_n_epochs == 0 | |
| or epoch == self.cfg.num_epochs - 1 | |
| or self._stop_requested | |
| ): | |
| self._save_checkpoint(epoch) | |
| if epoch_steps > 0: | |
| avg_epoch_loss = epoch_loss / epoch_steps | |
| logger.info( | |
| f"Epoch {epoch+1} complete – avg loss {avg_epoch_loss:.6f}" | |
| ) | |
| # Final save | |
| final_dir = self._save_checkpoint(self.current_epoch, final=True) | |
| status = ( | |
| "Training stopped early." if self._stop_requested else "Training complete!" | |
| ) | |
| return f"{status} Adapter saved to {final_dir}" | |
| # ------------------------------------------------------------------ | |
| # Checkpointing | |
| # ------------------------------------------------------------------ | |
| def _save_checkpoint(self, epoch: int, final: bool = False) -> str: | |
| tag = "final" if final else f"epoch-{epoch+1}" | |
| save_dir = os.path.join(self.cfg.output_dir, tag) | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Save PEFT adapter | |
| self.peft_model.save_pretrained(save_dir) | |
| # Save training state | |
| torch.save( | |
| { | |
| "optimizer": self.optimizer.state_dict(), | |
| "global_step": self.global_step, | |
| "epoch": epoch, | |
| }, | |
| os.path.join(save_dir, "training_state.pt"), | |
| ) | |
| # Save loss curve | |
| loss_path = os.path.join(save_dir, "loss_history.json") | |
| with open(loss_path, "w") as f: | |
| json.dump(self.loss_history, f) | |
| # Save config | |
| cfg_path = os.path.join(save_dir, "train_config.json") | |
| with open(cfg_path, "w") as f: | |
| json.dump(asdict(self.cfg), f, indent=2) | |
| logger.info(f"Checkpoint saved → {save_dir}") | |
| return save_dir | |
| # ------------------------------------------------------------------ | |
| # Adapter listing | |
| # ------------------------------------------------------------------ | |
| def list_adapters(output_dir: str = "lora_output") -> List[str]: | |
| """Return adapter directories inside *output_dir* (recursive).""" | |
| results = [] | |
| root = Path(output_dir) | |
| if not root.is_dir(): | |
| return results | |
| for cfg in sorted(root.rglob("adapter_config.json")): | |
| d = cfg.parent | |
| if d.is_dir(): | |
| results.append(str(d)) | |
| return results | |
| def _build_arg_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="ACE-Step 1.5 LoRA trainer (CLI)") | |
| # Dataset | |
| parser.add_argument("--dataset-dir", type=str, default="", help="Local dataset folder path") | |
| parser.add_argument("--dataset-repo", type=str, default="", help="HF dataset repo id (optional)") | |
| parser.add_argument("--dataset-revision", type=str, default="main", help="HF dataset revision") | |
| parser.add_argument("--dataset-subdir", type=str, default="", help="Subdirectory inside downloaded dataset") | |
| # Model init | |
| parser.add_argument("--model-config", type=str, default="acestep-v15-base", help="DiT config name") | |
| parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "xpu", "cpu"]) | |
| parser.add_argument("--offload-to-cpu", action="store_true") | |
| parser.add_argument("--offload-dit-to-cpu", action="store_true") | |
| parser.add_argument("--prefer-source", type=str, default="huggingface", choices=["huggingface", "modelscope"]) | |
| # Train config | |
| parser.add_argument("--output-dir", type=str, default="lora_output") | |
| parser.add_argument("--resume-from", type=str, default="") | |
| parser.add_argument("--num-epochs", type=int, default=50) | |
| parser.add_argument("--batch-size", type=int, default=1) | |
| parser.add_argument("--grad-accum", type=int, default=1) | |
| parser.add_argument("--save-every", type=int, default=10) | |
| parser.add_argument("--log-every", type=int, default=5) | |
| parser.add_argument("--max-duration-sec", type=float, default=240.0) | |
| parser.add_argument("--lora-rank", type=int, default=64) | |
| parser.add_argument("--lora-alpha", type=int, default=64) | |
| parser.add_argument("--lora-dropout", type=float, default=0.1) | |
| parser.add_argument("--learning-rate", type=float, default=1e-4) | |
| parser.add_argument("--weight-decay", type=float, default=0.01) | |
| parser.add_argument("--optimizer", type=str, default="adamw_8bit", choices=["adamw", "adamw_8bit"]) | |
| parser.add_argument("--max-grad-norm", type=float, default=1.0) | |
| parser.add_argument("--warmup-ratio", type=float, default=0.03) | |
| parser.add_argument("--scheduler", type=str, default="constant_with_warmup", choices=["constant_with_warmup", "linear", "cosine"]) | |
| parser.add_argument("--shift", type=float, default=3.0) | |
| # Optional upload | |
| parser.add_argument("--upload-repo", type=str, default="", help="HF model repo to upload final adapter") | |
| parser.add_argument("--upload-path", type=str, default="", help="Path inside upload repo (optional)") | |
| parser.add_argument("--upload-private", action="store_true") | |
| parser.add_argument("--hf-token-env", type=str, default="HF_TOKEN", help="Environment variable name for HF token") | |
| return parser | |
| def _resolve_dataset_dir(args) -> str: | |
| if args.dataset_dir: | |
| return args.dataset_dir | |
| if not args.dataset_repo: | |
| raise ValueError("Provide --dataset-dir or --dataset-repo.") | |
| from huggingface_hub import snapshot_download | |
| token = os.getenv(args.hf_token_env) | |
| temp_root = tempfile.mkdtemp(prefix="acestep_lora_dataset_") | |
| local_dir = os.path.join(temp_root, "dataset") | |
| logger.info(f"Downloading dataset repo {args.dataset_repo}@{args.dataset_revision} to {local_dir}") | |
| snapshot_download( | |
| repo_id=args.dataset_repo, | |
| repo_type="dataset", | |
| revision=args.dataset_revision, | |
| local_dir=local_dir, | |
| local_dir_use_symlinks=False, | |
| token=token, | |
| ) | |
| if args.dataset_subdir: | |
| sub = os.path.join(local_dir, args.dataset_subdir) | |
| if not os.path.isdir(sub): | |
| raise FileNotFoundError(f"Dataset subdir not found: {sub}") | |
| return sub | |
| return local_dir | |
| def _upload_adapter_if_requested(args, final_dir: str): | |
| if not args.upload_repo: | |
| return | |
| from huggingface_hub import HfApi | |
| token = os.getenv(args.hf_token_env) | |
| if not token: | |
| raise RuntimeError( | |
| f"{args.hf_token_env} is not set. Needed for upload to {args.upload_repo}." | |
| ) | |
| api = HfApi(token=token) | |
| api.create_repo( | |
| repo_id=args.upload_repo, | |
| repo_type="model", | |
| exist_ok=True, | |
| private=bool(args.upload_private), | |
| ) | |
| path_in_repo = args.upload_path.strip().strip("/") if args.upload_path else "" | |
| commit_message = f"Upload ACE-Step LoRA adapter from {Path(final_dir).name}" | |
| logger.info(f"Uploading adapter from {final_dir} to {args.upload_repo}/{path_in_repo}") | |
| api.upload_folder( | |
| repo_id=args.upload_repo, | |
| repo_type="model", | |
| folder_path=final_dir, | |
| path_in_repo=path_in_repo, | |
| commit_message=commit_message, | |
| ) | |
| logger.info("Upload complete") | |
| def main(): | |
| args = _build_arg_parser().parse_args() | |
| dataset_dir = _resolve_dataset_dir(args) | |
| entries = scan_dataset_folder(dataset_dir) | |
| if not entries: | |
| raise RuntimeError(f"No audio files found in dataset: {dataset_dir}") | |
| from acestep.handler import AceStepHandler | |
| project_root = str(Path(__file__).resolve().parent) | |
| handler = AceStepHandler() | |
| status, ok = handler.initialize_service( | |
| project_root=project_root, | |
| config_path=args.model_config, | |
| device=args.device, | |
| use_flash_attention=False, | |
| compile_model=False, | |
| offload_to_cpu=bool(args.offload_to_cpu), | |
| offload_dit_to_cpu=bool(args.offload_dit_to_cpu), | |
| prefer_source=args.prefer_source, | |
| ) | |
| print(status) | |
| if not ok: | |
| raise RuntimeError("Model initialization failed") | |
| cfg = LoRATrainConfig( | |
| lora_rank=args.lora_rank, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| learning_rate=args.learning_rate, | |
| weight_decay=args.weight_decay, | |
| optimizer=args.optimizer, | |
| max_grad_norm=args.max_grad_norm, | |
| warmup_ratio=args.warmup_ratio, | |
| scheduler=args.scheduler, | |
| num_epochs=args.num_epochs, | |
| batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| save_every_n_epochs=args.save_every, | |
| log_every_n_steps=args.log_every, | |
| shift=args.shift, | |
| max_duration_sec=args.max_duration_sec, | |
| output_dir=args.output_dir, | |
| resume_from=(args.resume_from.strip() if args.resume_from else None), | |
| device=args.device, | |
| ) | |
| trainer = LoRATrainer(handler, cfg) | |
| trainer.prepare() | |
| start = time.time() | |
| def _progress(step, total, loss, epoch): | |
| elapsed = time.time() - start | |
| rate = step / elapsed if elapsed > 0 else 0.0 | |
| remaining = max(0.0, total - step) | |
| eta_sec = remaining / rate if rate > 0 else -1.0 | |
| eta_msg = f"{eta_sec/60:.1f}m" if eta_sec >= 0 else "unknown" | |
| logger.info( | |
| f"[progress] step={step}/{total} epoch={epoch+1} loss={loss:.6f} elapsed={elapsed/60:.1f}m eta={eta_msg}" | |
| ) | |
| msg = trainer.train(entries, progress_callback=_progress) | |
| print(msg) | |
| final_dir = os.path.join(cfg.output_dir, "final") | |
| if os.path.isdir(final_dir): | |
| _upload_adapter_if_requested(args, final_dir) | |
| print(f"Final adapter directory: {final_dir}") | |
| else: | |
| print(f"Warning: final adapter directory not found at {final_dir}") | |
| if __name__ == "__main__": | |
| main() | |