update yolox template config (#243)

* update yolox config templatee

* update oss sync hook,support sync all export  files

* add sync export model config
release/0.8.0
Cathy0908 2022-12-05 22:01:44 +08:00 committed by GitHub
parent befb23c2d5
commit a36948811f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 51 additions and 35 deletions

View File

@ -156,8 +156,6 @@ eval_pipelines = [
]
checkpoint_config = dict(interval='${interval}')
# export model during training
checkpoint_sync_export = True
# optimizer
# basic_lr_per_img = 0.01 / 64.0
optimizer = dict(
@ -197,8 +195,6 @@ load_from = None
resume_from = None
workflow = [('train', 1)]
export = dict(use_jit=False)
# oss io config
oss_io_config = dict(
ak_id='your oss ak id',
@ -206,3 +202,7 @@ oss_io_config = dict(
hosts='oss-cn-zhangjiakou.aliyuncs.com',
buckets=['your_bucket'])
oss_sync_config = dict(other_file_list=['**/events.out.tfevents*', '**/*log*'])
# export model during training
checkpoint_sync_export = True
export = dict(export_type='raw')

View File

@ -152,8 +152,6 @@ eval_pipelines = [
]
checkpoint_config = dict(interval='${interval}')
# export model during training
checkpoint_sync_export = True
# optimizer
# basic_lr_per_img = 0.01 / 64.0
optimizer = dict(
@ -193,5 +191,7 @@ load_from = None
resume_from = None
workflow = [('train', 1)]
export = dict(use_jit=False)
oss_sync_config = dict(other_file_list=['**/events.out.tfevents*', '**/*log*'])
# export model during training
checkpoint_sync_export = True
export = dict(export_type='raw')

View File

@ -43,6 +43,10 @@ CLASSES = [
# dataset settings
data_root = 'data/coco/'
train_ann_file = data_root + 'annotations/instances_train2017.json'
train_img_prefix = data_root + 'train2017/'
test_ann_file = data_root + 'annotations/instances_val2017.json'
test_img_prefix = data_root + 'val2017/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
@ -83,8 +87,8 @@ train_dataset = dict(
type='DetImagesMixDataset',
data_source=dict(
type='DetSourceCoco',
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
ann_file=train_ann_file,
img_prefix=train_img_prefix,
pipeline=[
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True)
@ -100,8 +104,8 @@ val_dataset = dict(
imgs_per_gpu=2,
data_source=dict(
type='DetSourceCoco',
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
ann_file=test_ann_file,
img_prefix=test_img_prefix,
pipeline=[
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True)
@ -188,4 +192,5 @@ log_config = dict(
# dict(type='WandbLoggerHookV2'),
])
checkpoint_sync_export = True
export = dict(export_type = 'raw', preprocess_jit = False, batch_size=1, blade_config=dict(enable_fp16=True, fp16_fallback_op_ratio=0.01), use_trt_efficientnms=False)

View File

@ -26,25 +26,33 @@ __all__ = [
]
def export(cfg, ckpt_path, filename, **kwargs):
def export(cfg, ckpt_path, filename, model=None, **kwargs):
""" export model for inference
Args:
cfg: Config object
ckpt_path (str): path to checkpoint file
filename (str): filename to save exported models
model (nn.module): model instance
"""
if hasattr(cfg.model, 'pretrained'):
logging.warning(
'Export needs to set model.pretrained to false to avoid hanging during distributed training'
)
cfg.model.pretrained = False
logging.warning(
'Export needs to set pretrained to false to avoid hanging during distributed training'
)
cfg.model['pretrained'] = False
if model is None:
model = build_model(cfg.model)
model = build_model(cfg.model)
if ckpt_path != 'dummy':
load_checkpoint(model, ckpt_path, map_location='cpu')
else:
cfg.model.backbone.pretrained = False
if hasattr(cfg.model, 'backbone') and hasattr(cfg.model.backbone,
'pretrained'):
logging.warning(
'Export needs to set model.backbone.pretrained to false to avoid hanging during distributed training'
)
cfg.model.backbone.pretrained = False
if isinstance(model, MOCO) or isinstance(model, DINO):
_export_moco(model, cfg, filename, **kwargs)

View File

@ -34,18 +34,20 @@ class ExportHook(Hook):
'export_after_each_ckpt', False)
def export_model(self, runner, epoch):
# epoch = runner.epoch
ckpt_fname = self.ckpt_filename_tmpl.format(epoch)
export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch)
local_ckpt = os.path.join(self.work_dir, ckpt_fname)
export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname)
if not os.path.exists(local_ckpt):
runner.logger.warning(f'{local_ckpt} does not exists, skip export')
runner.logger.info(f'export model to {export_local_ckpt}')
from easycv.apis.export import export
if hasattr(runner.model, 'module'):
model = runner.model.module
else:
runner.logger.info(f'export {local_ckpt} to {export_local_ckpt}')
from easycv.apis.export import export
export(self.cfg, local_ckpt, export_local_ckpt)
model = runner.model
export(
self.cfg,
ckpt_path='dummy',
filename=export_local_ckpt,
model=model)
@master_only
def after_train_iter(self, runner):

View File

@ -111,12 +111,13 @@ class OSSSyncHook(Hook):
# try to upload exported model
epoch = runner.epoch
export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch)
export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname)
export_oss_ckpt = os.path.join(self.oss_work_dir, export_ckpt_fname)
if not os.path.exists(export_local_ckpt):
runner.logger.warning(
f'{export_local_ckpt} does not exists, skip upload')
else:
runner.logger.info(
f'upload {export_local_ckpt} to {export_oss_ckpt}')
io.safe_copy(export_local_ckpt, export_oss_ckpt)
# upload all export files
export_files = glob.glob(
os.path.join(self.work_dir, '*{}*'.format(export_ckpt_fname)),
recursive=True)
for export_file in export_files:
rel_path = os.path.relpath(export_file, self.work_dir)
target_oss_path = os.path.join(self.oss_work_dir, rel_path)
runner.logger.info(f'upload {export_file} to {target_oss_path}')
io.safe_copy(export_file, target_oss_path)