Download this notebook

Quick Start Tutorial#

GluonTS contains:

  • A number of pre-built models

  • Components for building new models (likelihoods, feature processing pipelines, calendar features etc.)

  • Data loading and processing

  • Plotting and evaluation facilities

  • Artificial and real datasets (only external datasets with blessed license)

[1]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json

Datasets#

Provided datasets#

GluonTS comes with a number of publicly available datasets.

[2]:
from gluonts.dataset.repository import get_dataset, dataset_names
from gluonts.dataset.util import to_pandas
[3]:
print(f"Available datasets: {dataset_names}")
Available datasets: ['constant', 'exchange_rate', 'solar-energy', 'electricity', 'traffic', 'exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki2000_nips', 'wiki-rolling_nips', 'taxi_30min', 'kaggle_web_traffic_with_missing', 'kaggle_web_traffic_without_missing', 'kaggle_web_traffic_weekly', 'm1_yearly', 'm1_quarterly', 'm1_monthly', 'nn5_daily_with_missing', 'nn5_daily_without_missing', 'nn5_weekly', 'tourism_monthly', 'tourism_quarterly', 'tourism_yearly', 'cif_2016', 'london_smart_meters_without_missing', 'wind_farms_without_missing', 'car_parts_without_missing', 'dominick', 'fred_md', 'pedestrian_counts', 'hospital', 'covid_deaths', 'kdd_cup_2018_without_missing', 'weather', 'm3_monthly', 'm3_quarterly', 'm3_yearly', 'm3_other', 'm4_hourly', 'm4_daily', 'm4_weekly', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5', 'uber_tlc_daily', 'uber_tlc_hourly', 'airpassengers', 'australian_electricity_demand', 'electricity_hourly', 'electricity_weekly', 'rideshare_without_missing', 'saugeenday', 'solar_10_minutes', 'solar_weekly', 'sunspot_without_missing', 'temperature_rain_without_missing', 'vehicle_trips_without_missing', 'ercot', 'ett_small_15min', 'ett_small_1h']

To download one of the built-in datasets, simply call get_dataset with one of the above names. GluonTS can re-use the saved dataset so that it does not need to be downloaded again the next time around.

[4]:
dataset = get_dataset("m4_hourly")

In general, the datasets provided by GluonTS are objects that consists of three main members:

  • dataset.train is an iterable collection of data entries used for training. Each entry corresponds to one time series.

  • dataset.test is an iterable collection of data entries used for inference. The test dataset is an extended version of the train dataset that contains a window in the end of each time series that was not seen during training. This window has length equal to the recommended prediction length.

  • dataset.metadata contains metadata of the dataset such as the frequency of the time series, a recommended prediction horizon, associated features, etc.

[5]:
entry = next(iter(dataset.train))
train_series = to_pandas(entry)
train_series.plot()
plt.grid(which="both")
plt.legend(["train series"], loc="upper left")
plt.show()
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gluonts/dataset/common.py:263: FutureWarning: 'H' is deprecated and will be removed in a future version, please use 'h' instead.
  return pd.Period(val, freq)
../../_images/tutorials_forecasting_quick_start_tutorial_8_1.png
[6]:
entry = next(iter(dataset.test))
test_series = to_pandas(entry)
test_series.plot()
plt.axvline(train_series.index[-1], color="r")  # end of train dataset
plt.grid(which="both")
plt.legend(["test series", "end of train series"], loc="upper left")
plt.show()
../../_images/tutorials_forecasting_quick_start_tutorial_9_0.png
[7]:
print(
    f"Length of forecasting window in test dataset: {len(test_series) - len(train_series)}"
)
print(f"Recommended prediction horizon: {dataset.metadata.prediction_length}")
print(f"Frequency of the time series: {dataset.metadata.freq}")
Length of forecasting window in test dataset: 48
Recommended prediction horizon: 48
Frequency of the time series: H

Custom datasets#

At this point, it is important to emphasize that GluonTS does not require this specific format for a custom dataset that a user may have. The only requirements for a custom dataset are to be iterable and have a “target” and a “start” field. To make this more clear, assume the common case where a dataset is in the form of a numpy.array and the index of the time series in a pandas.Period (possibly different for each time series):

[8]:
N = 10  # number of time series
T = 100  # number of timesteps
prediction_length = 24
freq = "1H"
custom_dataset = np.random.normal(size=(N, T))
start = pd.Period("01-01-2019", freq=freq)  # can be different for each time series
/tmp/ipykernel_4086/2007592095.py:6: FutureWarning: 'H' is deprecated and will be removed in a future version, please use 'h' instead.
  start = pd.Period("01-01-2019", freq=freq)  # can be different for each time series

Now, you can split your dataset and bring it in a GluonTS appropriate format with just two lines of code:

[9]:
from gluonts.dataset.common import ListDataset
[10]:
# train dataset: cut the last window of length "prediction_length", add "target" and "start" fields
train_ds = ListDataset(
    [{"target": x, "start": start} for x in custom_dataset[:, :-prediction_length]],
    freq=freq,
)
# test dataset: use the whole dataset, add "target" and "start" fields
test_ds = ListDataset(
    [{"target": x, "start": start} for x in custom_dataset], freq=freq
)
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gluonts/dataset/common.py:255: FutureWarning: 'H' is deprecated and will be removed in a future version, please use 'h' instead.
  ProcessDataEntry(to_offset(freq), one_dim_target, use_timestamp),

Training an existing model (Estimator)#

GluonTS comes with a number of pre-built models. All the user needs to do is configure some hyperparameters. The existing models focus on (but are not limited to) probabilistic forecasting. Probabilistic forecasts are predictions in the form of a probability distribution, rather than simply a single point estimate.

We will begin with GluonTS’s pre-built feedforward neural network estimator, a simple but powerful forecasting model. We will use this model to demonstrate the process of training a model, producing forecasts, and evaluating the results.

GluonTS’s built-in feedforward neural network (SimpleFeedForwardEstimator) accepts an input window of length context_length and predicts the distribution of the values of the subsequent prediction_length values. In GluonTS parlance, the feedforward neural network model is an example of an Estimator. In GluonTS, Estimator objects represent a forecasting model as well as details such as its coefficients, weights, etc.

In general, each estimator (pre-built or custom) is configured by a number of hyperparameters that can be either common (but not binding) among all estimators (e.g., the prediction_length) or specific for the particular estimator (e.g., number of layers for a neural network or the stride in a CNN).

Finally, each estimator is configured by a Trainer, which defines how the model will be trained i.e., the number of epochs, the learning rate, etc.

[11]:
from gluonts.mx import SimpleFeedForwardEstimator, Trainer
[12]:
estimator = SimpleFeedForwardEstimator(
    num_hidden_dimensions=[10],
    prediction_length=dataset.metadata.prediction_length,
    context_length=100,
    trainer=Trainer(ctx="cpu", epochs=5, learning_rate=1e-3, num_batches_per_epoch=100),
)

After specifying our estimator with all the necessary hyperparameters we can train it using our training dataset dataset.train by invoking the train method of the estimator. The training algorithm returns a fitted model (or a Predictor in GluonTS parlance) that can be used to construct forecasts.

[13]:
predictor = estimator.train(dataset.train)
100%|██████████| 100/100 [00:00<00:00, 142.24it/s, epoch=1/5, avg_epoch_loss=5.43]
100%|██████████| 100/100 [00:00<00:00, 154.52it/s, epoch=2/5, avg_epoch_loss=4.88]
100%|██████████| 100/100 [00:00<00:00, 144.78it/s, epoch=3/5, avg_epoch_loss=4.76]
100%|██████████| 100/100 [00:00<00:00, 152.51it/s, epoch=4/5, avg_epoch_loss=4.82]
100%|██████████| 100/100 [00:00<00:00, 150.28it/s, epoch=5/5, avg_epoch_loss=4.6]

Visualize and evaluate forecasts#

With a predictor in hand, we can now predict the last window of the dataset.test and evaluate our model’s performance.

GluonTS comes with the make_evaluation_predictions function that automates the process of prediction and model evaluation. Roughly, this function performs the following steps:

  • Removes the final window of length prediction_length of the dataset.test that we want to predict

  • The estimator uses the remaining data to predict (in the form of sample paths) the “future” window that was just removed

  • The module outputs the forecast sample paths and the dataset.test (as python generator objects)

[14]:
from gluonts.evaluation import make_evaluation_predictions
[15]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test,  # test dataset
    predictor=predictor,  # predictor
    num_samples=100,  # number of sample paths we want for evaluation
)

