Fix pt1.5 unit tests. (#65)

* Fix pt1.5 unit tests.

* move to mmengine.testing
pull/62/head
RangiLyu 2022-03-01 11:28:21 +08:00 committed by GitHub
parent 1e5056ea7a
commit c2c5664fad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 55 additions and 3 deletions

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .compare import assert_allclose
__all__ = ['assert_allclose']

View File

@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Callable, Optional, Union
from torch.testing import assert_allclose as _assert_allclose
from mmengine.utils import TORCH_VERSION, digit_version
def assert_allclose(
actual: Any,
expected: Any,
rtol: Optional[float] = None,
atol: Optional[float] = None,
equal_nan: bool = True,
msg: Optional[Union[str, Callable]] = '',
) -> None:
"""Asserts that ``actual`` and ``expected`` are close. A wrapper function
of ``torch.testing.assert_allclose``.
Args:
actual (Any): Actual input.
expected (Any): Expected input.
rtol (Optional[float]): Relative tolerance. If specified ``atol`` must
also be specified. If omitted, default values based on the
:attr:`~torch.Tensor.dtype` are selected with the below table.
atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol`
must also be specified. If omitted, default values based on the
:attr:`~torch.Tensor.dtype` are selected with the below table.
equal_nan (bool): If ``True``, two ``NaN`` values will be considered
equal.
msg (Optional[Union[str, Callable]]): Optional error message to use if
the values of corresponding tensors mismatch. Unused when PyTorch
< 1.6.
"""
if 'parrots' not in TORCH_VERSION and \
digit_version(TORCH_VERSION) >= digit_version('1.6'):
_assert_allclose(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
msg=msg)
else:
# torch.testing.assert_allclose has no ``msg`` argument
# when PyTorch < 1.6
_assert_allclose(
actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan)

View File

@ -5,11 +5,11 @@ from unittest import TestCase
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.testing import assert_allclose
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
ExponentialLR, LinearLR, MultiStepLR,
StepLR, _ParamScheduler)
from mmengine.testing import assert_allclose
class ToyModel(torch.nn.Module):

View File

@ -5,13 +5,13 @@ from unittest import TestCase
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.testing import assert_allclose
from mmengine.optim.scheduler import (ConstantMomentum,
CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum,
MultiStepMomentum, StepMomentum,
_ParamScheduler)
from mmengine.testing import assert_allclose
class ToyModel(torch.nn.Module):

View File

@ -5,7 +5,6 @@ from unittest import TestCase
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.testing import assert_allclose
from mmengine.optim.scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
@ -13,6 +12,7 @@ from mmengine.optim.scheduler import (ConstantParamScheduler,
LinearParamScheduler,
MultiStepParamScheduler,
StepParamScheduler, _ParamScheduler)
from mmengine.testing import assert_allclose
class ToyModel(torch.nn.Module):