|
|
""" |
|
|
Chart generation for forecast visualization |
|
|
""" |
|
|
|
|
|
import plotly.graph_objs as go |
|
|
from plotly.subplots import make_subplots |
|
|
import pandas as pd |
|
|
from typing import List |
|
|
from config.constants import COLORS, CHART_CONFIG |
|
|
|
|
|
|
|
|
def create_forecast_chart( |
|
|
historical_data: pd.DataFrame, |
|
|
forecast_data: pd.DataFrame, |
|
|
confidence_levels: List[int], |
|
|
title: str = "Time Series Forecast", |
|
|
y_axis_label: str = "Value", |
|
|
backtest_data: pd.DataFrame = None |
|
|
) -> go.Figure: |
|
|
""" |
|
|
Create an interactive forecast chart with confidence intervals |
|
|
|
|
|
Args: |
|
|
historical_data: DataFrame with columns ['ds', 'y'] |
|
|
forecast_data: DataFrame with forecast and confidence intervals |
|
|
confidence_levels: List of confidence levels to plot |
|
|
title: Chart title |
|
|
y_axis_label: Label for y-axis (variable name being forecasted) |
|
|
backtest_data: Optional DataFrame with backtest results |
|
|
|
|
|
Returns: |
|
|
Plotly figure |
|
|
""" |
|
|
fig = go.Figure() |
|
|
|
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=historical_data['ds'], |
|
|
y=historical_data['y'], |
|
|
mode='lines', |
|
|
name='Historical', |
|
|
line=dict(color=COLORS['historical'], width=2), |
|
|
hovertemplate=f'<b>Date:</b> %{{x}}<br><b>{y_axis_label}:</b> %{{y:.2f}}<extra></extra>' |
|
|
)) |
|
|
|
|
|
|
|
|
if backtest_data is not None and len(backtest_data) > 0: |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=backtest_data['timestamp'], |
|
|
y=backtest_data['actual'], |
|
|
mode='lines', |
|
|
name='Backtest Actual', |
|
|
line=dict(color='rgba(100, 100, 100, 0.6)', width=2, dash='dot'), |
|
|
hovertemplate=f'<b>Date:</b> %{{x}}<br><b>{y_axis_label} (Actual):</b> %{{y:.2f}}<extra></extra>' |
|
|
)) |
|
|
|
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=backtest_data['timestamp'], |
|
|
y=backtest_data['predicted'], |
|
|
mode='lines', |
|
|
name='Backtest Predicted', |
|
|
line=dict(color='rgba(255, 100, 100, 0.8)', width=2), |
|
|
hovertemplate=f'<b>Date:</b> %{{x}}<br><b>{y_axis_label} (Predicted):</b> %{{y:.2f}}<extra></extra>' |
|
|
)) |
|
|
|
|
|
|
|
|
for cl in sorted(confidence_levels, reverse=True): |
|
|
lower_col = f'lower_{cl}' |
|
|
upper_col = f'upper_{cl}' |
|
|
|
|
|
if lower_col in forecast_data.columns and upper_col in forecast_data.columns: |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=forecast_data['ds'].tolist() + forecast_data['ds'].tolist()[::-1], |
|
|
y=forecast_data[upper_col].tolist() + forecast_data[lower_col].tolist()[::-1], |
|
|
fill='toself', |
|
|
fillcolor=COLORS['confidence'][cl], |
|
|
line=dict(width=0), |
|
|
name=f'{cl}% Confidence', |
|
|
showlegend=True, |
|
|
hoverinfo='skip' |
|
|
)) |
|
|
|
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=forecast_data['ds'], |
|
|
y=forecast_data['forecast'], |
|
|
mode='lines', |
|
|
name='Forecast', |
|
|
line=dict(color=COLORS['forecast'], width=2), |
|
|
hovertemplate=f'<b>Date:</b> %{{x}}<br><b>{y_axis_label} (Forecast):</b> %{{y:.2f}}<extra></extra>' |
|
|
)) |
|
|
|
|
|
|
|
|
if len(historical_data) > 0: |
|
|
last_historical_date = historical_data['ds'].iloc[-1] |
|
|
|
|
|
fig.add_shape( |
|
|
type="line", |
|
|
x0=last_historical_date, |
|
|
x1=last_historical_date, |
|
|
y0=0, |
|
|
y1=1, |
|
|
yref="paper", |
|
|
line=dict(color=COLORS['separator'], dash="dash", width=1) |
|
|
) |
|
|
|
|
|
fig.add_annotation( |
|
|
x=last_historical_date, |
|
|
y=1.0, |
|
|
yref="paper", |
|
|
text="Forecast Start", |
|
|
showarrow=False, |
|
|
yanchor="bottom" |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
title=dict(text=title, x=0.5, xanchor='center'), |
|
|
xaxis_title="Date", |
|
|
yaxis_title=y_axis_label, |
|
|
hovermode='x unified', |
|
|
template='plotly_white', |
|
|
height=700, |
|
|
showlegend=True, |
|
|
legend=dict( |
|
|
orientation="h", |
|
|
yanchor="bottom", |
|
|
y=1.02, |
|
|
xanchor="right", |
|
|
x=1 |
|
|
), |
|
|
margin=dict(l=50, r=50, t=80, b=150), |
|
|
xaxis=dict( |
|
|
rangeslider=dict( |
|
|
visible=True, |
|
|
thickness=0.12 |
|
|
), |
|
|
type='date' |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
modebar_add=['v1hovermode', 'toggleSpikelines'] |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
def create_empty_chart(message: str = "No data available") -> go.Figure: |
|
|
""" |
|
|
Create an empty placeholder chart |
|
|
|
|
|
Args: |
|
|
message: Message to display |
|
|
|
|
|
Returns: |
|
|
Plotly figure |
|
|
""" |
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_annotation( |
|
|
text=message, |
|
|
xref="paper", |
|
|
yref="paper", |
|
|
x=0.5, |
|
|
y=0.5, |
|
|
showarrow=False, |
|
|
font=dict(size=20, color='gray') |
|
|
) |
|
|
|
|
|
fig.update_layout( |
|
|
template='plotly_white', |
|
|
height=600, |
|
|
xaxis=dict(visible=False), |
|
|
yaxis=dict(visible=False) |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
def create_metrics_display(metrics: dict, inference_time: float = None) -> list: |
|
|
""" |
|
|
Create metrics display components |
|
|
|
|
|
Args: |
|
|
metrics: Dictionary of metric values |
|
|
inference_time: Time taken for inference in seconds |
|
|
|
|
|
Returns: |
|
|
List of Dash components |
|
|
""" |
|
|
import dash_bootstrap_components as dbc |
|
|
from dash import html |
|
|
|
|
|
metric_cards = [] |
|
|
|
|
|
|
|
|
if inference_time is not None: |
|
|
metric_cards.append( |
|
|
dbc.Col([ |
|
|
dbc.Card([ |
|
|
dbc.CardBody([ |
|
|
html.H6("Inference Time", className="text-muted mb-2"), |
|
|
html.H4(f"{inference_time:.2f}s") |
|
|
]) |
|
|
], className="text-center") |
|
|
], md=2) |
|
|
) |
|
|
|
|
|
|
|
|
metric_names = { |
|
|
'MAE': 'Mean Absolute Error', |
|
|
'RMSE': 'Root Mean Squared Error', |
|
|
'MAPE': 'Mean Absolute % Error', |
|
|
'R2': 'R-Squared' |
|
|
} |
|
|
|
|
|
for key, name in metric_names.items(): |
|
|
if key in metrics and metrics[key] is not None: |
|
|
value = metrics[key] |
|
|
if key in ['MAPE']: |
|
|
formatted_value = f"{value:.2f}%" |
|
|
elif key == 'R2': |
|
|
formatted_value = f"{value:.4f}" |
|
|
else: |
|
|
formatted_value = f"{value:.2f}" |
|
|
|
|
|
metric_cards.append( |
|
|
dbc.Col([ |
|
|
dbc.Card([ |
|
|
dbc.CardBody([ |
|
|
html.H6(name, className="text-muted mb-2"), |
|
|
html.H4(formatted_value) |
|
|
]) |
|
|
], className="text-center") |
|
|
], md=2) |
|
|
) |
|
|
|
|
|
return metric_cards |
|
|
|
|
|
|
|
|
def create_backtest_metrics_display(metrics: dict) -> list: |
|
|
""" |
|
|
Create backtest metrics display components |
|
|
|
|
|
Args: |
|
|
metrics: Dictionary of backtest metric values (MAE, RMSE, MAPE, R2) |
|
|
|
|
|
Returns: |
|
|
Dash component card |
|
|
""" |
|
|
import dash_bootstrap_components as dbc |
|
|
from dash import html |
|
|
|
|
|
return dbc.Card([ |
|
|
dbc.CardHeader([ |
|
|
html.I(className="fas fa-chart-bar me-2"), |
|
|
html.Span("Backtest Performance Metrics", className="fw-bold") |
|
|
]), |
|
|
dbc.CardBody([ |
|
|
html.P("Model performance on historical data validation:", className="text-muted small mb-3"), |
|
|
dbc.Row([ |
|
|
dbc.Col([ |
|
|
html.Small("MAE", className="text-muted"), |
|
|
html.H5(f"{metrics.get('MAE', 0):.2f}", className="mb-0") |
|
|
], md=3), |
|
|
dbc.Col([ |
|
|
html.Small("RMSE", className="text-muted"), |
|
|
html.H5(f"{metrics.get('RMSE', 0):.2f}", className="mb-0") |
|
|
], md=3), |
|
|
dbc.Col([ |
|
|
html.Small("MAPE", className="text-muted"), |
|
|
html.H5(f"{metrics.get('MAPE', 0):.2f}%", className="mb-0") |
|
|
], md=3), |
|
|
dbc.Col([ |
|
|
html.Small("R²", className="text-muted"), |
|
|
html.H5(f"{metrics.get('R2', 0):.4f}", className="mb-0") |
|
|
], md=3), |
|
|
]), |
|
|
html.Hr(), |
|
|
html.Small([ |
|
|
html.I(className="fas fa-info-circle me-1"), |
|
|
"Lower MAE/RMSE/MAPE and higher R² (closer to 1.0) indicate better model performance" |
|
|
], className="text-muted") |
|
|
]) |
|
|
], className="mt-3") |
|
|
|
|
|
|
|
|
def decimate_data(df: pd.DataFrame, max_points: int = 10000) -> pd.DataFrame: |
|
|
""" |
|
|
Reduce number of data points for visualization |
|
|
|
|
|
Args: |
|
|
df: Input DataFrame |
|
|
max_points: Maximum number of points to keep |
|
|
|
|
|
Returns: |
|
|
Decimated DataFrame |
|
|
""" |
|
|
if len(df) <= max_points: |
|
|
return df |
|
|
|
|
|
|
|
|
step = len(df) // max_points |
|
|
return df.iloc[::step].reset_index(drop=True) |
|
|
|