import time import json import logging import os from typing import Any, Dict, List, Optional, Tuple, Sequence import numpy as np import torch from backends_base import ChatBackend, ImagesBackend from config import settings logger = logging.getLogger(__name__) # ---------- helpers ---------- def _parse_series(series: Any) -> np.ndarray: """ Accepts: list[float|int], list[dict{'y'|'value'}], or dict with 'values'/'y'. Returns: 1D float32 numpy array. """ if series is None: raise ValueError("series is required") if isinstance(series, dict): series = series.get("values") or series.get("y") vals: List[float] = [] if isinstance(series, (list, tuple)): if series and isinstance(series[0], dict): for item in series: if "y" in item: vals.append(float(item["y"])) elif "value" in item: vals.append(float(item["value"])) else: vals = [float(x) for x in series] else: raise ValueError("series must be a list/tuple or dict with 'values'/'y'") if not vals: raise ValueError("series is empty") return np.asarray(vals, dtype=np.float32) def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]: s = s.strip() if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")): try: obj = json.loads(s) return obj if isinstance(obj, dict) else None except Exception: pass if "```" in s: parts = s.split("```") for i in range(1, len(parts), 2): block = parts[i] if block.lstrip().lower().startswith("json"): block = block.split("\n", 1)[-1] try: obj = json.loads(block.strip()) return obj if isinstance(obj, dict) else None except Exception: continue return None def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]: msgs = payload.get("messages") if not isinstance(msgs, list): return payload for m in reversed(msgs): if not isinstance(m, dict) or m.get("role") != "user": continue content = m.get("content") texts: List[str] = [] if isinstance(content, list): texts = [ p.get("text") for p in content if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str) ] elif isinstance(content, str): texts = [content] for t in reversed(texts): obj = _extract_json_from_text(t) if isinstance(obj, dict): return {**payload, **obj} break return payload # ---------- backend ---------- class TimesFMBackend(ChatBackend): """ TimesFM 2.5 backend. Input JSON can be in top-level keys, in CloudEvents .data, or embedded in last user message. Keys: series: list[float|int|{y|value}] OR list of such lists for batch horizon: int (>0) Optional: quantiles: bool (default True) -> include quantile forecasts max_context, max_horizon: ints to override defaults """ def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): # HF id for bookkeeping only self.model_id = model_id or "google/timesfm-2.5-200m-pytorch" self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self._model = None # lazy def _ensure_model(self) -> None: if self._model is not None: return try: import os import timesfm # 2.5 API hf_token = getattr(settings, "HF_TOKEN", None) or os.environ.get("HF_TOKEN") cache_dir = getattr(settings, "TIMESFM_CACHE_DIR", None) model = timesfm.TimesFM_2p5_200M_torch.from_pretrained( self.model_id, token=hf_token, cache_dir=cache_dir, local_files_only=False, ) try: # .model holds the underlying nn.Module; fall back to instance if absent. target = getattr(model, "model", model) target.to(self.device) # type: ignore[arg-type] except Exception: pass cfg = timesfm.ForecastConfig( max_context=1024, max_horizon=256, normalize_inputs=True, use_continuous_quantile_head=True, force_flip_invariance=True, infer_is_positive=True, fix_quantile_crossing=True, ) model.compile(cfg) self._model = model logger.info("TimesFM 2.5 model loaded on %s", self.device) except Exception as e: logger.exception("TimesFM 2.5 init failed") raise RuntimeError(f"timesfm 2.5 init failed: {e}") from e def _prepare_inputs(self, payload: Dict[str, Any]) -> Tuple[List[np.ndarray], int, bool, Dict[str, int]]: # unwrap CloudEvents and nested keys if isinstance(payload.get("data"), dict): payload = {**payload, **payload["data"]} if isinstance(payload.get("timeseries"), dict): payload = {**payload, **payload["timeseries"]} # merge JSON in last user message payload = _merge_openai_message_json(payload) horizon = int(payload.get("horizon", 0)) if horizon <= 0: raise ValueError("horizon must be a positive integer") quantiles = bool(payload.get("quantiles", True)) mc = int(payload.get("max_context", 1024)) mh = int(payload.get("max_horizon", 256)) series = payload.get("series") inputs: List[np.ndarray] if isinstance(series, list) and series and isinstance(series[0], (list, tuple, dict)): # batch input inputs = [_parse_series(s) for s in series] else: # single series -> batch of 1 inputs = [_parse_series(series)] return inputs, horizon, quantiles, {"max_context": mc, "max_horizon": mh} async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]: inputs, horizon, want_quantiles, cfg_overrides = self._prepare_inputs(payload) self._ensure_model() # if user wants larger limits, recompile once try: import timesfm if cfg_overrides["max_context"] != 1024 or cfg_overrides["max_horizon"] != 256: cfg = timesfm.ForecastConfig( max_context=cfg_overrides["max_context"], max_horizon=cfg_overrides["max_horizon"], normalize_inputs=True, use_continuous_quantile_head=want_quantiles, force_flip_invariance=True, infer_is_positive=True, fix_quantile_crossing=True, ) self._model.compile(cfg) except Exception: pass try: point, quant = self._model.forecast(horizon=horizon, inputs=inputs) point_list = [row.astype(float).tolist() for row in point] # shape (B, H) quant_list = None if want_quantiles and quant is not None: # shape (B, H, 10): mean, q10..q90 quant_list = [[row[h].astype(float).tolist() for h in range(row.shape[0])] for row in quant] except Exception as e: logger.exception("TimesFM 2.5 forecast failed") raise RuntimeError(f"forecast failed: {e}") from e # If single-series input, unwrap batch dim for convenience single = len(inputs) == 1 return { "model": self.model_id, "horizon": horizon, "forecast": point_list[0] if single else point_list, "quantiles": (quant_list[0] if single else quant_list) if want_quantiles else None, "backend": "timesfm-2.5", } async def stream(self, request: Dict[str, Any]): rid = f"chatcmpl-timesfm-{int(time.time())}" now = int(time.time()) try: result = await self.forecast(dict(request) if isinstance(request, dict) else {}) content = json.dumps(result, separators=(",", ":"), ensure_ascii=False) except Exception as e: content = json.dumps({"error": str(e)}, separators=(",", ":"), ensure_ascii=False) yield { "id": rid, "object": "chat.completion.chunk", "created": now, "model": self.model_id, "choices": [ {"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"} ], } class StubImagesBackend(ImagesBackend): async def generate_b64(self, request: Dict[str, Any]) -> str: logger.warning("Image generation not supported in TimesFM backend.") return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="