diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index f08bf6913..400864700 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -11,6 +11,7 @@ from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME, _load_checkpoint, get_deprecated_model_names, get_external_models) +from mmcv.utils import TORCH_VERSION @patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) @@ -77,13 +78,23 @@ def load(filepath, map_location=None): def test_load_external_url(): # test modelzoo:// url = _load_checkpoint('modelzoo://resnet50') - assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \ - '.pth' + if TORCH_VERSION < '1.9.0': + assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e' + '357.pth') + else: + # filename of checkpoint is renamed in torch1.9.0 + assert url == ('url:https://download.pytorch.org/models/resnet50-0676b' + 'a61.pth') # test torchvision:// url = _load_checkpoint('torchvision://resnet50') - assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \ - '.pth' + if TORCH_VERSION < '1.9.0': + assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e' + '357.pth') + else: + # filename of checkpoint is renamed in torch1.9.0 + assert url == ('url:https://download.pytorch.org/models/resnet50-0676b' + 'a61.pth') # test open-mmlab:// with default MMCV_HOME os.environ.pop(ENV_MMCV_HOME, None) diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 93c8f5705..7d73aa81d 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock, patch +import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel @@ -15,7 +16,7 @@ def mock(*args, **kwargs): @patch('torch.distributed._broadcast_coalesced', mock) @patch('torch.distributed.broadcast', mock) -@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', MagicMock) +@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) def test_is_module_wrapper(): class Model(nn.Module): @@ -27,6 +28,12 @@ def test_is_module_wrapper(): def forward(self, x): return self.conv(x) + # _verify_model_across_ranks is added in torch1.9.0 so we should check + # wether _verify_model_across_ranks is the member of torch.distributed + # before mocking + if hasattr(torch.distributed, '_verify_model_across_ranks'): + torch.distributed._verify_model_across_ranks = mock + model = Model() assert not is_module_wrapper(model)