mmpretrain/mmcls/apis/train.py

229 lines
8.1 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
2020-05-21 21:21:43 +08:00
import random
import warnings
2020-05-21 21:21:43 +08:00
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
build_optimizer, build_runner, get_dist_info)
2020-05-21 21:21:43 +08:00
[Feature] Dedicated MMClsWandbHook for MMClassification (Weights and Biases Integration) (#764) * wandb integration * visualize using wandb tables * wandb tables enhanced * Refactor MMClsWandbHook (#1) * [Enhance] Add extra dataloader settings in configs. (#752) * Use `train_dataloader`, `val_dataloader` and `test_dataloader` settings in the `data` field to specify different arguments. * Fix bug * Fix bug * [Enhance] Improve CPE performance by reduce memory copy. (#762) * [Feature] Support resize relative position embedding in `SwinTransformer`. (#749) * [Feature]: Add resize rel pos embed * [Refactor]: Create a separated resize_rel_pos_bias_table func * [Refactor]: Refactor rel pos embed bias * [Refactor]: Move interpolate into func * Remove index buffer only when window_size changes Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature] Add PoolFormer backbone and checkpoints. (#746) * add PoolFormer * fix some typos in PoolFormer * fix lint error * modify out_indices and gap * fix typo * fix lint * fix typo * fix typo in poolforemr README * fix lint * Update some paths * Refactor freeze_stages method * Add unit tests * Fix lint Co-authored-by: mzr1996 <mzr1996@163.com> * Bump version to v0.22.1 (#785) * [Docs] Refine API reference. (#774) * [Docs] Refine API reference * Add PoolFormer * [Docs] Fix docs. * [Enhance] Reduce the memory usage of unit tests for Swin-Transformer. (#759) * [Feature] Support VAN. (#739) * add van * fix config * add metafile * add test * model convert script * fix review * fix lint * fix the configs and improve docs * rm debug lines * add VAN into api Co-authored-by: Yu Zhaohui <1105212286@qq.com> * [Feature] Support DenseNet. (#750) * init add densenet implementation * Add config and converted models * update meta * add test for memory efficient * Add docs * add doc for jit * Update checkpoint path * Update readthedocs Co-authored-by: mzr1996 <mzr1996@163.com> * [Fix] Use symbolic link in the API reference of Chinese docs. * [Enhance] Support training on IPU and add fine-tuning configs of ViT. (#723) * implement training and evaluation on IPU * fp16 SOTA * Tput reaches 5600 * 123 * add poptorch dataloder * change ipu_replicas to ipu-replicas * add noqa to config long line(website) * remove ipu dataloder test code * del one blank line in test_builder * refine the dataloder initialization * fix a typo * refine args for dataloder * remove an annoted line * process one more conflict * adjust code structure in mmcv.ipu * adjust ipu code structure in mmcv * IPUDataloader to IPUDataLoader * align with mmcv * adjust according to mmcv * mmcv code structre fixed Co-authored-by: hudi <dihu@graphcore.ai> * [Fix] Fix lint and mmcv version requirement for IPU. * Bump version to v0.23.0 (#809) * Refacoter Wandb hook and refine docstring Co-authored-by: XiaobingZhang <xiaobing.zhang@intel.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Weihao Yu <1090924009@qq.com> Co-authored-by: takuoko <to78314910@gmail.com> Co-authored-by: Yu Zhaohui <1105212286@qq.com> Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: Hu Di <476658825@qq.com> Co-authored-by: hudi <dihu@graphcore.ai> * shuffle val data * minor updates * minor fix Co-authored-by: Ma Zerun <mzr1996@163.com> Co-authored-by: XiaobingZhang <xiaobing.zhang@intel.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Weihao Yu <1090924009@qq.com> Co-authored-by: takuoko <to78314910@gmail.com> Co-authored-by: Yu Zhaohui <1105212286@qq.com> Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: Hu Di <476658825@qq.com> Co-authored-by: hudi <dihu@graphcore.ai>
2022-06-02 17:58:49 +08:00
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
2020-05-21 21:21:43 +08:00
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import (get_root_logger, wrap_distributed_model,
wrap_non_distributed_model)
2020-05-21 21:21:43 +08:00
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
2020-05-21 21:21:43 +08:00
def set_random_seed(seed, deterministic=False):
"""Set random seed.
2020-05-21 21:21:43 +08:00
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
device=None,
2020-05-21 21:21:43 +08:00
meta=None):
"""Train a model.
This method will build dataloaders, wrap the model and build a runner
according to the provided config.
Args:
model (:obj:`torch.nn.Module`): The model to be run.
dataset (:obj:`mmcls.datasets.BaseDataset` | List[BaseDataset]):
The dataset used to train the model. It can be a single dataset,
or a list of dataset with the same length as workflow.
cfg (:obj:`mmcv.utils.Config`): The configs of the experiment.
distributed (bool): Whether to train the model in a distributed
environment. Defaults to False.
validate (bool): Whether to do validation with
:obj:`mmcv.runner.EvalHook`. Defaults to False.
timestamp (str, optional): The timestamp string to auto generate the
name of log files. Defaults to None.
device (str, optional): TODO
meta (dict, optional): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
"""
logger = get_root_logger()
2020-05-21 21:21:43 +08:00
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# The default loader config
loader_cfg = dict(
# cfg.gpus will be ignored if distributed
num_gpus=cfg.ipu_replicas if device == 'ipu' else len(cfg.gpu_ids),
dist=distributed,
round_up=True,
seed=cfg.get('seed'),
sampler_cfg=cfg.get('sampler', None),
)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
# The specific dataloader settings
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
2020-05-21 21:21:43 +08:00
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = wrap_distributed_model(
model,
cfg.device,
2020-05-21 21:21:43 +08:00
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = wrap_non_distributed_model(
model, cfg.device, device_ids=cfg.gpu_ids)
2020-05-21 21:21:43 +08:00
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if cfg.get('runner') is None:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
if device == 'ipu':
if not cfg.runner['type'].startswith('IPU'):
cfg.runner['type'] = 'IPU' + cfg.runner['type']
if 'options_cfg' not in cfg.runner:
cfg.runner['options_cfg'] = {}
cfg.runner['options_cfg']['replicationFactor'] = cfg.ipu_replicas
cfg.runner['fp16_cfg'] = cfg.get('fp16', None)
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
batch_processor=None,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
2020-05-21 21:21:43 +08:00
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
if device == 'ipu':
from mmcv.device.ipu import IPUFp16OptimizerHook
optimizer_config = IPUFp16OptimizerHook(
**cfg.optimizer_config,
loss_scale=fp16_cfg['loss_scale'],
distributed=distributed)
else:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config,
loss_scale=fp16_cfg['loss_scale'],
distributed=distributed)
2020-05-21 21:21:43 +08:00
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
if distributed and cfg.runner['type'] == 'EpochBasedRunner':
2020-05-21 21:21:43 +08:00
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
# The specific dataloader settings
val_loader_cfg = {
**loader_cfg,
'shuffle': False, # Not shuffle by default
'sampler_cfg': None, # Not use sampler by default
**cfg.data.get('val_dataloader', {}),
}
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
2020-05-21 21:21:43 +08:00
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
2020-05-21 21:21:43 +08:00
eval_hook = DistEvalHook if distributed else EvalHook
# `EvalHook` needs to be executed after `IterTimerHook`.
# Otherwise, it will cause a bug if use `IterBasedRunner`.
# Refers to https://github.com/open-mmlab/mmcv/issues/1261
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
2020-05-21 21:21:43 +08:00
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)