Source code for gluonts.testutil.equality

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pandas as pd
import numpy as np


[docs]def assert_recursively_equal(obj_a, obj_b, equal_nan=True): """ Asserts that two objects are equal, recursively. This is based on :func:`assert_recursively_close`, and accepts the same arguments, except that tolerances are set to zero. Parameters: ----------- obj_a obj_b Objects to compare. equal_nan Indicates whether or not numpy.nan values should be considered equal. """ _assert_recursively_close( obj_a, obj_b, location="", rtol=0, atol=0, equal_nan=equal_nan )
[docs]def assert_recursively_close( obj_a, obj_b, rtol=1e-05, atol=1e-08, equal_nan=True ): """ Asserts that two objects are "close" to each other, recursively. Strings or ints are close iff they are equal; floats or numpy arrays are defined close according to the numpy.isclose and numpy.allclose functions, respectively. Lists are close if all of their items are close. Dicts are close if they have the same keys, and elements corresponding to the same key are close. Parameters: ----------- obj_a obj_b Objects to compare. rtol atol Relative and absolute tolerance for float comparison; see docs for numpy.isclose. equal_nan Indicates whether or not numpy.nan values should be considered equal. """ _assert_recursively_close( obj_a, obj_b, location="", rtol=rtol, atol=atol, equal_nan=equal_nan )
def _assert_recursively_close(obj_a, obj_b, location, *args, **kwargs): assert type(obj_a) == type( obj_b ), f"types don't match (at {location}) {type(obj_a)} != {type(obj_b)}" if isinstance(obj_a, (str, int)): assert obj_a == obj_b elif isinstance(obj_a, float): assert np.isclose(obj_a, obj_b, *args, **kwargs) elif isinstance(obj_a, list): assert len(obj_a) == len(obj_b), f"lengths don't match (at {location})" for i, (element_a, element_b) in enumerate(zip(obj_a, obj_b)): _assert_recursively_close( element_a, element_b, location=f"{location}.{i}", *args, **kwargs, ) elif isinstance(obj_a, dict): assert ( obj_a.keys() == obj_b.keys() ), f"keys don't match (at {location})" for k in obj_a: _assert_recursively_close( obj_a[k], obj_b[k], location=f"{location}.{k}", *args, **kwargs ) elif isinstance(obj_a, np.ndarray): assert ( obj_a.dtype == obj_b.dtype ), f"numpy arrays have different dtype (at {location})" assert np.allclose( obj_a, obj_b, *args, **kwargs ), f"numpy arrays are not close enough (at {location})" elif isinstance(obj_a, pd.Period): assert obj_a == obj_b elif obj_a is None: assert obj_b is None else: raise TypeError(f"unsupported type {type(obj_a)}")