First, we can convert these generators to lists to ease the subsequent computations.

[16]:
forecasts = list(forecast_it)
tss = list(ts_it)

We can examine the first element of these lists (that corresponds to the first time series of the dataset). Let’s start with the list containing the time series, i.e., tss. We expect the first entry of tss to contain the (target of the) first time series of dataset.test.

[17]:
# first entry of the time series list
ts_entry = tss[0]
[18]:
# first 5 values of the time series (convert from pandas to numpy)
np.array(ts_entry[:5]).reshape(
    -1,
)
[18]:
array([605., 586., 586., 559., 511.], dtype=float32)
[19]:
# first entry of dataset.test
dataset_test_entry = next(iter(dataset.test))
[20]:
# first 5 values
dataset_test_entry["target"][:5]
[20]:
array([605., 586., 586., 559., 511.], dtype=float32)

The entries in the forecast list are a bit more complex. They are objects that contain all the sample paths in the form of numpy.ndarray with dimension (num_samples, prediction_length), the start date of the forecast, the frequency of the time series, etc. We can access all this information by simply invoking the corresponding attribute of the forecast object.

[21]:
# first entry of the forecast list
forecast_entry = forecasts[0]
[22]:
print(f"Number of sample paths: {forecast_entry.num_samples}")
print(f"Dimension of samples: {forecast_entry.samples.shape}")
print(f"Start date of the forecast window: {forecast_entry.start_date}")
print(f"Frequency of the time series: {forecast_entry.freq}")
Number of sample paths: 100
Dimension of samples: (100, 48)
Start date of the forecast window: 1750-01-30 04:00
Frequency of the time series: <Hour>

