diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index d079936a..2e14a9b3 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -59,6 +59,12 @@ design/visualization.md design/logging.md +.. toctree:: + :maxdepth: 1 + :caption: 迁移指南 + + migration/migrate_runner_from_mmcv.md + .. toctree:: :maxdepth: 2 :caption: API 文档 diff --git a/docs/zh_cn/migration/migrate_runner_from_mmcv.md b/docs/zh_cn/migration/migrate_runner_from_mmcv.md new file mode 100644 index 00000000..223e1854 --- /dev/null +++ b/docs/zh_cn/migration/migrate_runner_from_mmcv.md @@ -0,0 +1,1478 @@ +# 迁移 MMCV 执行器到 MMEngine + +## 简介 + +随着支持的深度学习任务越来越多,用户的需求不断增加,我们对 MMCV 已有的执行器(Runner)的灵活性和通用性有了更高的要求。 +因此,MMEngine 在 MMCV 的基础上,实现了一个更加通用灵活的执行器以支持更多复杂的模型训练流程。 +MMEngine 中的执行器扩大了作用域,也承担了更多的功能;我们抽象出了[训练循环控制器(EpochBasedTrainLoop/IterBasedTrainLoop)](mmengine.runner.EpochBasedLoop)、[验证循环控制器(ValLoop)](mmengine.runner.ValLoop)和[测试循环控制器(TestLoop)](mmengine.runner.TestLoop)来方便用户灵活拓展模型的执行流程。 + +我们将首先介绍算法库的执行入口该如何从 MMCV 迁移到 MMEngine, 以最大程度地简化和统一执行入口的代码。 +然后我们将详细介绍在 MMCV 和 MMEngine 中构造执行器及其内部组件进行训练的差异。 +在开始迁移前,我们建议用户先阅读[执行器教程](../tutorials/runner.md)。 + +## 执行入口 + +以 MMDet 为例,我们首先展示基于 MMEngine 重构前后,配置文件和训练启动脚本的区别: + +### 配置文件的迁移 + +
+ | 基于 MMCV 执行器的配置文件概览 | +基于 MMEngine 执行器的配置文件概览 | +
---|---|---|
default_runtime.py | ++ +```python +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +custom_hooks = [dict(type='NumClassCheckHook')] + +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + + +opencv_num_threads = 0 +mp_start_method = 'fork' +auto_scale_lr = dict(enable=False, base_batch_size=16) +``` + + | ++ +```python +default_scope = 'mmdet' + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=1), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='DetVisualizationHook')) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True) + +log_level = 'INFO' +load_from = None +resume = False +``` + + | +
scheduler.py | ++ +```python +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[8, 11]) +runner = dict(type='EpochBasedRunner', max_epochs=12) +``` + + | ++ +```python +# training schedule for 1x +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)) + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=16) +``` + + | +
coco_detection.py | ++ +```python +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(interval=1, metric='bbox') +``` + + | + ++ +```python +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' + +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='PackDetInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + # If you don't have a gt annotation, delete the pipeline + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'annotations/instances_val2017.json', + metric='bbox', + format_only=False) +test_evaluator = val_evaluator +``` + + | + +
+ | 基于 MMCV 执行器的训练启动脚本 | +基于 MMEngine 执行器的训练启动脚本 | +
---|---|---|
tools/train.py | ++ +```python +args = parse_args() + +cfg = Config.fromfile(args.config) + +# replace the ${key} with the value of cfg.key +cfg = replace_cfg_vals(cfg) + +# update data root according to MMDET_DATASETS +update_data_root(cfg) + +if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + +if args.auto_scale_lr: + if 'auto_scale_lr' in cfg and \ + 'enable' in cfg.auto_scale_lr and \ + 'base_batch_size' in cfg.auto_scale_lr: + cfg.auto_scale_lr.enable = True + else: + warnings.warn('Can not find "auto_scale_lr" or ' + '"auto_scale_lr.enable" or ' + '"auto_scale_lr.base_batch_size" in your' + ' configuration file. Please update all the ' + 'configuration files to mmdet >= 2.24.1.') + +# set multi-process settings +setup_multi_processes(cfg) + +# set cudnn_benchmark +if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + +# work_dir is determined in this priority: CLI > segment in file > filename +if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir +elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + +if args.resume_from is not None: + cfg.resume_from = args.resume_from +cfg.auto_resume = args.auto_resume +if args.gpus is not None: + cfg.gpu_ids = range(1) + warnings.warn('`--gpus` is deprecated because we only support ' + 'single GPU mode in non-distributed training. ' + 'Use `gpus=1` now.') +if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids[0:1] + warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' + 'Because we only support single GPU mode in ' + 'non-distributed training. Use the first GPU ' + 'in `gpu_ids` now.') +if args.gpus is None and args.gpu_ids is None: + cfg.gpu_ids = [args.gpu_id] + +# init distributed env first, since logger depends on the dist info. +if args.launcher == 'none': + distributed = False +else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + +# create work_dir +mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) +# dump config +cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) +# init the logger before other steps +timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) +log_file = osp.join(cfg.work_dir, f'{timestamp}.log') +logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + +# init the meta dict to record some important information such as +# environment info and seed, which will be logged +meta = dict() +# log env info +env_info_dict = collect_env() +env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) +dash_line = '-' * 60 + '\n' +logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) +meta['env_info'] = env_info +meta['config'] = cfg.pretty_text +# log some basic info +logger.info(f'Distributed training: {distributed}') +logger.info(f'Config:\n{cfg.pretty_text}') + +cfg.device = get_device() +# set random seeds +seed = init_random_seed(args.seed, device=cfg.device) +seed = seed + dist.get_rank() if args.diff_seed else seed +logger.info(f'Set random seed to {seed}, ' + f'deterministic: {args.deterministic}') +set_random_seed(seed, deterministic=args.deterministic) +cfg.seed = seed +meta['seed'] = seed +meta['exp_name'] = osp.basename(args.config) + +model = build_detector( + cfg.model, + train_cfg=cfg.get('train_cfg'), + test_cfg=cfg.get('test_cfg')) +model.init_weights() + +datasets = [] +train_detector( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta) +``` + + | ++ +```python +args = parse_args() + +# register all modules in mmdet into the registries +# do not init the default scope here because it will be init in the runner +register_all_modules(init_default_scope=False) + +# load config +cfg = Config.fromfile(args.config) +cfg.launcher = args.launcher +if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + +# work_dir is determined in this priority: CLI > segment in file > filename +if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir +elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + +# enable automatic-mixed-precision training +if args.amp is True: + optim_wrapper = cfg.optim_wrapper.type + if optim_wrapper == 'AmpOptimWrapper': + print_log( + 'AMP training is already enabled in your config.', + logger='current', + level=logging.WARNING) + else: + assert optim_wrapper == 'OptimWrapper', ( + '`--amp` is only supported when the optimizer wrapper type is ' + f'`OptimWrapper` but got {optim_wrapper}.') + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.loss_scale = 'dynamic' + +# enable automatically scaling LR +if args.auto_scale_lr: + if 'auto_scale_lr' in cfg and \ + 'enable' in cfg.auto_scale_lr and \ + 'base_batch_size' in cfg.auto_scale_lr: + cfg.auto_scale_lr.enable = True + else: + raise RuntimeError('Can not find "auto_scale_lr" or ' + '"auto_scale_lr.enable" or ' + '"auto_scale_lr.base_batch_size" in your' + ' configuration file.') + +cfg.resume = args.resume + +# build the runner from config +if 'runner_type' not in cfg: + # build the default runner + runner = Runner.from_cfg(cfg) +else: + # build customized runner from the registry + # if 'runner_type' is set in the cfg + runner = RUNNERS.build(cfg) + +# start training +runner.train() +``` + + | +
apis/train.py | ++ +```python +def init_random_seed(...): + ... + +def set_random_seed(...): + ... + +# define function tools. +... + + +def train_detector(model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + meta=None): + + cfg = compat_cfg(cfg) + logger = get_root_logger(log_level=cfg.log_level) + + # put model on gpus + if distributed: + find_unused_parameters = cfg.get('find_unused_parameters', False) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + model = build_ddp( + model, + cfg.device, + device_ids=[int(os.environ['LOCAL_RANK'])], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids) + + # build optimizer + auto_scale_lr(cfg, distributed, logger) + optimizer = build_optimizer(model, cfg.optimizer) + + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) + + # an ugly workaround to make .log and .log.json filenames the same + runner.timestamp = timestamp + + # fp16 setting + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + optimizer_config = Fp16OptimizerHook( + **cfg.optimizer_config, **fp16_cfg, distributed=distributed) + elif distributed and 'type' not in cfg.optimizer_config: + optimizer_config = OptimizerHook(**cfg.optimizer_config) + else: + optimizer_config = cfg.optimizer_config + + # register hooks + runner.register_training_hooks( + cfg.lr_config, + optimizer_config, + cfg.checkpoint_config, + cfg.log_config, + cfg.get('momentum_config', None), + custom_hooks_config=cfg.get('custom_hooks', None)) + + if distributed: + if isinstance(runner, EpochBasedRunner): + runner.register_hook(DistSamplerSeedHook()) + + # register eval hooks + if validate: + val_dataloader_default_args = dict( + samples_per_gpu=1, + workers_per_gpu=2, + dist=distributed, + shuffle=False, + persistent_workers=False) + + val_dataloader_args = { + **val_dataloader_default_args, + **cfg.data.get('val_dataloader', {}) + } + # Support batch_size > 1 in validation + + if val_dataloader_args['samples_per_gpu'] > 1: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + cfg.data.val.pipeline = replace_ImageToTensor( + cfg.data.val.pipeline) + val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) + + val_dataloader = build_dataloader(val_dataset, **val_dataloader_args) + eval_cfg = cfg.get('evaluation', {}) + eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' + eval_hook = DistEvalHook if distributed else EvalHook + # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the + # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'. + runner.register_hook( + eval_hook(val_dataloader, **eval_cfg), priority='LOW') + + resume_from = None + if cfg.resume_from is None and cfg.get('auto_resume'): + resume_from = find_latest_checkpoint(cfg.work_dir) + if resume_from is not None: + cfg.resume_from = resume_from + + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow) +``` + + | ++ +```python +# `apis/train.py` is removed in `mmengine` +``` + + | +
MMCV 配置 | +MMEngine 配置 | +
---|---|
+ +```python +seed = 1 +deterministic=False +diff_seed=False +``` + + | ++ +```python +randomness=dict(seed=1, + deterministic=True, + diff_rank_seed=False) +``` + + | +
MMCV 配置 | +MMEngine 配置 | +
---|---|
+ +```python +launcher = 'pytorch' # 开启分布式训练 +dist_params = dict(backend='nccl') # 选择多进程通信后端 +``` + + | ++ +```python +launcher = 'pytorch' +env_cfg = dict(dist_cfg=dict(backend='nccl')) +``` + + | +
MMCV 配置 | +MMEngine 配置 | +
---|---|
+ +```python +data = dict( + samples_per_gpu=2, # 单卡 batch_size + workers_per_gpu=2, # Dataloader 采样进程数 + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +``` + + | ++ +```python +train_dataloader = dict( + batch_size=2, # samples_per_gpu -> batch_size, + num_workers=2, + # 遍历完 DataLoader 后,是否重启多进程采样 + persistent_workers=True, + # 可配置的 sampler + sampler=dict(type='DefaultSampler', shuffle=True), + # 可配置的 batch_sampler + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, # 验证阶段的 batch_size + num_workers=2, + persistent_workers=True, + drop_last=False, # 是否丢弃最后一个 batch + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline)) + +test_dataloader = val_dataloader +``` + + | +
MMCV 配置 | +MMEngine 配置 | +
---|---|
+ +```python +optimizer = dict( + constructor='CustomConstructor', + type='AdamW', # 优化器配置为一级字段 + lr=0.0001, # 优化器配置为一级字段 + betas=(0.9, 0.999), # 优化器配置为一级字段 + weight_decay=0.05, # 优化器配置为一级字段 + paramwise_cfg={ # constructor 的参数 + 'decay_rate': 0.95, + 'decay_type': 'layer_wise', + 'num_layers': 6 + }) +# MMEngine 还需要配置 `optim_config` +# 来构建优化器钩子,而 MMEngine 不需要 +optimizer_config = dict(grad_clip=None) +``` + + | ++ +```python +optim_wrapper = dict( + constructor='CustomConstructor', + type='OptimWrapper', # 指定优化器封装类型 + optimizer=dict( # 将优化器配置集中在 optimizer 内 + type='AdamW', + lr=0.0001, + betas=(0.9, 0.999), + weight_decay=0.05) + paramwise_cfg={ + 'decay_rate': 0.95, + 'decay_type': 'layer_wise', + 'num_layers': 6 + }) +``` + + | +
MMCV 常用训练钩子 | +MMEngine 默认钩子 | +
---|---|
+ +```python +# MMCV 零散的配置训练钩子 +# 配置 LrUpdaterHook,相当于 MMEngine 的参数调度器 +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[8, 11]) + +# 配置 OptimizerHook,MMEngine 不需要 +optimizer_config = dict(grad_clip=None) + +# 配置 LoggerHook +log_config = dict( # LoggerHook + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +# 配置 CheckPointHook +checkpoint_config = dict(interval=1) # CheckPointHook +``` + + | ++ +```python +# 配置参数调度器 +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] + +# MMEngine 集中配置默认钩子 +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=1), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='DetVisualizationHook')) +``` + + | +
MMCV 配置验证流程 | +MMEngine 配置验证流程 | +
---|---|
+ +```python +eval_cfg = cfg.get('evaluation', {}) +eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' +eval_hook = DistEvalHook if distributed else EvalHook # 配置 EvalHook +runner.register_hook( + eval_hook(val_dataloader, **eval_cfg), priority='LOW') # 注册 EvalHook +``` + + | ++ +```python +val_dataloader = val_dataloader # 配置验证数据 +val_evaluator = dict(type='ToyAccuracyMetric') # 配置评测器 +val_cfg = dict(type='ValLoop') # 配置验证循环控制器 +``` + + | +
+ | MMCV 加载检查点配置 | +MMEngine 加载检查点配置 | +
---|---|---|
加载检查点 | ++ +```python +load_from = 'path/to/ckpt' +``` + + | ++ +```python +load_from = 'path/to/ckpt' +resume = False +``` + + | +
恢复检查点 | ++ +```python +resume_from = 'path/to/ckpt' +``` + + | ++ +```python +load_from = 'path/to/ckpt' +resume = True +``` + + | +