[Fix] Fix unittest in pt1.9 (#1146)

* fix test.txt

* fix unittest in pt1.9

* fix checkpoint filename error

* add comment

* fix unittest

* fix onnxruntime version
pull/1172/head
Zaida Zhou 2021-07-03 20:47:22 +08:00 committed by GitHub
parent 6c63621a86
commit 4a9f83467c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 5 deletions

View File

@ -11,6 +11,7 @@ from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
_load_checkpoint,
get_deprecated_model_names,
get_external_models)
from mmcv.utils import TORCH_VERSION
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
@ -77,13 +78,23 @@ def load(filepath, map_location=None):
def test_load_external_url():
# test modelzoo://
url = _load_checkpoint('modelzoo://resnet50')
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
'.pth'
if TORCH_VERSION < '1.9.0':
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
# filename of checkpoint is renamed in torch1.9.0
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
# test torchvision://
url = _load_checkpoint('torchvision://resnet50')
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
'.pth'
if TORCH_VERSION < '1.9.0':
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
# filename of checkpoint is renamed in torch1.9.0
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
# test open-mmlab:// with default MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)

View File

@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
@ -15,7 +16,7 @@ def mock(*args, **kwargs):
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', MagicMock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_module_wrapper():
class Model(nn.Module):
@ -27,6 +28,12 @@ def test_is_module_wrapper():
def forward(self, x):
return self.conv(x)
# _verify_model_across_ranks is added in torch1.9.0 so we should check
# wether _verify_model_across_ranks is the member of torch.distributed
# before mocking
if hasattr(torch.distributed, '_verify_model_across_ranks'):
torch.distributed._verify_model_across_ranks = mock
model = Model()
assert not is_module_wrapper(model)