EasyCV/tests/test_tools/test_classification_train.py
Cathy0908 269aa6d01c
update citest yaml and fix unittest (#322)
* update citest yaml and fix unittest
2023-06-14 16:49:27 +08:00

129 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import logging
import os
import sys
import tempfile
import unittest
from tests.ut_config import CLS_TRAIN_TEST
from easycv.file import io
from easycv.utils.config_tools import mmcv_config_fromfile, pai_config_fromfile
from easycv.utils.test_util import run_in_subprocess
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
logging.basicConfig(level=logging.INFO)
_COMMON_OPTIONS = {
'checkpoint_config.interval': 1,
'total_epochs': 1,
'data.imgs_per_gpu': 8,
'model.backbone.norm_cfg.type': 'BN',
'model.head.num_classes': 2,
}
TRAIN_CONFIGS = [{
'config_file':
'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py',
'cfg_options': {
**_COMMON_OPTIONS,
'data.train.data_source.root': '',
'data.train.data_source.list_file': CLS_TRAIN_TEST,
'data.val.data_source.root': '',
'data.val.data_source.list_file': CLS_TRAIN_TEST,
'data.train.data_source.class_list': ['ok', 'ng'],
'data.val.data_source.class_list': ['ok', 'ng'],
}
}, {
'config_file':
'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py',
'cfg_options': {
**_COMMON_OPTIONS,
'data.train.data_source.root': '',
'data.train.data_source.list_file': CLS_TRAIN_TEST,
'data.val.data_source.root': '',
'data.val.data_source.list_file': CLS_TRAIN_TEST,
'model.train_preprocess': ['randomErasing', 'mixUp'],
'data.train.data_source.class_list': ['ok', 'ng'],
'data.val.data_source.class_list': ['ok', 'ng'],
}
}, {
'config_file':
'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py',
'cfg_options': {
**_COMMON_OPTIONS,
'data_train_root': '',
'data_train_list': CLS_TRAIN_TEST,
'data_test_root': '',
'data_test_list': CLS_TRAIN_TEST,
'image_resize2': [224, 224],
'save_epochs': 1,
'eval_epochs': 1,
'class_list': ['ok', 'ng'],
}
}]
class ClassificationTrainTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def tearDown(self):
super().tearDown()
def _base_train(self, train_cfgs, adapt_pai=False):
cfg_file = train_cfgs.pop('config_file')
cfg_options = train_cfgs.pop('cfg_options', None)
work_dir = train_cfgs.pop('work_dir', None)
if not work_dir:
work_dir = tempfile.TemporaryDirectory().name
if adapt_pai:
cfg = pai_config_fromfile(cfg_file, user_config_params=cfg_options)
cfg.eval_pipelines[0].data = cfg.data.val
else:
cfg = mmcv_config_fromfile(cfg_file)
if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
cfg.eval_pipelines[0].data = cfg.data.val
tmp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py').name
cfg.dump(tmp_cfg_file)
args_str = ' '.join(
['='.join((str(k), str(v))) for k, v in train_cfgs.items()])
cmd = 'python tools/train.py %s --work_dir=%s %s --fp16' % \
(tmp_cfg_file, work_dir, args_str)
logging.info('run command: %s' % cmd)
# run_in_subprocess(cmd) # 管道缓冲区被写满后面的写入请求都hang住了
import subprocess
subprocess.call(cmd, shell=True)
output_files = io.listdir(work_dir)
self.assertIn('epoch_1.pth', output_files)
io.remove(work_dir)
io.remove(tmp_cfg_file)
def test_classification(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0])
self._base_train(train_cfgs)
def test_classification_mixup(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[1])
self._base_train(train_cfgs)
def test_classification_pai(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[2])
self._base_train(train_cfgs, adapt_pai=True)
if __name__ == '__main__':
unittest.main()