massive rewrite...pink noise warmup, no more crossfading needed, one_shot_generation fixes, jam_worker overhaul...still gotta fix html web tester now due to system.py changes in mrt
1c7440e | # jam_worker.py - Bar-locked spool worker (MagentaRT crossfade handled by system.py) | |
| from __future__ import annotations | |
| import os | |
| import threading | |
| import time | |
| from dataclasses import dataclass | |
| from fractions import Fraction | |
| from typing import Optional, Dict, Tuple | |
| import numpy as np | |
| from magenta_rt import audio as au | |
| from utils import ( | |
| StreamingResampler, | |
| match_loudness_to_reference, | |
| take_bar_aligned_tail, | |
| wav_bytes_base64, | |
| ) | |
| # ----------------------------- | |
| # Data classes | |
| # ----------------------------- | |
| class JamParams: | |
| bpm: float | |
| beats_per_bar: int | |
| bars_per_chunk: int | |
| target_sr: int | |
| loudness_mode: str = "auto" | |
| headroom_db: float = 1.0 | |
| style_vec: Optional[np.ndarray] = None | |
| ref_loop: Optional[au.Waveform] = None | |
| combined_loop: Optional[au.Waveform] = None | |
| guidance_weight: float = 1.1 | |
| temperature: float = 1.1 | |
| topk: int = 40 | |
| # style glide | |
| style_ramp_seconds: float = 8.0 # 0 => instant | |
| class JamChunk: | |
| index: int | |
| audio_bytes: bytes # RAW WAV bytes (not base64) | |
| metadata: dict | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| class BarClock: | |
| """ | |
| Sample-domain bar clock with drift-free absolute boundaries. | |
| We use Fraction to avoid floating drift, then round to integer samples. | |
| """ | |
| def __init__(self, target_sr: int, bpm: float, beats_per_bar: int, base_offset_samples: int = 0): | |
| self.sr = int(target_sr) | |
| self.bpm = Fraction(str(bpm)) # exact decimal | |
| self.beats_per_bar = int(beats_per_bar) | |
| self.bar_samps = Fraction(self.sr * 60 * self.beats_per_bar, 1) / self.bpm | |
| self.base = int(base_offset_samples) | |
| def bounds_for_chunk(self, chunk_index: int, bars_per_chunk: int) -> Tuple[int, int]: | |
| start_f = self.base + self.bar_samps * (chunk_index * bars_per_chunk) | |
| end_f = self.base + self.bar_samps * ((chunk_index + 1) * bars_per_chunk) | |
| return int(round(start_f)), int(round(end_f)) | |
| def seconds_per_bar(self) -> float: | |
| return float(self.beats_per_bar) * (60.0 / float(self.bpm)) | |
| def wav_bytes_raw(samples: np.ndarray, sr: int) -> tuple[bytes, int, int]: | |
| """ | |
| Convert numpy samples to raw WAV bytes. | |
| Returns: (wav_bytes, total_samples, channels) | |
| """ | |
| import io | |
| import soundfile as sf | |
| if samples.ndim == 1: | |
| samples = samples[:, None] | |
| channels = samples.shape[1] | |
| total_samples = samples.shape[0] | |
| buf = io.BytesIO() | |
| sf.write(buf, samples, sr, format="WAV", subtype="PCM_16") | |
| wav_bytes = buf.getvalue() | |
| return wav_bytes, total_samples, channels | |
| # ----------------------------- | |
| # Worker | |
| # ----------------------------- | |
| class JamWorker(threading.Thread): | |
| FRAMES_PER_SECOND: float | None = None | |
| """ | |
| Generates continuous audio with MagentaRT, spools it at target SR, | |
| and emits bar-aligned chunks. | |
| IMPORTANT: | |
| MagentaRT's system.py already performs crossfade internally and returns a chunk | |
| of exactly config.chunk_length_samples. The caller should *not* do extra overlap/correction. | |
| """ | |
| def __init__(self, mrt, params: JamParams): | |
| import logging | |
| super().__init__(daemon=True) | |
| self.mrt = mrt | |
| self.params = params | |
| # Setup logging | |
| self._log = logging.getLogger("JamWorker") | |
| self._log.setLevel(logging.DEBUG) | |
| # external callers (FastAPI endpoints) use this for atomic updates | |
| self._lock = threading.RLock() | |
| # generation state | |
| self.state = self.mrt.init_state() | |
| self.mrt.guidance_weight = float(self.params.guidance_weight) | |
| self.mrt.temperature = float(self.params.temperature) | |
| self.mrt.topk = int(self.params.topk) | |
| # codec/setup | |
| self._codec_fps = float(self.mrt.codec.frame_rate) | |
| JamWorker.FRAMES_PER_SECOND = self._codec_fps | |
| self._ctx_frames = int(self.mrt.config.context_length_frames) | |
| self._ctx_seconds = self._ctx_frames / self._codec_fps | |
| self._model_sr = int(self.mrt.sample_rate) | |
| # style vector (already normalized upstream) | |
| self._style_vec = None if self.params.style_vec is None else np.array(self.params.style_vec, dtype=np.float32, copy=True) | |
| self._chunk_secs = ( | |
| self.mrt.config.chunk_length_frames * self.mrt.config.frame_length_samples | |
| ) / float(self._model_sr) | |
| # target-SR spool | |
| target_sr = int(self.params.target_sr) | |
| if target_sr != self._model_sr: | |
| self._rs = StreamingResampler(self._model_sr, target_sr, channels=2) | |
| else: | |
| self._rs = None | |
| self._spool = np.zeros((0, 2), dtype=np.float32) # target SR | |
| self._spool_written = 0 | |
| # bar clock (assumes the input loop is downbeat-aligned at t=0) | |
| # NOTE: MagentaRT's internal 40ms crossfades happen every 2 seconds in the | |
| # continuous stream, not at bar boundaries. The flam effect on downbeats | |
| # may need to be addressed on the Swift side by adjusting swap timing. | |
| self._bar_clock = BarClock(target_sr, float(self.params.bpm), int(self.params.beats_per_bar), base_offset_samples=500) | |
| # emission counters | |
| self.idx = 0 | |
| self._next_to_deliver = 0 | |
| self._last_consumed_index = -1 | |
| # outbox and synchronization | |
| self._outbox: Dict[int, JamChunk] = {} | |
| self._cv = threading.Condition() | |
| # control flags | |
| self._stop_event = threading.Event() | |
| self._max_buffer_ahead = 1 | |
| # reseed queues | |
| self._pending_reseed: Optional[dict] = None | |
| self._pending_token_splice: Optional[dict] = None | |
| # original context tokens snapshot (used by reseed_splice) | |
| self._original_context_tokens: Optional[np.ndarray] = None | |
| # Prepare initial context from combined loop | |
| if self.params.combined_loop is not None: | |
| self._install_context_from_loop(self.params.combined_loop) | |
| # ---------- lifecycle ---------- | |
| def set_buffer_seconds(self, seconds: float): | |
| chunk_secs = float(self.params.bars_per_chunk) * self._bar_clock.seconds_per_bar() | |
| max_chunks = max(0, int(round(seconds / max(chunk_secs, 1e-6)))) | |
| with self._cv: | |
| self._max_buffer_ahead = max_chunks | |
| def set_buffer_chunks(self, k: int): | |
| with self._cv: | |
| self._max_buffer_ahead = max(0, int(k)) | |
| def stop(self): | |
| self._stop_event.set() | |
| def get_next_chunk(self, timeout: float = 30.0) -> Optional[JamChunk]: | |
| deadline = time.time() + timeout | |
| with self._cv: | |
| while True: | |
| c = self._outbox.get(self._next_to_deliver) | |
| if c is not None: | |
| self._next_to_deliver += 1 | |
| return c | |
| remaining = deadline - time.time() | |
| if remaining <= 0: | |
| return None | |
| self._cv.wait(timeout=min(0.25, remaining)) | |
| def mark_chunk_consumed(self, chunk_index: int): | |
| with self._cv: | |
| self._last_consumed_index = max(self._last_consumed_index, int(chunk_index)) | |
| for k in list(self._outbox.keys()): | |
| if k < self._last_consumed_index - 1: | |
| self._outbox.pop(k, None) | |
| def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None): | |
| with self._lock: | |
| if guidance_weight is not None: | |
| self.params.guidance_weight = float(guidance_weight) | |
| if temperature is not None: | |
| self.params.temperature = float(temperature) | |
| if topk is not None: | |
| self.params.topk = int(topk) | |
| self.mrt.guidance_weight = float(self.params.guidance_weight) | |
| self.mrt.temperature = float(self.params.temperature) | |
| self.mrt.topk = int(self.params.topk) | |
| # ---------- context ---------- | |
| def _expected_token_shape(self) -> Tuple[int, int]: | |
| F = int(self._ctx_frames) | |
| D = int(self.mrt.config.decoder_codec_rvq_depth) | |
| return F, D | |
| def _coerce_tokens(self, toks: np.ndarray) -> np.ndarray: | |
| """ | |
| Coerce tokens to the expected shape (F, D) for the model. | |
| If tokens are too short, we TILE them to fill the context window, | |
| preserving the looping pattern rather than padding with repeated | |
| last tokens at the front. | |
| """ | |
| import math | |
| F, D = self._expected_token_shape() | |
| toks = np.asarray(toks) | |
| if toks.ndim != 2: | |
| toks = np.atleast_2d(toks) | |
| # depth | |
| if toks.shape[1] > D: | |
| toks = toks[:, :D] | |
| elif toks.shape[1] < D: | |
| pad_cols = np.tile(toks[:, -1:], (1, D - toks.shape[1])) | |
| toks = np.concatenate([toks, pad_cols], axis=1) | |
| # frames - TILE instead of padding with repeated last token | |
| if toks.shape[0] < F: | |
| if toks.shape[0] == 0: | |
| toks = np.zeros((1, D), dtype=np.int32) | |
| # Tile to fill the context, then take the tail so END aligns properly | |
| reps = int(math.ceil(F / toks.shape[0])) + 1 | |
| tiled = np.tile(toks, (reps, 1)) | |
| toks = tiled[-F:, :] # Take the last F frames | |
| elif toks.shape[0] > F: | |
| toks = toks[-F:, :] | |
| if toks.dtype != np.int32: | |
| toks = toks.astype(np.int32, copy=False) | |
| return toks | |
| def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray: | |
| """ | |
| Build exactly context_length_frames tokens, ensuring the *end* of the audio | |
| lands on a bar boundary before encoding. | |
| For now we rely on take_bar_aligned_tail() which works in the sample domain, | |
| then coerce tokens to exact shape expected by MagentaRTState. | |
| """ | |
| wav = loop.as_stereo().resample(self._model_sr) | |
| tail = take_bar_aligned_tail( | |
| wav, | |
| bpm=float(self.params.bpm), | |
| beats_per_bar=int(self.params.beats_per_bar), | |
| ctx_seconds=float(self._ctx_seconds), | |
| ) | |
| # Store the context audio for debug purposes | |
| self._debug_context_audio = tail | |
| toks_full = self.mrt.codec.encode(tail).astype(np.int32) | |
| depth = int(self.mrt.config.decoder_codec_rvq_depth) | |
| ctx = toks_full[:, :depth] | |
| return self._coerce_tokens(ctx) | |
| def debug_export_context_audio(self, path: str = "/tmp/debug_context.wav") -> dict: | |
| """ | |
| Export the 10-second context audio that was fed to the model. | |
| Use this to verify the context doesn't have silence gaps at the beginning. | |
| Returns metadata about the context audio. | |
| """ | |
| import soundfile as sf | |
| if not hasattr(self, '_debug_context_audio') or self._debug_context_audio is None: | |
| return {"ok": False, "error": "No context audio available"} | |
| ctx = self._debug_context_audio | |
| samples = ctx.samples | |
| sr = ctx.sample_rate | |
| # Analyze the context for silence | |
| if samples.ndim == 1: | |
| samples = samples[:, None] | |
| # Calculate RMS in 100ms windows to detect silence gaps | |
| window_samples = int(0.1 * sr) # 100ms windows | |
| num_windows = samples.shape[0] // window_samples | |
| rms_values = [] | |
| for i in range(num_windows): | |
| window = samples[i * window_samples: (i + 1) * window_samples] | |
| rms = float(np.sqrt(np.mean(window ** 2))) | |
| rms_values.append(rms) | |
| # Find silence (RMS < threshold) | |
| silence_threshold = 0.001 | |
| silent_windows = [i for i, rms in enumerate(rms_values) if rms < silence_threshold] | |
| # Save to file | |
| sf.write(path, samples, sr, subtype="PCM_16") | |
| return { | |
| "ok": True, | |
| "path": path, | |
| "sample_rate": sr, | |
| "duration_seconds": float(samples.shape[0] / sr), | |
| "num_samples": int(samples.shape[0]), | |
| "channels": int(samples.shape[1]) if samples.ndim > 1 else 1, | |
| "analysis": { | |
| "num_windows": num_windows, | |
| "window_ms": 100, | |
| "silent_windows": silent_windows, | |
| "silence_at_start": any(i < 5 for i in silent_windows), # First 500ms | |
| "silence_at_end": any(i >= num_windows - 5 for i in silent_windows), # Last 500ms | |
| "rms_first_500ms": rms_values[:5] if len(rms_values) >= 5 else rms_values, | |
| "rms_last_500ms": rms_values[-5:] if len(rms_values) >= 5 else rms_values, | |
| } | |
| } | |
| def _install_context_from_loop(self, loop: au.Waveform): | |
| """Install context tokens and crossfade samples from the input loop.""" | |
| # Log input loop stats | |
| loop_rms = float(np.sqrt(np.mean(loop.samples ** 2))) if loop.samples.size > 0 else 0 | |
| self._log.info( | |
| "Installing context from loop: samples=%d, sr=%d, rms=%.6f, duration=%.3fs", | |
| loop.samples.shape[0], loop.sample_rate, loop_rms, | |
| loop.samples.shape[0] / loop.sample_rate | |
| ) | |
| # 1) Build context tokens (your existing exact/coerced logic) | |
| ctx = self._encode_exact_context_tokens(loop) | |
| # Log context token stats | |
| self._log.info( | |
| "Context tokens: shape=%s, min=%d, max=%d, unique=%d", | |
| ctx.shape, int(ctx.min()), int(ctx.max()), len(np.unique(ctx)) | |
| ) | |
| # Save as "original" context for future reseed_splice operations. | |
| # (This represents the bar-locked tail that actually fed the model.) | |
| self._original_context_tokens = np.copy(ctx) | |
| # 2) Build initial crossfade_samples from the same loop tail | |
| # so MagentaRT's first internal crossfade fades from REAL audio, not silence. | |
| n = int(self.mrt.config.crossfade_length_samples) # validated by MagentaRTState | |
| wav = loop.as_stereo().resample(self._model_sr) | |
| tail = wav.samples.astype(np.float32, copy=False) | |
| if tail.ndim == 1: | |
| tail = tail[:, None] | |
| if tail.shape[1] == 1: | |
| tail = np.repeat(tail, 2, axis=1) | |
| elif tail.shape[1] > 2: | |
| tail = tail[:, :2] | |
| # If loop too short for crossfade, TILE it instead of padding with zeros. | |
| original_tail_len = tail.shape[0] | |
| if tail.shape[0] < n and tail.shape[0] > 0: | |
| import math | |
| reps = int(math.ceil(n / tail.shape[0])) + 1 | |
| tail = np.tile(tail, (reps, 1)) | |
| self._log.info( | |
| "Tiled crossfade tail: %d -> %d samples (needed %d, reps=%d)", | |
| original_tail_len, tail.shape[0], n, reps | |
| ) | |
| xfade = tail[-n:, :] | |
| xfade_rms = float(np.sqrt(np.mean(xfade ** 2))) if xfade.size > 0 else 0 | |
| self._log.info( | |
| "Crossfade samples: shape=%s, rms=%.6f (silence=%.6f threshold)", | |
| xfade.shape, xfade_rms, 0.001 | |
| ) | |
| if xfade_rms < 0.001: | |
| self._log.warning("⚠️ Crossfade samples appear to be SILENCE! This may cause first chunk issues.") | |
| # 3) Install into a fresh state (session start only) | |
| s = self.mrt.init_state() | |
| s.context_tokens = ctx | |
| s.crossfade_samples = au.Waveform(xfade, self._model_sr) # must match codec_sample_rate + length | |
| self.state = s | |
| self._log.info("Context installation complete. State chunk_index=%d", s.chunk_index) | |
| # ---------- reseed / token splice ---------- | |
| def _extract_crossfade_from_loop(self, loop: au.Waveform) -> au.Waveform: | |
| """Extract a non-silent crossfade tail from `loop` (tiling if needed).""" | |
| n = int(self.mrt.config.crossfade_length_samples) | |
| wav = loop.as_stereo().resample(self._model_sr) | |
| tail = wav.samples.astype(np.float32, copy=False) | |
| if tail.ndim == 1: | |
| tail = tail[:, None] | |
| if tail.shape[1] == 1: | |
| tail = np.repeat(tail, 2, axis=1) | |
| elif tail.shape[1] > 2: | |
| tail = tail[:, :2] | |
| if tail.shape[0] < n and tail.shape[0] > 0: | |
| import math | |
| reps = int(math.ceil(n / tail.shape[0])) + 1 | |
| tail = np.tile(tail, (reps, 1)) | |
| xfade = tail[-n:, :] | |
| return au.Waveform(xfade, self._model_sr) | |
| def _apply_context_tokens( | |
| self, | |
| *, | |
| ctx_tokens: np.ndarray, | |
| crossfade: Optional[au.Waveform] = None, | |
| update_original: bool = False, | |
| preserve_chunk_index: bool = True, | |
| ) -> None: | |
| """Apply new context tokens to the live state (optionally updating crossfade/original).""" | |
| ctx_tokens = self._coerce_tokens(ctx_tokens) | |
| with self._lock: | |
| # Preserve chunk index so RNG progression stays consistent across reseeds. | |
| cur_idx = int(getattr(self.state, "chunk_index", 0)) | |
| self.state.context_tokens = ctx_tokens | |
| if crossfade is not None: | |
| self.state.crossfade_samples = crossfade | |
| if preserve_chunk_index: | |
| try: | |
| setattr(self.state, "_chunk_index", cur_idx) | |
| except Exception: | |
| pass | |
| if update_original: | |
| self._original_context_tokens = np.copy(ctx_tokens) | |
| def reseed_from_waveform(self, wav: au.Waveform) -> None: | |
| """ | |
| Reseed the model context from a waveform. | |
| New spool-worker behavior: | |
| - Do NOT reset the thread or spool. | |
| - Queue the reseed and apply it at the *next emitted chunk boundary* | |
| so the audible stream stays seamless. | |
| """ | |
| ctx = self._encode_exact_context_tokens(wav) | |
| xfade = self._extract_crossfade_from_loop(wav) | |
| with self._lock: | |
| self._pending_reseed = { | |
| "ctx": ctx, | |
| "xfade": xfade, | |
| "update_original": True, | |
| } | |
| def reseed_splice(self, recent_wav: Optional[au.Waveform], anchor_bars: float = 2.0) -> None: | |
| """ | |
| Seamless reseed by splicing tokens at the next chunk boundary. | |
| Intended behavior: | |
| - Keep the *first* `anchor_bars` worth of context from the original loop | |
| - Keep the *remainder* from the CURRENT state.context_tokens (no re-encoding!) | |
| - This ensures the tail of the spliced context is token-identical to what | |
| MagentaRT was already working with, avoiding codec round-trip artifacts. | |
| """ | |
| F, D = self._expected_token_shape() | |
| with self._lock: | |
| orig = self._original_context_tokens | |
| cur = self._coerce_tokens(self.state.context_tokens.copy()) # LIVE tokens | |
| if orig is None: | |
| # No original context saved, can't do anchor-based splice | |
| self._log.warning("reseed_splice called but no original context saved, ignoring") | |
| return | |
| spb = self._bar_clock.seconds_per_bar() | |
| frames_per_bar = max(1, int(round(self._codec_fps * spb))) | |
| anchor_bars = max(0.0, float(anchor_bars)) | |
| anchor_frames = int(round(anchor_bars * frames_per_bar)) | |
| anchor_frames = max(1, min(anchor_frames, F - 1)) | |
| orig = self._coerce_tokens(orig) | |
| # Anchor from ORIGINAL, continuation from CURRENT (no re-encoding!) | |
| left = orig[:anchor_frames, :] | |
| right = cur[anchor_frames:, :] # Keep current tokens for the rest | |
| spliced = np.concatenate([left, right], axis=0) | |
| spliced = self._coerce_tokens(spliced) | |
| self._log.info( | |
| "reseed_splice: anchor_frames=%d from orig, continuation_frames=%d from current", | |
| anchor_frames, F - anchor_frames | |
| ) | |
| with self._lock: | |
| self._pending_token_splice = { | |
| "tokens": spliced, | |
| "crossfade": None, # Keep existing crossfade_samples for continuity | |
| "debug": { | |
| "F": int(F), | |
| "D": int(D), | |
| "frames_per_bar": int(frames_per_bar), | |
| "anchor_bars": float(anchor_bars), | |
| "anchor_frames": int(anchor_frames), | |
| }, | |
| } | |
| def _apply_pending_at_boundary(self) -> None: | |
| """Apply any queued reseed/token-splice right after a chunk boundary.""" | |
| pending_reseed = None | |
| pending_splice = None | |
| with self._lock: | |
| if self._pending_reseed is not None: | |
| pending_reseed = self._pending_reseed | |
| self._pending_reseed = None | |
| if self._pending_token_splice is not None: | |
| pending_splice = self._pending_token_splice | |
| self._pending_token_splice = None | |
| if pending_reseed is not None: | |
| try: | |
| self._log.info( | |
| "🔄 RESEED applying: spool_written=%d, idx=%d, next_bounds=%s", | |
| self._spool_written, self.idx, | |
| self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk) | |
| ) | |
| self._apply_context_tokens( | |
| ctx_tokens=pending_reseed["ctx"], | |
| crossfade=pending_reseed.get("xfade"), | |
| update_original=bool(pending_reseed.get("update_original", False)), | |
| preserve_chunk_index=True, | |
| ) | |
| self._post_reseed_pending = True | |
| self._log.info("Applied reseed at boundary (chunk_index=%d)", int(getattr(self.state, "chunk_index", 0))) | |
| except Exception as e: | |
| self._log.exception("Failed applying reseed: %s", e) | |
| if pending_splice is not None: | |
| try: | |
| self._log.info( | |
| "✂️ SPLICE applying: spool_written=%d, idx=%d, next_bounds=%s", | |
| self._spool_written, self.idx, | |
| self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk) | |
| ) | |
| self._apply_context_tokens( | |
| ctx_tokens=pending_splice["tokens"], | |
| crossfade=pending_splice.get("crossfade"), # Phase-aligned crossfade | |
| update_original=False, | |
| preserve_chunk_index=True, | |
| ) | |
| self._post_splice_pending = True | |
| self._log.info( | |
| "Applied token splice at boundary: %s", | |
| pending_splice.get("debug", {}) | |
| ) | |
| except Exception as e: | |
| self._log.exception("Failed applying token splice: %s", e) | |
| # ---------- core streaming helpers ---------- | |
| def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None: | |
| """ | |
| Append one MagentaRT chunk into the target-SR spool. | |
| NOTE: system.py already crossfades internally and returns exactly chunk_length_samples, | |
| so we do NOT do any overlap/correction here. We only resample (streaming) if needed. | |
| """ | |
| w = wav.as_stereo() | |
| x = w.samples.astype(np.float32, copy=False) | |
| if x.ndim == 1: | |
| x = x[:, None] | |
| if x.shape[1] == 1: | |
| x = np.repeat(x, 2, axis=1) | |
| elif x.shape[1] > 2: | |
| x = x[:, :2] | |
| # Log stats for each MagentaRT chunk BEFORE resampling | |
| chunk_rms = float(np.sqrt(np.mean(x ** 2))) if x.size > 0 else 0 | |
| chunk_peak = float(np.max(np.abs(x))) if x.size > 0 else 0 | |
| # Check for silence in 100ms windows | |
| window_samples = int(0.1 * self._model_sr) | |
| num_windows = x.shape[0] // window_samples | |
| silent_windows = 0 | |
| for i in range(num_windows): | |
| window = x[i * window_samples: (i + 1) * window_samples] | |
| if np.sqrt(np.mean(window ** 2)) < 0.001: | |
| silent_windows += 1 | |
| mrt_chunk_idx = self.state.chunk_index - 1 | |
| # NEW: Detect first chunk after reseed/splice | |
| post_reseed = getattr(self, '_post_reseed_pending', False) | |
| post_splice = getattr(self, '_post_splice_pending', False) | |
| if post_reseed or post_splice: | |
| event_type = "RESEED" if post_reseed else "SPLICE" | |
| self._log.info( | |
| "🎯 FIRST MRT CHUNK POST-%s: mrt_idx=%d, spool_written_before=%d, appending=%d samples", | |
| event_type, mrt_chunk_idx, self._spool_written, x.shape[0] | |
| ) | |
| self._post_reseed_pending = False | |
| self._post_splice_pending = False | |
| if silent_windows > num_windows // 2: | |
| self._log.warning( | |
| "⚠️ MRT chunk %d has >50%% silence! (%d/%d windows silent)", | |
| mrt_chunk_idx, silent_windows, num_windows | |
| ) | |
| if self._rs is not None: | |
| y = self._rs.process(x, final=False) | |
| else: | |
| y = x | |
| if y is None or y.size == 0: | |
| return | |
| if y.ndim == 1: | |
| y = y[:, None] | |
| if y.shape[1] == 1: | |
| y = np.repeat(y, 2, axis=1) | |
| elif y.shape[1] > 2: | |
| y = y[:, :2] | |
| self._spool = np.concatenate([self._spool, y.astype(np.float32, copy=False)], axis=0) | |
| self._spool_written += int(y.shape[0]) | |
| def _should_generate_next_chunk(self) -> bool: | |
| with self._cv: | |
| ahead = self.idx - self._last_consumed_index | |
| return ahead <= self._max_buffer_ahead | |
| def _emit_ready(self): | |
| while True: | |
| start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk) | |
| if end > self._spool_written: | |
| break | |
| loop = self._spool[start:end] | |
| # Log emitted chunk stats BEFORE loudness matching | |
| emit_rms_before = float(np.sqrt(np.mean(loop ** 2))) if loop.size > 0 else 0 | |
| emit_peak_before = float(np.max(np.abs(loop))) if loop.size > 0 else 0 | |
| # Check for silence in the emitted chunk | |
| sr = int(self.params.target_sr) | |
| window_samples = int(0.1 * sr) # 100ms windows | |
| num_windows = loop.shape[0] // window_samples | |
| silent_windows = 0 | |
| for i in range(num_windows): | |
| window = loop[i * window_samples: (i + 1) * window_samples] | |
| if np.sqrt(np.mean(window ** 2)) < 0.001: | |
| silent_windows += 1 | |
| self._log.info( | |
| "Emitting chunk %d: spool[%d:%d], samples=%d, rms=%.4f, peak=%.4f, silent=%d/%d windows", | |
| self.idx, start, end, loop.shape[0], emit_rms_before, emit_peak_before, silent_windows, num_windows | |
| ) | |
| if silent_windows > num_windows // 4: | |
| self._log.warning( | |
| "⚠️ Emitted chunk %d has >25%% silence! (%d/%d windows)", | |
| self.idx, silent_windows, num_windows | |
| ) | |
| # optional loudness match (per emitted chunk) | |
| if self.params.loudness_mode != "none" and self.params.combined_loop is not None: | |
| comb = self.params.combined_loop.as_stereo().resample(sr).samples.astype(np.float32, copy=False) | |
| if comb.ndim == 1: | |
| comb = comb[:, None] | |
| if comb.shape[1] == 1: | |
| comb = np.repeat(comb, 2, axis=1) | |
| elif comb.shape[1] > 2: | |
| comb = comb[:, :2] | |
| need = end - start | |
| if comb.shape[0] > 0 and need > 0: | |
| s = start % comb.shape[0] | |
| if s + need <= comb.shape[0]: | |
| ref_slice = comb[s:s + need] | |
| else: | |
| part1 = comb[s:] | |
| part2 = comb[:max(0, need - part1.shape[0])] | |
| ref_slice = np.vstack([part1, part2]) | |
| ref = au.Waveform(ref_slice, sr) | |
| tgt = au.Waveform(loop.copy(), sr) | |
| matched, _stats = match_loudness_to_reference( | |
| ref, tgt, | |
| method=self.params.loudness_mode, | |
| headroom_db=self.params.headroom_db, | |
| ) | |
| loop = matched.samples | |
| emit_rms_after = float(np.sqrt(np.mean(loop ** 2))) if loop.size > 0 else 0 | |
| self._log.debug( | |
| "Chunk %d loudness matched: rms %.4f -> %.4f, gain=%.2fdB", | |
| self.idx, emit_rms_before, emit_rms_after, _stats.get("applied_gain_db", 0) | |
| ) | |
| # Use raw bytes instead of base64 | |
| audio_bytes, total_samples, channels = wav_bytes_raw(loop, int(self.params.target_sr)) | |
| meta = { | |
| "bpm": float(self.params.bpm), | |
| "bars": int(self.params.bars_per_chunk), | |
| "beats_per_bar": int(self.params.beats_per_bar), | |
| "sample_rate": int(self.params.target_sr), | |
| "channels": int(channels), | |
| "total_samples": int(total_samples), | |
| "seconds_per_bar": self._bar_clock.seconds_per_bar(), | |
| "loop_duration_seconds": self.params.bars_per_chunk * self._bar_clock.seconds_per_bar(), | |
| "guidance_weight": float(self.params.guidance_weight), | |
| "temperature": float(self.params.temperature), | |
| "topk": int(self.params.topk), | |
| } | |
| chunk = JamChunk(index=self.idx, audio_bytes=audio_bytes, metadata=meta) | |
| with self._cv: | |
| self._outbox[self.idx] = chunk | |
| self._cv.notify_all() | |
| self.idx += 1 | |
| # Apply any queued reseed/splice exactly at this chunk boundary so the audible | |
| # stream remains seamless (changes affect *next* generated audio). | |
| self._apply_pending_at_boundary() | |
| # ---------- main loop ---------- | |
| def run(self): | |
| while not self._stop_event.is_set(): | |
| if not self._should_generate_next_chunk(): | |
| self._emit_ready() | |
| time.sleep(0.01) | |
| continue | |
| # snapshot style vector (with optional glide) | |
| with self._lock: | |
| target = self.params.style_vec | |
| if target is None: | |
| style_to_use = None | |
| else: | |
| if self._style_vec is None: | |
| self._style_vec = np.array(target, dtype=np.float32, copy=True) | |
| else: | |
| ramp = float(self.params.style_ramp_seconds or 0.0) | |
| step = 1.0 if ramp <= 0.0 else min(1.0, self._chunk_secs / ramp) | |
| self._style_vec += step * (target.astype(np.float32, copy=False) - self._style_vec) | |
| style_to_use = self._style_vec | |
| wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_to_use) | |
| self._append_model_chunk_and_spool(wav) | |
| self._emit_ready() | |
| # flush resampler if active | |
| if self._rs is not None: | |
| tail = self._rs.process(np.zeros((0, 2), np.float32), final=True) | |
| if tail is not None and tail.size: | |
| self._spool = np.concatenate([self._spool, tail], axis=0) | |
| self._spool_written += int(tail.shape[0]) | |
| self._emit_ready() | |