[Fix] Fix unit test in windows (#515)

pull/527/head
Mashiro 2022-09-13 11:46:21 +08:00 committed by GitHub
parent 0fb2b8ca8c
commit 6b1b8a3751
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 90 additions and 54 deletions

View File

@ -118,5 +118,10 @@ jobs:
pip install -r requirements/tests.txt
pip install openmim
mim install 'mmcv>=2.0.0rc1'
- name: Run unittests
- name: Run CPU unittests
run: pytest tests/
if: ${{ matrix.platform == 'cpu' }}
- name: Run GPU unittests
# Skip testing distributed related unit tests since the memory of windows CI is limited
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py
if: ${{ matrix.platform == 'cu111' }}

View File

@ -175,7 +175,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
f'RGB or gray image, but got {len(mean)} values')
assert len(std) == 3 or len(std) == 1, ( # type: ignore
'`std` should have 1 or 3 values, to be compatible with RGB ' # type: ignore # noqa: E501
f'or gray image, but got {len(std)} values')
f'or gray image, but got {len(std)} values') # type: ignore
self._enable_normalize = True
self.register_buffer('mean',
torch.tensor(mean).view(-1, 1, 1), False)

View File

@ -165,6 +165,11 @@ class InstanceData(BaseDataElement):
if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
# The default int type of numpy is platform dependent, int32 for
# windows and int64 for linux. `torch.Tensor` requires the index
# should be int64, therefore we simply convert it to int64 here.
# More details in https://github.com/numpy/numpy/issues/9464
item = item.astype(np.int64) if item.dtype == np.int32 else item
item = torch.from_numpy(item)
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,

View File

@ -4,6 +4,7 @@ import os.path as osp
import subprocess
import sys
from collections import OrderedDict, defaultdict
from distutils import errors
import cv2
import numpy as np
@ -103,7 +104,7 @@ def collect_env():
sys.stdout.fileno()) or locale.getpreferredencoding()
env_info['MSVC'] = cc.decode(encoding).partition('\n')[0].strip()
env_info['GCC'] = 'n/a'
except subprocess.CalledProcessError:
except (subprocess.CalledProcessError, errors.DistutilsPlatformError):
env_info['GCC'] = 'n/a'
env_info['PyTorch'] = torch.__version__

View File

@ -9,7 +9,7 @@ from mmengine.config.utils import (_get_external_cfg_base_path,
def test_get_external_cfg_base_path(tmp_path):
package_path = tmp_path
rel_cfg_path = 'cfg_dir/cfg_file'
rel_cfg_path = os.path.join('cfg_dir', 'cfg_file')
with pytest.raises(FileNotFoundError):
_get_external_cfg_base_path(str(package_path), rel_cfg_path)
cfg_dir = tmp_path / '.mim' / 'configs' / 'cfg_dir'

View File

@ -131,25 +131,21 @@ class TestDataUtils(TestCase):
self.assertEqual(tuple(batch_inputs_0.shape), (2, 1, 3, 5))
self.assertEqual(tuple(batch_inputs_1.shape), (2, 1, 3, 5))
self.assertTrue(
torch.allclose(batch_inputs_0, torch.stack([input1, input1])))
self.assertTrue(
torch.allclose(batch_inputs_1, torch.stack([input2, input2])))
self.assertTrue(
torch.allclose(batch_value_0,
torch.stack([torch.tensor(1),
torch.tensor(1)])))
self.assertTrue(
torch.allclose(batch_value_1,
torch.stack([torch.tensor(2),
torch.tensor(2)])))
target1 = torch.stack([torch.tensor(1), torch.tensor(1)])
target2 = torch.stack([torch.tensor(2), torch.tensor(2)])
self.assertTrue(
torch.allclose(batch_array_0,
torch.stack([torch.tensor(1),
torch.tensor(1)])))
torch.allclose(batch_value_0.to(target1.dtype), target1))
self.assertTrue(
torch.allclose(batch_array_1,
torch.stack([torch.tensor(2),
torch.tensor(2)])))
torch.allclose(batch_value_1.to(target2.dtype), target2))
self.assertTrue(
torch.allclose(batch_array_0.to(target1.dtype), target1))
self.assertTrue(
torch.allclose(batch_array_1.to(target2.dtype), target2))

View File

