Source code for gluonts.mx.context
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import functools
import logging
import re
from typing import Union
import mxnet as mx
logger = logging.getLogger(__name__)
[docs]class MXContext:
"""
Defines `custom data type validation.
<https://pydantic-docs.helpmanual.io/#custom-data-types>`_ for
the :class:`~mxnet.context.Context` data type.
"""
[docs] @classmethod
def validate(cls, v: Union[str, mx.Context]) -> mx.Context:
if isinstance(v, mx.Context):
return v
m = re.search(r"^(?P<dev_type>cpu|gpu)(\((?P<dev_id>\d+)\))?$", v)
if m:
return mx.Context(m["dev_type"], int(m["dev_id"] or 0))
else:
raise ValueError(
f"bad MXNet context {v}, expected either an "
"mx.context.Context or its string representation"
)
@classmethod
def __get_validators__(cls) -> mx.Context:
yield cls.validate
mx.Context.validate = MXContext.validate
mx.Context.__get_validators__ = MXContext.__get_validators__
NUM_GPUS = None
[docs]def num_gpus(refresh=False):
global NUM_GPUS
if NUM_GPUS is None or refresh:
n = 0
try:
n = mx.context.num_gpus()
except mx.base.MXNetError as e:
logger.error(f"Failure when querying GPU: {e}")
NUM_GPUS = n
return NUM_GPUS
[docs]@functools.lru_cache()
def get_mxnet_context(gpu_number=0) -> mx.Context:
"""
Returns either CPU or GPU context.
"""
if num_gpus():
logger.info("Using GPU")
return mx.context.gpu(gpu_number)
else:
logger.info("Using CPU")
return mx.context.cpu()
[docs]def check_gpu_support() -> bool:
"""
Emits a log line and returns a boolean that indicate whether the currently
installed MXNet version has GPU support.
"""
n = num_gpus()
logger.info(f'MXNet GPU support is {"ON" if n > 0 else "OFF"}')
return n != 0