mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* [Refactor]: modify interface of Visualizer.add_datasample (#365) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader` (#323) * acollate data in dataloader * fix docstring * refine comment * fix as comment * refactor default collate and psedo collate * foramt test file * fix docstring * fix as comment * rename elem to data_item * minor fix * fix as comment * [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360) * refine evaluator and metric * compatible with new default collate * replace default collate with pseudo * Handle data_batch in metric * fix unit test * fix unit test * fix unit test * minor refine * make data_batch optional make data_batch optional * rename outputs to predictions * fix ut * rename predictions to outputs * fix docstring * fix docstring * fix unit test * make outputs and data_batch to kwargs * fix unit test * keep signature of metric * fix ut * rename pred_sample arguments to data_sample(Visualizer) * fix loop and ut * [refactor]: Refactor model dataflow (#398) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * refactor model data flow tmp_commt tmp commit * make val_cfg and test_cfg optional * roll back runner * pass test mmdet * fix as comment fix as comment fix ci in DataPreprocessor * fix ut * fix ut * fix rebase main * [Fix]: Fix test val ddp (#462) * [Fix] Fix docstring and type hint of data flow (#463) * Fix docstring of data flow * change signature of hook * fix unit test * resolve conflicts * fix lint
305 lines
12 KiB
Python
305 lines
12 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from abc import abstractmethod
|
|
from collections import OrderedDict
|
|
from typing import Dict, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmengine.optim import OptimWrapper
|
|
from mmengine.registry import MODELS
|
|
from mmengine.utils import is_list_of
|
|
from ..base_module import BaseModule
|
|
from .data_preprocessor import BaseDataPreprocessor
|
|
|
|
|
|
class BaseModel(BaseModule):
|
|
"""Base class for all algorithmic models.
|
|
|
|
BaseModel implements the basic functions of the algorithmic model, such as
|
|
weights initialize, batch inputs preprocess(see more information in
|
|
:class:`BaseDataPreprocessor`), parse losses, and update model parameters.
|
|
|
|
Subclasses inherit from BaseModel only need to implement the forward
|
|
method, which implements the logic to calculate loss and predictions,
|
|
then can be trained in the runner.
|
|
|
|
Examples:
|
|
>>> @MODELS.register_module()
|
|
>>> class ToyModel(BaseModel):
|
|
>>>
|
|
>>> def __init__(self):
|
|
>>> super().__init__()
|
|
>>> self.backbone = nn.Sequential()
|
|
>>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5))
|
|
>>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2))
|
|
>>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5))
|
|
>>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120))
|
|
>>> self.backbone.add_module('fc2', nn.Linear(120, 84))
|
|
>>> self.backbone.add_module('fc3', nn.Linear(84, 10))
|
|
>>>
|
|
>>> self.criterion = nn.CrossEntropyLoss()
|
|
>>>
|
|
>>> def forward(self, batch_inputs, data_samples, mode='tensor'):
|
|
>>> data_samples = torch.stack(data_samples)
|
|
>>> if mode == 'tensor':
|
|
>>> return self.backbone(batch_inputs)
|
|
>>> elif mode == 'predict':
|
|
>>> feats = self.backbone(batch_inputs)
|
|
>>> predictions = torch.argmax(feats, 1)
|
|
>>> return predictions
|
|
>>> elif mode == 'loss':
|
|
>>> feats = self.backbone(batch_inputs)
|
|
>>> loss = self.criterion(feats, data_samples)
|
|
>>> return dict(loss=loss)
|
|
|
|
Args:
|
|
data_preprocessor (dict, optional): The pre-process config of
|
|
:class:`BaseDataPreprocessor`.
|
|
init_cfg (dict, optional): The weight initialized config for
|
|
:class:`BaseModule`.
|
|
|
|
Attributes:
|
|
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
|
|
pre-processing data sampled by dataloader to the format accepted by
|
|
:meth:`forward`.
|
|
init_cfg (dict, optional): Initialization config dict.
|
|
"""
|
|
|
|
def __init__(self,
|
|
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
|
init_cfg: Optional[dict] = None):
|
|
super().__init__(init_cfg)
|
|
if data_preprocessor is None:
|
|
data_preprocessor = dict(type='BaseDataPreprocessor')
|
|
if isinstance(data_preprocessor, nn.Module):
|
|
self.data_preprocessor = data_preprocessor
|
|
elif isinstance(data_preprocessor, dict):
|
|
self.data_preprocessor = MODELS.build(data_preprocessor)
|
|
else:
|
|
raise TypeError('data_preprocessor should be a `dict` or '
|
|
f'`nn.Module` instance, but got '
|
|
f'{type(data_preprocessor)}')
|
|
|
|
def train_step(self, data: Union[dict, tuple, list],
|
|
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
|
"""Implements the default model training process including
|
|
preprocessing, model forward propagation, loss calculation,
|
|
optimization, and back-propagation.
|
|
|
|
During non-distributed training. If subclasses do not override the
|
|
:meth:`train_step`, :class:`EpochBasedTrainLoop` or
|
|
:class:`IterBasedTrainLoop` will call this method to update model
|
|
parameters. The default parameter update process is as follows:
|
|
|
|
1. Calls ``self.data_processor(data, training=False) to collect
|
|
batch_inputs and corresponding data_samples(labels).
|
|
2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw
|
|
loss
|
|
3. Calls ``self.parse_losses`` to get ``parsed_losses`` tensor used to
|
|
backward and dict of loss tensor used to log messages.
|
|
4. Calls ``optim_wrapper.update_params(loss)`` to update model.
|
|
|
|
Args:
|
|
data (dict or tuple or list): Data sampled from dataset.
|
|
optim_wrapper (OptimWrapper): OptimWrapper instance
|
|
used to update model parameters.
|
|
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
|
|
"""
|
|
# Enable automatic mixed precision training context.
|
|
with optim_wrapper.optim_context(self):
|
|
data = self.data_preprocessor(data, True)
|
|
losses = self._run_forward(data, mode='loss') # type: ignore
|
|
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore
|
|
optim_wrapper.update_params(parsed_losses)
|
|
return log_vars
|
|
|
|
def val_step(self, data: Union[tuple, dict, list]) -> list:
|
|
"""Gets the predictions of given data.
|
|
|
|
Calls ``self.data_preprocessor(data, False)`` and
|
|
``self(inputs, data_sample, mode='predict')`` in order. Return the
|
|
predictions which will be passed to evaluator.
|
|
|
|
Args:
|
|
data (dict or tuple or list): Data sampled from dataset.
|
|
|
|
Returns:
|
|
list: The predictions of given data.
|
|
"""
|
|
data = self.data_preprocessor(data, False)
|
|
return self._run_forward(data, mode='predict') # type: ignore
|
|
|
|
def test_step(self, data: Union[dict, tuple, list]) -> list:
|
|
"""``BaseModel`` implements ``test_step`` the same as ``val_step``.
|
|
|
|
Args:
|
|
data (dict or tuple or list): Data sampled from dataset.
|
|
|
|
Returns:
|
|
list: The predictions of given data.
|
|
"""
|
|
data = self.data_preprocessor(data, False)
|
|
return self._run_forward(data, mode='predict') # type: ignore
|
|
|
|
def parse_losses(
|
|
self, losses: Dict[str, torch.Tensor]
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
"""Parses the raw outputs (losses) of the network.
|
|
|
|
Args:
|
|
losses (dict): Raw output of the network, which usually contain
|
|
losses and other necessary information.
|
|
|
|
Returns:
|
|
tuple[Tensor, dict]: There are two elements. The first is the
|
|
loss tensor passed to optim_wrapper which may be a weighted sum of
|
|
all losses, and the second is log_vars which will be sent to the
|
|
logger.
|
|
"""
|
|
log_vars = OrderedDict()
|
|
for loss_name, loss_value in losses.items():
|
|
if isinstance(loss_value, torch.Tensor):
|
|
log_vars[loss_name] = loss_value.mean()
|
|
elif is_list_of(loss_value, torch.Tensor):
|
|
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
|
else:
|
|
raise TypeError(
|
|
f'{loss_name} is not a tensor or list of tensors')
|
|
|
|
loss = sum(value for key, value in log_vars.items() if 'loss' in key)
|
|
log_vars['loss'] = loss
|
|
|
|
return loss, log_vars
|
|
|
|
def to(self,
|
|
device: Optional[Union[int, str, torch.device]] = None,
|
|
*args,
|
|
**kwargs) -> nn.Module:
|
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
|
|
additionally.
|
|
|
|
Args:
|
|
device (int, str or torch.device, optional): the desired device
|
|
of the parameters and buffers in this module.
|
|
|
|
Returns:
|
|
nn.Module: The model itself.
|
|
"""
|
|
if device is not None:
|
|
self._set_device(torch.device(device))
|
|
return super().to(device)
|
|
|
|
def cuda(
|
|
self,
|
|
device: Optional[Union[int, str, torch.device]] = None,
|
|
) -> nn.Module:
|
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
|
|
additionally.
|
|
|
|
Returns:
|
|
nn.Module: The model itself.
|
|
"""
|
|
if device is None or isinstance(device, int):
|
|
device = torch.device('cuda', index=device)
|
|
self._set_device(torch.device(device))
|
|
return super().cuda(device)
|
|
|
|
def cpu(self, *args, **kwargs) -> nn.Module:
|
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
|
|
additionally.
|
|
|
|
Returns:
|
|
nn.Module: The model itself.
|
|
"""
|
|
self._set_device(torch.device('cpu'))
|
|
return super().cpu()
|
|
|
|
def _set_device(self, device: torch.device) -> None:
|
|
"""Recursively set device for `BaseDataPreprocessor` instance.
|
|
|
|
Args:
|
|
device (torch.device): the desired device of the parameters and
|
|
buffers in this module.
|
|
"""
|
|
|
|
def apply_fn(module):
|
|
if not isinstance(module, BaseDataPreprocessor):
|
|
return
|
|
if device is not None:
|
|
module._device = device
|
|
|
|
self.apply(apply_fn)
|
|
|
|
@abstractmethod
|
|
def forward(self,
|
|
inputs: torch.Tensor,
|
|
data_samples: Optional[list] = None,
|
|
mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]:
|
|
"""Returns losses or predictions of training, validation, testing, and
|
|
simple inference process.
|
|
|
|
``forward`` method of BaseModel is an abstract method, its subclasses
|
|
must implement this method.
|
|
|
|
Accepts ``batch_inputs`` and ``data_sample`` processed by
|
|
:attr:`data_preprocessor`, and returns results according to mode
|
|
arguments.
|
|
|
|
During non-distributed training, validation, and testing process,
|
|
``forward`` will be called by ``BaseModel.train_step``,
|
|
``BaseModel.val_step`` and ``BaseModel.val_step`` directly.
|
|
|
|
During distributed data parallel training process,
|
|
``MMSeparateDistributedDataParallel.train_step`` will first call
|
|
``DistributedDataParallel.forward`` to enable automatic
|
|
gradient synchronization, and then call ``forward`` to get training
|
|
loss.
|
|
|
|
Args:
|
|
inputs (torch.Tensor): batch input tensor collated by
|
|
:attr:`data_preprocessor`.
|
|
data_samples (list, optional):
|
|
data samples collated by :attr:`data_preprocessor`.
|
|
mode (str): mode should be one of ``loss``, ``predict`` and
|
|
``tensor``
|
|
|
|
- ``loss``: Called by ``train_step`` and return loss ``dict``
|
|
used for logging
|
|
- ``predict``: Called by ``val_step`` and ``test_step``
|
|
and return list of `results used for computing metric.
|
|
- ``tensor``: Called by custom use to get ``Tensor`` type
|
|
results.
|
|
|
|
Returns:
|
|
dict or list:
|
|
- If ``mode == loss``, return a ``dict`` of loss tensor used
|
|
for backward and logging.
|
|
- If ``mode == predict``, return a ``list`` of inference
|
|
results.
|
|
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor
|
|
or ``dict of tensor for custom use.
|
|
"""
|
|
|
|
def _run_forward(self, data: Union[dict, tuple, list],
|
|
mode: str) -> Union[Dict[str, torch.Tensor], list]:
|
|
"""Unpacks data for :meth:`forward`
|
|
|
|
Args:
|
|
data (dict or tuple or list): Data sampled from dataset.
|
|
mode (str): Mode of forward.
|
|
|
|
Returns:
|
|
dict or list: Results of training or testing mode.
|
|
"""
|
|
if isinstance(data, dict):
|
|
results = self(**data, mode=mode)
|
|
elif isinstance(data, (list, tuple)):
|
|
results = self(*data, mode=mode)
|
|
else:
|
|
raise TypeError('Output of `data_preprocessor` should be '
|
|
f'list, tuple or dict, but got {type(data)}')
|
|
return results
|