mmselfsup/openselfsup/apis/train.py

276 lines
9.6 KiB
Python

import random
import re
from collections import OrderedDict
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict
from openselfsup.datasets import build_dataloader
from openselfsup.hooks import build_hook, DistOptimizerHook
from openselfsup.utils import get_root_logger, optimizers, print_log
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def parse_losses(losses):
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
'{} is not a tensor or list of tensors'.format(loss_name))
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def batch_processor(model, data, train_mode):
"""Process a data batch.
This method is required as an argument of Runner, which defines how to
process a data batch and obtain proper outputs. The first 3 arguments of
batch_processor are fixed.
Args:
model (nn.Module): A PyTorch model.
data (dict): The data batch in a dict.
train_mode (bool): Training mode or not. It may be useless for some
models.
Returns:
dict: A dict containing losses and log vars.
"""
assert model.training, "Must be in training mode."
losses = model(**data)
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
def train_model(model,
dataset,
cfg,
distributed=False,
timestamp=None,
meta=None):
logger = get_root_logger(cfg.log_level)
# start training
if distributed:
_dist_train(
model, dataset, cfg, logger=logger, timestamp=timestamp, meta=meta)
else:
_non_dist_train(
model, dataset, cfg, logger=logger, timestamp=timestamp, meta=meta)
def build_optimizer(model, optimizer_cfg):
"""Build optimizer from configs.
Args:
model (:obj:`nn.Module`): The model with parameters to be optimized.
optimizer_cfg (dict): The config dict of the optimizer.
Positional fields are:
- type: class name of the optimizer.
- lr: base learning rate.
Optional fields are:
- any arguments of the corresponding optimizer type, e.g.,
weight_decay, momentum, etc.
- paramwise_options: a dict with regular expression as keys
to match parameter names and a dict containing options as
values. Options include 6 fields: lr, lr_mult, momentum,
momentum_mult, weight_decay, weight_decay_mult.
Returns:
torch.optim.Optimizer: The initialized optimizer.
Example:
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
>>> paramwise_options = {
>>> '(bn|gn)(\d+)?.(weight|bias)': dict(weight_decay_mult=0.1),
>>> '\Ahead.': dict(lr_mult=10, momentum=0)}
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
>>> weight_decay=0.0001,
>>> paramwise_options=paramwise_options)
>>> optimizer = build_optimizer(model, optimizer_cfg)
"""
if hasattr(model, 'module'):
model = model.module
optimizer_cfg = optimizer_cfg.copy()
paramwise_options = optimizer_cfg.pop('paramwise_options', None)
# if no paramwise option is specified, just use the global setting
if paramwise_options is None:
return obj_from_dict(optimizer_cfg, optimizers,
dict(params=model.parameters()))
else:
assert isinstance(paramwise_options, dict)
params = []
for name, param in model.named_parameters():
param_group = {'params': [param]}
if not param.requires_grad:
params.append(param_group)
continue
for regexp, options in paramwise_options.items():
if re.search(regexp, name):
for key, value in options.items():
if key.endswith('_mult'): # is a multiplier
key = key[:-5]
assert key in optimizer_cfg, \
"{} not in optimizer_cfg".format(key)
value = optimizer_cfg[key] * value
param_group[key] = value
if not dist.is_initialized() or dist.get_rank() == 0:
print_log('paramwise_options -- {}: {}={}'.format(
name, key, value))
# otherwise use the global settings
params.append(param_group)
optimizer_cls = getattr(optimizers, optimizer_cfg.pop('type'))
return optimizer_cls(params, **optimizer_cfg)
def _dist_train(model, dataset, cfg, logger=None, timestamp=None, meta=None):
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
dist=True,
shuffle=True,
replace=getattr(cfg.data, 'sampling_replace', False),
seed=cfg.seed,
drop_last=getattr(cfg.data, 'drop_last', False)) for ds in dataset
]
# put model on gpus
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(
model,
batch_processor,
optimizer,
cfg.work_dir,
logger=logger,
meta=meta)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
runner.register_hook(DistSamplerSeedHook())
# register custom hooks
for hook in cfg.get('custom_hooks', ()):
if hook.type == 'DeepClusterHook':
common_params = dict(dist_mode=True, data_loaders=data_loaders)
else:
common_params = dict(dist_mode=True)
runner.register_hook(build_hook(hook, common_params))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
def _non_dist_train(model,
dataset,
cfg,
validate=False,
logger=None,
timestamp=None,
meta=None):
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
data_loaders = [
build_dataloader(
ds,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False,
shuffle=True,
replace=getattr(cfg.data, 'sampling_replace', False),
seed=cfg.seed,
drop_last=getattr(cfg.data, 'drop_last', False)) for ds in dataset
]
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = Runner(
model,
batch_processor,
optimizer,
cfg.work_dir,
logger=logger,
meta=meta)
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
optimizer_config = cfg.optimizer_config
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
# register custom hooks
for hook in cfg.get('custom_hooks', ()):
if hook.type == 'DeepClusterHook':
common_params = dict(dist_mode=False, data_loaders=data_loaders)
else:
common_params = dict(dist_mode=False)
runner.register_hook(build_hook(hook, common_params))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)