mmpretrain/mmcls/apis/train.py

118 lines
3.9 KiB
Python
Raw Normal View History

2020-05-21 21:21:43 +08:00
import random
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, EpochBasedRunner, build_optimizer
2020-05-21 21:21:43 +08:00
from mmcls.core import (DistEvalHook, DistOptimizerHook, EvalHook,
Fp16OptimizerHook)
2020-05-21 21:21:43 +08:00
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import get_root_logger
def set_random_seed(seed, deterministic=False):
"""Set random seed.
2020-05-21 21:21:43 +08:00
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_model(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
2020-05-21 21:21:43 +08:00
dist=distributed,
round_up=True,
2020-05-21 21:21:43 +08:00
seed=cfg.seed) for ds in dataset
]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = EpochBasedRunner(
2020-05-21 21:21:43 +08:00
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
2020-05-21 21:21:43 +08:00
logger=logger,
meta=meta)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=cfg.data.samples_per_gpu,
2020-05-21 21:21:43 +08:00
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False,
round_up=False)
2020-05-21 21:21:43 +08:00
eval_cfg = cfg.get('evaluation', {})
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)