mirror of https://github.com/alibaba/EasyCV.git
460 lines
18 KiB
Python
460 lines
18 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import random
|
|
import re
|
|
from distutils.version import LooseVersion
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
|
from mmcv.runner import DistSamplerSeedHook, obj_from_dict
|
|
from mmcv.runner.dist_utils import get_dist_info
|
|
|
|
from easycv.apis.train_misc import build_yolo_optimizer
|
|
from easycv.core import optimizer
|
|
from easycv.core.evaluation.builder import build_evaluator
|
|
from easycv.core.evaluation.metric_registry import METRICS
|
|
from easycv.core.optimizer import build_optimizer_constructor
|
|
from easycv.datasets import build_dataloader, build_dataset
|
|
from easycv.datasets.utils import is_dali_dataset_type
|
|
from easycv.hooks import (BestCkptSaverHook, DistEvalHook, EMAHook, EvalHook,
|
|
ExportHook, OptimizerHook, OSSSyncHook, build_hook)
|
|
from easycv.hooks.optimizer_hook import AMPFP16OptimizerHook
|
|
from easycv.runner import EVRunner
|
|
from easycv.utils.eval_utils import generate_best_metric_name
|
|
from easycv.utils.logger import get_root_logger, print_log
|
|
from easycv.utils.torchacc_util import is_torchacc_enabled
|
|
|
|
|
|
def init_random_seed(seed=None, device='cuda'):
|
|
"""Initialize random seed.
|
|
If the seed is not set, the seed will be automatically randomized,
|
|
and then broadcast to all processes to prevent some potential bugs.
|
|
Args:
|
|
seed (int, Optional): The seed. Default to None.
|
|
device (str): The device where the seed will be put on.
|
|
Default to 'cuda'.
|
|
Returns:
|
|
int: Seed to be used.
|
|
"""
|
|
if seed is not None:
|
|
return seed
|
|
|
|
# Make sure all ranks share the same random seed to prevent
|
|
# some potential bugs. Please refer to
|
|
# https://github.com/open-mmlab/mmdetection/issues/6339
|
|
rank, world_size = get_dist_info()
|
|
seed = np.random.randint(2**31)
|
|
if world_size == 1:
|
|
return seed
|
|
|
|
if rank == 0:
|
|
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
|
else:
|
|
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
|
dist.broadcast(random_num, src=0)
|
|
return random_num.item()
|
|
|
|
|
|
def set_random_seed(seed, deterministic=False):
|
|
"""Set random seed.
|
|
|
|
Args:
|
|
seed (int): Seed to be used.
|
|
deterministic (bool): Whether to set the deterministic option for
|
|
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
|
to True and `torch.backends.cudnn.benchmark` to False.
|
|
Default: False.
|
|
"""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
if deterministic:
|
|
if hasattr(torch, '_set_deterministic'):
|
|
torch._set_deterministic(True)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def train_model(model,
|
|
data_loaders,
|
|
cfg,
|
|
distributed=False,
|
|
timestamp=None,
|
|
meta=None,
|
|
use_fp16=False,
|
|
validate=True,
|
|
gpu_collect=True):
|
|
""" Training API.
|
|
|
|
Args:
|
|
model (:obj:`nn.Module`): user defined model
|
|
data_loaders: a list of dataloader for training data
|
|
cfg: config object
|
|
distributed: distributed training or not
|
|
timestamp: time str formated as '%Y%m%d_%H%M%S'
|
|
meta: a dict containing meta data info, such as env_info, seed, iter, epoch
|
|
use_fp16: use fp16 training or not
|
|
validate: do evaluation while training
|
|
gpu_collect: use gpu collect or cpu collect for tensor gathering
|
|
|
|
"""
|
|
logger = get_root_logger(cfg.log_level)
|
|
print('GPU INFO : ', torch.cuda.get_device_name(0))
|
|
|
|
# model.cuda() must be before build_optimizer in torchacc mode
|
|
model = model.cuda()
|
|
|
|
if cfg.model.type == 'YOLOX':
|
|
optimizer = build_yolo_optimizer(model, cfg.optimizer)
|
|
else:
|
|
optimizer = build_optimizer(model, cfg.optimizer)
|
|
|
|
# when use amp from apex, we should initialze amp with model not wrapper by DDP or DP,
|
|
# so we need to inialize amp here. In torch 1.6 or later, we do not need this
|
|
if use_fp16 and LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
|
|
from apex import amp
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
|
|
|
# SyncBatchNorm
|
|
open_sync_bn = cfg.get('sync_bn', False)
|
|
|
|
if open_sync_bn:
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
logger.info('Using SyncBatchNorm()')
|
|
|
|
# the functions of torchacc DDP split into OptimizerHook and TorchaccLoaderWrapper
|
|
if not is_torchacc_enabled():
|
|
if distributed:
|
|
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
|
model = MMDistributedDataParallel(
|
|
model,
|
|
find_unused_parameters=find_unused_parameters,
|
|
device_ids=[torch.cuda.current_device()],
|
|
broadcast_buffers=False)
|
|
else:
|
|
model = MMDataParallel(model, device_ids=range(cfg.gpus))
|
|
|
|
# build runner
|
|
runner = EVRunner(
|
|
model,
|
|
optimizer=optimizer,
|
|
work_dir=cfg.work_dir,
|
|
logger=logger,
|
|
meta=meta,
|
|
fp16_enable=use_fp16)
|
|
runner.data_loader = data_loaders
|
|
|
|
# an ugly walkaround to make the .log and .log.json filenames the same
|
|
runner.timestamp = timestamp
|
|
optimizer_config = cfg.optimizer_config
|
|
|
|
if use_fp16:
|
|
assert torch.cuda.is_available(), 'cuda is needed for fp16'
|
|
optimizer_config = AMPFP16OptimizerHook(**cfg.optimizer_config)
|
|
else:
|
|
optimizer_config = OptimizerHook(**cfg.optimizer_config)
|
|
|
|
# process tensor type, convert to numpy for dump logs
|
|
if len(cfg.log_config.get('hooks', [])) > 0:
|
|
cfg.log_config.hooks.insert(0, dict(type='PreLoggerHook'))
|
|
runner.register_training_hooks(cfg.lr_config, optimizer_config,
|
|
cfg.checkpoint_config, cfg.log_config)
|
|
|
|
if distributed:
|
|
logger.info('register DistSamplerSeedHook')
|
|
runner.register_hook(DistSamplerSeedHook())
|
|
|
|
# register eval hooks
|
|
validate = False
|
|
if 'eval_pipelines' in cfg:
|
|
if isinstance(cfg.eval_pipelines, dict):
|
|
cfg.eval_pipelines = [cfg.eval_pipelines]
|
|
if len(cfg.eval_pipelines) > 0:
|
|
validate = True
|
|
runner.logger.info('open validate hook')
|
|
|
|
best_metric_name = [
|
|
] # default is eval_pipe.evaluators[0]['type'] + eval_dataset_name + [metric_names]
|
|
best_metric_type = []
|
|
if validate:
|
|
interval = cfg.eval_config.pop('interval', 1)
|
|
for idx, eval_pipe in enumerate(cfg.eval_pipelines):
|
|
data = eval_pipe.get('data', None) or cfg.data.val
|
|
dist_eval = eval_pipe.get('dist_eval', False)
|
|
|
|
evaluator_cfg = eval_pipe.evaluators[0]
|
|
# get the metric_name
|
|
eval_dataset_name = evaluator_cfg.get('dataset_name', None)
|
|
default_metrics = METRICS.get(evaluator_cfg['type'])['metric_name']
|
|
default_metric_type = METRICS.get(
|
|
evaluator_cfg['type'])['metric_cmp_op']
|
|
if 'metric_names' not in evaluator_cfg:
|
|
evaluator_cfg['metric_names'] = default_metrics
|
|
eval_metric_names = evaluator_cfg['metric_names']
|
|
|
|
# get the metric_name
|
|
this_metric_names = generate_best_metric_name(
|
|
evaluator_cfg['type'], eval_dataset_name, eval_metric_names)
|
|
best_metric_name = best_metric_name + this_metric_names
|
|
|
|
# get the metric_type
|
|
this_metric_type = evaluator_cfg.pop('metric_type',
|
|
default_metric_type)
|
|
this_metric_type = this_metric_type + ['max'] * (
|
|
len(this_metric_names) - len(this_metric_type))
|
|
best_metric_type = best_metric_type + this_metric_type
|
|
|
|
imgs_per_gpu = data.pop('imgs_per_gpu', cfg.data.imgs_per_gpu)
|
|
workers_per_gpu = data.pop('workers_per_gpu',
|
|
cfg.data.workers_per_gpu)
|
|
if not is_dali_dataset_type(data['type']):
|
|
val_dataset = build_dataset(data)
|
|
val_dataloader = build_dataloader(
|
|
val_dataset,
|
|
imgs_per_gpu=imgs_per_gpu,
|
|
workers_per_gpu=workers_per_gpu,
|
|
dist=(distributed and dist_eval),
|
|
shuffle=False,
|
|
seed=cfg.seed)
|
|
else:
|
|
default_args = dict(
|
|
batch_size=imgs_per_gpu,
|
|
workers_per_gpu=workers_per_gpu,
|
|
distributed=distributed)
|
|
val_dataset = build_dataset(data, default_args)
|
|
val_dataloader = val_dataset.get_dataloader()
|
|
|
|
evaluators = build_evaluator(eval_pipe.evaluators)
|
|
eval_cfg = cfg.eval_config
|
|
eval_cfg['evaluators'] = evaluators
|
|
eval_hook = DistEvalHook if (distributed
|
|
and dist_eval) else EvalHook
|
|
if eval_hook == EvalHook:
|
|
eval_cfg.pop('gpu_collect', None) # only use in DistEvalHook
|
|
logger.info(f'register EvaluationHook {eval_cfg}')
|
|
# only flush log buffer at the last eval hook
|
|
flush_buffer = (idx == len(cfg.eval_pipelines) - 1)
|
|
runner.register_hook(
|
|
eval_hook(
|
|
val_dataloader,
|
|
interval=interval,
|
|
mode=eval_pipe.mode,
|
|
flush_buffer=flush_buffer,
|
|
**eval_cfg))
|
|
|
|
# user-defined hooks
|
|
if cfg.get('custom_hooks', None):
|
|
custom_hooks = cfg.custom_hooks
|
|
assert isinstance(custom_hooks, list), \
|
|
f'custom_hooks expect list type, but got {type(custom_hooks)}'
|
|
for hook_cfg in cfg.custom_hooks:
|
|
assert isinstance(hook_cfg, dict), \
|
|
'Each item in custom_hooks expects dict type, but got ' \
|
|
f'{type(hook_cfg)}'
|
|
hook_cfg = hook_cfg.copy()
|
|
priority = hook_cfg.pop('priority', 'NORMAL')
|
|
|
|
common_params = {}
|
|
if hook_cfg.type == 'DeepClusterHook':
|
|
common_params = dict(
|
|
dist_mode=distributed, data_loaders=data_loaders)
|
|
else:
|
|
common_params = dict(dist_mode=distributed)
|
|
|
|
hook = build_hook(hook_cfg, default_args=common_params)
|
|
runner.register_hook(hook, priority=priority)
|
|
|
|
if cfg.get('ema', None):
|
|
runner.logger.info('register ema hook')
|
|
runner.register_hook(EMAHook(decay=cfg.ema.decay))
|
|
|
|
if len(best_metric_name) > 0:
|
|
runner.register_hook(
|
|
BestCkptSaverHook(
|
|
by_epoch=True,
|
|
save_optimizer=True,
|
|
best_metric_name=best_metric_name,
|
|
best_metric_type=best_metric_type))
|
|
|
|
# export hook
|
|
if getattr(cfg, 'checkpoint_sync_export', False):
|
|
runner.register_hook(ExportHook(cfg))
|
|
|
|
# oss sync hook
|
|
if cfg.oss_work_dir is not None:
|
|
if cfg.checkpoint_config.get('by_epoch', True):
|
|
runner.register_hook(
|
|
OSSSyncHook(
|
|
cfg.work_dir,
|
|
cfg.oss_work_dir,
|
|
interval=cfg.checkpoint_config.interval,
|
|
**cfg.get('oss_sync_config', {})))
|
|
else:
|
|
runner.register_hook(
|
|
OSSSyncHook(
|
|
cfg.work_dir,
|
|
cfg.oss_work_dir,
|
|
interval=1,
|
|
iter_interval=cfg.checkpoint_config.interval),
|
|
**cfg.get('oss_sync_config', {}))
|
|
|
|
if cfg.resume_from:
|
|
runner.resume(cfg.resume_from)
|
|
elif cfg.load_from:
|
|
runner.logger.info(f'load checkpoint from {cfg.load_from}')
|
|
runner.load_checkpoint(cfg.load_from)
|
|
|
|
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
|
|
|
|
|
|
def get_skip_list_keywords(model):
|
|
skip = {}
|
|
skip_keywords = {}
|
|
if hasattr(model, 'no_weight_decay'):
|
|
skip = model.no_weight_decay()
|
|
if hasattr(model, 'no_weight_decay_keywords'):
|
|
skip_keywords = model.no_weight_decay_keywords()
|
|
return skip, skip_keywords
|
|
|
|
|
|
def _set_weight_decay(model, skip_list=(), skip_keywords=()):
|
|
has_decay = []
|
|
no_decay = []
|
|
|
|
for name, param in model.named_parameters():
|
|
print(name)
|
|
if not param.requires_grad:
|
|
continue # frozen weights
|
|
if len(param.shape) == 1 or name.endswith('.bias') or (name in skip_list) or \
|
|
_check_keywords_in_name(name, skip_keywords):
|
|
no_decay.append(param)
|
|
# print(f"{name} has no weight decay")
|
|
else:
|
|
has_decay.append(param)
|
|
return [{'params': has_decay}, {'params': no_decay, 'weight_decay': 0.}]
|
|
|
|
|
|
def _check_keywords_in_name(name, keywords=()):
|
|
isin = False
|
|
for keyword in keywords:
|
|
if keyword in name:
|
|
isin = True
|
|
return isin
|
|
|
|
|
|
def build_optimizer(model, optimizer_cfg):
|
|
"""Build optimizer from configs.
|
|
|
|
Args:
|
|
model (:obj:`nn.Module`): The model with parameters to be optimized.
|
|
optimizer_cfg (dict): The config dict of the optimizer.
|
|
|
|
Positional fields are:
|
|
- type: class name of the optimizer.
|
|
- lr: base learning rate.
|
|
|
|
Optional fields are:
|
|
- any arguments of the corresponding optimizer type, e.g.,
|
|
weight_decay, momentum, etc.
|
|
- paramwise_options: a dict with regular expression as keys
|
|
to match parameter names and a dict containing options as
|
|
values. Options include 6 fields: lr, lr_mult, momentum,
|
|
momentum_mult, weight_decay, weight_decay_mult.
|
|
|
|
Returns:
|
|
torch.optim.Optimizer: The initialized optimizer.
|
|
|
|
Example:
|
|
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
|
>>> paramwise_options = {
|
|
>>> '(bn|gn)(\d+)?.(weight|bias)': dict(weight_decay_mult=0.1),
|
|
>>> '\Ahead.': dict(lr_mult=10, momentum=0)}
|
|
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
|
|
>>> weight_decay=0.0001,
|
|
>>> paramwise_options=paramwise_options)
|
|
>>> optimizer = build_optimizer(model, optimizer_cfg)
|
|
"""
|
|
|
|
if hasattr(model, 'module'):
|
|
model = model.module
|
|
|
|
# some special model (DINO) only need to optimize parts of parameter, this kind of model will
|
|
# provide attribute get_params_groups to initial optimizer, as we catch this attribute, we do this if
|
|
if hasattr(model, 'get_params_groups'):
|
|
print('type : ', type(model),
|
|
'trigger opimizer model param_groups set for DINO')
|
|
parameters = model.get_params_groups()
|
|
optimizer_cfg = optimizer_cfg.copy()
|
|
optimizer_cls = getattr(optimizer, optimizer_cfg.pop('type'))
|
|
return optimizer_cls(parameters, **optimizer_cfg)
|
|
|
|
# for some model which use transformer(swin/shuffle/cswin), we should set it bias with no weight decay
|
|
set_var_bias_nowd = optimizer_cfg.pop('set_var_bias_nowd', None)
|
|
if set_var_bias_nowd is None:
|
|
set_var_bias_nowd = optimizer_cfg.pop(
|
|
'trans_weight_decay_set', None
|
|
) # this is failback when we switch version, set_var_bias_nowd used called trans_weight_decay_set
|
|
if set_var_bias_nowd is not None:
|
|
print('type : ', type(model), 'trigger transformer set_var_bias_nowd')
|
|
skip = []
|
|
skip_keywords = []
|
|
assert (type(set_var_bias_nowd) is list)
|
|
for model_part in set_var_bias_nowd:
|
|
mpart = getattr(model, model_part, None)
|
|
if mpart is not None:
|
|
tskip, tskip_keywords = get_skip_list_keywords(mpart)
|
|
skip += tskip
|
|
skip_keywords += tskip_keywords
|
|
parameters = _set_weight_decay(model, skip, skip_keywords)
|
|
optimizer_cfg = optimizer_cfg.copy()
|
|
optimizer_cls = getattr(optimizer, optimizer_cfg.pop('type'))
|
|
return optimizer_cls(parameters, **optimizer_cfg)
|
|
|
|
constructor_type = optimizer_cfg.pop('constructor', None)
|
|
optimizer_cfg = optimizer_cfg.copy()
|
|
paramwise_options = optimizer_cfg.pop('paramwise_options', None)
|
|
# if no paramwise option is specified, just use the global setting
|
|
|
|
if constructor_type is not None:
|
|
optimizer_cls = getattr(optimizer, optimizer_cfg.pop('type'))
|
|
optim_constructor = build_optimizer_constructor(
|
|
dict(
|
|
type=constructor_type,
|
|
optimizer_cfg=optimizer_cfg,
|
|
paramwise_cfg=paramwise_options))
|
|
params = []
|
|
optim_constructor.add_params(params, model)
|
|
return optimizer_cls(params, **optimizer_cfg)
|
|
elif paramwise_options is None:
|
|
return obj_from_dict(optimizer_cfg, optimizer,
|
|
dict(params=model.parameters()))
|
|
else:
|
|
assert isinstance(paramwise_options, dict)
|
|
params = []
|
|
for name, param in model.named_parameters():
|
|
param_group = {'params': [param]}
|
|
if not param.requires_grad:
|
|
params.append(param_group)
|
|
continue
|
|
|
|
for regexp, options in paramwise_options.items():
|
|
if re.search(regexp, name):
|
|
for key, value in options.items():
|
|
if key.endswith('_mult'): # is a multiplier
|
|
key = key[:-5]
|
|
assert key in optimizer_cfg, \
|
|
'{} not in optimizer_cfg'.format(key)
|
|
value = optimizer_cfg[key] * value
|
|
param_group[key] = value
|
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
|
print_log('paramwise_options -- {}: {}={}'.format(
|
|
name, key, value))
|
|
|
|
# otherwise use the global settings
|
|
params.append(param_group)
|
|
|
|
optimizer_cls = getattr(optimizer, optimizer_cfg.pop('type'))
|
|
return optimizer_cls(params, **optimizer_cfg)
|