@ -46,8 +46,8 @@ class TestCheckpointHook:
checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, out_dir='test_dir')
checkpoint_hook.before_train(runner)
assert checkpoint_hook.out_dir == (
f'test_dir/{osp.basename(work_dir)}')
assert checkpoint_hook.out_dir == osp.join(
'test_dir', osp.join(osp.basename(work_dir)))
runner.message_hub = MessageHub.get_instance('test_before_train')
# no 'best_ckpt_path' in runtime_info
@ -297,20 +297,20 @@ class TestCheckpointHook:
checkpoint_hook.after_train_epoch(runner)
assert (runner.epoch + 1) % 2 == 0
assert 'last_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('last_ckpt') == (
f'{work_dir}/epoch_10.pth')
runner.message_hub.get_info('last_ckpt') == \
osp.join(work_dir, 'epoch_10.pth')
last_ckpt_path = osp.join(work_dir, 'last_checkpoint')
assert osp.isfile(last_ckpt_path)
with open(last_ckpt_path) as f:
filepath = f.read()
assert filepath == f'{work_dir}/epoch_10.pth'
assert filepath == osp.join(work_dir, 'epoch_10.pth')
# epoch can not be evenly divided by 2
runner.epoch = 10
checkpoint_hook.after_train_epoch(runner)
assert 'last_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('last_ckpt') == (
f'{work_dir}/epoch_10.pth')
runner.message_hub.get_info('last_ckpt') == \
osp.join(work_dir, 'epoch_10.pth')
# by epoch is False
runner.epoch = 9
@ -351,20 +351,20 @@ class TestCheckpointHook:
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert (runner.iter + 1) % 2 == 0
assert 'last_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('last_ckpt') == (
f'{work_dir}/iter_10.pth')
runner.message_hub.get_info('last_ckpt') == \
osp.join(work_dir, 'iter_10.pth')
# epoch can not be evenly divided by 2
runner.iter = 10
checkpoint_hook.after_train_epoch(runner)
assert 'last_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('last_ckpt') == (
f'{work_dir}/iter_10.pth')
runner.message_hub.get_info('last_ckpt') == \
osp.join(work_dir, 'iter_10.pth')
# max_keep_ckpts > 0
runner.iter = 9
runner.work_dir = work_dir
os.system(f'touch {work_dir}/iter_8.pth')
os.system(f'touch {osp.join(work_dir, "iter_8.pth")}')
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, max_keep_ckpts=1)
checkpoint_hook.before_train(runner)

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os.path as osp
import tempfile
from unittest import TestCase
@ -10,6 +11,7 @@ from torch.utils.data import Dataset
from mmengine.evaluator import Evaluator
from mmengine.hooks import EMAHook
from mmengine.logging import MMLogger
from mmengine.model import BaseModel, ExponentialMovingAverage
from mmengine.optim import OptimWrapper
from mmengine.registry import DATASETS, MODEL_WRAPPERS
@ -89,6 +91,10 @@ class TestEMAHook(TestCase):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
self.temp_dir.cleanup()
def test_ema_hook(self):

View File

