mirror of https://github.com/alibaba/EasyCV.git
update yolox template config (#243)
* update yolox config templatee * update oss sync hook,support sync all export files * add sync export model configrelease/0.8.0
parent
befb23c2d5
commit
a36948811f
|
@ -156,8 +156,6 @@ eval_pipelines = [
|
|||
]
|
||||
|
||||
checkpoint_config = dict(interval='${interval}')
|
||||
# export model during training
|
||||
checkpoint_sync_export = True
|
||||
# optimizer
|
||||
# basic_lr_per_img = 0.01 / 64.0
|
||||
optimizer = dict(
|
||||
|
@ -197,8 +195,6 @@ load_from = None
|
|||
resume_from = None
|
||||
workflow = [('train', 1)]
|
||||
|
||||
export = dict(use_jit=False)
|
||||
|
||||
# oss io config
|
||||
oss_io_config = dict(
|
||||
ak_id='your oss ak id',
|
||||
|
@ -206,3 +202,7 @@ oss_io_config = dict(
|
|||
hosts='oss-cn-zhangjiakou.aliyuncs.com',
|
||||
buckets=['your_bucket'])
|
||||
oss_sync_config = dict(other_file_list=['**/events.out.tfevents*', '**/*log*'])
|
||||
|
||||
# export model during training
|
||||
checkpoint_sync_export = True
|
||||
export = dict(export_type='raw')
|
||||
|
|
|
@ -152,8 +152,6 @@ eval_pipelines = [
|
|||
]
|
||||
|
||||
checkpoint_config = dict(interval='${interval}')
|
||||
# export model during training
|
||||
checkpoint_sync_export = True
|
||||
# optimizer
|
||||
# basic_lr_per_img = 0.01 / 64.0
|
||||
optimizer = dict(
|
||||
|
@ -193,5 +191,7 @@ load_from = None
|
|||
resume_from = None
|
||||
workflow = [('train', 1)]
|
||||
|
||||
export = dict(use_jit=False)
|
||||
oss_sync_config = dict(other_file_list=['**/events.out.tfevents*', '**/*log*'])
|
||||
# export model during training
|
||||
checkpoint_sync_export = True
|
||||
export = dict(export_type='raw')
|
||||
|
|
|
@ -43,6 +43,10 @@ CLASSES = [
|
|||
|
||||
# dataset settings
|
||||
data_root = 'data/coco/'
|
||||
train_ann_file = data_root + 'annotations/instances_train2017.json'
|
||||
train_img_prefix = data_root + 'train2017/'
|
||||
test_ann_file = data_root + 'annotations/instances_val2017.json'
|
||||
test_img_prefix = data_root + 'val2017/'
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
@ -83,8 +87,8 @@ train_dataset = dict(
|
|||
type='DetImagesMixDataset',
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_train2017.json',
|
||||
img_prefix=data_root + 'train2017/',
|
||||
ann_file=train_ann_file,
|
||||
img_prefix=train_img_prefix,
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile', to_float32=True),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
|
@ -100,8 +104,8 @@ val_dataset = dict(
|
|||
imgs_per_gpu=2,
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
||||
img_prefix=data_root + 'val2017/',
|
||||
ann_file=test_ann_file,
|
||||
img_prefix=test_img_prefix,
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile', to_float32=True),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
|
@ -188,4 +192,5 @@ log_config = dict(
|
|||
# dict(type='WandbLoggerHookV2'),
|
||||
])
|
||||
|
||||
checkpoint_sync_export = True
|
||||
export = dict(export_type = 'raw', preprocess_jit = False, batch_size=1, blade_config=dict(enable_fp16=True, fp16_fallback_op_ratio=0.01), use_trt_efficientnms=False)
|
||||
|
|
|
@ -26,24 +26,32 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
def export(cfg, ckpt_path, filename, **kwargs):
|
||||
def export(cfg, ckpt_path, filename, model=None, **kwargs):
|
||||
""" export model for inference
|
||||
|
||||
Args:
|
||||
cfg: Config object
|
||||
ckpt_path (str): path to checkpoint file
|
||||
filename (str): filename to save exported models
|
||||
model (nn.module): model instance
|
||||
"""
|
||||
|
||||
if hasattr(cfg.model, 'pretrained'):
|
||||
logging.warning(
|
||||
'Export needs to set pretrained to false to avoid hanging during distributed training'
|
||||
'Export needs to set model.pretrained to false to avoid hanging during distributed training'
|
||||
)
|
||||
cfg.model['pretrained'] = False
|
||||
cfg.model.pretrained = False
|
||||
|
||||
if model is None:
|
||||
model = build_model(cfg.model)
|
||||
|
||||
if ckpt_path != 'dummy':
|
||||
load_checkpoint(model, ckpt_path, map_location='cpu')
|
||||
else:
|
||||
if hasattr(cfg.model, 'backbone') and hasattr(cfg.model.backbone,
|
||||
'pretrained'):
|
||||
logging.warning(
|
||||
'Export needs to set model.backbone.pretrained to false to avoid hanging during distributed training'
|
||||
)
|
||||
cfg.model.backbone.pretrained = False
|
||||
|
||||
if isinstance(model, MOCO) or isinstance(model, DINO):
|
||||
|
|
|
@ -34,18 +34,20 @@ class ExportHook(Hook):
|
|||
'export_after_each_ckpt', False)
|
||||
|
||||
def export_model(self, runner, epoch):
|
||||
# epoch = runner.epoch
|
||||
ckpt_fname = self.ckpt_filename_tmpl.format(epoch)
|
||||
export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch)
|
||||
local_ckpt = os.path.join(self.work_dir, ckpt_fname)
|
||||
export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname)
|
||||
|
||||
if not os.path.exists(local_ckpt):
|
||||
runner.logger.warning(f'{local_ckpt} does not exists, skip export')
|
||||
else:
|
||||
runner.logger.info(f'export {local_ckpt} to {export_local_ckpt}')
|
||||
runner.logger.info(f'export model to {export_local_ckpt}')
|
||||
from easycv.apis.export import export
|
||||
export(self.cfg, local_ckpt, export_local_ckpt)
|
||||
if hasattr(runner.model, 'module'):
|
||||
model = runner.model.module
|
||||
else:
|
||||
model = runner.model
|
||||
export(
|
||||
self.cfg,
|
||||
ckpt_path='dummy',
|
||||
filename=export_local_ckpt,
|
||||
model=model)
|
||||
|
||||
@master_only
|
||||
def after_train_iter(self, runner):
|
||||
|
|
|
@ -111,12 +111,13 @@ class OSSSyncHook(Hook):
|
|||
# try to upload exported model
|
||||
epoch = runner.epoch
|
||||
export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch)
|
||||
export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname)
|
||||
export_oss_ckpt = os.path.join(self.oss_work_dir, export_ckpt_fname)
|
||||
if not os.path.exists(export_local_ckpt):
|
||||
runner.logger.warning(
|
||||
f'{export_local_ckpt} does not exists, skip upload')
|
||||
else:
|
||||
runner.logger.info(
|
||||
f'upload {export_local_ckpt} to {export_oss_ckpt}')
|
||||
io.safe_copy(export_local_ckpt, export_oss_ckpt)
|
||||
|
||||
# upload all export files
|
||||
export_files = glob.glob(
|
||||
os.path.join(self.work_dir, '*{}*'.format(export_ckpt_fname)),
|
||||
recursive=True)
|
||||
for export_file in export_files:
|
||||
rel_path = os.path.relpath(export_file, self.work_dir)
|
||||
target_oss_path = os.path.join(self.oss_work_dir, rel_path)
|
||||
runner.logger.info(f'upload {export_file} to {target_oss_path}')
|
||||
io.safe_copy(export_file, target_oss_path)
|
||||
|
|
Loading…
Reference in New Issue