Mashiro 8770c6c7fc
[Refactor] Refactor data flow to make the interface more natural (#468)
* [Refactor]: modify interface of Visualizer.add_datasample (#365)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader`  (#323)

* acollate data in dataloader

* fix docstring

* refine comment

* fix as comment

* refactor default collate and psedo collate

* foramt test file

* fix docstring

* fix as comment

* rename elem to data_item

* minor fix

* fix as comment

* [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360)

* refine evaluator and metric

* compatible with new default collate

* replace default collate with pseudo

* Handle data_batch in metric

* fix unit test

* fix unit test

* fix unit test

* minor refine

* make data_batch optional

make data_batch optional

* rename outputs to predictions

* fix ut

* rename predictions to outputs

* fix docstring

* fix docstring

* fix unit test

* make outputs and data_batch to kwargs

* fix unit test

* keep signature of metric

* fix ut

* rename pred_sample arguments to data_sample(Visualizer)

* fix loop and ut

* [refactor]: Refactor model dataflow (#398)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* refactor model data flow

tmp_commt

tmp commit

* make val_cfg and test_cfg optional

* roll back runner

* pass test mmdet

* fix as comment

fix as comment

fix ci in DataPreprocessor

* fix ut

* fix ut

* fix rebase main

* [Fix]: Fix test val ddp (#462)

* [Fix] Fix docstring and type hint of data flow (#463)

* Fix docstring of data flow

* change signature of hook

* fix unit test

* resolve conflicts

* fix lint
2022-08-24 22:04:55 +08:00

305 lines
12 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from collections import OrderedDict
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.optim import OptimWrapper
from mmengine.registry import MODELS
from mmengine.utils import is_list_of
from ..base_module import BaseModule
from .data_preprocessor import BaseDataPreprocessor
class BaseModel(BaseModule):
"""Base class for all algorithmic models.
BaseModel implements the basic functions of the algorithmic model, such as
weights initialize, batch inputs preprocess(see more information in
:class:`BaseDataPreprocessor`), parse losses, and update model parameters.
Subclasses inherit from BaseModel only need to implement the forward
method, which implements the logic to calculate loss and predictions,
then can be trained in the runner.
Examples:
>>> @MODELS.register_module()
>>> class ToyModel(BaseModel):
>>>
>>> def __init__(self):
>>> super().__init__()
>>> self.backbone = nn.Sequential()
>>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5))
>>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2))
>>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5))
>>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120))
>>> self.backbone.add_module('fc2', nn.Linear(120, 84))
>>> self.backbone.add_module('fc3', nn.Linear(84, 10))
>>>
>>> self.criterion = nn.CrossEntropyLoss()
>>>
>>> def forward(self, batch_inputs, data_samples, mode='tensor'):
>>> data_samples = torch.stack(data_samples)
>>> if mode == 'tensor':
>>> return self.backbone(batch_inputs)
>>> elif mode == 'predict':
>>> feats = self.backbone(batch_inputs)
>>> predictions = torch.argmax(feats, 1)
>>> return predictions
>>> elif mode == 'loss':
>>> feats = self.backbone(batch_inputs)
>>> loss = self.criterion(feats, data_samples)
>>> return dict(loss=loss)
Args:
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
init_cfg (dict, optional): The weight initialized config for
:class:`BaseModule`.
Attributes:
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
pre-processing data sampled by dataloader to the format accepted by
:meth:`forward`.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg)
if data_preprocessor is None:
data_preprocessor = dict(type='BaseDataPreprocessor')
if isinstance(data_preprocessor, nn.Module):
self.data_preprocessor = data_preprocessor
elif isinstance(data_preprocessor, dict):
self.data_preprocessor = MODELS.build(data_preprocessor)
else:
raise TypeError('data_preprocessor should be a `dict` or '
f'`nn.Module` instance, but got '
f'{type(data_preprocessor)}')
def train_step(self, data: Union[dict, tuple, list],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""Implements the default model training process including
preprocessing, model forward propagation, loss calculation,
optimization, and back-propagation.
During non-distributed training. If subclasses do not override the
:meth:`train_step`, :class:`EpochBasedTrainLoop` or
:class:`IterBasedTrainLoop` will call this method to update model
parameters. The default parameter update process is as follows:
1. Calls ``self.data_processor(data, training=False) to collect
batch_inputs and corresponding data_samples(labels).
2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw
loss
3. Calls ``self.parse_losses`` to get ``parsed_losses`` tensor used to
backward and dict of loss tensor used to log messages.
4. Calls ``optim_wrapper.update_params(loss)`` to update model.
Args:
data (dict or tuple or list): Data sampled from dataset.
optim_wrapper (OptimWrapper): OptimWrapper instance
used to update model parameters.
Returns:
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
"""
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
data = self.data_preprocessor(data, True)
losses = self._run_forward(data, mode='loss') # type: ignore
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore
optim_wrapper.update_params(parsed_losses)
return log_vars
def val_step(self, data: Union[tuple, dict, list]) -> list:
"""Gets the predictions of given data.
Calls ``self.data_preprocessor(data, False)`` and
``self(inputs, data_sample, mode='predict')`` in order. Return the
predictions which will be passed to evaluator.
Args:
data (dict or tuple or list): Data sampled from dataset.
Returns:
list: The predictions of given data.
"""
data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore
def test_step(self, data: Union[dict, tuple, list]) -> list:
"""``BaseModel`` implements ``test_step`` the same as ``val_step``.
Args:
data (dict or tuple or list): Data sampled from dataset.
Returns:
list: The predictions of given data.
"""
data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore
def parse_losses(
self, losses: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Parses the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: There are two elements. The first is the
loss tensor passed to optim_wrapper which may be a weighted sum of
all losses, and the second is log_vars which will be sent to the
logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif is_list_of(loss_value, torch.Tensor):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(value for key, value in log_vars.items() if 'loss' in key)
log_vars['loss'] = loss
return loss, log_vars
def to(self,
device: Optional[Union[int, str, torch.device]] = None,
*args,
**kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
additionally.
Args:
device (int, str or torch.device, optional): the desired device
of the parameters and buffers in this module.
Returns:
nn.Module: The model itself.
"""
if device is not None:
self._set_device(torch.device(device))
return super().to(device)
def cuda(
self,
device: Optional[Union[int, str, torch.device]] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
additionally.
Returns:
nn.Module: The model itself.
"""
if device is None or isinstance(device, int):
device = torch.device('cuda', index=device)
self._set_device(torch.device(device))
return super().cuda(device)
def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
additionally.
Returns:
nn.Module: The model itself.
"""
self._set_device(torch.device('cpu'))
return super().cpu()
def _set_device(self, device: torch.device) -> None:
"""Recursively set device for `BaseDataPreprocessor` instance.
Args:
device (torch.device): the desired device of the parameters and
buffers in this module.
"""
def apply_fn(module):
if not isinstance(module, BaseDataPreprocessor):
return
if device is not None:
module._device = device
self.apply(apply_fn)
@abstractmethod
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[list] = None,
mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]:
"""Returns losses or predictions of training, validation, testing, and
simple inference process.
``forward`` method of BaseModel is an abstract method, its subclasses
must implement this method.
Accepts ``batch_inputs`` and ``data_sample`` processed by
:attr:`data_preprocessor`, and returns results according to mode
arguments.
During non-distributed training, validation, and testing process,
``forward`` will be called by ``BaseModel.train_step``,
``BaseModel.val_step`` and ``BaseModel.val_step`` directly.
During distributed data parallel training process,
``MMSeparateDistributedDataParallel.train_step`` will first call
``DistributedDataParallel.forward`` to enable automatic
gradient synchronization, and then call ``forward`` to get training
loss.
Args:
inputs (torch.Tensor): batch input tensor collated by
:attr:`data_preprocessor`.
data_samples (list, optional):
data samples collated by :attr:`data_preprocessor`.
mode (str): mode should be one of ``loss``, ``predict`` and
``tensor``
- ``loss``: Called by ``train_step`` and return loss ``dict``
used for logging
- ``predict``: Called by ``val_step`` and ``test_step``
and return list of `results used for computing metric.
- ``tensor``: Called by custom use to get ``Tensor`` type
results.
Returns:
dict or list:
- If ``mode == loss``, return a ``dict`` of loss tensor used
for backward and logging.
- If ``mode == predict``, return a ``list`` of inference
results.
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor
or ``dict of tensor for custom use.
"""
def _run_forward(self, data: Union[dict, tuple, list],
mode: str) -> Union[Dict[str, torch.Tensor], list]:
"""Unpacks data for :meth:`forward`
Args:
data (dict or tuple or list): Data sampled from dataset.
mode (str): Mode of forward.
Returns:
dict or list: Results of training or testing mode.
"""
if isinstance(data, dict):
results = self(**data, mode=mode)
elif isinstance(data, (list, tuple)):
results = self(*data, mode=mode)
else:
raise TypeError('Output of `data_preprocessor` should be '
f'list, tuple or dict, but got {type(data)}')
return results