EasyCV/tests/tools/test_classification_train.py
tuofeilun 23f2b0e399
Adapt designer (#235)
1. Use original config as startup script. (For details, see refactor config parsing method #225)
2. Refactor the splicing rules of the check_base_cfg_path function in the EasyCV/easycv/utils/config_tools.py
3. Support three ways to pass class_list parameter.
4. Fix the bug that clsevalutor may make mistakes when evaluating top5.
5. Fix the bug that the distributed export cannot export the model.
6. Fix the bug that the load pretrained model key does not match.
7. support cls data source itag.
2022-12-01 17:47:10 +08:00

131 lines
4.1 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import logging
import os
import sys
import tempfile
import unittest
from tests.ut_config import SMALL_IMAGENET_RAW_LOCAL
from easycv.file import io
from easycv.utils.config_tools import mmcv_config_fromfile, pai_config_fromfile
from easycv.utils.test_util import run_in_subprocess
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
logging.basicConfig(level=logging.INFO)
SMALL_IMAGENET_DATA_ROOT = SMALL_IMAGENET_RAW_LOCAL + '/'
_COMMON_OPTIONS = {
'checkpoint_config.interval': 1,
'total_epochs': 1,
'data.imgs_per_gpu': 8,
'model.backbone.norm_cfg.type': 'BN'
}
TRAIN_CONFIGS = [{
'config_file':
'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py',
'cfg_options': {
**_COMMON_OPTIONS,
'data.train.data_source.root':
SMALL_IMAGENET_DATA_ROOT + 'train/',
'data.train.data_source.list_file':
SMALL_IMAGENET_DATA_ROOT + 'meta/train_labeled_200.txt',
'data.val.data_source.root':
SMALL_IMAGENET_DATA_ROOT + 'validation/',
'data.val.data_source.list_file':
SMALL_IMAGENET_DATA_ROOT + 'meta/val_labeled_100.txt',
}
}, {
'config_file':
'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py',
'cfg_options': {
**_COMMON_OPTIONS, 'data.train.data_source.root':
SMALL_IMAGENET_DATA_ROOT + 'train/',
'data.train.data_source.list_file':
SMALL_IMAGENET_DATA_ROOT + 'meta/train_labeled_200.txt',
'data.val.data_source.root':
SMALL_IMAGENET_DATA_ROOT + 'validation/',
'data.val.data_source.list_file':
SMALL_IMAGENET_DATA_ROOT + 'meta/val_labeled_100.txt',
'model.train_preprocess': ['randomErasing', 'mixUp']
}
}, {
'config_file':
'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py',
'cfg_options': {
**_COMMON_OPTIONS, 'data_train_root':
SMALL_IMAGENET_DATA_ROOT + 'train/',
'data_train_list':
SMALL_IMAGENET_DATA_ROOT + 'meta/train_labeled_200.txt',
'data_test_root': SMALL_IMAGENET_DATA_ROOT + 'validation/',
'data_test_list':
SMALL_IMAGENET_DATA_ROOT + 'meta/val_labeled_100.txt',
'image_resize2': [224, 224],
'save_epochs': 1,
'eval_epochs': 1
}
}]
class ClassificationTrainTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def tearDown(self):
super().tearDown()
def _base_train(self, train_cfgs, adapt_pai=False):
cfg_file = train_cfgs.pop('config_file')
cfg_options = train_cfgs.pop('cfg_options', None)
work_dir = train_cfgs.pop('work_dir', None)
if not work_dir:
work_dir = tempfile.TemporaryDirectory().name
if adapt_pai:
cfg = pai_config_fromfile(cfg_file, user_config_params=cfg_options)
cfg.eval_pipelines[0].data = cfg.data.val
else:
cfg = mmcv_config_fromfile(cfg_file)
if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
cfg.eval_pipelines[0].data = cfg.data.val
tmp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py').name
cfg.dump(tmp_cfg_file)
args_str = ' '.join(
['='.join((str(k), str(v))) for k, v in train_cfgs.items()])
cmd = 'python tools/train.py %s --work_dir=%s %s --fp16' % \
(tmp_cfg_file, work_dir, args_str)
logging.info('run command: %s' % cmd)
run_in_subprocess(cmd)
output_files = io.listdir(work_dir)
self.assertIn('epoch_1.pth', output_files)
io.remove(work_dir)
io.remove(tmp_cfg_file)
def test_classification(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0])
self._base_train(train_cfgs)
def test_classification_mixup(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[1])
self._base_train(train_cfgs)
def test_classification_pai(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[2])
self._base_train(train_cfgs, adapt_pai=True)
if __name__ == '__main__':
unittest.main()