magenta-retry / jam_worker.py
thecollabagepatch's picture
massive rewrite...pink noise warmup, no more crossfading needed, one_shot_generation fixes, jam_worker overhaul...still gotta fix html web tester now due to system.py changes in mrt
1c7440e
# jam_worker.py - Bar-locked spool worker (MagentaRT crossfade handled by system.py)
from __future__ import annotations
import os
import threading
import time
from dataclasses import dataclass
from fractions import Fraction
from typing import Optional, Dict, Tuple
import numpy as np
from magenta_rt import audio as au
from utils import (
StreamingResampler,
match_loudness_to_reference,
take_bar_aligned_tail,
wav_bytes_base64,
)
# -----------------------------
# Data classes
# -----------------------------
@dataclass
class JamParams:
bpm: float
beats_per_bar: int
bars_per_chunk: int
target_sr: int
loudness_mode: str = "auto"
headroom_db: float = 1.0
style_vec: Optional[np.ndarray] = None
ref_loop: Optional[au.Waveform] = None
combined_loop: Optional[au.Waveform] = None
guidance_weight: float = 1.1
temperature: float = 1.1
topk: int = 40
# style glide
style_ramp_seconds: float = 8.0 # 0 => instant
@dataclass
class JamChunk:
index: int
audio_bytes: bytes # RAW WAV bytes (not base64)
metadata: dict
# -----------------------------
# Helpers
# -----------------------------
class BarClock:
"""
Sample-domain bar clock with drift-free absolute boundaries.
We use Fraction to avoid floating drift, then round to integer samples.
"""
def __init__(self, target_sr: int, bpm: float, beats_per_bar: int, base_offset_samples: int = 0):
self.sr = int(target_sr)
self.bpm = Fraction(str(bpm)) # exact decimal
self.beats_per_bar = int(beats_per_bar)
self.bar_samps = Fraction(self.sr * 60 * self.beats_per_bar, 1) / self.bpm
self.base = int(base_offset_samples)
def bounds_for_chunk(self, chunk_index: int, bars_per_chunk: int) -> Tuple[int, int]:
start_f = self.base + self.bar_samps * (chunk_index * bars_per_chunk)
end_f = self.base + self.bar_samps * ((chunk_index + 1) * bars_per_chunk)
return int(round(start_f)), int(round(end_f))
def seconds_per_bar(self) -> float:
return float(self.beats_per_bar) * (60.0 / float(self.bpm))
def wav_bytes_raw(samples: np.ndarray, sr: int) -> tuple[bytes, int, int]:
"""
Convert numpy samples to raw WAV bytes.
Returns: (wav_bytes, total_samples, channels)
"""
import io
import soundfile as sf
if samples.ndim == 1:
samples = samples[:, None]
channels = samples.shape[1]
total_samples = samples.shape[0]
buf = io.BytesIO()
sf.write(buf, samples, sr, format="WAV", subtype="PCM_16")
wav_bytes = buf.getvalue()
return wav_bytes, total_samples, channels
# -----------------------------
# Worker
# -----------------------------
class JamWorker(threading.Thread):
FRAMES_PER_SECOND: float | None = None
"""
Generates continuous audio with MagentaRT, spools it at target SR,
and emits bar-aligned chunks.
IMPORTANT:
MagentaRT's system.py already performs crossfade internally and returns a chunk
of exactly config.chunk_length_samples. The caller should *not* do extra overlap/correction.
"""
def __init__(self, mrt, params: JamParams):
import logging
super().__init__(daemon=True)
self.mrt = mrt
self.params = params
# Setup logging
self._log = logging.getLogger("JamWorker")
self._log.setLevel(logging.DEBUG)
# external callers (FastAPI endpoints) use this for atomic updates
self._lock = threading.RLock()
# generation state
self.state = self.mrt.init_state()
self.mrt.guidance_weight = float(self.params.guidance_weight)
self.mrt.temperature = float(self.params.temperature)
self.mrt.topk = int(self.params.topk)
# codec/setup
self._codec_fps = float(self.mrt.codec.frame_rate)
JamWorker.FRAMES_PER_SECOND = self._codec_fps
self._ctx_frames = int(self.mrt.config.context_length_frames)
self._ctx_seconds = self._ctx_frames / self._codec_fps
self._model_sr = int(self.mrt.sample_rate)
# style vector (already normalized upstream)
self._style_vec = None if self.params.style_vec is None else np.array(self.params.style_vec, dtype=np.float32, copy=True)
self._chunk_secs = (
self.mrt.config.chunk_length_frames * self.mrt.config.frame_length_samples
) / float(self._model_sr)
# target-SR spool
target_sr = int(self.params.target_sr)
if target_sr != self._model_sr:
self._rs = StreamingResampler(self._model_sr, target_sr, channels=2)
else:
self._rs = None
self._spool = np.zeros((0, 2), dtype=np.float32) # target SR
self._spool_written = 0
# bar clock (assumes the input loop is downbeat-aligned at t=0)
# NOTE: MagentaRT's internal 40ms crossfades happen every 2 seconds in the
# continuous stream, not at bar boundaries. The flam effect on downbeats
# may need to be addressed on the Swift side by adjusting swap timing.
self._bar_clock = BarClock(target_sr, float(self.params.bpm), int(self.params.beats_per_bar), base_offset_samples=500)
# emission counters
self.idx = 0
self._next_to_deliver = 0
self._last_consumed_index = -1
# outbox and synchronization
self._outbox: Dict[int, JamChunk] = {}
self._cv = threading.Condition()
# control flags
self._stop_event = threading.Event()
self._max_buffer_ahead = 1
# reseed queues
self._pending_reseed: Optional[dict] = None
self._pending_token_splice: Optional[dict] = None
# original context tokens snapshot (used by reseed_splice)
self._original_context_tokens: Optional[np.ndarray] = None
# Prepare initial context from combined loop
if self.params.combined_loop is not None:
self._install_context_from_loop(self.params.combined_loop)
# ---------- lifecycle ----------
def set_buffer_seconds(self, seconds: float):
chunk_secs = float(self.params.bars_per_chunk) * self._bar_clock.seconds_per_bar()
max_chunks = max(0, int(round(seconds / max(chunk_secs, 1e-6))))
with self._cv:
self._max_buffer_ahead = max_chunks
def set_buffer_chunks(self, k: int):
with self._cv:
self._max_buffer_ahead = max(0, int(k))
def stop(self):
self._stop_event.set()
def get_next_chunk(self, timeout: float = 30.0) -> Optional[JamChunk]:
deadline = time.time() + timeout
with self._cv:
while True:
c = self._outbox.get(self._next_to_deliver)
if c is not None:
self._next_to_deliver += 1
return c
remaining = deadline - time.time()
if remaining <= 0:
return None
self._cv.wait(timeout=min(0.25, remaining))
def mark_chunk_consumed(self, chunk_index: int):
with self._cv:
self._last_consumed_index = max(self._last_consumed_index, int(chunk_index))
for k in list(self._outbox.keys()):
if k < self._last_consumed_index - 1:
self._outbox.pop(k, None)
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
with self._lock:
if guidance_weight is not None:
self.params.guidance_weight = float(guidance_weight)
if temperature is not None:
self.params.temperature = float(temperature)
if topk is not None:
self.params.topk = int(topk)
self.mrt.guidance_weight = float(self.params.guidance_weight)
self.mrt.temperature = float(self.params.temperature)
self.mrt.topk = int(self.params.topk)
# ---------- context ----------
def _expected_token_shape(self) -> Tuple[int, int]:
F = int(self._ctx_frames)
D = int(self.mrt.config.decoder_codec_rvq_depth)
return F, D
def _coerce_tokens(self, toks: np.ndarray) -> np.ndarray:
"""
Coerce tokens to the expected shape (F, D) for the model.
If tokens are too short, we TILE them to fill the context window,
preserving the looping pattern rather than padding with repeated
last tokens at the front.
"""
import math
F, D = self._expected_token_shape()
toks = np.asarray(toks)
if toks.ndim != 2:
toks = np.atleast_2d(toks)
# depth
if toks.shape[1] > D:
toks = toks[:, :D]
elif toks.shape[1] < D:
pad_cols = np.tile(toks[:, -1:], (1, D - toks.shape[1]))
toks = np.concatenate([toks, pad_cols], axis=1)
# frames - TILE instead of padding with repeated last token
if toks.shape[0] < F:
if toks.shape[0] == 0:
toks = np.zeros((1, D), dtype=np.int32)
# Tile to fill the context, then take the tail so END aligns properly
reps = int(math.ceil(F / toks.shape[0])) + 1
tiled = np.tile(toks, (reps, 1))
toks = tiled[-F:, :] # Take the last F frames
elif toks.shape[0] > F:
toks = toks[-F:, :]
if toks.dtype != np.int32:
toks = toks.astype(np.int32, copy=False)
return toks
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
"""
Build exactly context_length_frames tokens, ensuring the *end* of the audio
lands on a bar boundary before encoding.
For now we rely on take_bar_aligned_tail() which works in the sample domain,
then coerce tokens to exact shape expected by MagentaRTState.
"""
wav = loop.as_stereo().resample(self._model_sr)
tail = take_bar_aligned_tail(
wav,
bpm=float(self.params.bpm),
beats_per_bar=int(self.params.beats_per_bar),
ctx_seconds=float(self._ctx_seconds),
)
# Store the context audio for debug purposes
self._debug_context_audio = tail
toks_full = self.mrt.codec.encode(tail).astype(np.int32)
depth = int(self.mrt.config.decoder_codec_rvq_depth)
ctx = toks_full[:, :depth]
return self._coerce_tokens(ctx)
def debug_export_context_audio(self, path: str = "/tmp/debug_context.wav") -> dict:
"""
Export the 10-second context audio that was fed to the model.
Use this to verify the context doesn't have silence gaps at the beginning.
Returns metadata about the context audio.
"""
import soundfile as sf
if not hasattr(self, '_debug_context_audio') or self._debug_context_audio is None:
return {"ok": False, "error": "No context audio available"}
ctx = self._debug_context_audio
samples = ctx.samples
sr = ctx.sample_rate
# Analyze the context for silence
if samples.ndim == 1:
samples = samples[:, None]
# Calculate RMS in 100ms windows to detect silence gaps
window_samples = int(0.1 * sr) # 100ms windows
num_windows = samples.shape[0] // window_samples
rms_values = []
for i in range(num_windows):
window = samples[i * window_samples: (i + 1) * window_samples]
rms = float(np.sqrt(np.mean(window ** 2)))
rms_values.append(rms)
# Find silence (RMS < threshold)
silence_threshold = 0.001
silent_windows = [i for i, rms in enumerate(rms_values) if rms < silence_threshold]
# Save to file
sf.write(path, samples, sr, subtype="PCM_16")
return {
"ok": True,
"path": path,
"sample_rate": sr,
"duration_seconds": float(samples.shape[0] / sr),
"num_samples": int(samples.shape[0]),
"channels": int(samples.shape[1]) if samples.ndim > 1 else 1,
"analysis": {
"num_windows": num_windows,
"window_ms": 100,
"silent_windows": silent_windows,
"silence_at_start": any(i < 5 for i in silent_windows), # First 500ms
"silence_at_end": any(i >= num_windows - 5 for i in silent_windows), # Last 500ms
"rms_first_500ms": rms_values[:5] if len(rms_values) >= 5 else rms_values,
"rms_last_500ms": rms_values[-5:] if len(rms_values) >= 5 else rms_values,
}
}
def _install_context_from_loop(self, loop: au.Waveform):
"""Install context tokens and crossfade samples from the input loop."""
# Log input loop stats
loop_rms = float(np.sqrt(np.mean(loop.samples ** 2))) if loop.samples.size > 0 else 0
self._log.info(
"Installing context from loop: samples=%d, sr=%d, rms=%.6f, duration=%.3fs",
loop.samples.shape[0], loop.sample_rate, loop_rms,
loop.samples.shape[0] / loop.sample_rate
)
# 1) Build context tokens (your existing exact/coerced logic)
ctx = self._encode_exact_context_tokens(loop)
# Log context token stats
self._log.info(
"Context tokens: shape=%s, min=%d, max=%d, unique=%d",
ctx.shape, int(ctx.min()), int(ctx.max()), len(np.unique(ctx))
)
# Save as "original" context for future reseed_splice operations.
# (This represents the bar-locked tail that actually fed the model.)
self._original_context_tokens = np.copy(ctx)
# 2) Build initial crossfade_samples from the same loop tail
# so MagentaRT's first internal crossfade fades from REAL audio, not silence.
n = int(self.mrt.config.crossfade_length_samples) # validated by MagentaRTState
wav = loop.as_stereo().resample(self._model_sr)
tail = wav.samples.astype(np.float32, copy=False)
if tail.ndim == 1:
tail = tail[:, None]
if tail.shape[1] == 1:
tail = np.repeat(tail, 2, axis=1)
elif tail.shape[1] > 2:
tail = tail[:, :2]
# If loop too short for crossfade, TILE it instead of padding with zeros.
original_tail_len = tail.shape[0]
if tail.shape[0] < n and tail.shape[0] > 0:
import math
reps = int(math.ceil(n / tail.shape[0])) + 1
tail = np.tile(tail, (reps, 1))
self._log.info(
"Tiled crossfade tail: %d -> %d samples (needed %d, reps=%d)",
original_tail_len, tail.shape[0], n, reps
)
xfade = tail[-n:, :]
xfade_rms = float(np.sqrt(np.mean(xfade ** 2))) if xfade.size > 0 else 0
self._log.info(
"Crossfade samples: shape=%s, rms=%.6f (silence=%.6f threshold)",
xfade.shape, xfade_rms, 0.001
)
if xfade_rms < 0.001:
self._log.warning("⚠️ Crossfade samples appear to be SILENCE! This may cause first chunk issues.")
# 3) Install into a fresh state (session start only)
s = self.mrt.init_state()
s.context_tokens = ctx
s.crossfade_samples = au.Waveform(xfade, self._model_sr) # must match codec_sample_rate + length
self.state = s
self._log.info("Context installation complete. State chunk_index=%d", s.chunk_index)
# ---------- reseed / token splice ----------
def _extract_crossfade_from_loop(self, loop: au.Waveform) -> au.Waveform:
"""Extract a non-silent crossfade tail from `loop` (tiling if needed)."""
n = int(self.mrt.config.crossfade_length_samples)
wav = loop.as_stereo().resample(self._model_sr)
tail = wav.samples.astype(np.float32, copy=False)
if tail.ndim == 1:
tail = tail[:, None]
if tail.shape[1] == 1:
tail = np.repeat(tail, 2, axis=1)
elif tail.shape[1] > 2:
tail = tail[:, :2]
if tail.shape[0] < n and tail.shape[0] > 0:
import math
reps = int(math.ceil(n / tail.shape[0])) + 1
tail = np.tile(tail, (reps, 1))
xfade = tail[-n:, :]
return au.Waveform(xfade, self._model_sr)
def _apply_context_tokens(
self,
*,
ctx_tokens: np.ndarray,
crossfade: Optional[au.Waveform] = None,
update_original: bool = False,
preserve_chunk_index: bool = True,
) -> None:
"""Apply new context tokens to the live state (optionally updating crossfade/original)."""
ctx_tokens = self._coerce_tokens(ctx_tokens)
with self._lock:
# Preserve chunk index so RNG progression stays consistent across reseeds.
cur_idx = int(getattr(self.state, "chunk_index", 0))
self.state.context_tokens = ctx_tokens
if crossfade is not None:
self.state.crossfade_samples = crossfade
if preserve_chunk_index:
try:
setattr(self.state, "_chunk_index", cur_idx)
except Exception:
pass
if update_original:
self._original_context_tokens = np.copy(ctx_tokens)
def reseed_from_waveform(self, wav: au.Waveform) -> None:
"""
Reseed the model context from a waveform.
New spool-worker behavior:
- Do NOT reset the thread or spool.
- Queue the reseed and apply it at the *next emitted chunk boundary*
so the audible stream stays seamless.
"""
ctx = self._encode_exact_context_tokens(wav)
xfade = self._extract_crossfade_from_loop(wav)
with self._lock:
self._pending_reseed = {
"ctx": ctx,
"xfade": xfade,
"update_original": True,
}
def reseed_splice(self, recent_wav: Optional[au.Waveform], anchor_bars: float = 2.0) -> None:
"""
Seamless reseed by splicing tokens at the next chunk boundary.
Intended behavior:
- Keep the *first* `anchor_bars` worth of context from the original loop
- Keep the *remainder* from the CURRENT state.context_tokens (no re-encoding!)
- This ensures the tail of the spliced context is token-identical to what
MagentaRT was already working with, avoiding codec round-trip artifacts.
"""
F, D = self._expected_token_shape()
with self._lock:
orig = self._original_context_tokens
cur = self._coerce_tokens(self.state.context_tokens.copy()) # LIVE tokens
if orig is None:
# No original context saved, can't do anchor-based splice
self._log.warning("reseed_splice called but no original context saved, ignoring")
return
spb = self._bar_clock.seconds_per_bar()
frames_per_bar = max(1, int(round(self._codec_fps * spb)))
anchor_bars = max(0.0, float(anchor_bars))
anchor_frames = int(round(anchor_bars * frames_per_bar))
anchor_frames = max(1, min(anchor_frames, F - 1))
orig = self._coerce_tokens(orig)
# Anchor from ORIGINAL, continuation from CURRENT (no re-encoding!)
left = orig[:anchor_frames, :]
right = cur[anchor_frames:, :] # Keep current tokens for the rest
spliced = np.concatenate([left, right], axis=0)
spliced = self._coerce_tokens(spliced)
self._log.info(
"reseed_splice: anchor_frames=%d from orig, continuation_frames=%d from current",
anchor_frames, F - anchor_frames
)
with self._lock:
self._pending_token_splice = {
"tokens": spliced,
"crossfade": None, # Keep existing crossfade_samples for continuity
"debug": {
"F": int(F),
"D": int(D),
"frames_per_bar": int(frames_per_bar),
"anchor_bars": float(anchor_bars),
"anchor_frames": int(anchor_frames),
},
}
def _apply_pending_at_boundary(self) -> None:
"""Apply any queued reseed/token-splice right after a chunk boundary."""
pending_reseed = None
pending_splice = None
with self._lock:
if self._pending_reseed is not None:
pending_reseed = self._pending_reseed
self._pending_reseed = None
if self._pending_token_splice is not None:
pending_splice = self._pending_token_splice
self._pending_token_splice = None
if pending_reseed is not None:
try:
self._log.info(
"🔄 RESEED applying: spool_written=%d, idx=%d, next_bounds=%s",
self._spool_written, self.idx,
self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk)
)
self._apply_context_tokens(
ctx_tokens=pending_reseed["ctx"],
crossfade=pending_reseed.get("xfade"),
update_original=bool(pending_reseed.get("update_original", False)),
preserve_chunk_index=True,
)
self._post_reseed_pending = True
self._log.info("Applied reseed at boundary (chunk_index=%d)", int(getattr(self.state, "chunk_index", 0)))
except Exception as e:
self._log.exception("Failed applying reseed: %s", e)
if pending_splice is not None:
try:
self._log.info(
"✂️ SPLICE applying: spool_written=%d, idx=%d, next_bounds=%s",
self._spool_written, self.idx,
self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk)
)
self._apply_context_tokens(
ctx_tokens=pending_splice["tokens"],
crossfade=pending_splice.get("crossfade"), # Phase-aligned crossfade
update_original=False,
preserve_chunk_index=True,
)
self._post_splice_pending = True
self._log.info(
"Applied token splice at boundary: %s",
pending_splice.get("debug", {})
)
except Exception as e:
self._log.exception("Failed applying token splice: %s", e)
# ---------- core streaming helpers ----------
def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
"""
Append one MagentaRT chunk into the target-SR spool.
NOTE: system.py already crossfades internally and returns exactly chunk_length_samples,
so we do NOT do any overlap/correction here. We only resample (streaming) if needed.
"""
w = wav.as_stereo()
x = w.samples.astype(np.float32, copy=False)
if x.ndim == 1:
x = x[:, None]
if x.shape[1] == 1:
x = np.repeat(x, 2, axis=1)
elif x.shape[1] > 2:
x = x[:, :2]
# Log stats for each MagentaRT chunk BEFORE resampling
chunk_rms = float(np.sqrt(np.mean(x ** 2))) if x.size > 0 else 0
chunk_peak = float(np.max(np.abs(x))) if x.size > 0 else 0
# Check for silence in 100ms windows
window_samples = int(0.1 * self._model_sr)
num_windows = x.shape[0] // window_samples
silent_windows = 0
for i in range(num_windows):
window = x[i * window_samples: (i + 1) * window_samples]
if np.sqrt(np.mean(window ** 2)) < 0.001:
silent_windows += 1
mrt_chunk_idx = self.state.chunk_index - 1
# NEW: Detect first chunk after reseed/splice
post_reseed = getattr(self, '_post_reseed_pending', False)
post_splice = getattr(self, '_post_splice_pending', False)
if post_reseed or post_splice:
event_type = "RESEED" if post_reseed else "SPLICE"
self._log.info(
"🎯 FIRST MRT CHUNK POST-%s: mrt_idx=%d, spool_written_before=%d, appending=%d samples",
event_type, mrt_chunk_idx, self._spool_written, x.shape[0]
)
self._post_reseed_pending = False
self._post_splice_pending = False
if silent_windows > num_windows // 2:
self._log.warning(
"⚠️ MRT chunk %d has >50%% silence! (%d/%d windows silent)",
mrt_chunk_idx, silent_windows, num_windows
)
if self._rs is not None:
y = self._rs.process(x, final=False)
else:
y = x
if y is None or y.size == 0:
return
if y.ndim == 1:
y = y[:, None]
if y.shape[1] == 1:
y = np.repeat(y, 2, axis=1)
elif y.shape[1] > 2:
y = y[:, :2]
self._spool = np.concatenate([self._spool, y.astype(np.float32, copy=False)], axis=0)
self._spool_written += int(y.shape[0])
def _should_generate_next_chunk(self) -> bool:
with self._cv:
ahead = self.idx - self._last_consumed_index
return ahead <= self._max_buffer_ahead
def _emit_ready(self):
while True:
start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk)
if end > self._spool_written:
break
loop = self._spool[start:end]
# Log emitted chunk stats BEFORE loudness matching
emit_rms_before = float(np.sqrt(np.mean(loop ** 2))) if loop.size > 0 else 0
emit_peak_before = float(np.max(np.abs(loop))) if loop.size > 0 else 0
# Check for silence in the emitted chunk
sr = int(self.params.target_sr)
window_samples = int(0.1 * sr) # 100ms windows
num_windows = loop.shape[0] // window_samples
silent_windows = 0
for i in range(num_windows):
window = loop[i * window_samples: (i + 1) * window_samples]
if np.sqrt(np.mean(window ** 2)) < 0.001:
silent_windows += 1
self._log.info(
"Emitting chunk %d: spool[%d:%d], samples=%d, rms=%.4f, peak=%.4f, silent=%d/%d windows",
self.idx, start, end, loop.shape[0], emit_rms_before, emit_peak_before, silent_windows, num_windows
)
if silent_windows > num_windows // 4:
self._log.warning(
"⚠️ Emitted chunk %d has >25%% silence! (%d/%d windows)",
self.idx, silent_windows, num_windows
)
# optional loudness match (per emitted chunk)
if self.params.loudness_mode != "none" and self.params.combined_loop is not None:
comb = self.params.combined_loop.as_stereo().resample(sr).samples.astype(np.float32, copy=False)
if comb.ndim == 1:
comb = comb[:, None]
if comb.shape[1] == 1:
comb = np.repeat(comb, 2, axis=1)
elif comb.shape[1] > 2:
comb = comb[:, :2]
need = end - start
if comb.shape[0] > 0 and need > 0:
s = start % comb.shape[0]
if s + need <= comb.shape[0]:
ref_slice = comb[s:s + need]
else:
part1 = comb[s:]
part2 = comb[:max(0, need - part1.shape[0])]
ref_slice = np.vstack([part1, part2])
ref = au.Waveform(ref_slice, sr)
tgt = au.Waveform(loop.copy(), sr)
matched, _stats = match_loudness_to_reference(
ref, tgt,
method=self.params.loudness_mode,
headroom_db=self.params.headroom_db,
)
loop = matched.samples
emit_rms_after = float(np.sqrt(np.mean(loop ** 2))) if loop.size > 0 else 0
self._log.debug(
"Chunk %d loudness matched: rms %.4f -> %.4f, gain=%.2fdB",
self.idx, emit_rms_before, emit_rms_after, _stats.get("applied_gain_db", 0)
)
# Use raw bytes instead of base64
audio_bytes, total_samples, channels = wav_bytes_raw(loop, int(self.params.target_sr))
meta = {
"bpm": float(self.params.bpm),
"bars": int(self.params.bars_per_chunk),
"beats_per_bar": int(self.params.beats_per_bar),
"sample_rate": int(self.params.target_sr),
"channels": int(channels),
"total_samples": int(total_samples),
"seconds_per_bar": self._bar_clock.seconds_per_bar(),
"loop_duration_seconds": self.params.bars_per_chunk * self._bar_clock.seconds_per_bar(),
"guidance_weight": float(self.params.guidance_weight),
"temperature": float(self.params.temperature),
"topk": int(self.params.topk),
}
chunk = JamChunk(index=self.idx, audio_bytes=audio_bytes, metadata=meta)
with self._cv:
self._outbox[self.idx] = chunk
self._cv.notify_all()
self.idx += 1
# Apply any queued reseed/splice exactly at this chunk boundary so the audible
# stream remains seamless (changes affect *next* generated audio).
self._apply_pending_at_boundary()
# ---------- main loop ----------
def run(self):
while not self._stop_event.is_set():
if not self._should_generate_next_chunk():
self._emit_ready()
time.sleep(0.01)
continue
# snapshot style vector (with optional glide)
with self._lock:
target = self.params.style_vec
if target is None:
style_to_use = None
else:
if self._style_vec is None:
self._style_vec = np.array(target, dtype=np.float32, copy=True)
else:
ramp = float(self.params.style_ramp_seconds or 0.0)
step = 1.0 if ramp <= 0.0 else min(1.0, self._chunk_secs / ramp)
self._style_vec += step * (target.astype(np.float32, copy=False) - self._style_vec)
style_to_use = self._style_vec
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_to_use)
self._append_model_chunk_and_spool(wav)
self._emit_ready()
# flush resampler if active
if self._rs is not None:
tail = self._rs.process(np.zeros((0, 2), np.float32), final=True)
if tail is not None and tail.size:
self._spool = np.concatenate([self._spool, tail], axis=0)
self._spool_written += int(tail.shape[0])
self._emit_ready()