We can also do calculations to summarize the sample paths, such as computing the mean or a quantile for each of the 48 time steps in the forecast window.

[23]:
print(f"Mean of the future window:\n {forecast_entry.mean}")
print(f"0.5-quantile (median) of the future window:\n {forecast_entry.quantile(0.5)}")
Mean of the future window:
 [691.43665 616.1314  535.5448  535.63245 536.85236 460.21567 450.53052
 503.3305  487.14017 554.0942  577.6539  622.9944  699.02875 855.8484
 905.4359  924.7087  933.27423 897.6598  920.61115 929.56396 845.7369
 804.671   814.56195 730.5156  622.3796  625.06323 601.3767  482.6475
 538.5274  517.2887  579.4265  584.4302  529.604   558.17236 684.3204
 773.47406 779.5009  867.49756 926.01666 916.4293  885.9521  929.0373
 880.40967 845.68463 826.5437  816.805   776.2443  740.25836]
0.5-quantile (median) of the future window:
 [701.4136  624.0384  554.1458  532.92255 533.4006  463.03613 461.16885
 500.52344 486.36057 538.89966 573.3184  619.5366  705.0485  866.58875
 906.6944  928.2669  929.2074  894.6779  911.5794  929.3663  846.1513
 787.9489  804.12085 718.76825 625.4044  629.6601  609.66656 477.6561
 532.4805  526.9441  555.94604 577.8249  530.6848  560.92084 684.0824
 763.2665  764.98157 873.4303  921.3239  933.6672  873.61884 935.4046
 889.3567  849.0873  802.92426 819.0139  775.47064 746.21594]

Forecast objects have a plot method that can summarize the forecast paths as the mean, prediction intervals, etc. The prediction intervals are shaded in different colors as a “fan chart”.

