Source code for gluonts.dataset.artificial.generate_synthetic

# Copyright 2018, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# or in the "license" file accompanying this file. This file is distributed
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import csv
import json
import os
from typing import List, TextIO

import holidays
import pandas as pd

from gluonts.dataset.artificial._base import (
from gluonts.dataset.field_names import FieldName

[docs]def write_csv_row( time_series: List, freq: str, csv_file: TextIO, is_missing: bool, num_missing: int, ) -> None: csv_writer = csv.writer(csv_file) # convert to right date where MON == 0, ..., SUN == 6 week_dict = { 0: "MON", 1: "TUE", 2: "WED", 3: "THU", 4: "FRI", 5: "SAT", 6: "SUN", } for i in range(len(time_series)): data = time_series[i] timestamp = pd.Timestamp(data[FieldName.START]) freq_week_start = freq if freq_week_start == "W": freq_week_start = f"W-{week_dict[timestamp.weekday()]}" timestamp = pd.Timestamp(data[FieldName.START], freq=freq_week_start) item_id = int(data[FieldName.ITEM_ID]) for j, target in enumerate(data[FieldName.TARGET]): # Using convention that there are no missing values before the start date if is_missing and j != 0 and j % num_missing == 0: timestamp += 1 continue # Skip every 4th entry else: timestamp_row = timestamp if freq in ["W", "D", "M"]: timestamp_row = row = [item_id, timestamp_row, target] # Check if related time series is present if FieldName.FEAT_DYNAMIC_REAL in data.keys(): for feat_dynamic_real in data[FieldName.FEAT_DYNAMIC_REAL]: row.append(feat_dynamic_real[j]) csv_writer.writerow(row) timestamp += 1
[docs]def generate_sf2( filename: str, time_series: List, is_missing: bool, num_missing: int ) -> None: # This function generates the test and train json files which will be converted to csv format if not os.path.exists(os.path.dirname(filename)): os.makedirs(os.path.dirname(filename)) with open(filename, "w") as json_file: for ts in time_series: if is_missing: target = [] # type: List # For Forecast don't output feat_static_cat and feat_static_real for j, val in enumerate(ts[FieldName.TARGET]): # only add ones that are not missing if j != 0 and j % num_missing == 0: target.append(None) else: target.append(val) ts[FieldName.TARGET] = target ts.pop(FieldName.FEAT_STATIC_CAT, None) ts.pop(FieldName.FEAT_STATIC_REAL, None) # Chop features in training set if ( FieldName.FEAT_DYNAMIC_REAL in ts.keys() and "train" in filename ): # TODO: Fix for missing values for i, feat_dynamic_real in enumerate( ts[FieldName.FEAT_DYNAMIC_REAL] ): ts[FieldName.FEAT_DYNAMIC_REAL][i] = feat_dynamic_real[ : len(ts[FieldName.TARGET]) ] json.dump(ts, json_file) json_file.write("\n")
[docs]def generate_sf2s_and_csv( file_path: str, folder_name: str, artificial_dataset: ArtificialDataset, is_missing: bool = False, num_missing: int = 4, ) -> None: file_path += f"{folder_name}" if not os.path.exists(os.path.dirname(file_path)): os.makedirs(os.path.dirname(file_path)) freq = artificial_dataset.metadata.freq train_set = artificial_dataset.train generate_sf2(file_path + "train.json", train_set, is_missing, num_missing) test_set = artificial_dataset.test generate_sf2(file_path + "test.json", test_set, is_missing, num_missing) with open(file_path + "input_to_forecast.csv", "w") as csv_file: # Test set has training set with the additional values to predict write_csv_row(test_set, freq, csv_file, is_missing, num_missing)
if __name__ == "__main__": num_timeseries = 1 file_path = "../../../datasets/synthetic/" generate_sf2s_and_csv(file_path, "constant/", ConstantDataset()) generate_sf2s_and_csv( file_path, "constant_missing/", ConstantDataset(), is_missing=True ) generate_sf2s_and_csv( file_path, "constant_random/", ConstantDataset(is_random_constant=True) ) generate_sf2s_and_csv( file_path, "constant_one_ts/", ConstantDataset( num_timeseries=num_timeseries, is_random_constant=True ), ) generate_sf2s_and_csv( file_path, "constant_diff_scales/", ConstantDataset(is_different_scales=True), ) generate_sf2s_and_csv( file_path, "constant_noise/", ConstantDataset(is_noise=True) ) generate_sf2s_and_csv( file_path, "constant_linear_trend/", ConstantDataset(is_trend=True) ) generate_sf2s_and_csv( file_path, "constant_linear_trend_noise/", ConstantDataset(is_noise=True, is_trend=True), ) generate_sf2s_and_csv( file_path, "constant_noise_long/", ConstantDataset(is_noise=True, is_long=True), ) generate_sf2s_and_csv( file_path, "constant_noise_short/", ConstantDataset(is_noise=True, is_short=True), ) generate_sf2s_and_csv( file_path, "constant_diff_scales_noise/", ConstantDataset(is_noise=True, is_different_scales=True), ) generate_sf2s_and_csv( file_path, "constant_zeros_and_nans/", ConstantDataset(is_nan=True) ) generate_sf2s_and_csv( # Requires is_random_constant to be set to True file_path, "constant_piecewise/", ConstantDataset(is_piecewise=True, is_random_constant=True), ) generate_sf2s_and_csv( file_path, "complex_seasonal_noise_scale/", ComplexSeasonalTimeSeries() ) generate_sf2s_and_csv( file_path, "complex_seasonal_noise/", ComplexSeasonalTimeSeries(is_scale=False), ) generate_sf2s_and_csv( file_path, "complex_seasonal/", ComplexSeasonalTimeSeries(is_scale=False, is_noise=False), ) generate_sf2s_and_csv( file_path, "complex_seasonal_missing/", ComplexSeasonalTimeSeries(proportion_missing_values=0.8), ) generate_sf2s_and_csv( file_path, "constant_missing_middle/", ConstantDataset(num_steps=500, num_missing_middle=100), ) generate_sf2s_and_csv( file_path, "complex_seasonal_random_start_dates_weekly/", ComplexSeasonalTimeSeries( freq_str="W", percentage_unique_timestamps=1, is_out_of_bounds_date=True, ), ) generate_sf2s_and_csv( file_path, "constant_promotions/", ConstantDataset( is_promotions=True, freq="M", start="2015-11-30", num_timeseries=100, num_steps=50, ), ) generate_sf2s_and_csv( file_path, "constant_holidays/", ConstantDataset( start="2017-07-01", freq="D", holidays=list(holidays.UnitedStates(years=[2017, 2018]).keys()), num_steps=365, ), )