Source code for gluonts.nursery.few_shot_prediction.src.meta.datasets.datasets

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import itertools
from sklearn.model_selection import train_test_split
from math import isclose
import numpy as np


FREQ_M1_M3 = ["monthly", "quarterly", "yearly"]
FREQ_M4 = ["daily", "weekly", "monthly", "quarterly", "yearly"]

CATEGORIES_M1 = [
    "micro1",
    "micro2",
    "micro3",
    "indust",
    "macro1",
    "macro2",
    "demogr",
]

CATEGORIES_M3_M4 = [
    "micro",
    "industry",
    "macro",
    "finance",
    "demographic",
]


M1 = [
    "m1_" + freq + "_" + cat
    for freq, cat in itertools.product(FREQ_M1_M3, CATEGORIES_M1)
]
M3 = [
    "m3_" + freq + "_" + cat
    for freq, cat in itertools.product(FREQ_M1_M3, CATEGORIES_M3_M4)
]
M4 = [
    "m4_" + freq + "_" + cat
    for freq, cat in itertools.product(FREQ_M4, CATEGORIES_M3_M4)
]

DATASETS_SINGLE = [
    # "exchange_rate",
    "exchange_rate_nips",
    # "solar-energy",
    "solar_nips",
    # "electricity",
    "electricity_nips",
    # "traffic",
    "traffic_nips",
    "wiki-rolling_nips",  # is a very challenging dataset
    "taxi_30min",
    "kaggle_web_traffic_without_missing",
    "kaggle_web_traffic_weekly",
    "nn5_daily_without_missing",  # http://www.neural-forecasting-competition.com/NN5/  (ATM cash withdrawals)
    "nn5_weekly",
    "tourism_monthly",  # no need to split, https://robjhyndman.com/papers/forecompijf.pdf
    "tourism_quarterly",
    "tourism_yearly",
    "cif_2016",  # https://irafm.osu.cz/cif/main.php?c=Static&page=download ,
    # https://zenodo.org/record/3904073#.YeE6ZMYo8UE , 72 time series, 24 real, 48 artificial, all monthly, banking domain
    "london_smart_meters_without_missing",
    "wind_farms_without_missing",
    # "car_parts_without_missing",  # intermittent time series
    # "dominick",  # This dataset contains 115704 weekly time series representing the profit of individual stock keeping units from a retailer.
    # https://www.chicagobooth.edu/research/kilts/datasets/dominicks
    # dominick has intermittency and other problems
    # TODO: Could this be clustered? (as step 2, it is already from one category only)
    "fred_md",  # https://s3.amazonaws.com/real.stlouisfed.org/wp/2015/2015-012.pdf , too mixed up?
    "pedestrian_counts",  # hourly pedestrian counts captured from 66 sensors in Melbourne city starting from May 2009
    "hospital",  # https://zenodo.org/record/4656014#.YeFBT8Yo8UE
    # This dataset contains 767 monthly time series that represent the patient counts related to medical products from January 2000 to December 2006. It was extracted from R expsmooth package.
    "covid_deaths",  # https://zenodo.org/record/4656009#.YeFBq8Yo8UE
    # This dataset contains 266 daily time series that represent the COVID-19 deaths in a set of countries and states from 22/01/2020 to 20/08/2020. It was extracted from the Johns Hopkins repository.
    "kdd_cup_2018_without_missing",  # https://zenodo.org/record/4656756#.YeFB7cYo8UE
    # This dataset was used in the KDD Cup 2018 forecasting competition.
    # It contains long hourly time series representing the air quality levels
    # in 59 stations in 2 cities: Beijing (35 stations) and London (24 stations) from 01/01/2017 to 31/03/2018.
    # TODO: could be clustered by cities
    "weather",  # https://zenodo.org/record/4654822#.YeFCXcYo8UE
    # 3010 daily time series representing the variations of four weather variables: rain, mintemp, maxtemp and solar radiation, measured at the weather stations in Australia
    # TODO: could be clustered by variable (as step 2, it is already from one category only)
    # "m5",  # intermittent time series, see https://mofc.unic.ac.cy/m5-competition/, Walmart product data, could be splitted by product or store
]

DATASETS_FULL = DATASETS_SINGLE + M1 + M3 + M4

# The excluded datasets contain only time series that are shorter than 3 predictions lengths.
# In the current setting we filter these time series.
DATASETS_FILTERED = [
    ds
    for ds in DATASETS_FULL
    if ds not in ["m1_yearly_macro2", "m3_yearly_micro", "m3_yearly_macro"]
]


[docs]def sample_datasets( n_folds: int, train_size: float = 0.7, val_size: float = 0.2, test_size: float = 0.1, seed=42, ): assert isclose( train_size + val_size + test_size, 1.0 ), "sizes need to add up to 1" random_state = np.random.RandomState(seed) folds = [] for _ in range(n_folds): train_split, test_split = train_test_split( DATASETS_FILTERED, train_size=train_size, random_state=random_state.randint(low=0, high=10000), ) rest = 1 - train_size val_split, test_split = train_test_split( test_split, test_size=test_size / rest, random_state=random_state.randint(low=0, high=10000), ) folds.append((train_split, val_split, test_split)) assert not any( [ set(train_split) & (set(val_split)), set(train_split) & set(test_split), set(val_split) & set(test_split), ] ), "Splits should not intersect!" return folds