# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import functools
import inspect
import logging
from collections import OrderedDict
from functools import singledispatch
from typing import Any, Type, TypeVar
import numpy as np
from gluonts.core import fqname_for
from gluonts.exceptions import GluonTSHyperparametersError
from gluonts.pydantic import (
BaseConfig,
BaseModel,
ValidationError,
create_model,
)
logger = logging.getLogger(__name__)
A = TypeVar("A")
[docs]def from_hyperparameters(cls: Type[A], **hyperparameters) -> A:
"""
Reflectively create an instance of a class with a :func:`validated`
initializer.
Parameters
----------
cls
The type ``A`` of the component to be instantiated.
hyperparameters
A dictionary of key-value pairs to be used as parameters to the
component initializer.
Returns
-------
A
An instance of the given class.
Raises
------
GluonTSHyperparametersError
Wraps a :class:`ValidationError` thrown when validating the
initializer parameters.
"""
Model = getattr(cls.__init__, "Model", None)
try:
if Model is not None:
return cls(**Model(**hyperparameters).__dict__) # type: ignore
else:
return cls(**hyperparameters) # type: ignore
except ValidationError as error:
raise GluonTSHyperparametersError from error
[docs]@singledispatch
def equals(this: Any, that: Any) -> bool:
"""
Structural equality check between two objects of arbitrary type.
By default, this function delegates to :func:`equals_default_impl`.
In addition, the function dispatches to specialized implementations based
on the type of the first argument, so the above conditions might be
stricter for certain types.
Parameters
----------
this, that
Objects to compare.
Returns
-------
bool
A boolean value indicating whether ``this`` and ``that`` are
structurally equal.
See Also
--------
equals_default_impl
Default semantics of a structural equality check between two objects
of arbitrary type.
equals_representable_block
Specialization for Gluon :class:`~mxnet.gluon.HybridBlock` input
arguments.
equals_parameter_dict
Specialization for Gluon :class:`~mxnet.gluon.ParameterDict` input
arguments.
"""
return equals_default_impl(this, that)
[docs]def equals_default_impl(this: Any, that: Any) -> bool:
"""
Default semantics of a structural equality check between two objects of
arbitrary type.
Two objects ``this`` and ``that`` are defined to be structurally equal
if and only if the following criteria are satisfied:
1. Their types match.
2. If their initializer are :func:`validated`, their initializer arguments
are pairwise structurally equal.
3. If their initializer are not :func:`validated`, they are referentially
equal (i.e. ``this == that``).
Parameters
----------
this, that
Objects to compare.
Returns
-------
bool
A boolean value indicating whether ``this`` and ``that`` are
structurally equal.
"""
if type(this) != type(that):
return False
if hasattr(this, "__init_args__") and hasattr(that, "__init_args__"):
return equals(
this.__init_args__,
that.__init_args__,
)
if hasattr(this, "__init_passed_kwargs__") and hasattr(
that, "__init_passed_kwargs__"
):
return equals(
this.__init_passed_kwargs__,
that.__init_passed_kwargs__,
)
return this == that
[docs]@equals.register(list)
def equals_list(this: list, that: list) -> bool:
if not len(this) == len(that):
return False
for x, y in zip(this, that):
if not equals(x, y):
return False
return True
[docs]@equals.register(dict)
def equals_dict(this: dict, that: dict) -> bool:
this_keys = this.keys()
that_keys = that.keys()
if not this_keys == that_keys:
return False
for name in this_keys:
x = this[name]
y = that[name]
if not equals(x, y):
return False
return True
[docs]@equals.register(np.ndarray)
def equals_ndarray(this: np.ndarray, that: np.ndarray) -> bool:
return np.array_equal(this, that)
[docs]@singledispatch
def tensor_to_numpy(tensor) -> np.ndarray:
raise NotImplementedError
@tensor_to_numpy.register(np.ndarray)
def _numpy_to_numpy(tensor: np.ndarray) -> np.ndarray:
return tensor
[docs]@singledispatch
def skip_encoding(v: Any) -> bool:
"""
Tells whether the input value `v` should be encoded using the
:func:`~gluonts.core.serde.encode` function.
This is used by :func:`validated` to determine which values need to
be skipped when recording the initializer arguments for later
serialization.
This is the fallback implementation, and can be specialized for
specific types by registering handler functions.
"""
return False
[docs]class BaseValidatedInitializerModel(BaseModel):
"""
Base Pydantic model for components with :func:`validated` initializers.
See Also
--------
validated
Decorates an initializer methods with argument validation logic.
"""
[docs] class Config(BaseConfig):
"""
`Config <https://pydantic-docs.helpmanual.io/#model-config>`_ for the
Pydantic model inherited by all :func:`validated` initializers.
Allows the use of arbitrary type annotations in initializer parameters.
"""
arbitrary_types_allowed = True
[docs]def validated(base_model=None):
"""
Decorates an ``__init__`` method with typed parameters with validation and
auto-conversion logic.
>>> class ComplexNumber:
... @validated()
... def __init__(self, x: float = 0.0, y: float = 0.0) -> None:
... self.x = x
... self.y = y
Classes with decorated initializers can be instantiated using arguments of
another type (e.g. an ``y`` argument of type ``str`` ). The decorator
handles the type conversion logic.
>>> c = ComplexNumber(y='42')
>>> (c.x, c.y)
(0.0, 42.0)
If the bound argument cannot be converted, the decorator throws an error.
>>> c = ComplexNumber(y=None)
Traceback (most recent call last):
...
pydantic.v1.error_wrappers.ValidationError: 1 validation error for
ComplexNumberModel
y
none is not an allowed value (type=type_error.none.not_allowed)
Internally, the decorator delegates all validation and conversion logic to
`a Pydantic model <https://pydantic-docs.helpmanual.io/>`_, which can be
accessed through the ``Model`` attribute of the decorated initializer.
>>> ComplexNumber.__init__.Model
<class 'pydantic.v1.main.ComplexNumberModel'>
The Pydantic model is synthesized automatically from on the parameter
names and types of the decorated initializer. In the ``ComplexNumber``
example, the synthesized Pydantic model corresponds to the following
definition.
>>> class ComplexNumberModel(BaseValidatedInitializerModel):
... x: float = 0.0
... y: float = 0.0
Clients can optionally customize the base class of the synthesized
Pydantic model using the ``base_model`` decorator parameter. The default
behavior uses :class:`BaseValidatedInitializerModel` and its
`model config <https://pydantic-docs.helpmanual.io/#config>`_.
See Also
--------
BaseValidatedInitializerModel
Default base class for all synthesized Pydantic models.
"""
def validator(init):
init_qualname = dict(inspect.getmembers(init))["__qualname__"] # noqa
init_clsnme = init_qualname.split(".")[0]
init_params = inspect.signature(init).parameters
init_fields = {
param.name: (
(
param.annotation
if param.annotation != inspect.Parameter.empty
else Any
),
(
param.default
if param.default != inspect.Parameter.empty
else ...
),
)
for param in init_params.values()
if param.name != "self"
and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
}
if base_model is None:
PydanticModel = create_model(
f"{init_clsnme}Model",
__config__=BaseValidatedInitializerModel.Config,
**init_fields,
)
else:
PydanticModel = create_model(
f"{init_clsnme}Model",
__base__=base_model,
**init_fields,
)
def validated_repr(self) -> str:
cname = fqname_for(self.__class__)
kwargs = ", ".join(
f"{key}={value!r}" for key, value in self.__init_args__.items()
)
return f"{cname}({kwargs})"
def validated_getnewargs_ex(self):
return (), self.__init_args__
@functools.wraps(init)
def init_wrapper(*args, **kwargs):
self, *args = args
nmargs = {
name: arg
for (name, param), arg in zip(
list(init_params.items()), [self] + args
)
if name != "self"
}
model = PydanticModel(**{**nmargs, **kwargs})
# merge nmargs, kwargs, and the model fields into a single dict
all_args = {**nmargs, **kwargs, **model.__dict__}
# save the merged dictionary for Representable use, but only of the
# __init_args__ is not already set in order to avoid overriding a
# value set by a subclass initializer in super().__init__ calls
if not getattr(self, "__init_args__", {}):
self.__init_args__ = OrderedDict(
{
name: arg
for name, arg in sorted(all_args.items())
if not skip_encoding(arg)
}
)
self.__class__.__getnewargs_ex__ = validated_getnewargs_ex
self.__class__.__repr__ = validated_repr
return init(self, **all_args)
# attach the Pydantic model as the attribute of the initializer wrapper
setattr(init_wrapper, "Model", PydanticModel)
return init_wrapper
return validator