mmpretrain/mmcls/apis/train.py
Ayush Thakur ccdbc82e39
[Feature] Dedicated MMClsWandbHook for MMClassification (Weights and Biases Integration) (#764)
* wandb integration

* visualize using wandb tables

* wandb tables enhanced

* Refactor MMClsWandbHook (#1)

* [Enhance] Add extra dataloader settings in configs. (#752)

* Use `train_dataloader`, `val_dataloader` and `test_dataloader` settings
in the `data` field to specify different arguments.

* Fix bug

* Fix bug

* [Enhance] Improve CPE performance by reduce memory copy. (#762)

* [Feature] Support resize relative position embedding in `SwinTransformer`. (#749)

* [Feature]: Add resize rel pos embed

* [Refactor]: Create a separated resize_rel_pos_bias_table func

* [Refactor]: Refactor rel pos embed bias

* [Refactor]: Move interpolate into func

* Remove index buffer only when window_size changes

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature] Add PoolFormer backbone and checkpoints. (#746)

* add PoolFormer

* fix some typos in PoolFormer

* fix lint error

* modify out_indices and gap

* fix typo

* fix lint

* fix typo

* fix typo in poolforemr README

* fix lint

* Update some paths

* Refactor freeze_stages method

* Add unit tests

* Fix lint

Co-authored-by: mzr1996 <mzr1996@163.com>

* Bump version to v0.22.1 (#785)

* [Docs] Refine API reference. (#774)

* [Docs] Refine API reference

* Add PoolFormer

* [Docs] Fix docs.

* [Enhance] Reduce the memory usage of unit tests for Swin-Transformer. (#759)

* [Feature] Support VAN. (#739)

* add van

* fix config

* add metafile

* add test

* model convert script

* fix review

* fix lint

* fix the configs and improve docs

* rm debug lines

* add VAN into api

Co-authored-by: Yu Zhaohui <1105212286@qq.com>

* [Feature] Support DenseNet. (#750)

* init add densenet implementation

* Add config and converted models

* update meta

* add test for memory efficient

* Add docs

* add doc for jit

* Update checkpoint path

* Update readthedocs

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Fix] Use symbolic link in the API reference of Chinese docs.

* [Enhance] Support training on IPU and add fine-tuning configs of ViT. (#723)

* implement training and evaluation on IPU

* fp16 SOTA

* Tput reaches 5600

* 123

* add poptorch dataloder

* change ipu_replicas to ipu-replicas

* add noqa to config long line(website)

* remove ipu dataloder test code

* del one blank line in test_builder

* refine the dataloder initialization

* fix a typo

* refine args for dataloder

* remove an annoted line

* process one more conflict

* adjust code structure in mmcv.ipu

* adjust ipu code structure in mmcv

* IPUDataloader to IPUDataLoader

* align with mmcv

* adjust according to mmcv

* mmcv code structre fixed

Co-authored-by: hudi <dihu@graphcore.ai>

* [Fix] Fix lint and mmcv version requirement for IPU.

* Bump version to v0.23.0 (#809)

* Refacoter Wandb hook and refine docstring

Co-authored-by: XiaobingZhang <xiaobing.zhang@intel.com>
Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Weihao Yu <1090924009@qq.com>
Co-authored-by: takuoko <to78314910@gmail.com>
Co-authored-by: Yu Zhaohui <1105212286@qq.com>
Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com>
Co-authored-by: Hu Di <476658825@qq.com>
Co-authored-by: hudi <dihu@graphcore.ai>

* shuffle val data

* minor updates

* minor fix

Co-authored-by: Ma Zerun <mzr1996@163.com>
Co-authored-by: XiaobingZhang <xiaobing.zhang@intel.com>
Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Weihao Yu <1090924009@qq.com>
Co-authored-by: takuoko <to78314910@gmail.com>
Co-authored-by: Yu Zhaohui <1105212286@qq.com>
Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com>
Co-authored-by: Hu Di <476658825@qq.com>
Co-authored-by: hudi <dihu@graphcore.ai>
2022-06-02 17:58:49 +08:00

241 lines
8.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
build_optimizer, build_runner, get_dist_info)
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import get_root_logger
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
device=None,
meta=None):
"""Train a model.
This method will build dataloaders, wrap the model and build a runner
according to the provided config.
Args:
model (:obj:`torch.nn.Module`): The model to be run.
dataset (:obj:`mmcls.datasets.BaseDataset` | List[BaseDataset]):
The dataset used to train the model. It can be a single dataset,
or a list of dataset with the same length as workflow.
cfg (:obj:`mmcv.utils.Config`): The configs of the experiment.
distributed (bool): Whether to train the model in a distributed
environment. Defaults to False.
validate (bool): Whether to do validation with
:obj:`mmcv.runner.EvalHook`. Defaults to False.
timestamp (str, optional): The timestamp string to auto generate the
name of log files. Defaults to None.
device (str, optional): TODO
meta (dict, optional): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
"""
logger = get_root_logger()
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# The default loader config
loader_cfg = dict(
# cfg.gpus will be ignored if distributed
num_gpus=cfg.ipu_replicas if device == 'ipu' else len(cfg.gpu_ids),
dist=distributed,
round_up=True,
seed=cfg.get('seed'),
sampler_cfg=cfg.get('sampler', None),
)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
# The specific dataloader settings
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if device == 'cpu':
warnings.warn(
'The argument `device` is deprecated. To use cpu to train, '
'please refers to https://mmclassification.readthedocs.io/en'
'/latest/getting_started.html#train-a-model')
model = model.cpu()
elif device == 'ipu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if not model.device_ids:
from mmcv import __version__, digit_version
assert digit_version(__version__) >= (1, 4, 4), \
'To train with CPU, please confirm your mmcv version ' \
'is not lower than v1.4.4'
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if cfg.get('runner') is None:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
if device == 'ipu':
if not cfg.runner['type'].startswith('IPU'):
cfg.runner['type'] = 'IPU' + cfg.runner['type']
if 'options_cfg' not in cfg.runner:
cfg.runner['options_cfg'] = {}
cfg.runner['options_cfg']['replicationFactor'] = cfg.ipu_replicas
cfg.runner['fp16_cfg'] = cfg.get('fp16', None)
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
batch_processor=None,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
if device == 'ipu':
from mmcv.device.ipu import IPUFp16OptimizerHook
optimizer_config = IPUFp16OptimizerHook(
**cfg.optimizer_config,
loss_scale=fp16_cfg['loss_scale'],
distributed=distributed)
else:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config,
loss_scale=fp16_cfg['loss_scale'],
distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
if distributed and cfg.runner['type'] == 'EpochBasedRunner':
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
# The specific dataloader settings
val_loader_cfg = {
**loader_cfg,
'shuffle': False, # Not shuffle by default
'sampler_cfg': None, # Not use sampler by default
**cfg.data.get('val_dataloader', {}),
}
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
# `EvalHook` needs to be executed after `IterTimerHook`.
# Otherwise, it will cause a bug if use `IterBasedRunner`.
# Refers to https://github.com/open-mmlab/mmcv/issues/1261
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)