EasyCV/tests/tools/test_yolox_train.py

140 lines
4.5 KiB
Python
Raw Normal View History

2022-04-02 20:01:06 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import glob
import json
2022-04-02 20:01:06 +08:00
import logging
import os
import sys
import tempfile
import unittest
import torch
from mmcv import Config
from tests.ut_config import (DET_DATA_MANIFEST_OSS, DET_DATA_SMALL_COCO_LOCAL,
PRETRAINED_MODEL_YOLOXS)
2022-04-02 20:01:06 +08:00
from easycv.file import io
from easycv.file.utils import get_oss_config
from easycv.utils.test_util import run_in_subprocess
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
logging.basicConfig(level=logging.INFO)
SMALL_COCO_DATA_ROOT = DET_DATA_SMALL_COCO_LOCAL.rstrip('/') + '/'
SMALL_COCO_ITAG_DATA_ROOT = DET_DATA_MANIFEST_OSS.rstrip('/') + '/'
_COMMON_OPTIONS = {
'checkpoint_config.interval': 1,
'eval_config.interval': 1,
'total_epochs': 1,
'data.imgs_per_gpu': 8,
'load_from': PRETRAINED_MODEL_YOLOXS,
'optimizer.lr': 0.0
2022-04-02 20:01:06 +08:00
}
TRAIN_CONFIGS = [
# itag test
{
'config_file':
'configs/detection/yolox/yolox_s_8xb16_300e_coco_pai.py',
'cfg_options': {
2022-04-29 14:31:58 +08:00
**_COMMON_OPTIONS,
'data.train.data_source.path':
2022-04-02 20:01:06 +08:00
SMALL_COCO_ITAG_DATA_ROOT + 'train2017_20.manifest',
'data.val.data_source.path':
2022-04-29 14:31:58 +08:00
SMALL_COCO_ITAG_DATA_ROOT + 'val2017_20.manifest',
2022-04-02 20:01:06 +08:00
}
},
{
'config_file': 'configs/detection/yolox/yolox_s_8xb16_300e_coco.py',
'cfg_options': {
**_COMMON_OPTIONS, 'data.train.data_source.img_prefix':
SMALL_COCO_DATA_ROOT + 'train2017',
'data.val.data_source.img_prefix':
SMALL_COCO_DATA_ROOT + 'val2017',
'data.train.data_source.ann_file':
SMALL_COCO_DATA_ROOT + 'instances_train2017_20.json',
'data.val.data_source.ann_file':
SMALL_COCO_DATA_ROOT + 'instances_val2017_20.json'
}
},
]
class YOLOXTrainTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def tearDown(self):
super().tearDown()
def check_metric(self, work_dir):
json_file = glob.glob(os.path.join(work_dir, '*.log.json'))
with io.open(json_file[0], 'r') as f:
content = f.readlines()
res = json.loads(content[1])
self.assertGreater(res['DetectionBoxes_Precision/mAP'], 0.4)
self.assertGreater(res['DetectionBoxes_Precision/mAP@.50IOU'], 0.5)
self.assertGreater(res['DetectionBoxes_Precision/mAP@.75IOU'], 0.4)
def _base_train(self, train_cfgs, dist=False, dist_eval=False):
2022-04-02 20:01:06 +08:00
cfg_file = train_cfgs.pop('config_file')
cfg_options = train_cfgs.pop('cfg_options', None)
work_dir = train_cfgs.pop('work_dir', None)
if not work_dir:
work_dir = tempfile.TemporaryDirectory().name
cfg = Config.fromfile(cfg_file)
if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
cfg.eval_pipelines[0].data = dict(**cfg.data.val) # imgs_per_gpu=1
cfg.eval_pipelines[0].dist_eval = dist_eval
2022-04-02 20:01:06 +08:00
tmp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py').name
cfg.dump(tmp_cfg_file)
args_str = ' '.join(
['='.join((str(k), str(v))) for k, v in train_cfgs.items()])
if dist:
nproc_per_node = 2
cmd = 'bash tools/dist_train.sh %s %s --launcher pytorch --work_dir=%s %s --fp16' % (
2022-04-02 20:01:06 +08:00
tmp_cfg_file, nproc_per_node, work_dir, args_str)
else:
cmd = 'python tools/train.py %s --work_dir=%s %s --fp16' % (
2022-04-02 20:01:06 +08:00
tmp_cfg_file, work_dir, args_str)
logging.info('run command: %s' % cmd)
run_in_subprocess(cmd)
output_files = io.listdir(work_dir)
self.assertIn('epoch_1.pth', output_files)
self.check_metric(work_dir)
2022-04-02 20:01:06 +08:00
io.remove(work_dir)
io.remove(tmp_cfg_file)
def test_yolox_itag(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0])
train_cfgs['cfg_options'].update(dict(oss_io_config=get_oss_config()))
self._base_train(train_cfgs)
def test_yolox_coco(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[1])
self._base_train(train_cfgs)
@unittest.skipIf(torch.cuda.device_count() <= 1, 'distributed unittest')
def test_yolox_itag_dist(self):
train_cfgs = copy.deepcopy(TRAIN_CONFIGS[0])
train_cfgs['cfg_options'].update(dict(oss_io_config=get_oss_config()))
self._base_train(train_cfgs, dist_eval=True)
2022-04-02 20:01:06 +08:00
if __name__ == '__main__':
unittest.main()