@ -43,7 +43,10 @@ class TestLogger:
'rank0.pkg3', logger_name='logger_test', log_level='INFO')
assert logger.name == 'logger_test'
assert logger.instance_name == 'rank0.pkg3'
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
@patch('mmengine.logging.logger._get_rank', lambda: 1)
def test_init_rank1(self, tmp_path):
@ -62,7 +65,10 @@ class TestLogger:
assert logger.handlers[1].level == logging.INFO
assert len(logger.handlers) == 2
assert os.path.exists(log_path)
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
@pytest.mark.parametrize('log_level',
[logging.WARNING, logging.INFO, logging.DEBUG])
@ -92,7 +98,10 @@ class TestLogger:
f' - mmengine - {loglevl_name} - '
f'welcome\n', log_text)
assert match is not None
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
def test_error_format(self, capsys):
# test error level log can output file path, function name and
@ -100,7 +109,10 @@ class TestLogger:
logger = MMLogger.get_instance('test_error', log_level='INFO')
logger.error('welcome')
lineno = sys._getframe().f_lineno - 1
file_path = __file__
# replace \ for windows:
# origin: c:\\a\\b\\c.py
# replaced: c:\\\\a\\\\b\\\\c.py for re.match.
file_path = __file__.replace('\\', '\\\\')
function_name = sys._getframe().f_code.co_name
pattern = self.stream_handler_regex_time + \
r' - mmengine - (.*)ERROR(.*) - ' \

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from unittest import TestCase
import torch
@ -106,6 +107,12 @@ class TestBaseModule(TestCase):
conv1d=dict(type='FooConv1d')))
self.model = build_from_cfg(self.model_cfg, FOOMODELS)
self.logger = MMLogger.get_instance(self._testMethodName)
def tearDown(self) -> None:
logging.shutdown()
MMLogger._instance_dict.clear()
return super().tearDown()
def test_is_init(self):
assert self.BaseModule.is_init is False
@ -194,6 +201,10 @@ class TestBaseModule(TestCase):
model2.init_weights()
assert len(os.listdir(dump_dir)) == 1
assert os.stat(log_path).st_size != 0
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
shutil.rmtree(dump_dir)

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import os
import os.path as osp
import shutil
@ -410,6 +411,10 @@ class TestRunner(TestCase):
sampler_seed=dict(type='DistSamplerSeedHook'))
def tearDown(self):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
shutil.rmtree(self.temp_dir)
def test_init(self):
@ -579,8 +584,9 @@ class TestRunner(TestCase):
runner.train()
runner.test()
# 5. Test building multiple runners
if torch.cuda.is_available():
# 5. Test building multiple runners. In Windows, nccl could not be
# available, and this test will be skipped.
if torch.cuda.is_available() and torch.distributed.is_nccl_available():
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init15'
cfg.launcher = 'pytorch'
@ -589,9 +595,9 @@ class TestRunner(TestCase):
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['LOCAL_RANK'] = '0'
runner = Runner(**cfg)
Runner(**cfg)
cfg.experiment_name = 'test_init16'
runner = Runner(**cfg)
Runner(**cfg)
# 6.1 Test initializing with empty scheduler.
cfg = copy.deepcopy(self.epoch_based_cfg)
@ -680,8 +686,11 @@ class TestRunner(TestCase):
osp.join(runner.work_dir, f'{runner.timestamp}.py'))
# dump config from file.
with tempfile.TemporaryDirectory() as temp_config_dir:
# Set `delete=Flase` and close the file to make it
# work in Windows.
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix='.py')
dir=temp_config_dir, suffix='.py', delete=False)
temp_config_file.close()
file_cfg = Config(
self.epoch_based_cfg._cfg_dict,
filename=temp_config_file.name)
@ -834,7 +843,7 @@ class TestRunner(TestCase):
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, BaseModel)
if torch.cuda.is_available():
if torch.cuda.is_available() and torch.distributed.is_nccl_available():
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29515'
os.environ['RANK'] = str(0)

View File

@ -2,22 +2,12 @@
import sys
from unittest import TestCase
import torch.cuda
import mmengine
from mmengine.utils.dl_utils import collect_env
from mmengine.utils.dl_utils.parrots_wrapper import _get_cuda_home
class TestCollectEnv(TestCase):
def test_get_cuda_home(self):
CUDA_HOME = _get_cuda_home()
if torch.version.cuda is not None:
self.assertIsNotNone(CUDA_HOME)
else:
self.assertIsNone(CUDA_HOME)
def test_collect_env(self):
env_info = collect_env()
expected_keys = [
@ -31,9 +21,6 @@ class TestCollectEnv(TestCase):
for key in ['CUDA_HOME', 'NVCC']:
assert key in env_info
if sys.platform == 'win32':
assert 'MSVC' in env_info
assert env_info['sys.platform'] == sys.platform
assert env_info['Python'] == sys.version.replace('\n', '')
assert env_info['MMEngine'] == mmengine.__version__

View File

@ -18,10 +18,12 @@ def test_timer_init():
def test_timer_run():
timer = mmengine.Timer()
time.sleep(1)
assert abs(timer.since_start() - 1) < 1e-2
# In Windows, the error could be larger than 20ms. More details in
# https://stackoverflow.com/questions/11657734/sleep-for-exact-time-in-python. # noqa: E501
assert abs(timer.since_start() - 1) < 3e-2
time.sleep(1)
assert abs(timer.since_last_check() - 1) < 1e-2
assert abs(timer.since_start() - 2) < 1e-2
assert abs(timer.since_last_check() - 1) < 3e-2
assert abs(timer.since_start() - 2) < 3e-2
timer = mmengine.Timer(False)
with pytest.raises(mmengine.TimerError):
timer.since_start()
@ -33,7 +35,9 @@ def test_timer_context(capsys):
with mmengine.Timer():
time.sleep(1)
out, _ = capsys.readouterr()
assert abs(float(out) - 1) < 1e-2
# In Windows, the error could be larger than 20ms. More details in
# https://stackoverflow.com/questions/11657734/sleep-for-exact-time-in-python. # noqa: E501
assert abs(float(out) - 1) < 3e-2
with mmengine.Timer(print_tmpl='time: {:.1f}s'):
time.sleep(1)
out, _ = capsys.readouterr()