[24]:
plt.plot(ts_entry[-150:].to_timestamp())
forecast_entry.plot(show_label=True)
plt.legend()
[24]:
<matplotlib.legend.Legend at 0x7fc19816d6d0>
../../_images/tutorials_forecasting_quick_start_tutorial_37_1.png

We can also evaluate the quality of our forecasts numerically. In GluonTS, the Evaluator class can compute aggregate performance metrics, as well as metrics per time series (which can be useful for analyzing performance across heterogeneous time series).

[25]:
from gluonts.evaluation import Evaluator
[26]:
evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])
agg_metrics, item_metrics = evaluator(tss, forecasts)
Running evaluation: 414it [00:00, 10725.07it/s]

The aggregate metrics, agg_metrics, aggregate both across time-steps and across time series.

[27]:
print(json.dumps(agg_metrics, indent=4))
{
    "MSE": 19151565.001189325,
    "abs_error": 13362738.066356659,
    "abs_target_sum": 145558863.59960938,
    "abs_target_mean": 7324.822041043146,
    "seasonal_error": 336.9046924038305,
    "MASE": 4.686514358902478,
    "MAPE": 0.2643111116836228,
    "sMAPE": 0.19426367691029672,
    "MSIS": 60.252643883036775,
    "num_masked_target_values": 0.0,
    "QuantileLoss[0.1]": 4317812.271811676,
    "Coverage[0.1]": 0.14356884057971017,
    "QuantileLoss[0.5]": 13362737.883940697,
    "Coverage[0.5]": 0.6398953301127214,
    "QuantileLoss[0.9]": 7580827.01402273,
    "Coverage[0.9]": 0.8897946859903381,
    "RMSE": 4376.250107248136,
    "NRMSE": 0.5974548026869065,
    "ND": 0.09180298427661343,
    "wQuantileLoss[0.1]": 0.02966368495214924,
    "wQuantileLoss[0.5]": 0.09180298302340248,
    "wQuantileLoss[0.9]": 0.05208083401142378,
    "mean_absolute_QuantileLoss": 8420459.0565917,
    "mean_wQuantileLoss": 0.057849167328991834,
    "MAE_Coverage": 0.39659822866344596,
    "OWA": NaN
}

Individual metrics are aggregated only across time-steps.

[28]:
item_metrics.head()
[28]:
item_id forecast_start MSE abs_error abs_target_sum abs_target_mean seasonal_error MASE MAPE sMAPE num_masked_target_values ND MSIS QuantileLoss[0.1] Coverage[0.1] QuantileLoss[0.5] Coverage[0.5] QuantileLoss[0.9] Coverage[0.9]
0 0 1750-01-30 04:00 4351.303385 2603.800781 31644.0 659.250000 42.371302 1.280250 0.083877 0.079127 0.0 0.082284 11.925520 654.108093 0.062500 2603.800751 0.812500 1615.507690 1.000000
1 1 1750-01-30 04:00 281753.479167 22727.175781 124149.0 2586.437500 165.107988 2.867716 0.191402 0.169930 0.0 0.183064 13.736474 10280.554956 0.500000 22727.176514 0.979167 9153.454883 1.000000
2 2 1750-01-30 04:00 23819.419271 5426.807617 65030.0 1354.791667 78.889053 1.433133 0.076811 0.080609 0.0 0.083451 12.555220 2851.053223 0.000000 5426.807983 0.291667 1927.753052 0.937500
3 3 1750-01-30 04:00 193753.166667 15962.206055 235783.0 4912.145833 258.982249 1.284049 0.070260 0.067739 0.0 0.067699 13.592661 7837.278516 0.062500 15962.205566 0.583333 8790.279785 0.979167
4 4 1750-01-30 04:00 140006.822917 14615.171875 131088.0 2731.000000 200.494083 1.518662 0.112001 0.104921 0.0 0.111491 12.345502 4144.776880 0.145833 14615.171265 0.791667 7620.407788 0.979167
[29]:
item_metrics.plot(x="MSIS", y="MASE", kind="scatter")
plt.grid(which="both")
plt.show()
../../_images/tutorials_forecasting_quick_start_tutorial_45_0.png