Quick Start Tutorial¶
The GluonTS toolkit contains components and tools for building time series models using MXNet. The models that are currently included are forecasting models but the components also support other time series use cases, such as classification or anomaly detection.
The toolkit is not intended as a forecasting solution for businesses or end users but it rather targets scientists and engineers who want to tweak algorithms or build and experiment with their own models.
GluonTS contains:
Components for building new models (likelihoods, feature processing pipelines, calendar features etc.)
Data loading and processing
A number of pre-built models
Plotting and evaluation facilities
Artificial and real datasets (only external datasets with blessed license)
In [1]:
%matplotlib inline
import mxnet as mx
from mxnet import gluon
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
Datasets¶
Provided datasets¶
GluonTS comes with a number of publicly available datasets.
In [2]:
from gluonts.dataset.repository.datasets import get_dataset, dataset_recipes
from gluonts.dataset.util import to_pandas
In [3]:
print(f"Available datasets: {list(dataset_recipes.keys())}")
Available datasets: ['constant', 'exchange_rate', 'solar-energy', 'electricity', 'traffic', 'exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki-rolling_nips', 'taxi_30min', 'm3_monthly', 'm3_quarterly', 'm3_yearly', 'm3_other', 'm4_hourly', 'm4_daily', 'm4_weekly', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5']
To download one of the built-in datasets, simply call get_dataset with
one of the above names. GluonTS can re-use the saved dataset so that it
does not need to be downloaded again: simply set regenerate=False
.
In [4]:
dataset = get_dataset("m4_hourly", regenerate=True)
saving time-series into /home/runner/.mxnet/gluon-ts/datasets/m4_hourly/train/data.json
saving time-series into /home/runner/.mxnet/gluon-ts/datasets/m4_hourly/test/data.json
In general, the datasets provided by GluonTS are objects that consists of three main members:
dataset.train
is an iterable collection of data entries used for training. Each entry corresponds to one time seriesdataset.test
is an iterable collection of data entries used for inference. The test dataset is an extended version of the train dataset that contains a window in the end of each time series that was not seen during training. This window has length equal to the recommended prediction length.dataset.metadata
contains metadata of the dataset such as the frequency of the time series, a recommended prediction horizon, associated features, etc.
In [5]:
entry = next(iter(dataset.train))
train_series = to_pandas(entry)
train_series.plot()
plt.grid(which="both")
plt.legend(["train series"], loc="upper left")
plt.show()

In [6]:
entry = next(iter(dataset.test))
test_series = to_pandas(entry)
test_series.plot()
plt.axvline(train_series.index[-1], color='r') # end of train dataset
plt.grid(which="both")
plt.legend(["test series", "end of train series"], loc="upper left")
plt.show()

In [7]:
print(f"Length of forecasting window in test dataset: {len(test_series) - len(train_series)}")
print(f"Recommended prediction horizon: {dataset.metadata.prediction_length}")
print(f"Frequency of the time series: {dataset.metadata.freq}")
Length of forecasting window in test dataset: 48
Recommended prediction horizon: 48
Frequency of the time series: H
Custom datasets¶
At this point, it is important to emphasize that GluonTS does not
require this specific format for a custom dataset that a user may have.
The only requirements for a custom dataset are to be iterable and have a
“target” and a “start” field. To make this more clear, assume the common
case where a dataset is in the form of a numpy.array
and the index
of the time series in a pandas.Timestamp
(possibly different for
each time series):
In [8]:
N = 10 # number of time series
T = 100 # number of timesteps
prediction_length = 24
freq = "1H"
custom_dataset = np.random.normal(size=(N, T))
start = pd.Timestamp("01-01-2019", freq=freq) # can be different for each time series
Now, you can split your dataset and bring it in a GluonTS appropriate format with just two lines of code:
In [9]:
from gluonts.dataset.common import ListDataset
In [10]:
# train dataset: cut the last window of length "prediction_length", add "target" and "start" fields
train_ds = ListDataset(
[{'target': x, 'start': start} for x in custom_dataset[:, :-prediction_length]],
freq=freq
)
# test dataset: use the whole dataset, add "target" and "start" fields
test_ds = ListDataset(
[{'target': x, 'start': start} for x in custom_dataset],
freq=freq
)
Training an existing model (Estimator
)¶
GluonTS comes with a number of pre-built models. All the user needs to do is configure some hyperparameters. The existing models focus on (but are not limited to) probabilistic forecasting. Probabilistic forecasts are predictions in the form of a probability distribution, rather than simply a single point estimate.
We will begin with GulonTS’s pre-built feedforward neural network estimator, a simple but powerful forecasting model. We will use this model to demonstrate the process of training a model, producing forecasts, and evaluating the results.
GluonTS’s built-in feedforward neural network
(SimpleFeedForwardEstimator
) accepts an input window of length
context_length
and predicts the distribution of the values of the
subsequent prediction_length
values. In GluonTS parlance, the
feedforward neural network model is an example of Estimator
. In
GluonTS, Estimator
objects represent a forecasting model as well as
details such as its coefficients, weights, etc.
In general, each estimator (pre-built or custom) is configured by a
number of hyperparameters that can be either common (but not binding)
among all estimators (e.g., the prediction_length
) or specific for
the particular estimator (e.g., number of layers for a neural network or
the stride in a CNN).
Finally, each estimator is configured by a Trainer
, which defines
how the model will be trained i.e., the number of epochs, the learning
rate, etc.
In [11]:
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.mx.trainer import Trainer
In [12]:
estimator = SimpleFeedForwardEstimator(
num_hidden_dimensions=[10],
prediction_length=dataset.metadata.prediction_length,
context_length=100,
freq=dataset.metadata.freq,
trainer=Trainer(
ctx="cpu",
epochs=5,
learning_rate=1e-3,
num_batches_per_epoch=100
)
)
After specifying our estimator with all the necessary hyperparameters we
can train it using our training dataset dataset.train
by invoking
the train
method of the estimator. The training algorithm returns a
fitted model (or a Predictor
in GluonTS parlance) that can be used
to construct forecasts.
In [13]:
predictor = estimator.train(dataset.train)
0%| | 0/100 [00:00<?, ?it/s]
learning rate from ``lr_scheduler`` has been overwritten by ``learning_rate`` in optimizer.
100%|██████████| 100/100 [00:00<00:00, 104.53it/s, epoch=1/5, avg_epoch_loss=5.57]
100%|██████████| 100/100 [00:00<00:00, 115.73it/s, epoch=2/5, avg_epoch_loss=4.89]
100%|██████████| 100/100 [00:00<00:00, 112.88it/s, epoch=3/5, avg_epoch_loss=4.83]
100%|██████████| 100/100 [00:00<00:00, 112.68it/s, epoch=4/5, avg_epoch_loss=4.82]
100%|██████████| 100/100 [00:00<00:00, 114.48it/s, epoch=5/5, avg_epoch_loss=4.63]
Visualize and evaluate forecasts¶
With a predictor in hand, we can now predict the last window of the
dataset.test
and evaluate our model’s performance.
GluonTS comes with the make_evaluation_predictions
function that
automates the process of prediction and model evaluation. Roughly, this
function performs the following steps:
Removes the final window of length
prediction_length
of thedataset.test
that we want to predictThe estimator uses the remaining data to predict (in the form of sample paths) the “future” window that was just removed
The module outputs the forecast sample paths and the
dataset.test
(as python generator objects)
In [14]:
from gluonts.evaluation import make_evaluation_predictions
In [15]:
forecast_it, ts_it = make_evaluation_predictions(
dataset=dataset.test, # test dataset
predictor=predictor, # predictor
num_samples=100, # number of sample paths we want for evaluation
)
First, we can convert these generators to lists to ease the subsequent computations.
In [16]:
forecasts = list(forecast_it)
tss = list(ts_it)
We can examine the first element of these lists (that corresponds to the
first time series of the dataset). Let’s start with the list containing
the time series, i.e., tss
. We expect the first entry of tss
to
contain the (target of the) first time series of dataset.test
.
In [17]:
# first entry of the time series list
ts_entry = tss[0]
In [18]:
# first 5 values of the time series (convert from pandas to numpy)
np.array(ts_entry[:5]).reshape(-1,)
Out[18]:
array([605., 586., 586., 559., 511.], dtype=float32)
In [19]:
# first entry of dataset.test
dataset_test_entry = next(iter(dataset.test))
In [20]:
# first 5 values
dataset_test_entry['target'][:5]
Out[20]:
array([605., 586., 586., 559., 511.], dtype=float32)
The entries in the forecast
list are a bit more complex. They are
objects that contain all the sample paths in the form of
numpy.ndarray
with dimension (num_samples, prediction_length)
,
the start date of the forecast, the frequency of the time series, etc.
We can access all these information by simply invoking the corresponding
attribute of the forecast object.
In [21]:
# first entry of the forecast list
forecast_entry = forecasts[0]
In [22]:
print(f"Number of sample paths: {forecast_entry.num_samples}")
print(f"Dimension of samples: {forecast_entry.samples.shape}")
print(f"Start date of the forecast window: {forecast_entry.start_date}")
print(f"Frequency of the time series: {forecast_entry.freq}")
Number of sample paths: 100
Dimension of samples: (100, 48)
Start date of the forecast window: 1750-01-30 04:00:00
Frequency of the time series: H
We can also do calculations to summarize the sample paths, such computing the mean or a quantile for each of the 48 time steps in the forecast window.
In [23]:
print(f"Mean of the future window:\n {forecast_entry.mean}")
print(f"0.5-quantile (median) of the future window:\n {forecast_entry.quantile(0.5)}")
Mean of the future window:
[657.432 567.8759 509.03964 462.90637 533.49255 474.97253 449.83563
489.14816 520.5916 556.1061 586.52 689.13794 743.9377 755.9347
846.90674 861.06946 884.1995 844.0505 858.2082 851.92267 804.64655
809.32513 808.19763 698.4341 628.41486 571.8308 545.1014 489.84125
533.61194 497.71265 527.5373 463.3403 488.81894 553.01324 617.9387
693.801 704.50995 802.0996 779.80225 842.65845 859.039 937.10315
886.0708 741.0214 912.11237 926.66284 805.56836 730.22266]
0.5-quantile (median) of the future window:
[669.59045 575.02576 531.0032 464.3974 525.8405 476.6602 458.5382
489.29544 519.2877 542.18964 585.2373 681.03265 763.4545 773.0337
843.76996 875.1853 872.0016 834.8163 849.573 841.62476 808.5004
801.8551 802.89844 679.6511 631.09546 573.2768 555.4188 487.28363
518.59454 508.8435 502.41446 456.2519 490.46066 551.9699 612.1071
685.0114 684.63007 789.92535 762.7166 859.23254 848.2905 952.18243
890.4299 755.653 881.6748 933.1342 800.9058 744.48395]
Forecast
objects have a plot
method that can summarize the
forecast paths as the mean, prediction intervals, etc. The prediction
intervals are shaded in different colors as a “fan chart”.
In [24]:
def plot_prob_forecasts(ts_entry, forecast_entry):
plot_length = 150
prediction_intervals = (50.0, 90.0)
legend = ["observations", "median prediction"] + [f"{k}% prediction interval" for k in prediction_intervals][::-1]
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
ts_entry[-plot_length:].plot(ax=ax) # plot the time series
forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')
plt.grid(which="both")
plt.legend(legend, loc="upper left")
plt.show()
In [25]:
plot_prob_forecasts(ts_entry, forecast_entry)

We can also evaluate the quality of our forecasts numerically. In
GluonTS, the Evaluator
class can compute aggregate performance
metrics, as well as metrics per time series (which can be useful for
analyzing performance across heterogeneous time series).
In [26]:
from gluonts.evaluation import Evaluator
In [27]:
evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])
agg_metrics, item_metrics = evaluator(iter(tss), iter(forecasts), num_series=len(dataset.test))
Running evaluation: 100%|██████████| 414/414 [00:00<00:00, 19917.66it/s]
Aggregate metrics aggregate both across time-steps and across time series.
In [28]:
print(json.dumps(agg_metrics, indent=4))
{
"MSE": 9656613.653655292,
"abs_error": 8926028.686189651,
"abs_target_sum": 145558863.59960938,
"abs_target_mean": 7324.822041043146,
"seasonal_error": 336.9046924038305,
"MASE": 3.3882664642129723,
"MAPE": 0.24571015283273423,
"sMAPE": 0.1852512300890329,
"OWA": NaN,
"MSIS": 63.97941756142848,
"QuantileLoss[0.1]": 5090296.274248887,
"Coverage[0.1]": 0.09686996779388084,
"QuantileLoss[0.5]": 8926028.815477371,
"Coverage[0.5]": 0.46804549114331717,
"QuantileLoss[0.9]": 6892946.789128493,
"Coverage[0.9]": 0.8796296296296295,
"RMSE": 3107.509236294446,
"NRMSE": 0.4242436497272087,
"ND": 0.06132246752587044,
"wQuantileLoss[0.1]": 0.034970706340843864,
"wQuantileLoss[0.5]": 0.0613224684140865,
"wQuantileLoss[0.9]": 0.047355046739640735,
"mean_absolute_QuantileLoss": 6969757.292951584,
"mean_wQuantileLoss": 0.047882740498190364,
"MAE_Coverage": 0.01848497047772416
}
Individual metrics are aggregated only across time-steps.
In [29]:
item_metrics.head()
Out[29]:
item_id | MSE | abs_error | abs_target_sum | abs_target_mean | seasonal_error | MASE | MAPE | sMAPE | OWA | MSIS | QuantileLoss[0.1] | Coverage[0.1] | QuantileLoss[0.5] | Coverage[0.5] | QuantileLoss[0.9] | Coverage[0.9] | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 2997.847005 | 1756.685791 | 31644.0 | 659.250000 | 42.371302 | 0.863736 | 0.054706 | 0.052378 | NaN | 13.179705 | 972.442920 | 0.020833 | 1756.685730 | 0.666667 | 1419.121350 | 1.000000 |
1 | 1.0 | 155417.968750 | 16964.033203 | 124149.0 | 2586.437500 | 165.107988 | 2.140522 | 0.142941 | 0.131103 | NaN | 13.879598 | 3420.709351 | 0.166667 | 16964.032104 | 0.979167 | 8374.049316 | 1.000000 |
2 | 2.0 | 29935.947917 | 6447.917969 | 65030.0 | 1354.791667 | 78.889053 | 1.702792 | 0.089015 | 0.095048 | NaN | 13.872233 | 3473.809839 | 0.000000 | 6447.917603 | 0.145833 | 1539.831445 | 0.833333 |
3 | 3.0 | 292266.479167 | 19561.388672 | 235783.0 | 4912.145833 | 258.982249 | 1.573579 | 0.080606 | 0.080400 | NaN | 14.879645 | 10134.517090 | 0.041667 | 19561.388916 | 0.416667 | 8175.906152 | 0.979167 |
4 | 4.0 | 100070.489583 | 9540.142578 | 131088.0 | 2731.000000 | 200.494083 | 0.991316 | 0.065609 | 0.061433 | NaN | 13.176000 | 4892.513110 | 0.062500 | 9540.142822 | 0.812500 | 7150.268848 | 1.000000 |
In [30]:
item_metrics.plot(x='MSIS', y='MASE', kind='scatter')
plt.grid(which="both")
plt.show()
