parent
1e5056ea7a
commit
c2c5664fad
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .compare import assert_allclose
|
||||
|
||||
__all__ = ['assert_allclose']
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from torch.testing import assert_allclose as _assert_allclose
|
||||
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
|
||||
|
||||
def assert_allclose(
|
||||
actual: Any,
|
||||
expected: Any,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
equal_nan: bool = True,
|
||||
msg: Optional[Union[str, Callable]] = '',
|
||||
) -> None:
|
||||
"""Asserts that ``actual`` and ``expected`` are close. A wrapper function
|
||||
of ``torch.testing.assert_allclose``.
|
||||
|
||||
Args:
|
||||
actual (Any): Actual input.
|
||||
expected (Any): Expected input.
|
||||
rtol (Optional[float]): Relative tolerance. If specified ``atol`` must
|
||||
also be specified. If omitted, default values based on the
|
||||
:attr:`~torch.Tensor.dtype` are selected with the below table.
|
||||
atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol`
|
||||
must also be specified. If omitted, default values based on the
|
||||
:attr:`~torch.Tensor.dtype` are selected with the below table.
|
||||
equal_nan (bool): If ``True``, two ``NaN`` values will be considered
|
||||
equal.
|
||||
msg (Optional[Union[str, Callable]]): Optional error message to use if
|
||||
the values of corresponding tensors mismatch. Unused when PyTorch
|
||||
< 1.6.
|
||||
"""
|
||||
if 'parrots' not in TORCH_VERSION and \
|
||||
digit_version(TORCH_VERSION) >= digit_version('1.6'):
|
||||
_assert_allclose(
|
||||
actual,
|
||||
expected,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=equal_nan,
|
||||
msg=msg)
|
||||
else:
|
||||
# torch.testing.assert_allclose has no ``msg`` argument
|
||||
# when PyTorch < 1.6
|
||||
_assert_allclose(
|
||||
actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
|
@ -5,11 +5,11 @@ from unittest import TestCase
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.testing import assert_allclose
|
||||
|
||||
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
|
||||
ExponentialLR, LinearLR, MultiStepLR,
|
||||
StepLR, _ParamScheduler)
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
|
|
|
@ -5,13 +5,13 @@ from unittest import TestCase
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.testing import assert_allclose
|
||||
|
||||
from mmengine.optim.scheduler import (ConstantMomentum,
|
||||
CosineAnnealingMomentum,
|
||||
ExponentialMomentum, LinearMomentum,
|
||||
MultiStepMomentum, StepMomentum,
|
||||
_ParamScheduler)
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
|
|
|
@ -5,7 +5,6 @@ from unittest import TestCase
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.testing import assert_allclose
|
||||
|
||||
from mmengine.optim.scheduler import (ConstantParamScheduler,
|
||||
CosineAnnealingParamScheduler,
|
||||
|
@ -13,6 +12,7 @@ from mmengine.optim.scheduler import (ConstantParamScheduler,
|
|||
LinearParamScheduler,
|
||||
MultiStepParamScheduler,
|
||||
StepParamScheduler, _ParamScheduler)
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
|
|
Loading…
Reference in New Issue