Lora-ace-step / lora_train.py
Andrew
github push
bd37cca
"""
ACE-Step 1.5 LoRA Training Engine
Handles dataset building, VAE encoding, and flow-matching LoRA training
of the DiT decoder. Designed to work with the existing AceStepHandler.
"""
import os
import sys
import json
import math
import time
import random
import hashlib
import argparse
import tempfile
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import Optional, List, Dict, Any, Tuple
import torch
import torch.nn.functional as F
import torchaudio
import soundfile as sf
import numpy as np
from loguru import logger
from tqdm import tqdm
# ---------------------------------------------------------------------------
# Dataset helpers
# ---------------------------------------------------------------------------
AUDIO_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aac"}
@dataclass
class TrackEntry:
"""One audio file + its metadata."""
audio_path: str
caption: str = ""
lyrics: str = ""
bpm: Optional[int] = None
keyscale: str = ""
timesignature: str = "4/4"
vocal_language: str = "en"
duration: Optional[float] = None # seconds (measured at scan time)
def _load_track_entry(audio_path: Path) -> TrackEntry:
"""Load one track + optional sidecar metadata."""
sidecar = audio_path.with_suffix(".json")
meta: Dict[str, Any] = {}
if sidecar.exists():
try:
meta = json.loads(sidecar.read_text(encoding="utf-8"))
except Exception as exc:
logger.warning(f"Bad sidecar {sidecar}: {exc}")
try:
info = torchaudio.info(str(audio_path))
duration = info.num_frames / info.sample_rate
except Exception:
duration = meta.get("duration")
return TrackEntry(
audio_path=str(audio_path),
caption=meta.get("caption", ""),
lyrics=meta.get("lyrics", ""),
bpm=meta.get("bpm"),
keyscale=meta.get("keyscale", ""),
timesignature=meta.get("timesignature", "4/4"),
vocal_language=meta.get("vocal_language", "en"),
duration=duration,
)
def scan_dataset_folder(folder: str) -> List[TrackEntry]:
"""Scan *folder* for audio files and optional JSON sidecars.
For every ``track.wav`` found, if ``track.json`` exists next to it the
metadata fields are loaded from the sidecar. Missing sidecars are fine –
the entry will have empty metadata that can be filled later.
"""
folder = Path(folder)
if not folder.is_dir():
raise FileNotFoundError(f"Dataset folder not found: {folder}")
entries: List[TrackEntry] = []
for audio_path in sorted(folder.rglob("*")):
if audio_path.suffix.lower() not in AUDIO_EXTENSIONS:
continue
entries.append(_load_track_entry(audio_path))
logger.info(f"Scanned {len(entries)} audio files in {folder}")
return entries
def scan_uploaded_files(file_paths: List[str]) -> List[TrackEntry]:
"""Build entries from dropped/uploaded files.
Supports uploading audio files together with optional ``.json`` sidecars.
Sidecars are matched by basename stem (``song.mp3`` <-> ``song.json``).
"""
meta_by_stem: Dict[str, Dict[str, Any]] = {}
for path in file_paths:
p = Path(path)
if not p.exists() or p.suffix.lower() != ".json":
continue
try:
meta_by_stem[p.stem] = json.loads(p.read_text(encoding="utf-8"))
except Exception as exc:
logger.warning(f"Bad uploaded sidecar {p}: {exc}")
entries: List[TrackEntry] = []
for path in file_paths:
p = Path(path)
if not p.exists() or p.suffix.lower() not in AUDIO_EXTENSIONS:
continue
uploaded_meta = meta_by_stem.get(p.stem)
if uploaded_meta is None:
entries.append(_load_track_entry(p))
continue
try:
info = torchaudio.info(str(p))
duration = info.num_frames / info.sample_rate
except Exception:
duration = uploaded_meta.get("duration")
bpm_val = uploaded_meta.get("bpm")
if isinstance(bpm_val, str) and bpm_val.strip():
try:
bpm_val = int(float(bpm_val))
except Exception:
bpm_val = None
entries.append(
TrackEntry(
audio_path=str(p),
caption=uploaded_meta.get("caption", "") or "",
lyrics=uploaded_meta.get("lyrics", "") or "",
bpm=bpm_val if isinstance(bpm_val, int) else None,
keyscale=uploaded_meta.get("keyscale", "") or "",
timesignature=uploaded_meta.get("timesignature", "4/4") or "4/4",
vocal_language=uploaded_meta.get("vocal_language", uploaded_meta.get("language", "en")) or "en",
duration=duration,
)
)
logger.info(
"Loaded {} uploaded audio files ({} uploaded sidecars detected)".format(
len(entries), len(meta_by_stem)
)
)
return entries
# ---------------------------------------------------------------------------
# Training hyper-parameters
# ---------------------------------------------------------------------------
@dataclass
class LoRATrainConfig:
"""All tuneable knobs for a LoRA run."""
# LoRA architecture
lora_rank: int = 64
lora_alpha: int = 64
lora_dropout: float = 0.1
lora_target_modules: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)
# Optimiser
learning_rate: float = 1e-4
weight_decay: float = 0.01
optimizer: str = "adamw_8bit" # "adamw" | "adamw_8bit"
max_grad_norm: float = 1.0
# Schedule
warmup_ratio: float = 0.03
scheduler: str = "constant_with_warmup"
# Training loop
num_epochs: int = 50
batch_size: int = 1
gradient_accumulation_steps: int = 1
save_every_n_epochs: int = 10
log_every_n_steps: int = 5
# Flow matching
shift: float = 3.0 # timestep shift factor
# Audio pre-processing
max_duration_sec: float = 240.0 # clamp audio to this length
sample_rate: int = 48000
# Paths
output_dir: str = "lora_output"
resume_from: Optional[str] = None
# Device
device: str = "auto"
dtype: str = "bf16" # "bf16" | "fp16" | "fp32"
mixed_precision: bool = True
# ---------------------------------------------------------------------------
# Core trainer
# ---------------------------------------------------------------------------
class LoRATrainer:
"""Thin training loop that wraps the existing AceStepHandler."""
def __init__(self, handler, config: LoRATrainConfig):
"""
Args:
handler: Initialised ``AceStepHandler`` (model, vae, text_encoder loaded).
config: Training hyper-parameters.
"""
self.handler = handler
self.cfg = config
self.device = handler.device
self.dtype = handler.dtype
# Will be set during prepare()
self.peft_model = None
self.optimizer = None
self.scheduler = None
self.global_step = 0
self.current_epoch = 0
# Loss history for UI
self.loss_history: List[Dict[str, Any]] = []
self._stop_requested = False
# ------------------------------------------------------------------
# Setup
# ------------------------------------------------------------------
@staticmethod
def _resolve_lora_target_modules(model, requested_targets: Optional[List[str]]) -> List[str]:
"""Resolve LoRA target module suffixes against the actual decoder module names."""
linear_module_names = [
name for name, module in model.named_modules() if isinstance(module, torch.nn.Linear)
]
def _exists_as_suffix(target: str) -> bool:
return any(name.endswith(target) for name in linear_module_names)
requested_targets = requested_targets or []
resolved = [target for target in requested_targets if _exists_as_suffix(target)]
if resolved:
return resolved
fallback_groups = [
["q_proj", "k_proj", "v_proj", "o_proj"],
["to_q", "to_k", "to_v", "to_out.0"],
["query", "key", "value", "out_proj"],
["wq", "wk", "wv", "wo"],
["qkv", "proj_out"],
]
for group in fallback_groups:
group_resolved = [target for target in group if _exists_as_suffix(target)]
if len(group_resolved) >= 2:
return group_resolved
sample = ", ".join(linear_module_names[:30])
raise ValueError(
"Could not find LoRA target modules in decoder. "
f"Requested={requested_targets}. "
f"Sample linear modules: {sample}"
)
def prepare(self):
"""Attach LoRA adapters to the decoder and build the optimiser."""
import copy
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
# Keep a backup of the plain base decoder so load/unload logic remains valid.
if self.handler._base_decoder is None:
self.handler._base_decoder = copy.deepcopy(self.handler.model.decoder)
else:
self.handler.model.decoder = copy.deepcopy(self.handler._base_decoder)
self.handler.model.decoder = self.handler.model.decoder.to(self.device).to(self.dtype)
self.handler.model.decoder.eval()
resume_adapter = None
if self.cfg.resume_from:
adapter_cfg = os.path.join(self.cfg.resume_from, "adapter_config.json")
if os.path.isfile(adapter_cfg):
resume_adapter = self.cfg.resume_from
if resume_adapter:
logger.info(f"Loading existing LoRA adapter for resume: {resume_adapter}")
self.peft_model = PeftModel.from_pretrained(
self.handler.model.decoder,
resume_adapter,
is_trainable=True,
)
else:
resolved_targets = self._resolve_lora_target_modules(
self.handler.model.decoder,
self.cfg.lora_target_modules,
)
logger.info(f"Using LoRA target modules: {resolved_targets}")
peft_cfg = LoraConfig(
r=self.cfg.lora_rank,
lora_alpha=self.cfg.lora_alpha,
lora_dropout=self.cfg.lora_dropout,
target_modules=resolved_targets,
bias="none",
task_type=TaskType.FEATURE_EXTRACTION,
)
self.peft_model = get_peft_model(self.handler.model.decoder, peft_cfg)
self.peft_model.print_trainable_parameters()
self.handler.model.decoder = self.peft_model
self.handler.model.decoder.to(self.device).to(self.dtype)
self.handler.model.decoder.train()
self.handler.lora_loaded = True
self.handler.use_lora = True
# Build optimiser (only LoRA params are trainable)
trainable_params = [p for p in self.peft_model.parameters() if p.requires_grad]
if self.cfg.optimizer == "adamw_8bit":
try:
import bitsandbytes as bnb
self.optimizer = bnb.optim.AdamW8bit(
trainable_params,
lr=self.cfg.learning_rate,
weight_decay=self.cfg.weight_decay,
)
except ImportError:
logger.warning("bitsandbytes not found – falling back to standard AdamW")
self.optimizer = torch.optim.AdamW(
trainable_params,
lr=self.cfg.learning_rate,
weight_decay=self.cfg.weight_decay,
)
else:
self.optimizer = torch.optim.AdamW(
trainable_params,
lr=self.cfg.learning_rate,
weight_decay=self.cfg.weight_decay,
)
# Resume checkpoint state (after model/adapter restore).
if self.cfg.resume_from and os.path.isfile(
os.path.join(self.cfg.resume_from, "training_state.pt")
):
state = torch.load(
os.path.join(self.cfg.resume_from, "training_state.pt"),
weights_only=False,
)
try:
self.optimizer.load_state_dict(state["optimizer"])
except Exception as exc:
logger.warning(f"Could not restore optimizer state, continuing fresh optimizer: {exc}")
self.global_step = int(state.get("global_step", 0))
# Saved epoch is completed epoch index; continue from next epoch.
self.current_epoch = int(state.get("epoch", -1)) + 1
loss_path = os.path.join(self.cfg.resume_from, "loss_history.json")
if os.path.isfile(loss_path):
try:
with open(loss_path, "r", encoding="utf-8") as f:
self.loss_history = json.load(f)
except Exception:
pass
logger.info(
f"Resumed from {self.cfg.resume_from} "
f"(epoch {self.current_epoch}, step {self.global_step})"
)
# ------------------------------------------------------------------
# Data loading
# ------------------------------------------------------------------
@staticmethod
def _coerce_audio_tensor(audio: Any) -> torch.Tensor:
"""Coerce decoded audio into torch.Tensor with shape [C, T]."""
if isinstance(audio, list):
audio = np.asarray(audio, dtype=np.float32)
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
if not torch.is_tensor(audio):
raise TypeError(f"Unsupported audio type: {type(audio)}")
# Ensure floating point for downstream resample/vae encode.
if not torch.is_floating_point(audio):
audio = audio.float()
# Normalize dimensions to [C, T].
if audio.dim() == 1:
audio = audio.unsqueeze(0)
elif audio.dim() == 2:
# Accept either [T, C] or [C, T]; transpose only when clearly [T, C].
if audio.shape[0] > audio.shape[1] and audio.shape[1] <= 8:
audio = audio.transpose(0, 1)
elif audio.dim() == 3:
# If batched, take first item.
audio = audio[0]
else:
raise ValueError(f"Unexpected audio dims: {tuple(audio.shape)}")
return audio.contiguous()
def _load_audio(self, path: str) -> torch.Tensor:
"""Load audio, resample to 48 kHz stereo, clamp to max_duration."""
try:
wav, sr = torchaudio.load(path)
except Exception as torchaudio_exc:
# torchaudio on some Space images requires torchcodec for decode.
# Fallback to soundfile so training can proceed without torchcodec.
try:
audio_np, sr = sf.read(path, dtype="float32", always_2d=True)
wav = torch.from_numpy(audio_np.T)
except Exception as sf_exc:
raise RuntimeError(
f"Failed to decode audio '{path}' with torchaudio ({torchaudio_exc}) "
f"and soundfile ({sf_exc})."
) from sf_exc
wav = self._coerce_audio_tensor(wav)
# Resample if needed
if sr != self.cfg.sample_rate:
wav = torchaudio.functional.resample(wav, sr, self.cfg.sample_rate)
# Convert mono → stereo
if wav.shape[0] == 1:
wav = wav.repeat(2, 1)
elif wav.shape[0] > 2:
wav = wav[:2]
# Clamp length
max_samples = int(self.cfg.max_duration_sec * self.cfg.sample_rate)
if wav.shape[1] > max_samples:
wav = wav[:, :max_samples]
return wav # [2, T]
def _encode_audio(self, wav: torch.Tensor) -> torch.Tensor:
"""Encode raw waveform → VAE latent on device."""
with torch.no_grad():
latent = self.handler._encode_audio_to_latents(wav)
if latent.dim() == 2:
latent = latent.unsqueeze(0)
latent = latent.to(self.dtype)
return latent
def _build_text_embeddings(self, caption: str, lyrics: str):
"""Compute text & lyric embeddings using the text encoder."""
tokenizer = self.handler.text_tokenizer
text_encoder = self.handler.text_encoder
# Caption embedding
text_tokens = tokenizer(
caption or "",
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512,
).to(self.device)
with torch.no_grad():
text_hidden = text_encoder(
input_ids=text_tokens["input_ids"]
).last_hidden_state.to(self.dtype)
text_mask = text_tokens["attention_mask"].to(self.dtype)
# Lyric embedding (token-level via embed_tokens)
lyric_tokens = tokenizer(
lyrics or "",
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512,
).to(self.device)
with torch.no_grad():
lyric_hidden = text_encoder.embed_tokens(
lyric_tokens["input_ids"]
).to(self.dtype)
lyric_mask = lyric_tokens["attention_mask"].to(self.dtype)
return text_hidden, text_mask, lyric_hidden, lyric_mask
# ------------------------------------------------------------------
# Flow matching loss
# ------------------------------------------------------------------
def _flow_matching_loss(
self,
x1: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
context_latents: torch.Tensor,
) -> torch.Tensor:
"""Compute rectified-flow MSE loss for one sample.
Notation follows ACE-Step convention:
x0 = noise, x1 = clean latent
xt = t * x0 + (1 - t) * x1
target velocity = x0 - x1
"""
bsz = x1.shape[0]
# Sample random timestep per element
t = torch.rand(bsz, device=self.device, dtype=self.dtype)
# Apply timestep shift: t_shifted = shift * t / (1 + (shift - 1) * t)
if self.cfg.shift != 1.0:
t = self.cfg.shift * t / (1.0 + (self.cfg.shift - 1.0) * t)
t = t.clamp(1e-5, 1.0 - 1e-5)
# Noise
x0 = torch.randn_like(x1)
# Interpolate
t_expand = t.view(bsz, 1, 1)
xt = t_expand * x0 + (1.0 - t_expand) * x1
# Target velocity
velocity_target = x0 - x1
# Attention mask
attention_mask = torch.ones(
bsz, x1.shape[1], device=self.device, dtype=self.dtype
)
# Forward through decoder
decoder_out = self.handler.model.decoder(
hidden_states=xt,
timestep=t,
timestep_r=t,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
use_cache=False,
output_attentions=False,
)
velocity_pred = decoder_out[0] # first element is the predicted output
loss = F.mse_loss(velocity_pred, velocity_target)
return loss
@staticmethod
def _pad_and_stack(tensors: List[torch.Tensor], pad_value: float = 0.0) -> torch.Tensor:
"""Pad variable-length tensors on dimension 0 and stack as batch."""
normalized = []
for t in tensors:
if t.dim() >= 2 and t.shape[0] == 1:
normalized.append(t.squeeze(0))
else:
normalized.append(t)
max_len = max(t.shape[0] for t in normalized)
template = normalized[0]
out_shape = (len(normalized), max_len, *template.shape[1:])
out = template.new_full(out_shape, pad_value)
for i, t in enumerate(normalized):
out[i, : t.shape[0]] = t
return out
# ------------------------------------------------------------------
# Main training loop
# ------------------------------------------------------------------
def request_stop(self):
"""Ask the training loop to stop after the current step."""
self._stop_requested = True
def train(
self,
entries: List[TrackEntry],
progress_callback=None,
) -> str:
"""Run the full LoRA training.
Args:
entries: List of scanned TrackEntry objects.
progress_callback: ``fn(step, total_steps, loss, epoch)`` for UI updates.
Returns:
Status message.
"""
self._stop_requested = False
self.loss_history.clear()
os.makedirs(self.cfg.output_dir, exist_ok=True)
if not entries:
return "No training data provided."
num_entries = len(entries)
total_steps = (
math.ceil(num_entries / self.cfg.batch_size)
* self.cfg.num_epochs
)
# ---- Pre-encode all audio & text (fits in CPU RAM) ----
logger.info("Pre-encoding dataset through VAE & text encoder ...")
dataset: List[Dict[str, Any]] = []
failed_encode: List[str] = []
# Freeze VAE and text encoder (they are not trained)
self.handler.vae.eval()
self.handler.text_encoder.eval()
# Reuse silence reference latent (same as handler's internal fallback path).
ref_latent = self.handler.silence_latent[:, :750, :].to(self.device).to(self.dtype)
ref_order_mask = torch.zeros(1, device=self.device, dtype=torch.long)
for idx, entry in enumerate(tqdm(entries, desc="Encoding dataset")):
try:
wav = self._load_audio(entry.audio_path)
latent = self._encode_audio(wav)
text_h, text_m, lyric_h, lyric_m = self._build_text_embeddings(
entry.caption, entry.lyrics
)
# Prepare condition using the model's own prepare_condition
with torch.no_grad():
enc_hs, enc_mask, ctx_lat = self.handler.model.prepare_condition(
text_hidden_states=text_h,
text_attention_mask=text_m,
lyric_hidden_states=lyric_h,
lyric_attention_mask=lyric_m,
refer_audio_acoustic_hidden_states_packed=ref_latent,
refer_audio_order_mask=ref_order_mask,
hidden_states=latent,
attention_mask=torch.ones(
1, latent.shape[1],
device=self.device, dtype=self.dtype,
),
silence_latent=self.handler.silence_latent,
src_latents=latent,
chunk_masks=torch.ones_like(latent),
is_covers=[False],
)
dataset.append(
{
"latent": latent.cpu(),
"enc_hs": enc_hs.cpu(),
"enc_mask": enc_mask.cpu(),
"ctx_lat": ctx_lat.cpu(),
"name": Path(entry.audio_path).stem,
}
)
except Exception as exc:
reason = f"{Path(entry.audio_path).name}: {exc}"
failed_encode.append(reason)
logger.warning(f"Skipping {entry.audio_path}: {exc}")
if not dataset:
preview = "\n".join(f"- {msg}" for msg in failed_encode[:8]) or "- (no detailed errors captured)"
return (
"All tracks failed to encode. Check audio files.\n"
"First errors:\n"
f"{preview}\n"
"Tip: try WAV/FLAC files and dataset folder scan instead of temporary uploads."
)
logger.info(f"Encoded {len(dataset)}/{num_entries} tracks.")
# ---- Warmup scheduler ----
total_optim_steps = math.ceil(
total_steps / self.cfg.gradient_accumulation_steps
)
warmup_steps = int(total_optim_steps * self.cfg.warmup_ratio)
if self.cfg.scheduler in {"constant_with_warmup", "linear", "cosine"}:
try:
from transformers import get_scheduler
self.scheduler = get_scheduler(
name=self.cfg.scheduler,
optimizer=self.optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_optim_steps,
)
except Exception as exc:
logger.warning(f"Could not create scheduler '{self.cfg.scheduler}', disabling scheduler: {exc}")
self.scheduler = None
else:
self.scheduler = None
# ---- Training loop ----
logger.info(
f"Starting LoRA training: {self.cfg.num_epochs} epochs, "
f"{len(dataset)} samples, {total_optim_steps} optimiser steps"
)
self.peft_model.train()
accum_loss = 0.0
step_in_accum = 0
for epoch in range(self.current_epoch, self.cfg.num_epochs):
if self._stop_requested:
break
self.current_epoch = epoch
indices = list(range(len(dataset)))
random.shuffle(indices)
epoch_loss = 0.0
epoch_steps = 0
for i in range(0, len(indices), self.cfg.batch_size):
if self._stop_requested:
break
batch_indices = indices[i : i + self.cfg.batch_size]
batch_items = [dataset[j] for j in batch_indices]
# Move batch to device
latents = self._pad_and_stack([it["latent"] for it in batch_items]).to(self.device, self.dtype)
enc_hs = self._pad_and_stack([it["enc_hs"] for it in batch_items]).to(self.device, self.dtype)
enc_mask = self._pad_and_stack([it["enc_mask"] for it in batch_items], pad_value=0.0).to(self.device)
if enc_mask.dtype != self.dtype:
enc_mask = enc_mask.to(self.dtype)
ctx_lat = self._pad_and_stack([it["ctx_lat"] for it in batch_items]).to(self.device, self.dtype)
# Forward + loss
loss = self._flow_matching_loss(latents, enc_hs, enc_mask, ctx_lat)
loss = loss / self.cfg.gradient_accumulation_steps
loss.backward()
accum_loss += loss.item()
step_in_accum += 1
if step_in_accum >= self.cfg.gradient_accumulation_steps:
torch.nn.utils.clip_grad_norm_(
self.peft_model.parameters(), self.cfg.max_grad_norm
)
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
avg_loss = accum_loss
accum_loss = 0.0
step_in_accum = 0
self.loss_history.append(
{
"step": self.global_step,
"epoch": epoch,
"loss": avg_loss,
"lr": self.optimizer.param_groups[0]["lr"],
}
)
if self.global_step % self.cfg.log_every_n_steps == 0:
logger.info(
f"Epoch {epoch+1}/{self.cfg.num_epochs} "
f"Step {self.global_step}/{total_optim_steps} "
f"Loss {avg_loss:.6f} "
f"LR {self.optimizer.param_groups[0]['lr']:.2e}"
)
if progress_callback:
progress_callback(
self.global_step, total_optim_steps, avg_loss, epoch
)
epoch_loss += loss.item() * self.cfg.gradient_accumulation_steps
epoch_steps += 1
# Flush remaining micro-batches when len(dataset) is not divisible by grad accumulation.
if step_in_accum > 0:
torch.nn.utils.clip_grad_norm_(self.peft_model.parameters(), self.cfg.max_grad_norm)
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
avg_loss = accum_loss
accum_loss = 0.0
step_in_accum = 0
self.loss_history.append(
{
"step": self.global_step,
"epoch": epoch,
"loss": avg_loss,
"lr": self.optimizer.param_groups[0]["lr"],
}
)
# End of epoch – checkpoint?
if (
(epoch + 1) % self.cfg.save_every_n_epochs == 0
or epoch == self.cfg.num_epochs - 1
or self._stop_requested
):
self._save_checkpoint(epoch)
if epoch_steps > 0:
avg_epoch_loss = epoch_loss / epoch_steps
logger.info(
f"Epoch {epoch+1} complete – avg loss {avg_epoch_loss:.6f}"
)
# Final save
final_dir = self._save_checkpoint(self.current_epoch, final=True)
status = (
"Training stopped early." if self._stop_requested else "Training complete!"
)
return f"{status} Adapter saved to {final_dir}"
# ------------------------------------------------------------------
# Checkpointing
# ------------------------------------------------------------------
def _save_checkpoint(self, epoch: int, final: bool = False) -> str:
tag = "final" if final else f"epoch-{epoch+1}"
save_dir = os.path.join(self.cfg.output_dir, tag)
os.makedirs(save_dir, exist_ok=True)
# Save PEFT adapter
self.peft_model.save_pretrained(save_dir)
# Save training state
torch.save(
{
"optimizer": self.optimizer.state_dict(),
"global_step": self.global_step,
"epoch": epoch,
},
os.path.join(save_dir, "training_state.pt"),
)
# Save loss curve
loss_path = os.path.join(save_dir, "loss_history.json")
with open(loss_path, "w") as f:
json.dump(self.loss_history, f)
# Save config
cfg_path = os.path.join(save_dir, "train_config.json")
with open(cfg_path, "w") as f:
json.dump(asdict(self.cfg), f, indent=2)
logger.info(f"Checkpoint saved → {save_dir}")
return save_dir
# ------------------------------------------------------------------
# Adapter listing
# ------------------------------------------------------------------
@staticmethod
def list_adapters(output_dir: str = "lora_output") -> List[str]:
"""Return adapter directories inside *output_dir* (recursive)."""
results = []
root = Path(output_dir)
if not root.is_dir():
return results
for cfg in sorted(root.rglob("adapter_config.json")):
d = cfg.parent
if d.is_dir():
results.append(str(d))
return results
def _build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="ACE-Step 1.5 LoRA trainer (CLI)")
# Dataset
parser.add_argument("--dataset-dir", type=str, default="", help="Local dataset folder path")
parser.add_argument("--dataset-repo", type=str, default="", help="HF dataset repo id (optional)")
parser.add_argument("--dataset-revision", type=str, default="main", help="HF dataset revision")
parser.add_argument("--dataset-subdir", type=str, default="", help="Subdirectory inside downloaded dataset")
# Model init
parser.add_argument("--model-config", type=str, default="acestep-v15-base", help="DiT config name")
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "xpu", "cpu"])
parser.add_argument("--offload-to-cpu", action="store_true")
parser.add_argument("--offload-dit-to-cpu", action="store_true")
parser.add_argument("--prefer-source", type=str, default="huggingface", choices=["huggingface", "modelscope"])
# Train config
parser.add_argument("--output-dir", type=str, default="lora_output")
parser.add_argument("--resume-from", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=50)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--grad-accum", type=int, default=1)
parser.add_argument("--save-every", type=int, default=10)
parser.add_argument("--log-every", type=int, default=5)
parser.add_argument("--max-duration-sec", type=float, default=240.0)
parser.add_argument("--lora-rank", type=int, default=64)
parser.add_argument("--lora-alpha", type=int, default=64)
parser.add_argument("--lora-dropout", type=float, default=0.1)
parser.add_argument("--learning-rate", type=float, default=1e-4)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--optimizer", type=str, default="adamw_8bit", choices=["adamw", "adamw_8bit"])
parser.add_argument("--max-grad-norm", type=float, default=1.0)
parser.add_argument("--warmup-ratio", type=float, default=0.03)
parser.add_argument("--scheduler", type=str, default="constant_with_warmup", choices=["constant_with_warmup", "linear", "cosine"])
parser.add_argument("--shift", type=float, default=3.0)
# Optional upload
parser.add_argument("--upload-repo", type=str, default="", help="HF model repo to upload final adapter")
parser.add_argument("--upload-path", type=str, default="", help="Path inside upload repo (optional)")
parser.add_argument("--upload-private", action="store_true")
parser.add_argument("--hf-token-env", type=str, default="HF_TOKEN", help="Environment variable name for HF token")
return parser
def _resolve_dataset_dir(args) -> str:
if args.dataset_dir:
return args.dataset_dir
if not args.dataset_repo:
raise ValueError("Provide --dataset-dir or --dataset-repo.")
from huggingface_hub import snapshot_download
token = os.getenv(args.hf_token_env)
temp_root = tempfile.mkdtemp(prefix="acestep_lora_dataset_")
local_dir = os.path.join(temp_root, "dataset")
logger.info(f"Downloading dataset repo {args.dataset_repo}@{args.dataset_revision} to {local_dir}")
snapshot_download(
repo_id=args.dataset_repo,
repo_type="dataset",
revision=args.dataset_revision,
local_dir=local_dir,
local_dir_use_symlinks=False,
token=token,
)
if args.dataset_subdir:
sub = os.path.join(local_dir, args.dataset_subdir)
if not os.path.isdir(sub):
raise FileNotFoundError(f"Dataset subdir not found: {sub}")
return sub
return local_dir
def _upload_adapter_if_requested(args, final_dir: str):
if not args.upload_repo:
return
from huggingface_hub import HfApi
token = os.getenv(args.hf_token_env)
if not token:
raise RuntimeError(
f"{args.hf_token_env} is not set. Needed for upload to {args.upload_repo}."
)
api = HfApi(token=token)
api.create_repo(
repo_id=args.upload_repo,
repo_type="model",
exist_ok=True,
private=bool(args.upload_private),
)
path_in_repo = args.upload_path.strip().strip("/") if args.upload_path else ""
commit_message = f"Upload ACE-Step LoRA adapter from {Path(final_dir).name}"
logger.info(f"Uploading adapter from {final_dir} to {args.upload_repo}/{path_in_repo}")
api.upload_folder(
repo_id=args.upload_repo,
repo_type="model",
folder_path=final_dir,
path_in_repo=path_in_repo,
commit_message=commit_message,
)
logger.info("Upload complete")
def main():
args = _build_arg_parser().parse_args()
dataset_dir = _resolve_dataset_dir(args)
entries = scan_dataset_folder(dataset_dir)
if not entries:
raise RuntimeError(f"No audio files found in dataset: {dataset_dir}")
from acestep.handler import AceStepHandler
project_root = str(Path(__file__).resolve().parent)
handler = AceStepHandler()
status, ok = handler.initialize_service(
project_root=project_root,
config_path=args.model_config,
device=args.device,
use_flash_attention=False,
compile_model=False,
offload_to_cpu=bool(args.offload_to_cpu),
offload_dit_to_cpu=bool(args.offload_dit_to_cpu),
prefer_source=args.prefer_source,
)
print(status)
if not ok:
raise RuntimeError("Model initialization failed")
cfg = LoRATrainConfig(
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
optimizer=args.optimizer,
max_grad_norm=args.max_grad_norm,
warmup_ratio=args.warmup_ratio,
scheduler=args.scheduler,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
save_every_n_epochs=args.save_every,
log_every_n_steps=args.log_every,
shift=args.shift,
max_duration_sec=args.max_duration_sec,
output_dir=args.output_dir,
resume_from=(args.resume_from.strip() if args.resume_from else None),
device=args.device,
)
trainer = LoRATrainer(handler, cfg)
trainer.prepare()
start = time.time()
def _progress(step, total, loss, epoch):
elapsed = time.time() - start
rate = step / elapsed if elapsed > 0 else 0.0
remaining = max(0.0, total - step)
eta_sec = remaining / rate if rate > 0 else -1.0
eta_msg = f"{eta_sec/60:.1f}m" if eta_sec >= 0 else "unknown"
logger.info(
f"[progress] step={step}/{total} epoch={epoch+1} loss={loss:.6f} elapsed={elapsed/60:.1f}m eta={eta_msg}"
)
msg = trainer.train(entries, progress_callback=_progress)
print(msg)
final_dir = os.path.join(cfg.output_dir, "final")
if os.path.isdir(final_dir):
_upload_adapter_if_requested(args, final_dir)
print(f"Final adapter directory: {final_dir}")
else:
print(f"Warning: final adapter directory not found at {final_dir}")
if __name__ == "__main__":
main()