-
Notifications
You must be signed in to change notification settings - Fork 191
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add amazon chronos benchmark (#257)
- Loading branch information
Showing
7 changed files
with
667 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# A Statistical Ensemble of traditional methods is 10% more accurate and 5x faster than Amazon Chronos | ||
|
||
We present a comprehensive evaluation showcasing that a Statistical Ensemble, consisting of AutoARIMA, AutoETS, AutoCES, and DynamicOptimizedTheta, outperforms Amazon Chronos—a foundational model for time series forecasting with over 710 million parameters. Specifically, the **Statistical Ensemble demonstrates 10%, 10%, and 11% superior performance in CRPS, MASE, and SMAPE metrics, respectively**, and it is **5x faster**. This analysis spans over 50,000 unique time series across M1, M3, M4, and Tourism datasets, robustly comparing these models. | ||
|
||
# Introduction | ||
|
||
The rise of foundational models in time series forecasting, such as Amazon Chronos, represents a significant leap forward, leveraging deep learning and massive datasets for model pre-training to enhance predictive accuracy. Amazon Chronos, in particular, is noteworthy for its extensive parameterization and ambitious scope. However, our study shows that a comparatively simpler approach, employing a Statistical Ensemble of traditional forecasting methods, yields better accuracy and computational efficiency. | ||
|
||
## Empirical Evaluation | ||
|
||
This study considers over 50,000 unique time series from the M1, M3, M4, and Tourism datasets, spanning various time series frequencies. Chronos did not use these datasets in the training phase. We have also included comparisons to the Seasonal Naive model to provide a benchmark for traditional forecasting methods. | ||
|
||
## Results | ||
|
||
Our findings are shown in the following table, showcasing the performance across different metrics: CRPS, MASE, SMAPE, and computational time (in seconds). The best results are highlighted in **bold** for ease of reference. | ||
|
||
<img width="1099" alt="image" src="https://github.com/Nixtla/nixtla/assets/10517170/4d4fe9f3-4251-4b95-bd9b-248fc283e97b"> | ||
|
||
|
||
## Reproducibility | ||
|
||
To ensure the reproducibility of our findings, the Statistical Ensemble experiments were conducted on an AWS c5a.24xlarge instance, equipped with 96 vCPUs and 192 GiB of RAM. In contrast, the experiments for Amazon Chronos were carried out on an AWS g5.4xlarge GPU instance, which includes 16 vCPUs, 64 GiB of RAM, and an NVIDIA A10G Tensor Core GPU with 24 GiB. All necessary code and detailed instructions for reproducing the experiments are available in this directory. | ||
|
||
### Instructions | ||
|
||
1. Set up a Python environment: | ||
|
||
```bash | ||
mamba env create -f environment.yml | ||
conda activate amazon-chronos | ||
``` | ||
|
||
2. Run the experiments as reported in the table: | ||
|
||
```bash | ||
python -m src.main --mode fcst_statsforecast | ||
python -m src.main --mode fcst_chronos | ||
``` | ||
|
||
3. Evaluate the results using: | ||
|
||
```bash | ||
python -m src.main --mode evaluation | ||
``` | ||
|
||
### References | ||
- **Statistical Ensemble Paper**: [A Simple Combination of Univariate Models](https://www.sciencedirect.com/science/article/abs/pii/S0169207019300585?via%3Dihub) | ||
- **Amazon Chronos Paper**: [Chronos: Learning the Language of Time Series](https://arxiv.org/abs/2403.07815) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
name: amazon-chronos | ||
channels: | ||
- conda-forge | ||
- defaults | ||
- anaconda | ||
dependencies: | ||
- jupyterlab | ||
- pip | ||
- python=3.10 | ||
- pip: | ||
- datasetsforecast | ||
- fire | ||
- gluonts | ||
- huggingface_hub[cli] | ||
- neuralforecast | ||
- orjson | ||
- statsforecast | ||
- utilsforecast | ||
- git+https://github.com/amazon-science/chronos-forecasting.git | ||
|
122 changes: 122 additions & 0 deletions
122
experiments/amazon-chronos/src/amazon_chronos/forecaster.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import logging | ||
from typing import Iterable, List | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from chronos import ChronosPipeline | ||
from utilsforecast.processing import make_future_dataframe | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
main_logger = logging.getLogger(__name__) | ||
|
||
|
||
class TimeSeriesDataset: | ||
def __init__( | ||
self, | ||
data: torch.Tensor, | ||
uids: Iterable, | ||
last_times: Iterable, | ||
batch_size: int, | ||
): | ||
self.data = data | ||
self.uids = uids | ||
self.last_times = last_times | ||
self.batch_size = batch_size | ||
self.n_batches = len(data) // self.batch_size + ( | ||
0 if len(data) % self.batch_size == 0 else 1 | ||
) | ||
self.current_batch = 0 | ||
|
||
@classmethod | ||
def from_df(cls, df: pd.DataFrame, batch_size: int): | ||
num_unique_ids = df["unique_id"].nunique() | ||
max_series_length = df["unique_id"].value_counts().max() | ||
padded_tensor = torch.full( | ||
size=(num_unique_ids, max_series_length), | ||
fill_value=torch.nan, | ||
dtype=torch.bfloat16, | ||
) # type: ignore | ||
df_sorted = df.sort_values(by=["unique_id", "ds"]) | ||
for idx, (_, group) in enumerate(df_sorted.groupby("unique_id")): | ||
series_length = len(group) | ||
padded_tensor[idx, -series_length:] = torch.tensor( | ||
group["y"].values, | ||
dtype=torch.bfloat16, | ||
) | ||
uids = df_sorted["unique_id"].unique() | ||
last_times = df_sorted.groupby("unique_id")["ds"].tail(1) | ||
return cls(padded_tensor, uids, last_times, batch_size) | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def make_future_dataframe(self, h: int, freq: str) -> pd.DataFrame: | ||
return make_future_dataframe( | ||
uids=self.uids, | ||
last_times=pd.to_datetime(self.last_times), | ||
h=h, | ||
freq=freq, | ||
) # type: ignore | ||
|
||
def __iter__(self): | ||
self.current_batch = 0 # Reset for new iteration | ||
return self | ||
|
||
def __next__(self): | ||
if self.current_batch < self.n_batches: | ||
start_idx = self.current_batch * self.batch_size | ||
end_idx = start_idx + self.batch_size | ||
self.current_batch += 1 | ||
return self.data[start_idx:end_idx] | ||
else: | ||
raise StopIteration | ||
|
||
|
||
class AmazonChronos: | ||
def __init__(self, model_name: str): | ||
self.model_name = model_name | ||
self.model = ChronosPipeline.from_pretrained( | ||
model_name, | ||
device_map="auto", | ||
torch_dtype=torch.bfloat16, | ||
) | ||
|
||
def forecast( | ||
self, | ||
df: pd.DataFrame, | ||
h: int, | ||
freq: str, | ||
batch_size: int = 32, | ||
quantiles: List[float] | None = None, | ||
**predict_kwargs, | ||
) -> pd.DataFrame: | ||
main_logger.info("transforming dataframe to tensor") | ||
dataset = TimeSeriesDataset.from_df(df, batch_size=batch_size) | ||
main_logger.info("forecasting") | ||
fcsts = [self.model.predict(batch, prediction_length=h, **predict_kwargs) for batch in dataset] | ||
fcst = torch.cat(fcsts) | ||
main_logger.info("transforming forecast to dataframe") | ||
fcst = fcst.numpy() | ||
fcst_df = dataset.make_future_dataframe(h=h, freq=freq) | ||
fcst_df[self.model_name] = np.median(fcst, axis=1).reshape(-1, 1) | ||
if quantiles is not None: | ||
for q in quantiles: | ||
q_col = f"{self.model_name}-q-{q}" | ||
fcst_df[q_col] = np.quantile(fcst, q, axis=1).reshape(-1, 1) | ||
return fcst_df | ||
|
||
|
||
if __name__ == "__main__": | ||
import pandas as pd | ||
|
||
df = pd.read_csv( | ||
"https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv" | ||
) | ||
df = df.rename(columns={"#Passengers": "y", "Month": "ds"}) | ||
df["ds"] = pd.to_datetime(df["ds"]) | ||
df.insert(0, "unique_id", "AirPassengers") | ||
df = pd.concat([df, df.assign(unique_id="AirPassengers2")]) | ||
model = AmazonChronos("amazon/chronos-t5-small") | ||
fcst_df = model.forecast(df, h=12, freq="MS") | ||
print(fcst_df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
from time import time | ||
from typing import List, Tuple | ||
|
||
import fire | ||
import pandas as pd | ||
|
||
|
||
from ..utils import ExperimentHandler | ||
from .forecaster import AmazonChronos | ||
|
||
|
||
def run_amazon_chronos( | ||
train_df: pd.DataFrame, | ||
model_name: str, | ||
horizon: int, | ||
freq: str, | ||
quantiles: List[float], | ||
) -> Tuple[pd.DataFrame, float, str]: | ||
ac = AmazonChronos(model_name) | ||
init_time = time() | ||
fcsts_df = ac.forecast( | ||
df=train_df, | ||
h=horizon, | ||
freq=freq, | ||
batch_size=8, | ||
quantiles=quantiles, | ||
# parameters as in https://github.com/amazon-science/chronos-forecasting/blob/73be25042f5f587823d46106d372ba133152fb00/README.md?plain=1#L62-L65 | ||
num_samples=20, | ||
temperature=1.0, | ||
top_k=50, | ||
top_p=1.0, | ||
) | ||
total_time = time() - init_time | ||
return fcsts_df, total_time, model_name | ||
|
||
|
||
def main(dataset: str, model_name: str): | ||
exp = ExperimentHandler(dataset) | ||
fcst_df, total_time, model_name = run_amazon_chronos( | ||
train_df=exp.train_df, | ||
model_name=model_name, | ||
horizon=exp.horizon, | ||
freq=exp.freq, | ||
quantiles=exp.quantiles, | ||
) | ||
exp.save_results(fcst_df, total_time, model_name) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import logging | ||
import subprocess | ||
|
||
import fire | ||
import pandas as pd | ||
|
||
from src.utils import ExperimentHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
datasets = [ | ||
"m1_yearly", | ||
"m1_quarterly", | ||
"m1_monthly", | ||
"m3_yearly", | ||
"m3_quarterly", | ||
"m3_monthly", | ||
"m3_other", | ||
"tourism_yearly", | ||
"tourism_quarterly", | ||
"tourism_monthly", | ||
"m4_yearly", | ||
"m4_quarterly", | ||
] | ||
|
||
amazon_chronos_models = [ | ||
"amazon/chronos-t5-large", | ||
"amazon/chronos-t5-tiny", | ||
"amazon/chronos-t5-mini", | ||
"amazon/chronos-t5-small", | ||
"amazon/chronos-t5-base", | ||
] | ||
|
||
|
||
def main(mode: str): | ||
prefix_process = ["python", "-m"] | ||
|
||
eval_df = None | ||
for dataset in datasets: | ||
logger.info(f"Evaluating {dataset}...") | ||
if mode in ["fcst_statsforecast", "fcst_chronos"]: | ||
suffix_process = ["--dataset", dataset] | ||
|
||
def process(middle_process): | ||
return prefix_process + middle_process + suffix_process | ||
|
||
if mode == "fcst_statsforecast": | ||
logger.info("Running StatisticalEnsemble") | ||
subprocess.run(process(["src.statsforecast_pipeline"])) | ||
elif mode == "fcst_chronos": | ||
for model in amazon_chronos_models: | ||
logger.info(f"Running Amazon Chronos {model}") | ||
chronos_process = process(["src.amazon_chronos.pipeline"]) | ||
chronos_process.extend(["--model_name", model]) | ||
subprocess.run(chronos_process) | ||
elif mode == "evaluation": | ||
if eval_df is None: | ||
eval_df = [] | ||
logger.info("Running dataset evaluation") | ||
exp = ExperimentHandler(dataset) | ||
try: | ||
eval_dataset_df = exp.evaluate_models( | ||
amazon_chronos_models + ["StatisticalEnsemble", "SeasonalNaive"] | ||
) | ||
print(eval_dataset_df) | ||
eval_df.append(eval_dataset_df) | ||
except Exception as e: | ||
logger.error(e) | ||
if eval_df is not None: | ||
eval_df = pd.concat(eval_df).reset_index(drop=True) | ||
exp.save_dataframe(eval_df, "complete-results.csv") | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
Oops, something went wrong.