Source code for gluonts.shell.sagemaker.params

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.


from itertools import count
from typing import Any, Union

from toolz import valmap

from gluonts.core.serde import dump_json, load_json
from gluonts.itertools import batcher


[docs]def decode_sagemaker_parameter(value: str) -> Union[list, dict, str]: """ All values passed through the SageMaker API are encoded as strings. Thus we pro-actively decode values that seem like arrays or dicts. Integer values (e.g. `"1"`) are handled by pydantic models further down the pipeline. """ value = value.strip() # TODO: is this the right way to do things? # what about fields which start which match the pattern for # some reason? is_list = value.startswith("[") and value.endswith("]") is_dict = value.startswith("{") and value.endswith("}") if is_list or is_dict: return load_json(value) else: return value
[docs]def encode_sagemaker_parameter(value: Any) -> str: """ All values passed through the SageMaker API must be encoded as strings. """ if not isinstance(value, str): return dump_json(value) else: return value
[docs]def decode_sagemaker_parameters(encoded_params: dict) -> dict: """ Decode a SageMaker parameters dictionary where all values are strings. Example: >>> decode_sagemaker_parameters({ ... "foo": "[1, 2, 3]", ... "bar": "hello" ... }) {'foo': [1, 2, 3], 'bar': 'hello'} """ return valmap(decode_sagemaker_parameter, encoded_params)
[docs]def encode_sagemaker_parameters(decoded_params: dict) -> dict: """ Encode a SageMaker parameters dictionary where all values are strings. Example: >>> encode_sagemaker_parameters({ ... "foo": [1, 2, 3], ... "bar": "hello" ... }) {'foo': '[1, 2, 3]', 'bar': 'hello'} """ return valmap(encode_sagemaker_parameter, decoded_params)
[docs]def detrim_and_decode_sagemaker_parameters(trimmed_params: dict) -> dict: """ Decode a SageMaker parameters dictionary where all values are strings. Example: >>> detrim_and_decode_sagemaker_parameters({ ... '_0_foo': '[1, ', ... '_1_foo': '2, 3', ... '_2_foo': ']', ... '_0_bar': 'hell', ... '_1_bar': 'o' ... }) {'foo': [1, 2, 3], 'bar': 'hello'} """ encoded_params = detrim_sagemaker_parameters(trimmed_params) return valmap(decode_sagemaker_parameter, encoded_params)
[docs]def encode_and_trim_sagemaker_parameters( decoded_params: dict, max_len: int = 256 ) -> dict: """ Encode a SageMaker parameters dictionary where all values are strings then trim them to account for Sagemaker character size limit. >>> encode_and_trim_sagemaker_parameters({ ... "foo": [1, 2, 3], ... "bar": "hello" ... }, max_len = 4) {'_0_foo': '[1, ', '_1_foo': '2, 3', '_2_foo': ']', '_0_bar': 'hell', '_1_bar': 'o'} """ endoded_params = valmap(encode_sagemaker_parameter, decoded_params) return trim_encoded_sagemaker_parameters(endoded_params, max_len)
[docs]def trim_encoded_sagemaker_parameters( encoded_params: dict, max_len: int = 256 ) -> dict: """ Trim parameters that have already been encoded to a given max length. Example: >>> trim_encoded_sagemaker_parameters({ ... 'foo': '[1, 2, 3]', ... 'bar': 'hello' ... }, max_len = 4) {'_0_foo': '[1, ', '_1_foo': '2, 3', '_2_foo': ']', '_0_bar': 'hell', '_1_bar': 'o'} """ trimmed_params = {} for key, value in encoded_params.items(): if len(value) > max_len: for idx, substr in enumerate(batcher(value, max_len)): trimmed_params[f"_{idx}_{key}"] = "".join(substr) else: trimmed_params[key] = value return trimmed_params
[docs]def detrim_sagemaker_parameters(trimmed_params: dict) -> dict: """ DE-trim parameters that have already been trimmed. Example: >>> detrim_sagemaker_parameters({ ... '_0_foo': '[1, ', ... '_1_foo': '2, 3', ... '_2_foo': ']', ... '_0_bar': 'hell', ... '_1_bar': 'o' ... }) {'foo': '[1, 2, 3]', 'bar': 'hello'} """ detrimmed_params = trimmed_params.copy() trimmed_param_names = [ param[3:] for param in detrimmed_params if param.startswith("_0_") ] for name in trimmed_param_names: value = "" for idx in count(): part = detrimmed_params.pop(f"_{idx}_{name}", None) if part is None: break value += part detrimmed_params[name] = value return detrimmed_params