File size: 22,439 Bytes
48abd32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
"""
TimesFM Forecasting Module

This module provides a simplified and robust interface for TimesFM forecasting,
handling both basic and covariates-enhanced forecasting with consistent quantile output.

Key Features:
- Single forecast method with optional covariates
- Always returns quantiles (never "maybe")
- Simplified logic: IF covariates -> use covariates, ELSE -> use basic
- Consistent return format: (point_forecast, quantile_forecast)
"""

import numpy as np
import pandas as pd
import logging
from typing import List, Dict, Optional, Tuple, Any, Union
import timesfm

logger = logging.getLogger(__name__)


class Forecaster:
    """
    Simplified TimesFM Forecaster with consistent quantile output.
    
    This class provides a single forecast method that handles both basic and
    covariates-enhanced forecasting, always returning quantiles.
    
    Example:
        >>> forecaster = Forecaster(model)
        >>> point_forecast, quantile_forecast = forecaster.forecast(
        ...     inputs=[1,2,3,4,5], 
        ...     use_covariates=True,
        ...     dynamic_numerical_covariates={'feature1': [[1,2,3,4,5]]}
        ... )
    """
    
    def __init__(self, model: timesfm.TimesFm):
        """
        Initialize the Forecaster with a loaded TimesFM model.
        
        Args:
            model: Initialized TimesFM model instance
        """
        self.model = model
        self.capabilities = self._detect_capabilities()
        logger.info(f"Forecaster initialized with capabilities: {list(self.capabilities.keys())}")
    
    def _detect_capabilities(self) -> Dict[str, bool]:
        """Detect available forecasting capabilities of the model."""
        return {
            'basic_forecasting': True,
            'quantile_forecasting': hasattr(self.model, 'experimental_quantile_forecast'),
            'covariates_support': hasattr(self.model, 'forecast_with_covariates')
        }
    
    def forecast(
        self,
        inputs: Union[List[float], List[List[float]]],
        freq: Union[int, List[int]] = 0,
        dynamic_numerical_covariates: Optional[Dict[str, List[List[float]]]] = None,
        dynamic_categorical_covariates: Optional[Dict[str, List[List[str]]]] = None,
        static_numerical_covariates: Optional[Dict[str, List[float]]] = None,
        static_categorical_covariates: Optional[Dict[str, List[str]]] = None,
        use_covariates: bool = False,
        xreg_mode: str = "xreg + timesfm",
        ridge: float = 0.0,
        normalize_xreg_target_per_input: bool = True
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Perform TimesFM forecasting with optional covariates support.
        
        This is the main forecasting method that handles both basic and covariates-enhanced
        forecasting. Quantiles are always returned regardless of covariates usage.
        
        Args:
            inputs: Input time series data
            freq: Frequency indicator(s)
            dynamic_numerical_covariates: Dynamic numerical covariates (if use_covariates=True)
            dynamic_categorical_covariates: Dynamic categorical covariates (if use_covariates=True)
            static_numerical_covariates: Static numerical covariates (if use_covariates=True)
            static_categorical_covariates: Static categorical covariates (if use_covariates=True)
            use_covariates: Whether to use covariates-enhanced forecasting
            xreg_mode: Covariate integration mode ("xreg + timesfm" or "timesfm + xreg")
            ridge: Ridge regression parameter for covariates
            normalize_xreg_target_per_input: Whether to normalize covariates
            
        Returns:
            Tuple of (point_forecast, quantile_forecast) - both are always returned
            
        Raises:
            ValueError: If covariates are requested but not supported
            Exception: If forecasting fails
        """
        logger.info(f"Performing TimesFM forecasting (covariates={use_covariates})...")
        
        # Normalize inputs format
        if isinstance(inputs[0], (int, float)):
            # inputs is a single list of numbers
            inputs_norm = [inputs]
        else:
            # inputs is already a list of lists
            inputs_norm = inputs
            
        if isinstance(freq, int):
            freq_norm = [freq] * len(inputs_norm)
        else:
            freq_norm = freq

        try:
            if use_covariates and any([
                dynamic_numerical_covariates, dynamic_categorical_covariates,
                static_numerical_covariates, static_categorical_covariates
            ]):
                # Validate covariates support
                if not self.capabilities['covariates_support']:
                    raise ValueError("Model does not support covariates forecasting")
                
                # Validate covariates data structure
                self._validate_covariates(
                    inputs_norm, dynamic_numerical_covariates, dynamic_categorical_covariates,
                    static_numerical_covariates, static_categorical_covariates
                )
                
                logger.info(f"Using covariates-enhanced forecasting (mode: {xreg_mode})...")
                logger.info(f"Inputs shape: {[len(x) for x in inputs] if isinstance(inputs[0], list) else len(inputs)}")
                logger.info(f"Inputs type: {type(inputs)}")
                
                # Perform covariates forecasting with original mode
                covariates_result = self.model.forecast_with_covariates(
                    inputs=inputs_norm,
                    dynamic_numerical_covariates=dynamic_numerical_covariates or {},
                    dynamic_categorical_covariates=dynamic_categorical_covariates or {},
                    static_numerical_covariates=static_numerical_covariates or {},
                    static_categorical_covariates=static_categorical_covariates or {},
                    freq=freq_norm,
                    xreg_mode=xreg_mode,
                    ridge=ridge,
                    normalize_xreg_target_per_input=normalize_xreg_target_per_input
                )
                
                # Handle return format from forecast_with_covariates
                if isinstance(covariates_result, tuple) and len(covariates_result) == 2:
                    point_forecast, quantile_forecast = covariates_result
                    point_forecast = np.array(point_forecast)
                    quantile_forecast = np.array(quantile_forecast)
                    
                    logger.info(f"βœ… Covariates forecasting completed.")
                    logger.info(f"  Point forecast shape: {point_forecast.shape}")
                    logger.info(f"  Quantile forecast shape: {quantile_forecast.shape}")
                    
                    # Check if we have proper quantiles (multiple quantiles, not just 1)
                    if quantile_forecast.ndim == 2 and (quantile_forecast.shape[0] == 1 or quantile_forecast.shape[1] == 1):
                        logger.warning("⚠️ Covariates forecasting returned insufficient quantiles, falling back to basic forecast for quantiles")
                        # Get quantiles from basic forecast method
                        _, quantile_forecast = self.model.forecast(inputs=inputs_norm, freq=freq_norm)
                        quantile_forecast = np.array(quantile_forecast)
                        logger.info(f"βœ… Basic forecast quantiles obtained. Shape: {quantile_forecast.shape}")
                    else:
                        logger.info("βœ… Using quantiles from covariates forecasting")
                else:
                    # Fallback: If forecast_with_covariates doesn't return quantiles, get them separately
                    logger.warning("⚠️ Covariates forecasting didn't return quantiles, getting them separately")
                    point_forecast = np.array(covariates_result)
                    _, quantile_forecast = self.model.forecast(inputs=inputs_norm, freq=freq_norm)
                    quantile_forecast = np.array(quantile_forecast)
            
            else:
                logger.info("Using basic forecasting...")
                
                # Perform basic forecasting - this should return (point, quantiles)
                point_forecast, quantile_forecast = self.model.forecast(inputs=inputs_norm, freq=freq_norm)
                point_forecast = np.array(point_forecast)
                quantile_forecast = np.array(quantile_forecast)
                
                logger.info(f"βœ… Basic forecasting completed.")
            
            return point_forecast, quantile_forecast
            
        except Exception as e:
            logger.error(f"❌ Forecasting failed: {str(e)}")
            raise
    
    def _validate_covariates(
        self,
        inputs: List[List[float]],
        dynamic_numerical: Optional[Dict],
        dynamic_categorical: Optional[Dict],
        static_numerical: Optional[Dict],
        static_categorical: Optional[Dict]
    ) -> None:
        """Validate covariates data structure and compatibility."""
        logger.info("Validating covariates data structure...")
        
        # Check that all covariates have the same number of series as inputs
        num_series = len(inputs)
        
        for cov_type, cov_data in [
            ("dynamic_numerical", dynamic_numerical),
            ("dynamic_categorical", dynamic_categorical),
            ("static_numerical", static_numerical),
            ("static_categorical", static_categorical)
        ]:
            if cov_data:
                for name, data in cov_data.items():
                    if isinstance(data[0], (list, np.ndarray)):
                        # Dynamic covariates
                        if len(data) != num_series:
                            raise ValueError(f"Dynamic covariate '{name}' has {len(data)} series, expected {num_series}")
                    else:
                        # Static covariates
                        if len(data) != num_series:
                            raise ValueError(f"Static covariate '{name}' has {len(data)} values, expected {num_series}")
        
        logger.info("βœ… Covariates validation passed")
    

def run_forecast(
    forecaster: 'Forecaster',
    target_inputs: List[List[float]],
    covariates: Optional[Dict[str, Any]] = None,
    use_covariates: bool = False,
    freq: Union[int, List[int]] = 0
) -> Dict[str, Any]:
    """
    Centralized forecasting function that handles both basic and covariates-enhanced forecasting.
    
    This function implements the logic to decide whether to run forecast_with_covariates
    or the basic forecast, including fallback mechanisms and proper error handling.
    
    Args:
        forecaster: Initialized Forecaster instance
        target_inputs: Input time series data
        covariates: Dictionary containing covariate data (if use_covariates=True)
        use_covariates: Whether to use covariates-enhanced forecasting
        freq: Frequency indicator(s)
        
    Returns:
        Dictionary containing forecast results with keys:
        - 'enhanced_forecast' or 'point_forecast': Main forecast array
        - 'quantile_forecast': Quantile forecast array (always present)
        - 'method': String indicating the forecasting method used
        - 'metadata': Additional forecast metadata
        
    Raises:
        Exception: If forecasting fails
    """
    logger.info(f"πŸš€ Running centralized forecast (covariates={use_covariates})...")
    
    try:
        results = {}
        
        if use_covariates and covariates:
            logger.info("Using covariates-enhanced forecasting...")
            
            # Extract covariate data
            dynamic_numerical = covariates.get('dynamic_numerical_covariates')
            dynamic_categorical = covariates.get('dynamic_categorical_covariates')
            static_numerical = covariates.get('static_numerical_covariates')
            static_categorical = covariates.get('static_categorical_covariates')
            
            # Perform covariates forecasting
            point_forecast, quantile_forecast = forecaster.forecast(
                inputs=target_inputs,
                freq=freq,
                dynamic_numerical_covariates=dynamic_numerical,
                dynamic_categorical_covariates=dynamic_categorical,
                static_numerical_covariates=static_numerical,
                static_categorical_covariates=static_categorical,
                use_covariates=True
            )
            
            results['point_forecast'] = point_forecast
            results['method'] = 'covariates_enhanced'
            
        else:
            logger.info("Using basic forecasting...")
            
            # Perform basic forecasting
            point_forecast, quantile_forecast = forecaster.forecast(
                inputs=target_inputs,
                freq=freq,
                use_covariates=False
            )
            
            results['point_forecast'] = point_forecast
            results['method'] = 'basic_timesfm'
        
        # Check for NaN values before returning
        if np.any(np.isnan(point_forecast)):
            logger.error(f"❌ NaN values detected in point_forecast: {np.isnan(point_forecast).sum()} out of {point_forecast.size}")
            logger.error(f"Point forecast values: {point_forecast}")
            raise ValueError(f"Forecasting produced NaN values in point forecast. This may be due to insufficient data or model issues.")
        
        if np.any(np.isnan(quantile_forecast)):
            logger.error(f"❌ NaN values detected in quantile_forecast: {np.isnan(quantile_forecast).sum()} out of {quantile_forecast.size}")
            logger.error(f"Quantile forecast shape: {quantile_forecast.shape}")
            raise ValueError(f"Forecasting produced NaN values in quantile forecast. This may be due to insufficient data or model issues.")
        
        # Quantiles are always available
        results['quantile_forecast'] = quantile_forecast
        logger.info(f"βœ… Quantile forecast obtained. Shape: {quantile_forecast.shape}")
        
        # Add metadata
        results['metadata'] = {
            'input_series_count': len(target_inputs),
            'forecast_length': results.get('point_forecast').shape[-1],
            'covariates_used': use_covariates and covariates is not None,
            'quantiles_available': True  # Always true now
        }
        
        logger.info(f"βœ… Centralized forecast completed successfully!")
        logger.info(f"   Method: {results['method']}")
        logger.info(f"   Forecast shape: {results['metadata']['forecast_length']}")
        logger.info(f"   Quantiles: Yes (shape: {quantile_forecast.shape})")
        logger.info(f"   Point forecast range: {np.min(point_forecast):.2f} to {np.max(point_forecast):.2f}")
        
        return results
        
    except Exception as e:
        logger.error(f"❌ Centralized forecasting failed: {str(e)}")
        raise


def process_quantile_bands(
    quantile_forecast: np.ndarray,
    selected_indices: List[int] = None
) -> Dict[str, Any]:
    """
    Centralized function to process quantile forecasts into quantile bands for visualization.
    
    This function contains the logic for sorting quantiles and creating the quantile band
    dictionary, as used in both the webapp and notebook.
    
    Args:
        quantile_forecast: Array of quantile forecasts with shape (horizon, num_quantiles) or (num_quantiles, horizon)
        selected_indices: List of quantile indices to use for bands (default: [1, 3, 5, 7, 9])
        
    Returns:
        Dictionary of quantile bands ready for visualization with keys:
        - 'quantile_band_0_lower', 'quantile_band_0_upper', 'quantile_band_0_label'
        - 'quantile_band_1_lower', 'quantile_band_1_upper', 'quantile_band_1_label'
        - etc.
    """
    logger.info("πŸ”„ Processing quantile bands...")
    logger.info(f"Input quantile_forecast type: {type(quantile_forecast)}")
    logger.info(f"Input quantile_forecast shape: {quantile_forecast.shape if hasattr(quantile_forecast, 'shape') else 'N/A'}")

    # logger.info(f"!!!!!!!!!!!!! selected_indices: {selected_indices}")
    # logger.info(f"!!!!!!!!!!!!! quantile_forecast.shape: {quantile_forecast.shape}")
    
    if quantile_forecast is None:
        logger.warning("No quantile forecast provided")
        return {}
    
    try:
        # logger.info(f"!!!!!!!!!!!!! selected_indices: {selected_indices}")
        # logger.info(f"!!!!!!!!!!!!! quantile_forecast.shape: {quantile_forecast.shape}")

        # Handle quantile indices - only use default if explicitly None (not empty list)
        if selected_indices is None:
            # This means no quantile selection was made, use default
            selected_indices = [1, 3, 5, 7, 9]  # Q10, Q30, Q50, Q70, Q90
        elif selected_indices == []:
            # This means user explicitly selected no quantiles, return empty
            logger.info("No quantiles selected by user - returning empty quantile bands")
            return {}
        
        # Handle different array dimensions
        if quantile_forecast.ndim == 3:
            # Shape is (1, horizon, num_quantiles) - squeeze out first dimension
            q_mat = quantile_forecast.squeeze(0)
            logger.info(f"3D array detected, squeezed to shape: {q_mat.shape}")
        elif quantile_forecast.ndim == 1:
            # Shape is (horizon,) - reshape to (1, horizon)
            q_mat = quantile_forecast.reshape(1, -1)
            logger.info(f"1D array detected, reshaped to: {q_mat.shape}")
        else:
            # Shape is 2D - determine if we need to transpose
            # For quantiles, we expect (horizon, num_quantiles) format
            # If we have more horizon than quantiles, it's likely (horizon, num_quantiles) and should be kept as-is
            if quantile_forecast.shape[0] > quantile_forecast.shape[1]:
                # Shape is (horizon, num_quantiles) - keep as is
                q_mat = quantile_forecast
                logger.info(f"2D array kept as is (horizon, quantiles): {q_mat.shape}")
            else:
                # Shape is (num_quantiles, horizon) - transpose to (horizon, num_quantiles)
                q_mat = quantile_forecast.T
                logger.info(f"2D array transposed from {quantile_forecast.shape} to {q_mat.shape}")
        
        horizon_len, num_quantiles = q_mat.shape
        logger.info(f"πŸ“Š Available quantiles: {num_quantiles} (indices 0-{num_quantiles-1})")
        logger.info(f"πŸ“Š Note: Index 0 is legacy mean forecast, using indices 1-{num_quantiles-1} for actual quantiles")
        
        # Check if we have enough quantiles for band creation (need at least 3 total: 0=legacy, 1=Q10, 2=Q20)
        if num_quantiles < 3:
            logger.warning(f"Not enough quantiles for band creation. Have {num_quantiles}, need at least 3")
            return {}
        
        # Filter selected indices to valid range (skip index 0)
        valid_indices = [idx for idx in selected_indices if 1 <= idx < num_quantiles]  # Skip index 0
        if not valid_indices:
            logger.warning("No valid quantile indices selected (after skipping legacy index 0)")
            return {}
        
        # logger.info(f"!!!!!!!!!!!!! valid_indices: {valid_indices}")
        
        # Sort quantiles by their median magnitude to ensure proper ordering
        quantile_medians = np.median(q_mat, axis=0)
        sorted_indices = np.argsort(quantile_medians)
        
        # Create quantile bands from selected indices
        quantile_bands = {}
        band_count = 0
        
        for i in range(len(valid_indices) - 1):
            lower_idx = valid_indices[i]
            upper_idx = valid_indices[i + 1]
            
            # Get the sorted indices for these quantiles
            lower_sorted_idx = sorted_indices[lower_idx]
            upper_sorted_idx = sorted_indices[upper_idx]
            
            # Extract quantile values
            lower_quantile = q_mat[:, lower_sorted_idx]
            upper_quantile = q_mat[:, upper_sorted_idx]
            
            # Create band labels
            lower_pct = idx_to_percent(lower_idx, num_quantiles)
            upper_pct = idx_to_percent(upper_idx, num_quantiles)
            band_label = f"Q{lower_pct:02d}–Q{upper_pct:02d}"
            
            # Store band data
            quantile_bands[f'quantile_band_{band_count}_lower'] = lower_quantile.tolist()
            quantile_bands[f'quantile_band_{band_count}_upper'] = upper_quantile.tolist()
            quantile_bands[f'quantile_band_{band_count}_label'] = band_label
            
            logger.info(f"   Band {band_count}: {band_label} - Lower: {len(lower_quantile)}, Upper: {len(upper_quantile)}")
            band_count += 1
        
        logger.info(f"βœ… Created {band_count} quantile bands from indices: {valid_indices}")
        for i in range(band_count):
            label = quantile_bands[f'quantile_band_{i}_label']
            logger.info(f"   Band {i}: {label}")
        
        return quantile_bands
        
    except Exception as e:
        logger.error(f"❌ Quantile band processing failed: {str(e)}")
        raise


def idx_to_percent(idx: int, num_quantiles: int) -> int:
    """
    Convert quantile index to percentage for labeling.
    
    Note: Index 0 is legacy mean forecast and should be skipped.
    Actual quantiles start at index 1: 1->Q10, 2->Q20, ..., 9->Q90
    
    Args:
        idx: Quantile index (1-based for actual quantiles, 0 is legacy)
        num_quantiles: Total number of quantiles (including legacy index 0)
        
    Returns:
        Percentage value (e.g., 10 for Q10, 90 for Q90)
    """
    if num_quantiles == 10:
        # Special case for 10 quantiles: 1->Q10, 2->Q20, ..., 9->Q90
        # Index 0 is legacy mean, so actual quantiles start at index 1
        return idx * 10
    else:
        # General case: distribute evenly, accounting for skipped index 0
        # If we have 10 total quantiles (0-9), actual quantiles are 1-9
        actual_quantiles = num_quantiles - 1  # Subtract 1 for legacy index 0
        return int(100 * idx / actual_quantiles)