Quick Start Tutorial

The GluonTS toolkit contains components and tools for building time series models using MXNet. The models that are currently included are forecasting models but the components also support other time series use cases, such as classification or anomaly detection.

The toolkit is not intended as a forecasting solution for businesses or end users but it rather targets scientists and engineers who want to tweak algorithms or build and experiment with their own models.

GluonTS contains:

  • Components for building new models (likelihoods, feature processing pipelines, calendar features etc.)

  • Data loading and processing

  • A number of pre-built models

  • Plotting and evaluation facilities

  • Artificial and real datasets (only external datasets with blessed license)

In [1]:
%matplotlib inline
import mxnet as mx
from mxnet import gluon
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json


Provided datasets

GluonTS comes with a number of publicly available datasets.

In [2]:
from gluonts.dataset.repository.datasets import get_dataset, dataset_recipes
from gluonts.dataset.util import to_pandas
In [3]:
print(f"Available datasets: {list(dataset_recipes.keys())}")
Available datasets: ['constant', 'exchange_rate', 'solar-energy', 'electricity', 'traffic', 'exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki-rolling_nips', 'taxi_30min', 'kaggle_web_traffic_with_missing', 'kaggle_web_traffic_without_missing', 'kaggle_web_traffic_weekly', 'm1_yearly', 'm1_quarterly', 'm1_monthly', 'nn5_daily_with_missing', 'nn5_daily_without_missing', 'nn5_weekly', 'tourism_monthly', 'tourism_quarterly', 'tourism_yearly', 'm3_monthly', 'm3_quarterly', 'm3_yearly', 'm3_other', 'm4_hourly', 'm4_daily', 'm4_weekly', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5']

To download one of the built-in datasets, simply call get_dataset with one of the above names. GluonTS can re-use the saved dataset so that it does not need to be downloaded again: simply set regenerate=False.

In [4]:
dataset = get_dataset("m4_hourly", regenerate=True)
saving time-series into /home/runner/.mxnet/gluon-ts/datasets/m4_hourly/train/data.json
saving time-series into /home/runner/.mxnet/gluon-ts/datasets/m4_hourly/test/data.json

In general, the datasets provided by GluonTS are objects that consists of three main members:

  • dataset.train is an iterable collection of data entries used for training. Each entry corresponds to one time series

  • dataset.test is an iterable collection of data entries used for inference. The test dataset is an extended version of the train dataset that contains a window in the end of each time series that was not seen during training. This window has length equal to the recommended prediction length.

  • dataset.metadata contains metadata of the dataset such as the frequency of the time series, a recommended prediction horizon, associated features, etc.

In [5]:
entry = next(iter(dataset.train))
train_series = to_pandas(entry)
plt.legend(["train series"], loc="upper left")
/home/runner/work/gluon-ts/gluon-ts/src/gluonts/dataset/common.py:323: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.
  timestamp = pd.Timestamp(string, freq=freq)
/home/runner/work/gluon-ts/gluon-ts/src/gluonts/dataset/common.py:326: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version
  if isinstance(timestamp.freq, Tick):
/home/runner/work/gluon-ts/gluon-ts/src/gluonts/dataset/common.py:328: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version
  timestamp.floor(timestamp.freq), timestamp.freq
/home/runner/work/gluon-ts/gluon-ts/src/gluonts/dataset/common.py:328: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.
  timestamp.floor(timestamp.freq), timestamp.freq
/home/runner/work/gluon-ts/gluon-ts/src/gluonts/dataset/util.py:128: FutureWarning: Timestamp.freqstr is deprecated and will be removed in a future version.
  freq = start.freqstr
In [6]:
entry = next(iter(dataset.test))
test_series = to_pandas(entry)
plt.axvline(train_series.index[-1], color='r') # end of train dataset
plt.legend(["test series", "end of train series"], loc="upper left")
/home/runner/work/gluon-ts/gluon-ts/src/gluonts/dataset/util.py:128: FutureWarning: Timestamp.freqstr is deprecated and will be removed in a future version.
  freq = start.freqstr
In [7]:
print(f"Length of forecasting window in test dataset: {len(test_series) - len(train_series)}")
print(f"Recommended prediction horizon: {dataset.metadata.prediction_length}")
print(f"Frequency of the time series: {dataset.metadata.freq}")
Length of forecasting window in test dataset: 48
Recommended prediction horizon: 48
Frequency of the time series: H

Custom datasets

At this point, it is important to emphasize that GluonTS does not require this specific format for a custom dataset that a user may have. The only requirements for a custom dataset are to be iterable and have a “target” and a “start” field. To make this more clear, assume the common case where a dataset is in the form of a numpy.array and the index of the time series in a pandas.Timestamp (possibly different for each time series):

In [8]:
N = 10  # number of time series
T = 100  # number of timesteps
prediction_length = 24
freq = "1H"
custom_dataset = np.random.normal(size=(N, T))
start = pd.Timestamp("01-01-2019", freq=freq)  # can be different for each time series
/opt/hostedtoolcache/Python/3.7.10/x64/lib/python3.7/site-packages/ipykernel_launcher.py:6: FutureWarning: The 'freq' argument in Timestamp is deprecated and will be removed in a future version.

Now, you can split your dataset and bring it in a GluonTS appropriate format with just two lines of code:

In [9]:
from gluonts.dataset.common import ListDataset
In [10]:
# train dataset: cut the last window of length "prediction_length", add "target" and "start" fields
train_ds = ListDataset(
    [{'target': x, 'start': start} for x in custom_dataset[:, :-prediction_length]],
# test dataset: use the whole dataset, add "target" and "start" fields
test_ds = ListDataset(
    [{'target': x, 'start': start} for x in custom_dataset],

Training an existing model (Estimator)

GluonTS comes with a number of pre-built models. All the user needs to do is configure some hyperparameters. The existing models focus on (but are not limited to) probabilistic forecasting. Probabilistic forecasts are predictions in the form of a probability distribution, rather than simply a single point estimate.

We will begin with GulonTS’s pre-built feedforward neural network estimator, a simple but powerful forecasting model. We will use this model to demonstrate the process of training a model, producing forecasts, and evaluating the results.

GluonTS’s built-in feedforward neural network (SimpleFeedForwardEstimator) accepts an input window of length context_length and predicts the distribution of the values of the subsequent prediction_length values. In GluonTS parlance, the feedforward neural network model is an example of Estimator. In GluonTS, Estimator objects represent a forecasting model as well as details such as its coefficients, weights, etc.

In general, each estimator (pre-built or custom) is configured by a number of hyperparameters that can be either common (but not binding) among all estimators (e.g., the prediction_length) or specific for the particular estimator (e.g., number of layers for a neural network or the stride in a CNN).

Finally, each estimator is configured by a Trainer, which defines how the model will be trained i.e., the number of epochs, the learning rate, etc.

In [11]:
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.mx import Trainer
In [12]:
estimator = SimpleFeedForwardEstimator(

After specifying our estimator with all the necessary hyperparameters we can train it using our training dataset dataset.train by invoking the train method of the estimator. The training algorithm returns a fitted model (or a Predictor in GluonTS parlance) that can be used to construct forecasts.

In [13]:
predictor = estimator.train(dataset.train)
  0%|          | 0/100 [00:00<?, ?it/s]/home/runner/work/gluon-ts/gluon-ts/src/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version
  return _shift_timestamp_helper(ts, ts.freq, offset)
/home/runner/work/gluon-ts/gluon-ts/src/gluonts/transform/split.py:36: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version
  return _shift_timestamp_helper(ts, ts.freq, offset)
100%|██████████| 100/100 [00:01<00:00, 71.96it/s, epoch=1/5, avg_epoch_loss=5.53]
100%|██████████| 100/100 [00:01<00:00, 74.36it/s, epoch=2/5, avg_epoch_loss=4.86]
100%|██████████| 100/100 [00:01<00:00, 78.44it/s, epoch=3/5, avg_epoch_loss=4.63]
100%|██████████| 100/100 [00:01<00:00, 76.69it/s, epoch=4/5, avg_epoch_loss=4.73]
100%|██████████| 100/100 [00:01<00:00, 79.78it/s, epoch=5/5, avg_epoch_loss=4.8]

Visualize and evaluate forecasts

With a predictor in hand, we can now predict the last window of the dataset.test and evaluate our model’s performance.

GluonTS comes with the make_evaluation_predictions function that automates the process of prediction and model evaluation. Roughly, this function performs the following steps:

  • Removes the final window of length prediction_length of the dataset.test that we want to predict

  • The estimator uses the remaining data to predict (in the form of sample paths) the “future” window that was just removed

  • The module outputs the forecast sample paths and the dataset.test (as python generator objects)

In [14]:
from gluonts.evaluation import make_evaluation_predictions
In [15]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test,  # test dataset
    predictor=predictor,  # predictor
    num_samples=100,  # number of sample paths we want for evaluation

First, we can convert these generators to lists to ease the subsequent computations.

In [16]:
forecasts = list(forecast_it)
tss = list(ts_it)

We can examine the first element of these lists (that corresponds to the first time series of the dataset). Let’s start with the list containing the time series, i.e., tss. We expect the first entry of tss to contain the (target of the) first time series of dataset.test.

In [17]:
# first entry of the time series list
ts_entry = tss[0]
In [18]:
# first 5 values of the time series (convert from pandas to numpy)
array([605., 586., 586., 559., 511.], dtype=float32)
In [19]:
# first entry of dataset.test
dataset_test_entry = next(iter(dataset.test))
In [20]:
# first 5 values
array([605., 586., 586., 559., 511.], dtype=float32)

The entries in the forecast list are a bit more complex. They are objects that contain all the sample paths in the form of numpy.ndarray with dimension (num_samples, prediction_length), the start date of the forecast, the frequency of the time series, etc. We can access all these information by simply invoking the corresponding attribute of the forecast object.

In [21]:
# first entry of the forecast list
forecast_entry = forecasts[0]
In [22]:
print(f"Number of sample paths: {forecast_entry.num_samples}")
print(f"Dimension of samples: {forecast_entry.samples.shape}")
print(f"Start date of the forecast window: {forecast_entry.start_date}")
print(f"Frequency of the time series: {forecast_entry.freq}")
Number of sample paths: 100
Dimension of samples: (100, 48)
Start date of the forecast window: 1750-01-30 04:00:00
Frequency of the time series: H

We can also do calculations to summarize the sample paths, such computing the mean or a quantile for each of the 48 time steps in the forecast window.

In [23]:
print(f"Mean of the future window:\n {forecast_entry.mean}")
print(f"0.5-quantile (median) of the future window:\n {forecast_entry.quantile(0.5)}")
Mean of the future window:
 [ 672.8497   613.29663  408.92044  408.7472   473.79575  478.10278
  442.7516   480.20114  574.4379   552.1841   654.4995   686.7996
  812.1526   809.09607  893.7823  1061.1001   829.9734   803.3656
  815.19025  837.9185   855.30884  884.32404  746.14685  607.4099
  585.4521   576.35065  609.587    415.09277  360.20532  464.05334
  375.74634  423.8973   464.73593  470.3606   575.05975  629.2411
  735.3183   813.8365   792.67334  730.78406  814.73224  814.33545
  967.03143  868.4563   832.43475  792.10315  678.98413  714.4641 ]
0.5-quantile (median) of the future window:
 [ 674.39734  597.1892   410.16452  416.51376  440.16696  497.52975
  462.16727  497.64038  572.7849   543.5753   665.4903   679.0552
  833.44586  804.2619   881.7074  1090.3456   832.3505   813.21094
  819.67535  837.524    872.3092   836.6189   762.95074  612.43787
  592.9262   603.5703   601.07623  421.95715  381.2906   468.84097
  378.67023  415.2866   450.44235  473.39133  557.0327   636.6891
  722.41095  830.7027   782.19275  746.50397  819.3159   830.87897
  967.76385  849.4127   813.921    811.2827   691.1297   700.58655]

Forecast objects have a plot method that can summarize the forecast paths as the mean, prediction intervals, etc. The prediction intervals are shaded in different colors as a “fan chart”.

In [24]:
def plot_prob_forecasts(ts_entry, forecast_entry):
    plot_length = 150
    prediction_intervals = (50.0, 90.0)
    legend = ["observations", "median prediction"] + [f"{k}% prediction interval" for k in prediction_intervals][::-1]

    fig, ax = plt.subplots(1, 1, figsize=(10, 7))
    ts_entry[-plot_length:].plot(ax=ax)  # plot the time series
    forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')
    plt.legend(legend, loc="upper left")
In [25]:
plot_prob_forecasts(ts_entry, forecast_entry)

We can also evaluate the quality of our forecasts numerically. In GluonTS, the Evaluator class can compute aggregate performance metrics, as well as metrics per time series (which can be useful for analyzing performance across heterogeneous time series).

In [26]:
from gluonts.evaluation import Evaluator
In [27]:
evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])
agg_metrics, item_metrics = evaluator(iter(tss), iter(forecasts), num_series=len(dataset.test))
Running evaluation: 100%|██████████| 414/414 [00:00<00:00, 13984.17it/s]/home/runner/work/gluon-ts/gluon-ts/src/gluonts/evaluation/_base.py:305: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version
  date_before_forecast = forecast.index[0] - forecast.index[0].freq
/home/runner/work/gluon-ts/gluon-ts/src/gluonts/evaluation/_base.py:305: FutureWarning: Timestamp.freq is deprecated and will be removed in a future version
  date_before_forecast = forecast.index[0] - forecast.index[0].freq

Aggregate metrics aggregate both across time-steps and across time series.

In [28]:
print(json.dumps(agg_metrics, indent=4))
    "MSE": 14260369.53530052,
    "abs_error": 11468353.000846863,
    "abs_target_sum": 145558863.59960938,
    "abs_target_mean": 7324.822041043147,
    "seasonal_error": 336.9046924038302,
    "MASE": 5.616673211069731,
    "MAPE": 0.2843449467886089,
    "sMAPE": 0.2207723625136266,
    "MSIS": 71.38723225357148,
    "QuantileLoss[0.1]": 6737793.847027113,
    "Coverage[0.1]": 0.08489331723027377,
    "QuantileLoss[0.5]": 11468352.941949368,
    "Coverage[0.5]": 0.4316123188405798,
    "QuantileLoss[0.9]": 7381112.835865307,
    "Coverage[0.9]": 0.8621678743961353,
    "RMSE": 3776.290446364066,
    "NRMSE": 0.5155470570075823,
    "ND": 0.07878842083016674,
    "wQuantileLoss[0.1]": 0.04628913472120014,
    "wQuantileLoss[0.5]": 0.07878842042553666,
    "wQuantileLoss[0.9]": 0.05070878305404079,
    "mean_absolute_QuantileLoss": 8529086.541613929,
    "mean_wQuantileLoss": 0.05859544606692587,
    "MAE_Coverage": 0.04044216317767038,
    "OWA": NaN

Individual metrics are aggregated only across time-steps.

In [29]:
item_id MSE abs_error abs_target_sum abs_target_mean seasonal_error MASE MAPE sMAPE MSIS QuantileLoss[0.1] Coverage[0.1] QuantileLoss[0.5] Coverage[0.5] QuantileLoss[0.9] Coverage[0.9]
0 0.0 5015.815755 2708.087402 31644.0 659.250000 42.371302 1.331526 0.087151 0.087506 15.535559 1390.559863 0.000000 2708.087372 0.562500 1532.351172 1.000000
1 1.0 163256.854167 16005.593750 124149.0 2586.437500 165.107988 2.019587 0.129468 0.119604 16.714720 4307.569458 0.062500 16005.594116 0.895833 8736.129932 1.000000
2 2.0 54241.260417 8714.230469 65030.0 1354.791667 78.889053 2.301288 0.125179 0.136120 15.340832 4233.171997 0.000000 8714.231445 0.187500 2988.941272 0.729167
3 3.0 351751.333333 24109.828125 235783.0 4912.145833 258.982249 1.939469 0.102611 0.106455 16.602951 11793.812988 0.000000 24109.827148 0.416667 8335.885547 0.937500
4 4.0 114743.125000 11790.869141 131088.0 2731.000000 200.494083 1.225189 0.091969 0.090235 16.401613 6137.071960 0.020833 11790.868408 0.645833 7753.793896 1.000000
In [30]:
item_metrics.plot(x='MSIS', y='MASE', kind='scatter')