2022-04-02 20:01:06 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
from distutils.version import LooseVersion
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from mmcv.runner import EpochBasedRunner
|
2022-04-21 20:48:58 +08:00
|
|
|
from mmcv.runner.log_buffer import LogBuffer
|
2022-04-02 20:01:06 +08:00
|
|
|
|
|
|
|
from easycv.file import io
|
|
|
|
from easycv.utils.checkpoint import load_checkpoint, save_checkpoint
|
|
|
|
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
|
|
|
|
from torch.cuda import amp
|
|
|
|
from easycv.hooks.optimizer_hook import AMPFP16OptimizerHook
|
|
|
|
|
|
|
|
|
|
|
|
class EVRunner(EpochBasedRunner):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
model,
|
|
|
|
batch_processor=None,
|
|
|
|
optimizer=None,
|
|
|
|
work_dir=None,
|
|
|
|
logger=None,
|
|
|
|
meta=None):
|
|
|
|
""" Epoch Runner for easycv, add support for oss IO and file sync.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (:obj:`torch.nn.Module`): The model to be run.
|
|
|
|
batch_processor (callable): A callable method that process a data
|
|
|
|
batch. The interface of this method should be
|
|
|
|
`batch_processor(model, data, train_mode) -> dict`
|
|
|
|
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
|
|
|
|
optimizer (in most cases) or a dict of optimizers (in models that
|
|
|
|
requires more than one optimizer, e.g., GAN).
|
|
|
|
work_dir (str, optional): The working directory to save checkpoints
|
|
|
|
and logs. Defaults to None.
|
|
|
|
logger (:obj:`logging.Logger`): Logger used during training.
|
|
|
|
Defaults to None. (The default value is just for backward
|
|
|
|
compatibility)
|
|
|
|
meta (dict | None): A dict records some import information such as
|
|
|
|
environment info and seed, which will be logged in logger hook.
|
|
|
|
Defaults to None.
|
|
|
|
"""
|
|
|
|
|
|
|
|
super().__init__(model, batch_processor, optimizer, work_dir, logger,
|
|
|
|
meta)
|
|
|
|
self.data_loader = None
|
|
|
|
self.fp16_enable = False
|
2022-04-21 20:48:58 +08:00
|
|
|
self.visualization_buffer = LogBuffer()
|
2022-04-02 20:01:06 +08:00
|
|
|
|
|
|
|
def run_iter(self, data_batch, train_mode, **kwargs):
|
|
|
|
""" process for each iteration.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
data_batch: Batch of dict of data.
|
|
|
|
train_model (bool): If set True, run training step else validation step.
|
|
|
|
"""
|
|
|
|
if self.batch_processor is not None:
|
|
|
|
outputs = self.batch_processor(
|
|
|
|
self.model, data_batch, train_mode=train_mode, **kwargs)
|
|
|
|
elif train_mode:
|
|
|
|
outputs = self.model.train_step(data_batch, self.optimizer,
|
|
|
|
**kwargs)
|
|
|
|
else:
|
|
|
|
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
|
|
|
|
if not isinstance(outputs, dict):
|
|
|
|
raise TypeError('"batch_processor()" or "model.train_step()"'
|
|
|
|
'and "model.val_step()" must return a dict')
|
|
|
|
if 'log_vars' in outputs:
|
|
|
|
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
|
|
|
|
self.outputs = outputs
|
|
|
|
|
|
|
|
def train(self, data_loader, **kwargs):
|
|
|
|
""" Training process for one epoch which will iterate through all \
|
|
|
|
training data and call hooks at different stages.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
data_loader: data loader object for training
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.model.train()
|
|
|
|
self.mode = 'train'
|
|
|
|
self.data_loader = data_loader
|
|
|
|
self._max_iters = self._max_epochs * len(self.data_loader)
|
|
|
|
self.call_hook('before_train_epoch')
|
|
|
|
time.sleep(2) # Prevent possible deadlock during epoch transition
|
|
|
|
for i, data_batch in enumerate(self.data_loader):
|
|
|
|
self._inner_iter = i
|
|
|
|
self.call_hook('before_train_iter')
|
|
|
|
# only in amp from pytorch 1.6 or later, we should use amp.autocast
|
|
|
|
if self.fp16_enable and LooseVersion(
|
|
|
|
torch.__version__) >= LooseVersion('1.6.0'):
|
|
|
|
with amp.autocast():
|
|
|
|
self.run_iter(data_batch, train_mode=True)
|
|
|
|
else:
|
|
|
|
self.run_iter(data_batch, train_mode=True)
|
|
|
|
|
|
|
|
self.call_hook('after_train_iter')
|
|
|
|
self._iter += 1
|
|
|
|
|
|
|
|
self.call_hook('after_train_epoch')
|
|
|
|
|
|
|
|
self._epoch += 1
|
|
|
|
|
|
|
|
def val(self, data_loader, **kwargs):
|
|
|
|
""" Validation step which Deprecated, using evaluation hook instead.
|
|
|
|
"""
|
|
|
|
self.model.eval()
|
|
|
|
self.mode = 'val'
|
|
|
|
self.data_loader = data_loader
|
|
|
|
self.call_hook('before_val_epoch')
|
|
|
|
time.sleep(2) # Prevent possible deadlock during epoch transition
|
|
|
|
for i, data_batch in enumerate(self.data_loader):
|
|
|
|
self._inner_iter = i
|
|
|
|
self.call_hook('before_val_iter')
|
|
|
|
with torch.no_grad():
|
|
|
|
self.run_iter(data_batch, train_mode=False)
|
|
|
|
self.call_hook('after_val_iter')
|
|
|
|
|
|
|
|
self.call_hook('after_val_epoch')
|
|
|
|
|
|
|
|
def save_checkpoint(self,
|
|
|
|
out_dir,
|
|
|
|
filename_tmpl='epoch_{}.pth',
|
|
|
|
save_optimizer=True,
|
|
|
|
meta=None,
|
|
|
|
create_symlink=True):
|
|
|
|
"""Save checkpoint to file.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
out_dir: Directory where checkpoint files are to be saved.
|
|
|
|
filename_tmpl (str, optional): Checkpoint filename pattern.
|
|
|
|
save_optimizer (bool, optional): save optimizer state.
|
|
|
|
meta (dict, optional): Metadata to be saved in checkpoint.
|
|
|
|
"""
|
|
|
|
# implement save checkpoint to oss
|
|
|
|
if meta is None:
|
|
|
|
meta = dict(epoch=self.epoch + 1, iter=self.iter)
|
|
|
|
else:
|
|
|
|
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
|
|
|
|
|
|
|
filename = filename_tmpl.format(self.epoch + 1)
|
|
|
|
filepath = os.path.join(out_dir, filename)
|
|
|
|
optimizer = self.optimizer if save_optimizer else None
|
|
|
|
if hasattr(self, 'ema'):
|
|
|
|
model = self.ema.model
|
|
|
|
else:
|
|
|
|
model = self.model
|
|
|
|
|
|
|
|
save_checkpoint(model, filepath, optimizer, meta)
|
|
|
|
|
|
|
|
def current_lr(self):
|
|
|
|
"""Get current learning rates.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list[float] | dict[str, list[float]]: Current learning rates of all
|
|
|
|
param groups. If the runner has a dict of optimizers, this
|
|
|
|
method will return a dict.
|
|
|
|
"""
|
|
|
|
# add interface to selfdefine current_lr_fn for lr_hook
|
|
|
|
# so that runner can logging correct lrs
|
|
|
|
if hasattr(self, 'current_lr_fn'):
|
|
|
|
lr = self.current_lr_fn(self.optimizer)
|
|
|
|
elif isinstance(self.optimizer, torch.optim.Optimizer):
|
|
|
|
lr = [group['lr'] for group in self.optimizer.param_groups]
|
|
|
|
elif isinstance(self.optimizer, dict):
|
|
|
|
lr = dict()
|
|
|
|
for name, optim in self.optimizer.items():
|
|
|
|
lr[name] = [group['lr'] for group in optim.param_groups]
|
|
|
|
else:
|
|
|
|
raise RuntimeError(
|
|
|
|
'lr is not applicable because optimizer does not exist.')
|
|
|
|
return lr
|
|
|
|
|
|
|
|
# !! notice, map_location should be cpu, other wise may stack to some GPU ,which cause OOMß
|
|
|
|
def load_checkpoint(self,
|
|
|
|
filename,
|
|
|
|
map_location=torch.device('cpu'),
|
|
|
|
strict=False,
|
|
|
|
logger=None):
|
|
|
|
"""Load checkpoint from a file or URL.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
|
|
|
``open-mmlab://xxx``, ``oss://xxx``. Please refer to
|
|
|
|
``docs/source/model_zoo.md`` for details.
|
|
|
|
map_location (str): Same as :func:`torch.load`.
|
|
|
|
strict (bool): Whether to allow different params for the model and
|
|
|
|
checkpoint.
|
|
|
|
logger (:mod:`logging.Logger` or None): The logger for error message.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict or OrderedDict: The loaded checkpoint.
|
|
|
|
"""
|
|
|
|
return load_checkpoint(
|
|
|
|
self.model,
|
|
|
|
filename=filename,
|
|
|
|
map_location=map_location,
|
|
|
|
strict=strict,
|
|
|
|
logger=logger)
|
|
|
|
|
|
|
|
def resume(self,
|
|
|
|
checkpoint,
|
|
|
|
resume_optimizer=True,
|
|
|
|
map_location='default'):
|
|
|
|
""" Resume state dict from checkpoint.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
checkpoint: Checkpoint path
|
|
|
|
resume_optimizer: Whether to resume optimizer state
|
|
|
|
map_location (str): Same as :func:`torch.load`.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if map_location == 'default':
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
device_id = torch.cuda.current_device()
|
|
|
|
checkpoint = self.load_checkpoint(
|
|
|
|
checkpoint,
|
|
|
|
map_location=lambda storage, loc: storage.cuda(device_id))
|
|
|
|
else:
|
|
|
|
checkpoint = self.load_checkpoint(checkpoint)
|
|
|
|
else:
|
|
|
|
checkpoint = self.load_checkpoint(
|
|
|
|
checkpoint, map_location=map_location)
|
|
|
|
|
|
|
|
self._epoch = checkpoint['meta']['epoch']
|
|
|
|
self._iter = checkpoint['meta']['iter']
|
|
|
|
if 'optimizer' in checkpoint and resume_optimizer:
|
|
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
|
|
|
|
|
|
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
|