2020-05-21 21:21:43 +08:00
|
|
|
import random
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
2020-07-07 19:32:06 +08:00
|
|
|
from mmcv.runner import DistSamplerSeedHook, EpochBasedRunner, build_optimizer
|
2020-05-21 21:21:43 +08:00
|
|
|
|
|
|
|
from mmcls.core import (DistEvalHook, DistOptimizerHook, EvalHook,
|
2020-07-07 19:32:06 +08:00
|
|
|
Fp16OptimizerHook)
|
2020-05-21 21:21:43 +08:00
|
|
|
from mmcls.datasets import build_dataloader, build_dataset
|
|
|
|
from mmcls.utils import get_root_logger
|
|
|
|
|
|
|
|
|
|
|
|
def set_random_seed(seed, deterministic=False):
|
|
|
|
"""Set random seed.
|
2020-07-07 19:32:06 +08:00
|
|
|
|
2020-05-21 21:21:43 +08:00
|
|
|
Args:
|
|
|
|
seed (int): Seed to be used.
|
|
|
|
deterministic (bool): Whether to set the deterministic option for
|
|
|
|
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
|
|
|
to True and `torch.backends.cudnn.benchmark` to False.
|
|
|
|
Default: False.
|
|
|
|
"""
|
|
|
|
random.seed(seed)
|
|
|
|
np.random.seed(seed)
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
if deterministic:
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(model,
|
|
|
|
dataset,
|
|
|
|
cfg,
|
|
|
|
distributed=False,
|
|
|
|
validate=False,
|
|
|
|
timestamp=None,
|
|
|
|
meta=None):
|
|
|
|
logger = get_root_logger(cfg.log_level)
|
|
|
|
|
|
|
|
# prepare data loaders
|
|
|
|
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
|
|
|
|
|
|
|
data_loaders = [
|
|
|
|
build_dataloader(
|
|
|
|
ds,
|
|
|
|
cfg.data.samples_per_gpu,
|
|
|
|
cfg.data.workers_per_gpu,
|
|
|
|
# cfg.gpus will be ignored if distributed
|
2020-07-07 19:32:06 +08:00
|
|
|
num_gpus=len(cfg.gpu_ids),
|
2020-05-21 21:21:43 +08:00
|
|
|
dist=distributed,
|
2020-07-07 19:32:06 +08:00
|
|
|
round_up=True,
|
2020-05-21 21:21:43 +08:00
|
|
|
seed=cfg.seed) for ds in dataset
|
|
|
|
]
|
|
|
|
|
|
|
|
# put model on gpus
|
|
|
|
if distributed:
|
|
|
|
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
|
|
|
# Sets the `find_unused_parameters` parameter in
|
|
|
|
# torch.nn.parallel.DistributedDataParallel
|
|
|
|
model = MMDistributedDataParallel(
|
|
|
|
model.cuda(),
|
|
|
|
device_ids=[torch.cuda.current_device()],
|
|
|
|
broadcast_buffers=False,
|
|
|
|
find_unused_parameters=find_unused_parameters)
|
|
|
|
else:
|
|
|
|
model = MMDataParallel(
|
|
|
|
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
|
|
|
|
|
|
|
|
# build runner
|
|
|
|
optimizer = build_optimizer(model, cfg.optimizer)
|
2020-07-07 19:32:06 +08:00
|
|
|
runner = EpochBasedRunner(
|
2020-05-21 21:21:43 +08:00
|
|
|
model,
|
2020-07-07 19:32:06 +08:00
|
|
|
optimizer=optimizer,
|
|
|
|
work_dir=cfg.work_dir,
|
2020-05-21 21:21:43 +08:00
|
|
|
logger=logger,
|
|
|
|
meta=meta)
|
|
|
|
# an ugly walkaround to make the .log and .log.json filenames the same
|
|
|
|
runner.timestamp = timestamp
|
|
|
|
|
|
|
|
# fp16 setting
|
|
|
|
fp16_cfg = cfg.get('fp16', None)
|
|
|
|
if fp16_cfg is not None:
|
|
|
|
optimizer_config = Fp16OptimizerHook(
|
|
|
|
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
|
|
|
|
elif distributed and 'type' not in cfg.optimizer_config:
|
|
|
|
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
|
|
|
|
else:
|
|
|
|
optimizer_config = cfg.optimizer_config
|
|
|
|
|
|
|
|
# register hooks
|
|
|
|
runner.register_training_hooks(cfg.lr_config, optimizer_config,
|
|
|
|
cfg.checkpoint_config, cfg.log_config,
|
|
|
|
cfg.get('momentum_config', None))
|
|
|
|
if distributed:
|
|
|
|
runner.register_hook(DistSamplerSeedHook())
|
|
|
|
|
|
|
|
# register eval hooks
|
|
|
|
if validate:
|
|
|
|
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
|
|
|
val_dataloader = build_dataloader(
|
|
|
|
val_dataset,
|
2020-07-07 19:32:06 +08:00
|
|
|
samples_per_gpu=cfg.data.samples_per_gpu,
|
2020-05-21 21:21:43 +08:00
|
|
|
workers_per_gpu=cfg.data.workers_per_gpu,
|
|
|
|
dist=distributed,
|
2020-07-07 19:32:06 +08:00
|
|
|
shuffle=False,
|
|
|
|
round_up=False)
|
2020-05-21 21:21:43 +08:00
|
|
|
eval_cfg = cfg.get('evaluation', {})
|
|
|
|
eval_hook = DistEvalHook if distributed else EvalHook
|
|
|
|
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
|
|
|
|
|
|
|
if cfg.resume_from:
|
|
|
|
runner.resume(cfg.resume_from)
|
|
|
|
elif cfg.load_from:
|
|
|
|
runner.load_checkpoint(cfg.load_from)
|
|
|
|
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
|