mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix unittest in pt1.9 (#1146)
* fix test.txt * fix unittest in pt1.9 * fix checkpoint filename error * add comment * fix unittest * fix onnxruntime versionpull/1172/head
parent
6c63621a86
commit
4a9f83467c
|
@ -11,6 +11,7 @@ from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
|
||||||
_load_checkpoint,
|
_load_checkpoint,
|
||||||
get_deprecated_model_names,
|
get_deprecated_model_names,
|
||||||
get_external_models)
|
get_external_models)
|
||||||
|
from mmcv.utils import TORCH_VERSION
|
||||||
|
|
||||||
|
|
||||||
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
|
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
|
||||||
|
@ -77,13 +78,23 @@ def load(filepath, map_location=None):
|
||||||
def test_load_external_url():
|
def test_load_external_url():
|
||||||
# test modelzoo://
|
# test modelzoo://
|
||||||
url = _load_checkpoint('modelzoo://resnet50')
|
url = _load_checkpoint('modelzoo://resnet50')
|
||||||
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
|
if TORCH_VERSION < '1.9.0':
|
||||||
'.pth'
|
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
|
||||||
|
'357.pth')
|
||||||
|
else:
|
||||||
|
# filename of checkpoint is renamed in torch1.9.0
|
||||||
|
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
|
||||||
|
'a61.pth')
|
||||||
|
|
||||||
# test torchvision://
|
# test torchvision://
|
||||||
url = _load_checkpoint('torchvision://resnet50')
|
url = _load_checkpoint('torchvision://resnet50')
|
||||||
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
|
if TORCH_VERSION < '1.9.0':
|
||||||
'.pth'
|
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
|
||||||
|
'357.pth')
|
||||||
|
else:
|
||||||
|
# filename of checkpoint is renamed in torch1.9.0
|
||||||
|
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
|
||||||
|
'a61.pth')
|
||||||
|
|
||||||
# test open-mmlab:// with default MMCV_HOME
|
# test open-mmlab:// with default MMCV_HOME
|
||||||
os.environ.pop(ENV_MMCV_HOME, None)
|
os.environ.pop(ENV_MMCV_HOME, None)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||||
|
|
||||||
|
@ -15,7 +16,7 @@ def mock(*args, **kwargs):
|
||||||
|
|
||||||
@patch('torch.distributed._broadcast_coalesced', mock)
|
@patch('torch.distributed._broadcast_coalesced', mock)
|
||||||
@patch('torch.distributed.broadcast', mock)
|
@patch('torch.distributed.broadcast', mock)
|
||||||
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', MagicMock)
|
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
|
||||||
def test_is_module_wrapper():
|
def test_is_module_wrapper():
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
|
@ -27,6 +28,12 @@ def test_is_module_wrapper():
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
|
# _verify_model_across_ranks is added in torch1.9.0 so we should check
|
||||||
|
# wether _verify_model_across_ranks is the member of torch.distributed
|
||||||
|
# before mocking
|
||||||
|
if hasattr(torch.distributed, '_verify_model_across_ranks'):
|
||||||
|
torch.distributed._verify_model_across_ranks = mock
|
||||||
|
|
||||||
model = Model()
|
model = Model()
|
||||||
assert not is_module_wrapper(model)
|
assert not is_module_wrapper(model)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue