sapheneia-timesfm / src /interactive_visualization.py
rkovashikawa's picture
Initial Hugging Face Spaces deployment
48abd32
"""
Interactive Visualization Module for TimesFM Forecasting using Plotly
This module provides comprehensive interactive visualization capabilities for TimesFM forecasting,
including professional-grade plots with prediction intervals, covariates displays,
and publication-ready styling using Plotly for enhanced interactivity.
Key Features:
- Interactive forecast visualizations with seamless connections
- Prediction intervals with customizable confidence levels
- Covariates subplots integration
- Sapheneia-style professional formatting
- Interactive zoom, pan, and hover capabilities
- Export capabilities for presentations and publications
- Responsive design for web applications
"""
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from plotly.offline import plot
from datetime import datetime
from typing import List, Dict, Optional, Union
import logging
import json
logger = logging.getLogger(__name__)
class InteractiveVisualizer:
"""
Interactive visualization class for TimesFM forecasting results using Plotly.
This class provides methods to create interactive, publication-quality visualizations
of forecasting results, including prediction intervals, covariates analysis,
and comprehensive time series plots with enhanced user interaction.
Example:
>>> viz = InteractiveVisualizer()
>>> fig = viz.plot_forecast_with_intervals(
... historical_data=historical,
... forecast=point_forecast,
... intervals=prediction_intervals,
... title="Bitcoin Price Forecast"
... )
>>> fig.show()
"""
def __init__(self, style: str = "professional", theme: str = "plotly_white"):
"""
Initialize the InteractiveVisualizer with specified styling.
Args:
style: Visualization style ("professional", "minimal", "presentation")
theme: Plotly theme ("plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white")
"""
self.style = style
self.theme = theme
self._setup_style()
logger.info(f"InteractiveVisualizer initialized with '{style}' style and '{theme}' theme")
def _setup_style(self) -> None:
"""Set up the visualization style and parameters."""
if self.style == "professional":
# Sapheneia professional style
self.colors = {
'historical': '#1f77b4',
'forecast': '#d62728',
'actual': '#2ca02c',
'interval_80': 'rgba(255, 179, 102, 0.3)',
'interval_50': 'rgba(255, 127, 14, 0.5)',
'grid': '#e0e0e0',
'background': '#fafafa',
'text': '#2c3e50',
'axis': '#34495e'
}
self.layout_config = {
'width': 1200,
'height': 800,
'margin': {'l': 60, 'r': 60, 't': 80, 'b': 60}
}
elif self.style == "minimal":
# Clean minimal style
self.colors = {
'historical': '#2E86AB',
'forecast': '#A23B72',
'actual': '#F18F01',
'interval_80': 'rgba(199, 62, 29, 0.3)',
'interval_50': 'rgba(241, 143, 1, 0.5)',
'grid': '#f0f0f0',
'background': 'white',
'text': '#2c3e50',
'axis': '#34495e'
}
self.layout_config = {
'width': 1000,
'height': 700,
'margin': {'l': 50, 'r': 50, 't': 60, 'b': 50}
}
else: # presentation
# High contrast for presentations
self.colors = {
'historical': '#003f5c',
'forecast': '#ff6361',
'actual': '#58508d',
'interval_80': 'rgba(255, 166, 0, 0.3)',
'interval_50': 'rgba(255, 99, 97, 0.5)',
'grid': '#e8e8e8',
'background': 'white',
'text': '#2c3e50',
'axis': '#34495e'
}
self.layout_config = {
'width': 1400,
'height': 900,
'margin': {'l': 70, 'r': 70, 't': 100, 'b': 70}
}
def _create_base_layout(self, title: str, x_title: str = "Time", y_title: str = "Value") -> Dict:
"""Create base layout configuration for plots."""
return {
'title': {
'text': title,
'x': 0.5,
'xanchor': 'center',
'font': {'size': 18, 'color': self.colors['text']}
},
'xaxis': {
'title': {'text': x_title, 'font': {'size': 14, 'color': self.colors['axis']}},
'tickfont': {'size': 12, 'color': self.colors['axis']},
'gridcolor': self.colors['grid'],
'showgrid': True,
'zeroline': False
},
'yaxis': {
'title': {'text': y_title, 'font': {'size': 14, 'color': self.colors['axis']}},
'tickfont': {'size': 12, 'color': self.colors['axis']},
'gridcolor': self.colors['grid'],
'showgrid': True,
'zeroline': False
},
'plot_bgcolor': self.colors['background'],
'paper_bgcolor': 'white',
'font': {'family': 'Arial, sans-serif', 'color': self.colors['text']},
'showlegend': True,
'legend': {
'x': 0.02,
'y': 0.98,
'yanchor': 'top',
'bgcolor': 'rgba(255, 255, 255, 0.8)',
'bordercolor': 'rgba(0, 0, 0, 0.2)',
'borderwidth': 1
},
'hovermode': 'x unified',
**self.layout_config
}
def plot_forecast_with_intervals(
self,
historical_data: Union[List[float], np.ndarray],
forecast: Union[List[float], np.ndarray],
intervals: Optional[Dict[str, np.ndarray]] = None,
actual_future: Optional[Union[List[float], np.ndarray]] = None,
dates_historical: Optional[List[Union[str, datetime]]] = None,
dates_future: Optional[List[Union[str, datetime]]] = None,
title: str = "TimesFM Forecast with Prediction Intervals",
target_name: str = "Value",
save_path: Optional[str] = None,
show_figure: bool = True,
context_len: Optional[int] = None,
horizon_len: Optional[int] = None,
y_axis_padding: float = 0.1
) -> go.Figure:
"""
Create an interactive forecast visualization with prediction intervals.
Args:
historical_data: Historical time series data
forecast: Point forecast values
intervals: Dictionary containing prediction intervals
actual_future: Optional actual future values for comparison
dates_historical: Optional dates for historical data
dates_future: Optional dates for forecast period
title: Plot title
target_name: Name of the target variable
save_path: Optional path to save the plot (HTML format)
show_figure: Whether to display the figure
context_len: Length of context window for default view focus
horizon_len: Length of horizon for default view focus
y_axis_padding: Padding factor for focused y-axis range (0.1 = 10% padding)
Returns:
Plotly Figure object
"""
logger.info(f"Creating interactive forecast visualization: {title}")
# Convert to numpy arrays
if actual_future is not None:
actual_future = np.array(actual_future)
# Setup time axis
if dates_historical is None:
historical_x = np.arange(len(historical_data))
else:
historical_x = pd.to_datetime(dates_historical)
future_x = np.arange(
len(historical_data), len(historical_data) + len(forecast)
) if dates_future is None else pd.to_datetime(dates_future)
# Calculate default view range (context + horizon)
if context_len is not None and horizon_len is not None:
if dates_historical is not None:
start_date = historical_x[0]
end_date = future_x[min(horizon_len - 1, len(future_x) - 1)] if len(future_x) > 0 else historical_x[-1]
default_x_range = [start_date, end_date]
else:
start_idx = 0
end_idx = len(historical_x) + len(forecast)
default_x_range = [start_idx, end_idx]
else:
# No specific focus, show all data
if dates_historical is not None:
start_date = historical_x[0]
end_date = future_x[-1] if len(future_x) > 0 else historical_x[-1]
default_x_range = [start_date, end_date]
else:
start_idx = 0
end_idx = len(historical_x) + len(forecast)
default_x_range = [start_idx, end_idx]
# Calculate focused y-axis range for better visibility
if context_len is not None and horizon_len is not None:
# Focus y-axis on the context + horizon period data
if context_len < len(historical_data):
# Get the data range for context + horizon
context_data = historical_data[-context_len:]
focused_data = np.concatenate([context_data, forecast])
# Include prediction intervals in y-axis calculation
if intervals:
# Collect all interval data for y-axis range calculation
interval_data = []
# Add 50th percentile if available
if 'lower_50' in intervals and 'upper_50' in intervals:
interval_data.extend(intervals['lower_50'])
interval_data.extend(intervals['upper_50'])
# Add 80th percentile if available
if 'lower_80' in intervals and 'upper_80' in intervals:
interval_data.extend(intervals['lower_80'])
interval_data.extend(intervals['upper_80'])
# Add other confidence levels
for key in intervals.keys():
if key.startswith('lower_') and key not in ['lower_50', 'lower_80']:
interval_data.extend(intervals[key])
elif key.startswith('upper_') and key not in ['upper_50', 'upper_80']:
interval_data.extend(intervals[key])
# Add quantile bands
for key in intervals.keys():
if key.startswith('quantile_band_') and key.endswith('_lower'):
interval_data.extend(intervals[key])
elif key.startswith('quantile_band_') and key.endswith('_upper'):
interval_data.extend(intervals[key])
# Include interval data in range calculation
if interval_data:
interval_data = np.array(interval_data)
all_focused_data = np.concatenate([focused_data, interval_data])
else:
all_focused_data = focused_data
else:
all_focused_data = focused_data
# Calculate y-axis range including intervals
data_min = np.min(all_focused_data)
data_max = np.max(all_focused_data)
data_range = data_max - data_min
padding = data_range * y_axis_padding
default_y_range = [data_min - padding, data_max + padding]
else:
# If context_len >= historical_data length, use all data
all_data = np.concatenate([historical_x, forecast])
# Include prediction intervals in y-axis calculation
if intervals:
interval_data = []
# Add 50th percentile if available
if 'lower_50' in intervals and 'upper_50' in intervals:
interval_data.extend(intervals['lower_50'])
interval_data.extend(intervals['upper_50'])
# Add 80th percentile if available
if 'lower_80' in intervals and 'upper_80' in intervals:
interval_data.extend(intervals['lower_80'])
interval_data.extend(intervals['upper_80'])
# Add other confidence levels
for key in intervals.keys():
if key.startswith('lower_') and key not in ['lower_50', 'lower_80']:
interval_data.extend(intervals[key])
elif key.startswith('upper_') and key not in ['upper_50', 'upper_80']:
interval_data.extend(intervals[key])
# Add quantile bands
for key in intervals.keys():
if key.startswith('quantile_band_') and key.endswith('_lower'):
interval_data.extend(intervals[key])
elif key.startswith('quantile_band_') and key.endswith('_upper'):
interval_data.extend(intervals[key])
# Include interval data in range calculation
if interval_data:
interval_data = np.array(interval_data)
all_data = np.concatenate([all_data, interval_data])
data_min = np.min(all_data)
data_max = np.max(all_data)
data_range = data_max - data_min
padding = data_range * y_axis_padding
default_y_range = [data_min - padding, data_max + padding]
else:
# No focused y-axis, let Plotly auto-scale
default_y_range = None
# Create figure
fig = go.Figure()
# Debug logging for historical data
print(f"DEBUG: Historical data length: {len(historical_data)}")
print(f"DEBUG: Historical data type: {type(historical_data)}")
print(f"DEBUG: Historical data first 5: {historical_data[:5] if len(historical_data) > 0 else 'Empty'}")
print(f"DEBUG: Historical data last 5: {historical_data[-5:] if len(historical_data) > 0 else 'Empty'}")
print(f"DEBUG: Historical x length: {len(historical_x)}")
print(f"DEBUG: Historical x first 5: {historical_x[:5] if len(historical_x) > 0 else 'Empty'}")
# Validate data before plotting
if len(historical_data) == 0:
print("ERROR: Historical data is empty!")
return None
if len(historical_x) == 0:
print("ERROR: Historical x-axis data is empty!")
return None
if len(historical_data) != len(historical_x):
print(f"ERROR: Mismatch between historical data ({len(historical_data)}) and x-axis ({len(historical_x)}) lengths!")
return None
# Plot historical data
print(f"DEBUG: About to plot historical data with {len(historical_data)} points")
print(f"DEBUG: Historical data sample: {historical_data[:3]}...{historical_data[-3:]}")
print(f"DEBUG: Historical x sample: {historical_x[:3]}...{historical_x[-3:]}")
historical_trace = go.Scatter(
x=historical_x,
y=historical_data,
mode='lines',
name='Historical Data',
line=dict(color=self.colors['historical'], width=3),
hovertemplate='<b>Historical</b><br>Time: %{x}<br>Value: %{y:.2f}<extra></extra>'
)
print(f"DEBUG: Historical trace created: {historical_trace}")
fig.add_trace(historical_trace)
print(f"DEBUG: Historical trace added to figure. Figure has {len(fig.data)} traces")
# Create seamless connection for forecast
if dates_historical is None:
connection_x = [len(historical_x) - 1] + list(future_x)
else:
connection_x = [historical_x[-1]] + list(future_x)
# Plot quantile intervals if available
if intervals:
# Handle different types of intervals
if 'lower_80' in intervals and 'upper_80' in intervals:
# Traditional confidence intervals
interval_lower = [historical_data[-1]] + list(intervals['lower_80'])
interval_upper = [historical_data[-1]] + list(intervals['upper_80'])
fig.add_trace(go.Scatter(
x=connection_x,
y=interval_upper,
mode='lines',
line=dict(width=0),
showlegend=False,
hoverinfo='skip'
))
fig.add_trace(go.Scatter(
x=connection_x,
y=interval_lower,
mode='lines',
line=dict(width=0),
fill='tonexty',
fillcolor=self.colors['interval_80'],
name='80% Prediction Interval',
hovertemplate='<b>80% Interval</b><br>Time: %{x}<br>Upper: %{y:.2f}<extra></extra>'
))
# Add 50% interval if available
if 'lower_50' in intervals and 'upper_50' in intervals:
interval_lower_50 = [historical_data[-1]] + list(intervals['lower_50'])
interval_upper_50 = [historical_data[-1]] + list(intervals['upper_50'])
fig.add_trace(go.Scatter(
x=connection_x,
y=interval_upper_50,
mode='lines',
line=dict(width=0),
showlegend=False,
hoverinfo='skip'
))
fig.add_trace(go.Scatter(
x=connection_x,
y=interval_lower_50,
mode='lines',
line=dict(width=0),
fill='tonexty',
fillcolor=self.colors['interval_50'],
name='50% Prediction Interval',
hovertemplate='<b>50% Interval</b><br>Time: %{x}<br>Upper: %{y:.2f}<extra></extra>'
))
else:
# Check for generic confidence levels
conf_levels = []
for key in intervals.keys():
if key.startswith('lower_'):
conf_level = key.split('_')[1]
if f'upper_{conf_level}' in intervals:
conf_levels.append(int(conf_level))
conf_levels.sort(reverse=True) # Largest first for layering
for conf_level in conf_levels:
lower_key = f'lower_{conf_level}'
upper_key = f'upper_{conf_level}'
if lower_key in intervals and upper_key in intervals:
# Create seamless intervals
interval_lower = [historical_data[-1]] + list(intervals[lower_key])
interval_upper = [historical_data[-1]] + list(intervals[upper_key])
alpha = 0.3 if conf_level == max(conf_levels) else 0.5
color = self.colors['interval_80'] if conf_level >= 80 else self.colors['interval_50']
fig.add_trace(go.Scatter(
x=connection_x,
y=interval_upper,
mode='lines',
line=dict(width=0),
showlegend=False,
hoverinfo='skip'
))
fig.add_trace(go.Scatter(
x=connection_x,
y=interval_lower,
mode='lines',
line=dict(width=0),
fill='tonexty',
fillcolor=color,
name=f'{conf_level}% Prediction Interval',
hovertemplate=f'<b>{conf_level}% Interval</b><br>Time: %{{x}}<br>Upper: %{{y:.2f}}<extra></extra>'
))
# Handle quantile bands (new format)
quantile_bands = {}
for key in intervals.keys():
if key.startswith('quantile_band_') and key.endswith('_lower'):
band_name = key.replace('quantile_band_', '').replace('_lower', '')
upper_key = f'quantile_band_{band_name}_upper'
if upper_key in intervals:
quantile_bands[band_name] = {
'lower': intervals[key],
'upper': intervals[upper_key]
}
if quantile_bands:
# Define colors for different bands
band_colors = ['rgba(255, 153, 153, 0.3)', 'rgba(153, 204, 255, 0.3)',
'rgba(153, 255, 153, 0.3)', 'rgba(255, 204, 153, 0.3)',
'rgba(204, 153, 255, 0.3)', 'rgba(255, 255, 153, 0.3)']
for i, (band_name, band_data) in enumerate(sorted(quantile_bands.items())):
color = band_colors[i % len(band_colors)]
interval_lower = [historical_data[-1]] + list(band_data['lower'])
interval_upper = [historical_data[-1]] + list(band_data['upper'])
label_key = f'quantile_band_{band_name}_label'
label_text = intervals.get(label_key, f'Quantile Band {int(band_name)+1}')
fig.add_trace(go.Scatter(
x=connection_x,
y=interval_upper,
mode='lines',
line=dict(width=0),
showlegend=False,
hoverinfo='skip'
))
fig.add_trace(go.Scatter(
x=connection_x,
y=interval_lower,
mode='lines',
line=dict(width=0),
fill='tonexty',
fillcolor=color,
name=label_text,
hovertemplate=f'<b>{label_text}</b><br>Upper: %{{y:.2f}}<extra></extra>'
))
fig.add_trace(go.Scatter(
x=future_x,
y=forecast,
mode='lines',
name='Point Forecast',
line=dict(color=self.colors['forecast'], width=3, dash='dash'),
hovertemplate='<b>Forecast</b><br>Time: %{x}<br>Value: %{y:.2f}<extra></extra>',
legendgroup='forecast'
))
# 2) a 2-point seamless bridge with no hover/legend
fig.add_trace(go.Scatter(
x=[historical_x[-1], future_x[0]],
y=[historical_data[-1], forecast[0]],
mode='lines',
line=dict(color=self.colors['forecast'], width=3, dash='dash'),
hoverinfo='skip',
showlegend=False,
legendgroup='forecast'
))
# Plot actual future data if available
if actual_future is not None:
print(f"DEBUG: Plotting actual future values")
print(f"DEBUG: actual_future length: {len(actual_future)}")
print(f"DEBUG: actual_future sample: {actual_future[:3] if len(actual_future) > 0 else 'Empty'}")
print(f"DEBUG: dates_future length: {len(dates_future) if dates_future else 'None'}")
print(f"DEBUG: dates_future sample: {dates_future[:3] if dates_future and len(dates_future) > 0 else 'Empty'}")
print(f"DEBUG: historical_x last value: {historical_x[-1]}")
print(f"DEBUG: future_x first value: {future_x[0] if len(future_x) > 0 else 'Empty'}")
actual_connection = [historical_x[-1]] + list(actual_future)
fig.add_trace(go.Scatter(
x=connection_x,
y=actual_connection,
mode='lines+markers',
name='Actual Future',
line=dict(color=self.colors['actual'], width=3),
marker=dict(size=8, color=self.colors['actual'],
line=dict(width=2, color='white')),
hovertemplate='<b>Actual Future</b><br>Time: %{x}<br>Value: %{y:.2f}<extra></extra>',
legendgroup='actual'
))
print([historical_x[-1], connection_x[0]])
print([historical_data[-1], actual_connection[0]])
# 2) a 2-point seamless bridge with no hover/legend
fig.add_trace(go.Scatter(
x=[historical_x[-1], connection_x[1]],
y=[historical_data[-1], actual_connection[1]],
mode='lines',
line=dict(color=self.colors['actual'], width=3),
marker=dict(size=8, color=self.colors['actual'],
line=dict(width=2, color='white')),
hoverinfo='skip',
showlegend=False,
legendgroup='actual'
))
# Add forecast start line
fig.add_vline(
x=pd.to_datetime(historical_x[-1]).to_pydatetime(), # or .isoformat()
line_dash="dot", line_color="gray", line_width=1
)
fig.add_annotation(
x=pd.to_datetime(historical_x[-1]).to_pydatetime(),
y=1, # top of plotting area
xref="x",
yref="paper",
text="Forecast Start",
showarrow=False,
yanchor="bottom"
)
# Apply layout
layout = self._create_base_layout(title, "Time", target_name)
# Add default view range if specified
if context_len is not None and horizon_len is not None:
layout['xaxis']['range'] = default_x_range
# Add focused y-axis range if specified
if default_y_range is not None:
layout['yaxis']['range'] = default_y_range
# Add timestamp
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
layout['annotations'] = [{
'x': 1,
'y': -0.1,
'xref': 'paper',
'yref': 'paper',
'text': f'Generated: {timestamp}',
'showarrow': False,
'font': {'size': 10, 'color': 'gray'}
}]
fig.update_layout(**layout)
# Save if requested
if save_path:
if save_path.endswith('.html'):
fig.write_html(save_path)
else:
fig.write_image(save_path)
logger.info(f"Interactive plot saved to: {save_path}")
# Debug final figure
print(f"DEBUG: Final figure has {len(fig.data)} traces")
for i, trace in enumerate(fig.data):
print(f"DEBUG: Trace {i}: name='{trace.name}', type='{trace.type}', visible={trace.visible}")
if hasattr(trace, 'y') and trace.y is not None:
print(f"DEBUG: Trace {i} y-data length: {len(trace.y) if hasattr(trace.y, '__len__') else 'scalar'}")
# Show figure if requested
if show_figure:
fig.show()
logger.info("✅ Interactive forecast visualization completed")
return fig
def plot_forecast_with_covariates(
self,
historical_data: Union[List[float], np.ndarray],
forecast: Union[List[float], np.ndarray],
covariates_data: Dict[str, Dict[str, Union[List[float], float, str]]],
intervals: Optional[Dict[str, np.ndarray]] = None,
actual_future: Optional[Union[List[float], np.ndarray]] = None,
dates_historical: Optional[List[Union[str, datetime]]] = None,
dates_future: Optional[List[Union[str, datetime]]] = None,
title: str = "TimesFM Forecast with Covariates Analysis",
target_name: str = "Target Value",
save_path: Optional[str] = None,
show_figure: bool = True,
context_len: Optional[int] = None,
horizon_len: Optional[int] = None,
show_full_history: bool = True,
y_axis_padding: float = 0.1
) -> go.Figure:
"""
Create a comprehensive interactive visualization with main forecast and covariates subplots.
Args:
historical_data: Historical time series data
forecast: Point forecast values
covariates_data: Dictionary containing covariates information
intervals: Optional prediction intervals
actual_future: Optional actual future values
dates_historical: Optional historical dates
dates_future: Optional future dates
title: Main plot title
target_name: Name of target variable
save_path: Optional save path
show_figure: Whether to display the figure
context_len: Length of context window for default view focus
horizon_len: Length of horizon for default view focus
show_full_history: Whether to show full historical data (True) or just context (False)
Returns:
Plotly Figure object
"""
logger.info(f"Creating comprehensive interactive forecast with covariates: {title}")
# Count covariates for subplot layout
num_covariates = len([k for k, v in covariates_data.items()
if isinstance(v, dict) and 'historical' in v])
# Create subplot layout
if num_covariates == 0:
return self.plot_forecast_with_intervals(
historical_data, forecast, intervals, actual_future,
dates_historical, dates_future, title, target_name, save_path, show_figure,
context_len, horizon_len, show_full_history, y_axis_padding
)
# Determine grid layout
if num_covariates <= 2:
rows, cols = 2, 2
subplot_titles = [title] + [f'{name.replace("_", " ").title()}'
for name in list(covariates_data.keys())[:3]]
elif num_covariates <= 4:
rows, cols = 3, 2
subplot_titles = [title] + [f'{name.replace("_", " ").title()}'
for name in list(covariates_data.keys())[:5]]
else:
rows, cols = 4, 2
subplot_titles = [title] + [f'{name.replace("_", " ").title()}'
for name in list(covariates_data.keys())[:7]]
# Create subplots
fig = make_subplots(
rows=rows, cols=cols,
subplot_titles=subplot_titles,
vertical_spacing=0.08,
horizontal_spacing=0.1
)
# Convert data
historical_data = np.array(historical_data)
forecast = np.array(forecast)
# Setup time axes
if dates_historical is None:
historical_x = np.arange(len(historical_data))
future_x = np.arange(len(historical_data), len(historical_data) + len(forecast))
else:
historical_x = pd.to_datetime(dates_historical)
future_x = pd.to_datetime(dates_future) if dates_future is not None else None
# Plot main forecast (similar to single plot method)
# Historical data
fig.add_trace(go.Scatter(
x=historical_x,
y=historical_data,
mode='lines',
name='Historical Data',
line=dict(color=self.colors['historical'], width=3),
hovertemplate='<b>Historical</b><br>Time: %{x}<br>Value: %{y:.2f}<extra></extra>'
), row=1, col=1)
# Forecast with seamless connection
if dates_historical is None:
connection_x = [len(historical_data) - 1] + list(future_x)
else:
connection_x = [historical_x[-1]] + list(future_x)
connection_forecast = [historical_data[-1]] + list(forecast)
# Plot intervals if available
if intervals:
for key in intervals.keys():
if key.startswith('lower_'):
conf_level = key.split('_')[1]
upper_key = f'upper_{conf_level}'
if upper_key in intervals:
interval_lower = [historical_data[-1]] + list(intervals[key])
interval_upper = [historical_data[-1]] + list(intervals[upper_key])
alpha = 0.3 if int(conf_level) >= 80 else 0.5
color = self.colors['interval_80'] if int(conf_level) >= 80 else self.colors['interval_50']
fig.add_trace(go.Scatter(
x=connection_x,
y=interval_upper,
mode='lines',
line=dict(width=0),
showlegend=False,
hoverinfo='skip'
), row=1, col=1)
fig.add_trace(go.Scatter(
x=connection_x,
y=interval_lower,
mode='lines',
line=dict(width=0),
fill='tonexty',
fillcolor=color,
name=f'{conf_level}% Prediction Interval',
hovertemplate=f'<b>{conf_level}% Interval</b><br>Time: %{{x}}<br>Upper: %{{y:.2f}}<extra></extra>'
), row=1, col=1)
# Forecast line
fig.add_trace(go.Scatter(
x=connection_x,
y=connection_forecast,
mode='lines',
name='Point Forecast',
line=dict(color=self.colors['forecast'], width=3, dash='dash'),
hovertemplate='<b>Forecast</b><br>Time: %{x}<br>Value: %{y:.2f}<extra></extra>'
), row=1, col=1)
# Plot actual future if available
if actual_future is not None:
actual_future = np.array(actual_future)
actual_connection = [historical_data[-1]] + list(actual_future)
fig.add_trace(go.Scatter(
x=connection_x,
y=actual_connection,
mode='lines+markers',
name='Actual Future',
line=dict(color=self.colors['actual'], width=3),
marker=dict(size=8, color=self.colors['actual'],
line=dict(width=2, color='white')),
hovertemplate='<b>Actual Future</b><br>Time: %{x}<br>Value: %{y:.2f}<extra></extra>'
), row=1, col=1)
# Forecast start line (commented out due to datetime compatibility issues)
# forecast_start = historical_x[-1] if dates_historical is not None else len(historical_data) - 1
# fig.add_vline(
# x=forecast_start,
# line_dash="dot",
# line_color="gray",
# line_width=2,
# annotation_text="Forecast Start",
# annotation_position="top"
# )
# Create covariate subplots
covariate_colors = ['#9467bd', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#d62728']
plot_idx = 0
for cov_name, cov_data in covariates_data.items():
if not isinstance(cov_data, dict) or 'historical' not in cov_data:
continue
if plot_idx >= (rows - 1) * cols: # Don't exceed subplot capacity
break
# Calculate subplot position
row = 2 + plot_idx // cols
col = 1 + plot_idx % cols
color = covariate_colors[plot_idx % len(covariate_colors)]
# Plot historical covariate data
fig.add_trace(go.Scatter(
x=historical_x,
y=cov_data['historical'],
mode='lines',
name=f'{cov_name.replace("_", " ").title()} Historical',
line=dict(color=color, width=2.5),
hovertemplate=f'<b>{cov_name.replace("_", " ").title()}</b><br>Time: %{{x}}<br>Value: %{{y:.2f}}<extra></extra>',
showlegend=False
), row=row, col=col)
# Plot future covariate data if available
if 'future' in cov_data and future_x is not None:
combined_data = list(cov_data['historical']) + list(cov_data['future'])
if dates_historical is None:
combined_x = np.arange(len(combined_data))
else:
combined_x = list(historical_x) + list(future_x)
future_start_idx = len(cov_data['historical']) - 1
fig.add_trace(go.Scatter(
x=combined_x[future_start_idx:],
y=combined_data[future_start_idx:],
mode='lines+markers',
name=f'{cov_name.replace("_", " ").title()} Future',
line=dict(color=color, width=2.5, dash='dash'),
marker=dict(size=6, color=color),
hovertemplate=f'<b>{cov_name.replace("_", " ").title()} Future</b><br>Time: %{{x}}<br>Value: %{{y:.2f}}<extra></extra>',
showlegend=False
), row=row, col=col)
# Forecast start line for covariate (commented out due to datetime compatibility issues)
# fig.add_vline(
# x=forecast_start,
# line_dash="dot",
# line_color="gray",
# line_width=1,
# row=row, col=col
# )
plot_idx += 1
# Update layout
fig.update_layout(
title=f'TimesFM Comprehensive Forecasting Analysis',
title_x=0.5,
title_font_size=20,
height=800,
showlegend=True,
hovermode='x unified'
)
# Update axes
for i in range(1, rows + 1):
for j in range(1, cols + 1):
fig.update_xaxes(
title_text="Time" if i == 1 else "",
gridcolor=self.colors['grid'],
showgrid=True,
row=i, col=j
)
fig.update_yaxes(
title_text=target_name if i == 1 else "Value",
gridcolor=self.colors['grid'],
showgrid=True,
row=i, col=j
)
# Add timestamp
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
fig.add_annotation(
x=1, y=-0.1,
xref='paper', yref='paper',
text=f'Generated: {timestamp}',
showarrow=False,
font=dict(size=10, color='gray')
)
# Save if requested
if save_path:
if save_path.endswith('.html'):
fig.write_html(save_path)
else:
fig.write_image(save_path)
logger.info(f"Comprehensive interactive plot saved to: {save_path}")
# Show figure if requested
if show_figure:
fig.show()
logger.info("✅ Comprehensive interactive forecast visualization completed")
return fig
def plot_forecast_comparison(
self,
forecasts_dict: Dict[str, np.ndarray],
historical_data: Union[List[float], np.ndarray],
actual_future: Optional[Union[List[float], np.ndarray]] = None,
title: str = "Forecast Methods Comparison",
save_path: Optional[str] = None,
show_figure: bool = True
) -> go.Figure:
"""
Compare multiple forecasting methods in an interactive plot.
Args:
forecasts_dict: Dictionary of {method_name: forecast_array}
historical_data: Historical data for context
actual_future: Optional actual future values
title: Plot title
save_path: Optional save path
show_figure: Whether to display the figure
Returns:
Plotly Figure object
"""
logger.info(f"Creating interactive forecast comparison plot: {title}")
fig = go.Figure()
historical_data = np.array(historical_data)
historical_x = np.arange(len(historical_data))
# Plot historical data
fig.add_trace(go.Scatter(
x=historical_x,
y=historical_data,
mode='lines',
name='Historical Data',
line=dict(color=self.colors['historical'], width=3),
hovertemplate='<b>Historical</b><br>Time: %{x}<br>Value: %{y:.2f}<extra></extra>'
))
# Plot different forecasts
forecast_colors = ['#d62728', '#ff7f0e', '#2ca02c', '#9467bd', '#8c564b']
for i, (method, forecast) in enumerate(forecasts_dict.items()):
forecast = np.array(forecast)
future_x = np.arange(len(historical_data), len(historical_data) + len(forecast))
# Seamless connection
connection_x = [len(historical_data) - 1] + list(future_x)
connection_forecast = [historical_data[-1]] + list(forecast)
color = forecast_colors[i % len(forecast_colors)]
linestyle = 'dash' if i == 0 else 'dot'
fig.add_trace(go.Scatter(
x=connection_x,
y=connection_forecast,
mode='lines',
name=f'{method} Forecast',
line=dict(color=color, width=3, dash=linestyle),
hovertemplate=f'<b>{method} Forecast</b><br>Time: %{{x}}<br>Value: %{{y:.2f}}<extra></extra>'
))
# Plot actual future if available
if actual_future is not None:
actual_future = np.array(actual_future)
future_x = np.arange(len(historical_data), len(historical_data) + len(actual_future))
connection_x = [len(historical_data) - 1] + list(future_x)
actual_connection = [historical_data[-1]] + list(actual_future)
fig.add_trace(go.Scatter(
x=connection_x,
y=actual_connection,
mode='lines+markers',
name='Actual Future',
line=dict(color=self.colors['actual'], width=3),
marker=dict(size=8, color=self.colors['actual'],
line=dict(width=2, color='white')),
hovertemplate='<b>Actual Future</b><br>Time: %{x}<br>Value: %{y:.2f}<extra></extra>'
))
# Forecast start line
fig.add_vline(
x=len(historical_data) - 1,
line_dash="dot",
line_color="gray",
line_width=2,
annotation_text="Forecast Start",
annotation_position="top"
)
# Apply layout
layout = self._create_base_layout(title, "Time", "Value")
fig.update_layout(**layout)
# Add timestamp
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
fig.add_annotation(
x=1, y=-0.1,
xref='paper', yref='paper',
text=f'Generated: {timestamp}',
showarrow=False,
font=dict(size=10, color='gray')
)
# Save if requested
if save_path:
if save_path.endswith('.html'):
fig.write_html(save_path)
else:
fig.write_image(save_path)
logger.info(f"Comparison plot saved to: {save_path}")
# Show figure if requested
if show_figure:
fig.show()
logger.info("✅ Interactive forecast comparison visualization completed")
return fig
def create_dashboard(
self,
historical_data: Union[List[float], np.ndarray],
forecast: Union[List[float], np.ndarray],
intervals: Optional[Dict[str, np.ndarray]] = None,
covariates_data: Optional[Dict[str, Dict[str, Union[List[float], float, str]]]] = None,
actual_future: Optional[Union[List[float], np.ndarray]] = None,
dates_historical: Optional[List[Union[str, datetime]]] = None,
dates_future: Optional[List[Union[str, datetime]]] = None,
title: str = "TimesFM Forecasting Dashboard",
target_name: str = "Value",
save_path: Optional[str] = None,
show_figure: bool = True,
context_len: Optional[int] = None,
horizon_len: Optional[int] = None,
show_full_history: bool = True,
y_axis_padding: float = 0.1
) -> go.Figure:
"""
Create a comprehensive dashboard with multiple visualization panels.
Args:
historical_data: Historical time series data
forecast: Point forecast values
intervals: Optional prediction intervals
covariates_data: Optional covariates data
actual_future: Optional actual future values
dates_historical: Optional historical dates
dates_future: Optional future dates
title: Dashboard title
target_name: Name of target variable
save_path: Optional save path
show_figure: Whether to display the figure
Returns:
Plotly Figure object
"""
logger.info(f"Creating interactive forecasting dashboard: {title}")
# If covariates are provided, use the comprehensive view
if covariates_data and len(covariates_data) > 0:
return self.plot_forecast_with_covariates(
historical_data, forecast, covariates_data, intervals,
actual_future, dates_historical, dates_future,
title, target_name, save_path, show_figure,
context_len, horizon_len, show_full_history, y_axis_padding
)
else:
# Otherwise, use the standard forecast view
return self.plot_forecast_with_intervals(
historical_data, forecast, intervals, actual_future,
dates_historical, dates_future, title, target_name, save_path, show_figure,
context_len, horizon_len, show_full_history, y_axis_padding
)
def export_to_json(self, fig: go.Figure, file_path: str) -> None:
"""
Export a Plotly figure to JSON format for web integration.
Args:
fig: Plotly Figure object
file_path: Path to save the JSON file
"""
fig.write_json(file_path)
logger.info(f"Figure exported to JSON: {file_path}")
def get_figure_html(self, fig: go.Figure, include_plotlyjs: bool = True) -> str:
"""
Get the HTML representation of a figure.
Args:
fig: Plotly Figure object
include_plotlyjs: Whether to include Plotly.js in the HTML
Returns:
HTML string representation of the figure
"""
return fig.to_html(include_plotlyjs=include_plotlyjs)