Spaces:
Sleeping
Sleeping
| from typing import List | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure: | |
| """ | |
| Plot the training and test datasets using Plotly. | |
| Args: | |
| df1 (pd.DataFrame): Train dataset | |
| df2 (pd.DataFrame): Test dataset | |
| Returns: | |
| None | |
| """ | |
| # Create a Plotly figure | |
| fig = go.Figure() | |
| # Add the first scatter plot with steelblue color | |
| fig.add_trace( | |
| go.Scatter( | |
| x=df1.index, | |
| y=df1.iloc[:, 0], | |
| mode="lines", | |
| name="Training Data", | |
| line=dict(color="steelblue"), | |
| marker=dict(color="steelblue"), | |
| ) | |
| ) | |
| # Add the second scatter plot with yellow color | |
| fig.add_trace( | |
| go.Scatter( | |
| x=df2.index, | |
| y=df2.iloc[:, 0], | |
| mode="lines", | |
| name="Test Data", | |
| line=dict(color="gold"), | |
| marker=dict(color="gold"), | |
| ) | |
| ) | |
| # Customize the layout | |
| fig.update_layout( | |
| title="Univariate Time Series", | |
| xaxis=dict(title="Date"), | |
| yaxis=dict(title="Value"), | |
| showlegend=True, | |
| template="plotly_white", | |
| ) | |
| return fig | |
| def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]): | |
| """ | |
| Plot the true values and forecasts using Plotly. | |
| Args: | |
| df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns. | |
| forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts. | |
| Returns: | |
| go.Figure: Plotly figure object. | |
| """ | |
| # Create a Plotly figure | |
| fig = go.Figure() | |
| # Add the true values trace | |
| fig.add_trace( | |
| go.Scatter( | |
| x=pd.to_datetime(df.index), | |
| y=df.iloc[:, 0], | |
| mode="lines", | |
| name="True values", | |
| line=dict(color="black"), | |
| ) | |
| ) | |
| # Add the forecast traces | |
| colors = ["green", "blue", "purple"] | |
| for i, forecast in enumerate(forecasts): | |
| color = colors[i % len(colors)] | |
| for sample in forecast.samples: | |
| fig.add_trace( | |
| go.Scatter( | |
| x=forecast.index.to_timestamp(), | |
| y=sample, | |
| mode="lines", | |
| opacity=0.15, # Adjust opacity to control visibility of individual samples | |
| name=f"Forecast {i + 1}", | |
| showlegend=False, # Hide the individual forecast series from the legend | |
| hoverinfo="none", # Disable hover information for the forecast series | |
| line=dict(color=color), | |
| ) | |
| ) | |
| # Add the average | |
| mean_forecast = np.mean(forecast.samples, axis=0) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=forecast.index.to_timestamp(), | |
| y=mean_forecast, | |
| mode="lines", | |
| name="Mean Forecast", | |
| line=dict(color="red", dash="dash"), | |
| legendgroup="mean forecast", | |
| showlegend=i == 0, | |
| ) | |
| ) | |
| # Customize the layout | |
| fig.update_layout( | |
| title=f"{df.columns[0]} Forecast", | |
| yaxis=dict(title=df.columns[0]), | |
| showlegend=True, | |
| legend=dict(x=0, y=1), | |
| hovermode="x", # Enable x-axis hover for better interactivity | |
| ) | |
| # Return the figure | |
| return fig | |