357 lines
13 KiB
Python
357 lines
13 KiB
Python
# ------------------------------------------------------------------------
|
||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||
# ------------------------------------------------------------------------
|
||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||
# Copyright 2018-2020 BasicSR Authors
|
||
# ------------------------------------------------------------------------
|
||
import logging
|
||
import os
|
||
import torch
|
||
from collections import OrderedDict
|
||
from copy import deepcopy
|
||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||
|
||
from basicsr.models import lr_scheduler as lr_scheduler
|
||
from basicsr.utils.dist_util import master_only
|
||
|
||
logger = logging.getLogger('basicsr')
|
||
|
||
|
||
class BaseModel():
|
||
"""Base model."""
|
||
|
||
def __init__(self, opt):
|
||
self.opt = opt
|
||
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
||
self.is_train = opt['is_train']
|
||
self.schedulers = []
|
||
self.optimizers = []
|
||
|
||
def feed_data(self, data):
|
||
pass
|
||
|
||
def optimize_parameters(self):
|
||
pass
|
||
|
||
def get_current_visuals(self):
|
||
pass
|
||
|
||
def save(self, epoch, current_iter):
|
||
"""Save networks and training state."""
|
||
pass
|
||
|
||
def validation(self, dataloader, current_iter, tb_logger, save_img=False, rgb2bgr=True, use_image=True):
|
||
"""Validation function.
|
||
|
||
Args:
|
||
dataloader (torch.utils.data.DataLoader): Validation dataloader.
|
||
current_iter (int): Current iteration.
|
||
tb_logger (tensorboard logger): Tensorboard logger.
|
||
save_img (bool): Whether to save images. Default: False.
|
||
rgb2bgr (bool): Whether to save images using rgb2bgr. Default: True
|
||
use_image (bool): Whether to use saved images to compute metrics (PSNR, SSIM), if not, then use data directly from network' output. Default: True
|
||
"""
|
||
if self.opt['dist']:
|
||
return self.dist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image)
|
||
else:
|
||
return self.nondist_validation(dataloader, current_iter, tb_logger,
|
||
save_img, rgb2bgr, use_image)
|
||
|
||
def get_current_log(self):
|
||
return self.log_dict
|
||
|
||
def model_to_device(self, net):
|
||
"""Model to device. It also warps models with DistributedDataParallel
|
||
or DataParallel.
|
||
|
||
Args:
|
||
net (nn.Module)
|
||
"""
|
||
|
||
net = net.to(self.device)
|
||
if self.opt['dist']:
|
||
find_unused_parameters = self.opt.get('find_unused_parameters',
|
||
False)
|
||
net = DistributedDataParallel(
|
||
net,
|
||
device_ids=[torch.cuda.current_device()],
|
||
find_unused_parameters=find_unused_parameters)
|
||
elif self.opt['num_gpu'] > 1:
|
||
net = DataParallel(net)
|
||
return net
|
||
|
||
def setup_schedulers(self):
|
||
"""Set up schedulers."""
|
||
train_opt = self.opt['train']
|
||
scheduler_type = train_opt['scheduler'].pop('type')
|
||
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
|
||
for optimizer in self.optimizers:
|
||
self.schedulers.append(
|
||
lr_scheduler.MultiStepRestartLR(optimizer,
|
||
**train_opt['scheduler']))
|
||
elif scheduler_type == 'CosineAnnealingRestartLR':
|
||
for optimizer in self.optimizers:
|
||
self.schedulers.append(
|
||
lr_scheduler.CosineAnnealingRestartLR(
|
||
optimizer, **train_opt['scheduler']))
|
||
elif scheduler_type == 'TrueCosineAnnealingLR':
|
||
print('..', 'cosineannealingLR')
|
||
for optimizer in self.optimizers:
|
||
self.schedulers.append(
|
||
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler']))
|
||
elif scheduler_type == 'LinearLR':
|
||
for optimizer in self.optimizers:
|
||
self.schedulers.append(
|
||
lr_scheduler.LinearLR(
|
||
optimizer, train_opt['total_iter']))
|
||
elif scheduler_type == 'VibrateLR':
|
||
for optimizer in self.optimizers:
|
||
self.schedulers.append(
|
||
lr_scheduler.VibrateLR(
|
||
optimizer, train_opt['total_iter']))
|
||
else:
|
||
raise NotImplementedError(
|
||
f'Scheduler {scheduler_type} is not implemented yet.')
|
||
|
||
def get_bare_model(self, net):
|
||
"""Get bare model, especially under wrapping with
|
||
DistributedDataParallel or DataParallel.
|
||
"""
|
||
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
||
net = net.module
|
||
return net
|
||
|
||
@master_only
|
||
def print_network(self, net):
|
||
"""Print the str and parameter number of a network.
|
||
|
||
Args:
|
||
net (nn.Module)
|
||
"""
|
||
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
||
net_cls_str = (f'{net.__class__.__name__} - '
|
||
f'{net.module.__class__.__name__}')
|
||
else:
|
||
net_cls_str = f'{net.__class__.__name__}'
|
||
|
||
net = self.get_bare_model(net)
|
||
net_str = str(net)
|
||
net_params = sum(map(lambda x: x.numel(), net.parameters()))
|
||
|
||
logger.info(
|
||
f'Network: {net_cls_str}, with parameters: {net_params:,d}')
|
||
logger.info(net_str)
|
||
|
||
def _set_lr(self, lr_groups_l):
|
||
"""Set learning rate for warmup.
|
||
|
||
Args:
|
||
lr_groups_l (list): List for lr_groups, each for an optimizer.
|
||
"""
|
||
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
|
||
for param_group, lr in zip(optimizer.param_groups, lr_groups):
|
||
param_group['lr'] = lr
|
||
|
||
def _get_init_lr(self):
|
||
"""Get the initial lr, which is set by the scheduler.
|
||
"""
|
||
init_lr_groups_l = []
|
||
for optimizer in self.optimizers:
|
||
init_lr_groups_l.append(
|
||
[v['initial_lr'] for v in optimizer.param_groups])
|
||
return init_lr_groups_l
|
||
|
||
def update_learning_rate(self, current_iter, warmup_iter=-1):
|
||
"""Update learning rate.
|
||
|
||
Args:
|
||
current_iter (int): Current iteration.
|
||
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
||
Default: -1.
|
||
"""
|
||
if current_iter > 1:
|
||
for scheduler in self.schedulers:
|
||
scheduler.step()
|
||
# set up warm-up learning rate
|
||
if current_iter < warmup_iter:
|
||
# get initial lr for each group
|
||
init_lr_g_l = self._get_init_lr()
|
||
# modify warming-up learning rates
|
||
# currently only support linearly warm up
|
||
warm_up_lr_l = []
|
||
for init_lr_g in init_lr_g_l:
|
||
warm_up_lr_l.append(
|
||
[v / warmup_iter * current_iter for v in init_lr_g])
|
||
# set learning rate
|
||
self._set_lr(warm_up_lr_l)
|
||
|
||
def get_current_learning_rate(self):
|
||
return [
|
||
param_group['lr']
|
||
for param_group in self.optimizers[0].param_groups
|
||
]
|
||
|
||
@master_only
|
||
def save_network(self, net, net_label, current_iter, param_key='params'):
|
||
"""Save networks.
|
||
|
||
Args:
|
||
net (nn.Module | list[nn.Module]): Network(s) to be saved.
|
||
net_label (str): Network label.
|
||
current_iter (int): Current iter number.
|
||
param_key (str | list[str]): The parameter key(s) to save network.
|
||
Default: 'params'.
|
||
"""
|
||
if current_iter == -1:
|
||
current_iter = 'latest'
|
||
save_filename = f'{net_label}_{current_iter}.pth'
|
||
save_path = os.path.join(self.opt['path']['models'], save_filename)
|
||
|
||
net = net if isinstance(net, list) else [net]
|
||
param_key = param_key if isinstance(param_key, list) else [param_key]
|
||
assert len(net) == len(
|
||
param_key), 'The lengths of net and param_key should be the same.'
|
||
|
||
save_dict = {}
|
||
for net_, param_key_ in zip(net, param_key):
|
||
net_ = self.get_bare_model(net_)
|
||
state_dict = net_.state_dict()
|
||
for key, param in state_dict.items():
|
||
if key.startswith('module.'): # remove unnecessary 'module.'
|
||
key = key[7:]
|
||
state_dict[key] = param.cpu()
|
||
save_dict[param_key_] = state_dict
|
||
|
||
torch.save(save_dict, save_path)
|
||
|
||
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
|
||
"""Print keys with differnet name or different size when loading models.
|
||
|
||
1. Print keys with differnet names.
|
||
2. If strict=False, print the same key but with different tensor size.
|
||
It also ignore these keys with different sizes (not load).
|
||
|
||
Args:
|
||
crt_net (torch model): Current network.
|
||
load_net (dict): Loaded network.
|
||
strict (bool): Whether strictly loaded. Default: True.
|
||
"""
|
||
crt_net = self.get_bare_model(crt_net)
|
||
crt_net = crt_net.state_dict()
|
||
crt_net_keys = set(crt_net.keys())
|
||
load_net_keys = set(load_net.keys())
|
||
|
||
if crt_net_keys != load_net_keys:
|
||
logger.warning('Current net - loaded net:')
|
||
for v in sorted(list(crt_net_keys - load_net_keys)):
|
||
logger.warning(f' {v}')
|
||
logger.warning('Loaded net - current net:')
|
||
for v in sorted(list(load_net_keys - crt_net_keys)):
|
||
logger.warning(f' {v}')
|
||
|
||
# check the size for the same keys
|
||
if not strict:
|
||
common_keys = crt_net_keys & load_net_keys
|
||
for k in common_keys:
|
||
if crt_net[k].size() != load_net[k].size():
|
||
logger.warning(
|
||
f'Size different, ignore [{k}]: crt_net: '
|
||
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
|
||
load_net[k + '.ignore'] = load_net.pop(k)
|
||
|
||
def load_network(self, net, load_path, strict=True, param_key='params'):
|
||
"""Load network.
|
||
|
||
Args:
|
||
load_path (str): The path of networks to be loaded.
|
||
net (nn.Module): Network.
|
||
strict (bool): Whether strictly loaded.
|
||
param_key (str): The parameter key of loaded network. If set to
|
||
None, use the root 'path'.
|
||
Default: 'params'.
|
||
"""
|
||
net = self.get_bare_model(net)
|
||
logger.info(
|
||
f'Loading {net.__class__.__name__} model from {load_path}.')
|
||
load_net = torch.load(
|
||
load_path, map_location=lambda storage, loc: storage)
|
||
if param_key is not None:
|
||
load_net = load_net[param_key]
|
||
print(' load net keys', load_net.keys)
|
||
# remove unnecessary 'module.'
|
||
for k, v in deepcopy(load_net).items():
|
||
if k.startswith('module.'):
|
||
load_net[k[7:]] = v
|
||
load_net.pop(k)
|
||
self._print_different_keys_loading(net, load_net, strict)
|
||
net.load_state_dict(load_net, strict=strict)
|
||
|
||
@master_only
|
||
def save_training_state(self, epoch, current_iter):
|
||
"""Save training states during training, which will be used for
|
||
resuming.
|
||
|
||
Args:
|
||
epoch (int): Current epoch.
|
||
current_iter (int): Current iteration.
|
||
"""
|
||
if current_iter != -1:
|
||
state = {
|
||
'epoch': epoch,
|
||
'iter': current_iter,
|
||
'optimizers': [],
|
||
'schedulers': []
|
||
}
|
||
for o in self.optimizers:
|
||
state['optimizers'].append(o.state_dict())
|
||
for s in self.schedulers:
|
||
state['schedulers'].append(s.state_dict())
|
||
save_filename = f'{current_iter}.state'
|
||
save_path = os.path.join(self.opt['path']['training_states'],
|
||
save_filename)
|
||
torch.save(state, save_path)
|
||
|
||
def resume_training(self, resume_state):
|
||
"""Reload the optimizers and schedulers for resumed training.
|
||
|
||
Args:
|
||
resume_state (dict): Resume state.
|
||
"""
|
||
resume_optimizers = resume_state['optimizers']
|
||
resume_schedulers = resume_state['schedulers']
|
||
assert len(resume_optimizers) == len(
|
||
self.optimizers), 'Wrong lengths of optimizers'
|
||
assert len(resume_schedulers) == len(
|
||
self.schedulers), 'Wrong lengths of schedulers'
|
||
for i, o in enumerate(resume_optimizers):
|
||
self.optimizers[i].load_state_dict(o)
|
||
for i, s in enumerate(resume_schedulers):
|
||
self.schedulers[i].load_state_dict(s)
|
||
|
||
def reduce_loss_dict(self, loss_dict):
|
||
"""reduce loss dict.
|
||
|
||
In distributed training, it averages the losses among different GPUs .
|
||
|
||
Args:
|
||
loss_dict (OrderedDict): Loss dict.
|
||
"""
|
||
with torch.no_grad():
|
||
if self.opt['dist']:
|
||
keys = []
|
||
losses = []
|
||
for name, value in loss_dict.items():
|
||
keys.append(name)
|
||
losses.append(value)
|
||
losses = torch.stack(losses, 0)
|
||
torch.distributed.reduce(losses, dst=0)
|
||
if self.opt['rank'] == 0:
|
||
losses /= self.opt['world_size']
|
||
loss_dict = {key: loss for key, loss in zip(keys, losses)}
|
||
|
||
log_dict = OrderedDict()
|
||
for name, value in loss_dict.items():
|
||
log_dict[name] = value.mean().item()
|
||
|
||
return log_dict
|