[Refactor] Refactor the data flow. (#989)

* [Refactor] Refactor the data flow.

* Add comments about data preprocessor.

* Fix after mmengine folder structure refactoring.
pull/987/head
Ma Zerun 2022-08-26 10:40:43 +08:00 committed by GitHub
parent b4e39d51d6
commit 2b88df4484
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 300 additions and 306 deletions

View File

@ -3,8 +3,7 @@ import warnings
import torch
from mmengine.config import Config
from mmengine.data import pseudo_collate
from mmengine.dataset import Compose
from mmengine.dataset import Compose, pseudo_collate
from mmengine.runner import load_checkpoint
from mmcls.models import build_classifier

View File

@ -52,12 +52,12 @@ class PackClsInputs(BaseTransform):
**Added Keys:**
- inputs (:obj:`torch.Tensor`): The forward data of models.
- data_sample (:obj:`~mmcls.structures.ClsDataSample`): The annotation info
of the sample.
- data_samples (:obj:`~mmcls.structures.ClsDataSample`): The annotation
info of the sample.
Args:
meta_keys (Sequence[str]): The meta keys to be saved in the
``metainfo`` of the packed ``data_sample``.
``metainfo`` of the packed ``data_samples``.
Defaults to a tuple includes keys:
- ``sample_idx``: The id of the image sample.
@ -99,7 +99,7 @@ class PackClsInputs(BaseTransform):
img_meta = {k: results[k] for k in self.meta_keys if k in results}
data_sample.set_metainfo(img_meta)
packed_results['data_sample'] = data_sample
packed_results['data_samples'] = data_sample
return packed_results

View File

@ -54,21 +54,22 @@ class VisualizationHook(Hook):
def _draw_samples(self,
batch_idx: int,
data_batch: Sequence[dict],
outputs: Sequence[ClsDataSample],
data_batch: dict,
data_samples: Sequence[ClsDataSample],
step: int = 0) -> None:
"""Visualize every ``self.interval`` samples from a data batch.
Args:
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[dict]): Data from dataloader.
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
step (int): Global step value to record. Defaults to 0.
"""
if self.enable is False:
return
batch_size = len(outputs)
batch_size = len(data_samples)
images = data_batch['inputs']
start_idx = batch_size * batch_idx
end_idx = start_idx + batch_size
@ -76,10 +77,10 @@ class VisualizationHook(Hook):
first_sample_id = math.ceil(start_idx / self.interval) * self.interval
for sample_id in range(first_sample_id, end_idx, self.interval):
image = data_batch[sample_id - start_idx]['inputs']
image = images[sample_id - start_idx]
image = image.permute(1, 2, 0).numpy().astype('uint8')
data_sample = outputs[sample_id - start_idx]
data_sample = data_samples[sample_id - start_idx]
if 'img_path' in data_sample:
# osp.basename works on different platforms even file clients.
sample_name = osp.basename(data_sample.get('img_path'))
@ -99,15 +100,14 @@ class VisualizationHook(Hook):
**self.draw_args,
)
def after_val_iter(self, runner: Runner, batch_idx: int,
data_batch: Sequence[dict],
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[ClsDataSample]) -> None:
"""Visualize every ``self.interval`` samples during validation.
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[dict]): Data from dataloader.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
"""
if isinstance(runner.train_loop, EpochBasedTrainLoop):
@ -117,15 +117,14 @@ class VisualizationHook(Hook):
self._draw_samples(batch_idx, data_batch, outputs, step=step)
def after_test_iter(self, runner: Runner, batch_idx: int,
data_batch: Sequence[dict],
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[ClsDataSample]) -> None:
"""Visualize every ``self.interval`` samples during test.
Args:
runner (:obj:`Runner`): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[dict]): Data from dataloader.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
"""
self._draw_samples(batch_idx, data_batch, outputs, step=0)

View File

@ -3,8 +3,9 @@ from typing import List, Optional, Sequence, Union
import numpy as np
import torch
from mmengine import LabelData, MMLogger
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger
from mmengine.structures import LabelData
from mmcls.registry import METRICS
from .single_label import _precision_recall_f1_support, to_tensor
@ -95,15 +96,11 @@ class MultiLabelMetric(BaseMetric):
>>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator
>>> # The `data_batch` won't be used in this case, just use a fake.
>>> data_batch = [
... {'inputs': None, 'data_sample': ClsDataSample()}
... for i in range(1000)]
>>> pred = [
... ClsDataSample().set_pred_score(torch.rand((5, ))).set_gt_score(torch.randint(2, size=(5, )))
... for i in range(1000)]
>>> evaluator = Evaluator(metrics=MultiLabelMetric(thrs=0.5))
>>> evaluator.process(data_batch, pred)
>>> data_sampels = [
... ClsDataSample().set_pred_score(pred).set_gt_score(gt)
... for pred, gt in zip(torch.rand(1000, 5), torch.randint(0, 2, (1000, 5)))]
>>> evaluator = Evaluator(metrics=MultiLabelMetric(thr=0.5))
>>> evaluator.process(data_sampels)
>>> evaluator.evaluate(1000)
{
'multi-label/precision': 50.72898037055408,
@ -112,26 +109,13 @@ class MultiLabelMetric(BaseMetric):
}
>>> # Evaluate on each class by using topk strategy
>>> evaluator = Evaluator(metrics=MultiLabelMetric(topk=1, average=None))
>>> evaluator.process(data_batch, pred)
>>> evaluator.process(data_sampels)
>>> evaluator.evaluate(1000)
{
'multi-label/precision_top1_classwise': [48.22, 50.54, 50.99, 44.18, 52.5],
'multi-label/recall_top1_classwise': [18.92, 19.22, 19.92, 20.0, 20.27],
'multi-label/f1-score_top1_classwise': [27.18, 27.85, 28.65, 27.54, 29.25]
}
>>> # Evaluate by label data got from head
>>> pred = [
... ClsDataSample().set_pred_score(torch.rand((5, ))).set_pred_label(
... torch.randint(2, size=(5, ))).set_gt_score(torch.randint(2, size=(5, )))
... for i in range(1000)]
>>> evaluator = Evaluator(metrics=MultiLabelMetric())
>>> evaluator.process(data_batch, pred)
>>> evaluator.evaluate(1000)
{
'multi-label/precision': 20.28921606216292,
'multi-label/recall': 38.628095855722314,
'multi-label/f1-score': 26.603530359627918
}
""" # noqa: E501
default_prefix: Optional[str] = 'multi-label'
@ -165,20 +149,20 @@ class MultiLabelMetric(BaseMetric):
super().__init__(collect_device=collect_device, prefix=prefix)
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
"""Process one batch of data and predictions.
def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for pred in predictions:
for data_sample in data_samples:
result = dict()
pred_label = pred['pred_label']
gt_label = pred['gt_label']
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
result['pred_score'] = pred_label['score'].clone()
num_classes = result['pred_score'].size()[-1]
@ -459,21 +443,17 @@ class AveragePrecision(BaseMetric):
>>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator
>>> # The `data_batch` won't be used in this case, just use a fake.
>>> data_batch = [
... {'inputs': None, 'data_sample': ClsDataSample()}
... for i in range(4)]
>>> pred = [
>>> data_samples = [
... ClsDataSample().set_pred_score(i).set_gt_score(j)
... for i, j in zip(y_pred, y_true)
... ]
>>> evaluator = Evaluator(metrics=AveragePrecision())
>>> evaluator.process(data_batch, pred)
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(5)
{'multi-label/mAP': 70.83333587646484}
>>> # Evaluate on each class
>>> evaluator = Evaluator(metrics=AveragePrecision(average=None))
>>> evaluator.process(data_batch, pred)
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(5)
{'multi-label/AP_classwise': [100., 83.33, 100., 0.]}
"""
@ -486,21 +466,21 @@ class AveragePrecision(BaseMetric):
super().__init__(collect_device=collect_device, prefix=prefix)
self.average = average
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
"""Process one batch of data and predictions.
def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for pred in predictions:
for data_sample in data_samples:
result = dict()
pred_label = pred['pred_label']
gt_label = pred['gt_label']
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
result['pred_score'] = pred_label['score']
num_classes = result['pred_score'].size()[-1]

View File

@ -88,16 +88,12 @@ class Accuracy(BaseMetric):
>>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator
>>> data_batch = [{
... 'inputs': None, # In this example, the `inputs` is not used.
... 'data_sample': ClsDataSample().set_gt_label(0)
... } for i in range(1000)]
>>> pred = [
... ClsDataSample().set_pred_score(torch.rand(10))
>>> data_samples = [
... ClsDataSample().set_gt_label(0).set_pred_score(torch.rand(10))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5)))
>>> evaluator.process(data_batch, pred)
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{
'accuracy/top1': 9.300000190734863,
@ -123,22 +119,21 @@ class Accuracy(BaseMetric):
else:
self.thrs = tuple(thrs)
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
"""Process one batch of data and predictions.
def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data, pred in zip(data_batch, predictions):
for data_sample in data_samples:
result = dict()
pred_label = pred['pred_label']
# Use gt_label in the pred dict preferentially.
gt_label = pred.get('gt_label', data['data_sample']['gt_label'])
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
else:
@ -317,45 +312,30 @@ class SingleLabelMetric(BaseMetric):
>>> y_true = [0, 2, 1, 3]
>>> # Output precision, recall, f1-score and support.
>>> SingleLabelMetric.calculate(y_pred, y_true, num_classes=4)
(tensor(62.5000, dtype=torch.float64),
tensor(75., dtype=torch.float64),
tensor(66.6667, dtype=torch.float64),
tensor(4))
(tensor(62.5000), tensor(75.), tensor(66.6667), tensor(4))
>>> # Calculate with different thresholds.
>>> y_score = torch.rand((1000, 10))
>>> y_true = torch.zeros((1000, ))
>>> SingleLabelMetric.calculate(y_score, y_true, thrs=(0., 0.9))
[(tensor(10., dtype=torch.float64),
tensor(1.2100, dtype=torch.float64),
tensor(2.1588, dtype=torch.float64),
tensor(1000)),
(tensor(10., dtype=torch.float64),
tensor(0.8200, dtype=torch.float64),
tensor(1.5157, dtype=torch.float64),
tensor(1000))]
[(tensor(10.), tensor(0.9500), tensor(1.7352), tensor(1000)),
(tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator
>>> data_batch = [{
... 'inputs': None, # In this example, the `inputs` is not used.
... 'data_sample': ClsDataSample().set_gt_label(i%5)
... } for i in range(1000)]
>>> pred = [
... ClsDataSample().set_pred_score(torch.rand(5))
>>> data_samples = [
... ClsDataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=SingleLabelMetric())
>>> evaluator.process(data_batch, pred)
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{
'single-label/precision': 10.0,
'single-label/recall': 0.96,
'single-label/f1-score': 1.7518248175182483
}
{'single-label/precision': 19.650691986083984,
'single-label/recall': 19.600000381469727,
'single-label/f1-score': 19.619548797607422}
>>> # Evaluate on each class
>>> evaluator = Evaluator(metrics=SingleLabelMetric(average=None))
>>> evaluator.process(data_batch, pred)
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{
'single-label/precision_classwise': [21.1, 18.7, 17.8, 19.4, 16.1],
@ -386,31 +366,28 @@ class SingleLabelMetric(BaseMetric):
self.items = tuple(items)
self.average = average
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
"""Process one batch of data and predictions.
def process(self, data_batch, data_samples: Sequence[dict]):
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for data, pred in zip(data_batch, predictions):
for data_sample in data_samples:
result = dict()
pred_label = pred['pred_label']
# Use gt_label in the pred dict preferentially.
gt_label = pred.get('gt_label', data['data_sample']['gt_label'])
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
if 'score' in pred_label:
result['pred_score'] = pred_label['score'].cpu()
elif ('num_classes' in pred_label
or 'num_classes' in data['data_sample']):
elif ('num_classes' in pred_label):
result['pred_label'] = pred_label['label'].cpu()
result['num_classes'] = pred_label.get(
'num_classes', None) or data['data_sample']['num_classes']
result['num_classes'] = pred_label['num_classes']
else:
raise ValueError('The `pred_label` in predictions do not '
raise ValueError('The `pred_label` in data_samples do not '
'have neither `score` nor `num_classes`.')
result['gt_label'] = gt_label['label'].cpu()
# Save the result to `self.results`.

View File

@ -8,7 +8,7 @@ from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import AdaptivePadding
from mmengine.model import BaseModule
from mmengine.model.utils import trunc_normal_
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model.utils import trunc_normal_
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from .vision_transformer import VisionTransformer

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmengine.model import BaseModule, Sequential
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from ..utils.se_layer import SELayer

View File

@ -6,8 +6,8 @@ from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule
from mmengine.model.utils import constant_init
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.model.weight_init import constant_init
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, build_activation_layer
from mmengine.model import BaseModule
from mmengine.model.utils import constant_init, normal_init
from mmengine.model.weight_init import constant_init, normal_init
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import channel_shuffle, make_divisible

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmengine.model.utils import constant_init, normal_init
from mmengine.model.weight_init import constant_init, normal_init
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import channel_shuffle

View File

@ -9,8 +9,8 @@ import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.model.weight_init import trunc_normal_
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from ..utils import (ShiftWindowMSA, resize_pos_embed,

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from ..utils import to_2tuple

View File

@ -8,7 +8,8 @@ from mmcv.cnn import Conv2d, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import constant_init, normal_init, trunc_normal_init
from mmengine.model.weight_init import (constant_init, normal_init,
trunc_normal_init)
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils.attention import MultiheadAttention

View File

@ -5,7 +5,7 @@ from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -8,7 +8,7 @@ from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from ..utils import (BEiTAttention, MultiheadAttention, resize_pos_embed,

View File

@ -3,8 +3,8 @@ from abc import ABCMeta, abstractmethod
from typing import List, Optional, Sequence
import torch
from mmengine import BaseDataElement
from mmengine.model import BaseModel
from mmengine.structures import BaseDataElement
class BaseClassifier(BaseModel, metaclass=ABCMeta):
@ -43,7 +43,7 @@ class BaseClassifier(BaseModel, metaclass=ABCMeta):
@abstractmethod
def forward(self,
batch_inputs: torch.Tensor,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test.
@ -61,8 +61,8 @@ class BaseClassifier(BaseModel, metaclass=ABCMeta):
optimizer updating, which are done in the :meth:`train_step`.
Args:
batch_inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
inputs (torch.Tensor): The input tensor with shape (N, C, ...)
in general.
data_samples (List[BaseDataElement], optional): The annotation
data of every samples. It's required if ``mode="loss"``.
Defaults to None.
@ -78,34 +78,31 @@ class BaseClassifier(BaseModel, metaclass=ABCMeta):
"""
pass
def extract_feat(self, batch_inputs: torch.Tensor):
def extract_feat(self, inputs: torch.Tensor):
"""Extract features from the input tensor with shape (N, C, ...).
The sub-classes are recommended to implement this method to extract
features from backbone and neck.
Args:
batch_inputs (Tensor): A batch of inputs. The shape of it should be
inputs (Tensor): A batch of inputs. The shape of it should be
``(num_samples, num_channels, *img_shape)``.
"""
raise NotImplementedError
def extract_feats(self, multi_batch_inputs: Sequence[torch.Tensor],
def extract_feats(self, multi_inputs: Sequence[torch.Tensor],
**kwargs) -> list:
"""Extract features from a sequence of input tensor.
Args:
multi_batch_inputs (Sequence[torch.Tensor]): A sequence of input
multi_inputs (Sequence[torch.Tensor]): A sequence of input
tensor. It can be used in augmented inference.
**kwargs: Other keyword arguments accepted by :meth:`extract_feat`.
Returns:
list: Features of every input tensor.
"""
assert isinstance(multi_batch_inputs, Sequence), \
assert isinstance(multi_inputs, Sequence), \
'`extract_feats` is used for a sequence of inputs tensor. If you '\
'want to extract on single inputs tensor, use `extract_feat`.'
return [
self.extract_feat(batch_input, **kwargs)
for batch_input in multi_batch_inputs
]
return [self.extract_feat(inputs, **kwargs) for inputs in multi_inputs]

View File

@ -70,7 +70,7 @@ class ImageClassifier(BaseClassifier):
self.head = MODELS.build(head)
def forward(self,
batch_inputs: torch.Tensor,
inputs: torch.Tensor,
data_samples: Optional[List[ClsDataSample]] = None,
mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test.
@ -88,7 +88,7 @@ class ImageClassifier(BaseClassifier):
optimizer updating, which are done in the :meth:`train_step`.
Args:
batch_inputs (torch.Tensor): The input tensor with shape
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``.
@ -104,20 +104,20 @@ class ImageClassifier(BaseClassifier):
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'tensor':
feats = self.extract_feat(batch_inputs)
feats = self.extract_feat(inputs)
return self.head(feats) if self.with_head else feats
elif mode == 'loss':
return self.loss(batch_inputs, data_samples)
return self.loss(inputs, data_samples)
elif mode == 'predict':
return self.predict(batch_inputs, data_samples)
return self.predict(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def extract_feat(self, batch_inputs, stage='neck'):
def extract_feat(self, inputs, stage='neck'):
"""Extract features from the input tensor with shape (N, C, ...).
Args:
batch_inputs (Tensor): A batch of inputs. The shape of it should be
inputs (Tensor): A batch of inputs. The shape of it should be
``(num_samples, num_channels, *img_shape)``.
stage (str): Which stage to output the feature. Choose from:
@ -189,7 +189,7 @@ class ImageClassifier(BaseClassifier):
(f'Invalid output stage "{stage}", please choose from "backbone", '
'"neck" and "pre_logits"')
x = self.backbone(batch_inputs)
x = self.backbone(inputs)
if stage == 'backbone':
return x
@ -203,12 +203,12 @@ class ImageClassifier(BaseClassifier):
"No head or the head doesn't implement `pre_logits` method."
return self.head.pre_logits(x)
def loss(self, batch_inputs: torch.Tensor,
def loss(self, inputs: torch.Tensor,
data_samples: List[ClsDataSample]) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (torch.Tensor): The input tensor with shape
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[ClsDataSample]): The annotation data of
every samples.
@ -216,21 +216,21 @@ class ImageClassifier(BaseClassifier):
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
feats = self.extract_feat(batch_inputs)
feats = self.extract_feat(inputs)
return self.head.loss(feats, data_samples)
def predict(self,
batch_inputs: tuple,
inputs: tuple,
data_samples: Optional[List[ClsDataSample]] = None,
**kwargs) -> List[ClsDataSample]:
"""Predict results from the extracted features.
Args:
batch_inputs (tuple): The features extracted from the backbone.
inputs (tuple): The features extracted from the backbone.
data_samples (List[ClsDataSample], optional): The annotation
data of every samples. Defaults to None.
**kwargs: Other keyword arguments accepted by the ``predict``
method of :attr:`head`.
"""
feats = self.extract_feat(batch_inputs)
feats = self.extract_feat(inputs)
return self.head.predict(feats, data_samples, **kwargs)

View File

@ -2,8 +2,8 @@
from abc import ABCMeta, abstractmethod
from typing import List, Optional, Tuple
from mmengine import BaseDataElement
from mmengine.model import BaseModule
from mmengine.structures import BaseDataElement
class BaseHead(BaseModule, metaclass=ABCMeta):

View File

@ -2,7 +2,7 @@
from typing import Dict, List, Optional, Tuple
import torch
from mmengine.data import LabelData
from mmengine.structures import LabelData
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample

View File

@ -7,7 +7,7 @@ import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer
from mmengine.model import Sequential
from mmengine.model.utils import trunc_normal_
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from .cls_head import ClsHead

View File

@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.drop import build_dropout
from mmengine.model import BaseModule
from mmengine.model.utils import trunc_normal_
from mmengine.model.weight_init import trunc_normal_
from mmcls.registry import MODELS
from .helpers import to_2tuple

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
from mmengine.data import LabelData
from mmengine.structures import LabelData
from mmcls.registry import BATCH_AUGMENTS
from mmcls.structures import ClsDataSample

View File

@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from numbers import Number
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence
import torch
import torch.nn.functional as F
from mmengine.model import BaseDataPreprocessor, stack_batch
from mmcls.registry import MODELS
@ -74,39 +76,71 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
else:
self.batch_augments = None
def forward(self,
data: Sequence[dict],
training: bool = False) -> Tuple[torch.Tensor, list]:
def forward(self, data: dict, training: bool = False) -> dict:
"""Perform normalization, padding, bgr2rgb conversion and batch
augmentation based on ``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Tuple[torch.Tensor, list]: Data in the same format as the model
input.
dict: Data in the same format as the model input.
"""
inputs, batch_data_samples = self.collate_data(data)
data = self.cast_data(data)
inputs = data['inputs']
# --- Pad and stack --
batch_inputs = stack_batch(inputs, self.pad_size_divisor,
self.pad_value)
if isinstance(inputs, torch.Tensor):
# The branch if use `default_collate` as the collate_fn in the
# dataloader.
# ------ To RGB ------
if self.to_rgb and batch_inputs.size(1) == 3:
batch_inputs = batch_inputs[:, [2, 1, 0], ...]
# ------ To RGB ------
if self.to_rgb and inputs.size(1) == 3:
inputs = inputs.flip(1)
# -- Normalization ---
if self._enable_normalize:
batch_inputs = (batch_inputs - self.mean) / self.std
# -- Normalization ---
inputs = inputs.float()
if self._enable_normalize:
inputs = (inputs - self.mean) / self.std
# ------ Padding -----
if self.pad_size_divisor > 1:
h, w = inputs.shape[-2:]
target_h = math.ceil(
h / self.pad_size_divisor) * self.pad_size_divisor
target_w = math.ceil(
w / self.pad_size_divisor) * self.pad_size_divisor
pad_h = target_h - h
pad_w = target_w - w
inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant',
self.pad_value)
else:
batch_inputs = batch_inputs.to(torch.float32)
# The branch if use `pseudo_collate` as the collate_fn in the
# dataloader.
processed_inputs = []
for input_ in inputs:
# ------ To RGB ------
if self.to_rgb and input_.size(0) == 3:
input_ = input_.flip(0)
# -- Normalization ---
input_ = input_.float()
if self._enable_normalize:
input_ = (input_ - self.mean) / self.std
processed_inputs.append(input_)
# Combine padding and stack
inputs = stack_batch(processed_inputs, self.pad_size_divisor,
self.pad_value)
# ----- Batch Aug ----
if training and self.batch_augments is not None:
batch_inputs, batch_data_samples = self.batch_augments(
batch_inputs, batch_data_samples)
data_samples = data['data_samples']
inputs, data_samples = self.batch_augments(inputs, data_samples)
data['data_samples'] = data_samples
return batch_inputs, batch_data_samples
data['inputs'] = inputs
return data

View File

@ -5,7 +5,7 @@ from typing import Sequence, Union
import numpy as np
import torch
from mmengine.data import BaseDataElement, LabelData
from mmengine.structures import BaseDataElement, LabelData
from mmengine.utils import is_str

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils import collect_env as collect_base_env
from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env as collect_base_env
import mmcls

View File

@ -3,8 +3,8 @@ from typing import Optional, Tuple
import mmcv
import numpy as np
from mmengine import Visualizer
from mmengine.dist import master_only
from mmengine.visualization import Visualizer
from mmcls.registry import VISUALIZERS
from mmcls.structures import ClsDataSample

View File

@ -6,7 +6,7 @@ import unittest
import mmcv
import numpy as np
import torch
from mmengine.data import LabelData
from mmengine.structures import LabelData
from PIL import Image
from mmcls.registry import TRANSFORMS
@ -36,10 +36,10 @@ class TestPackClsInputs(unittest.TestCase):
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIn('data_sample', results)
self.assertIsInstance(results['data_sample'], ClsDataSample)
self.assertIn('flip', results['data_sample'].metainfo_keys())
self.assertIsInstance(results['data_sample'].gt_label, LabelData)
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], ClsDataSample)
self.assertIn('flip', results['data_samples'].metainfo_keys())
self.assertIsInstance(results['data_samples'].gt_label, LabelData)
# Test grayscale image
data['img'] = data['img'].mean(-1)
@ -53,7 +53,7 @@ class TestPackClsInputs(unittest.TestCase):
del data['gt_label']
with self.assertWarnsRegex(Warning, 'Cannot get "img"'):
results = transform(copy.deepcopy(data))
self.assertNotIn('gt_label', results['data_sample'])
self.assertNotIn('gt_label', results['data_samples'])
def test_repr(self):
cfg = dict(type='PackClsInputs', meta_keys=['flip', 'img_shape'])

View File

@ -23,10 +23,10 @@ class TestVisualizationHook(TestCase):
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(2)
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
self.data_batch = [{
'inputs': torch.randint(0, 256, (3, 224, 224)),
'data_sample': data_sample
}] * 10
self.data_batch = {
'inputs': torch.randint(0, 256, (10, 3, 224, 224)),
'data_sample': [data_sample] * 10
}
self.outputs = [data_sample] * 10

View File

@ -137,11 +137,6 @@ class TestMultiLabel(TestCase):
MultiLabelMetric.calculate(y_pred_binary, 5)
def test_evaluate(self):
fake_data_batch = [{
'inputs': None,
'data_sample': ClsDataSample()
} for _ in range(4)]
y_true = [[0], [1, 3], [0, 1, 2], [3]]
y_true_binary = torch.tensor([
[1, 0, 0, 0],
@ -163,7 +158,7 @@ class TestMultiLabel(TestCase):
# Test with default argument
evaluator = Evaluator(dict(type='MultiLabelMetric'))
evaluator.process(fake_data_batch, pred)
evaluator.process(pred)
res = evaluator.evaluate(4)
self.assertIsInstance(res, dict)
thr05_y_pred = np.array([
@ -184,7 +179,7 @@ class TestMultiLabel(TestCase):
# Test with topk argument
evaluator = Evaluator(dict(type='MultiLabelMetric', topk=1))
evaluator.process(fake_data_batch, pred)
evaluator.process(pred)
res = evaluator.evaluate(4)
self.assertIsInstance(res, dict)
top1_y_pred = np.array([
@ -205,7 +200,7 @@ class TestMultiLabel(TestCase):
# Test with both argument
evaluator = Evaluator(dict(type='MultiLabelMetric', thr=0.25, topk=1))
evaluator.process(fake_data_batch, pred)
evaluator.process(pred)
res = evaluator.evaluate(4)
self.assertIsInstance(res, dict)
# Expected values come from sklearn
@ -228,7 +223,7 @@ class TestMultiLabel(TestCase):
# Test with average micro
evaluator = Evaluator(dict(type='MultiLabelMetric', average='micro'))
evaluator.process(fake_data_batch, pred)
evaluator.process(pred)
res = evaluator.evaluate(4)
self.assertIsInstance(res, dict)
# Expected values come from sklearn
@ -247,7 +242,7 @@ class TestMultiLabel(TestCase):
# Test with average None
evaluator = Evaluator(dict(type='MultiLabelMetric', average=None))
evaluator.process(fake_data_batch, pred)
evaluator.process(pred)
res = evaluator.evaluate(4)
self.assertIsInstance(res, dict)
# Expected values come from sklearn
@ -271,7 +266,7 @@ class TestMultiLabel(TestCase):
]
evaluator = Evaluator(dict(type='MultiLabelMetric', items=['support']))
evaluator.process(fake_data_batch, pred)
evaluator.process(pred)
res = evaluator.evaluate(4)
self.assertIsInstance(res, dict)
self.assertEqual(res['multi-label/support'], 7)
@ -308,11 +303,6 @@ class TestAveragePrecision(TestCase):
[1, 0, 0, 0],
])
fake_data_batch = [{
'inputs': None,
'data_sample': ClsDataSample()
} for _ in range(4)]
pred = [
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
for i, j in zip(y_pred, y_true)
@ -320,14 +310,14 @@ class TestAveragePrecision(TestCase):
# Test with default macro avergae
evaluator = Evaluator(dict(type='AveragePrecision'))
evaluator.process(fake_data_batch, pred)
evaluator.process(pred)
res = evaluator.evaluate(5)
self.assertIsInstance(res, dict)
self.assertAlmostEqual(res['multi-label/mAP'], 70.83333, places=4)
# Test with average mode None
evaluator = Evaluator(dict(type='AveragePrecision', average=None))
evaluator.process(fake_data_batch, pred)
evaluator.process(pred)
res = evaluator.evaluate(5)
self.assertIsInstance(res, dict)
aps = res['multi-label/AP_classwise']
@ -342,7 +332,7 @@ class TestAveragePrecision(TestCase):
for i, j in zip(y_pred, [[0, 1], [1], [2], [0]])
]
evaluator = Evaluator(dict(type='AveragePrecision'))
evaluator.process(fake_data_batch, pred)
evaluator.process(pred)
res = evaluator.evaluate(5)
self.assertAlmostEqual(res['multi-label/mAP'], 70.83333, places=4)

View File

@ -14,31 +14,28 @@ class TestAccuracy(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
data_batch = [{
'data_sample': ClsDataSample().set_gt_label(i).to_dict()
} for i in [0, 0, 1, 2, 1, 0]]
pred = [
ClsDataSample().set_pred_score(i).set_pred_label(j).to_dict()
for i, j in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
torch.tensor([0.4, 0.5, 0.1]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
], [0, 0, 1, 2, 2, 2])
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
k).to_dict() for i, j, k in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
torch.tensor([0.4, 0.5, 0.1]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
], [0, 0, 1, 2, 2, 2], [0, 0, 1, 2, 1, 0])
]
# Test with score (use score instead of label if score exists)
metric = METRICS.build(dict(type='Accuracy', thrs=0.6))
metric.process(data_batch, pred)
metric.process(None, pred)
acc = metric.evaluate(6)
self.assertIsInstance(acc, dict)
self.assertAlmostEqual(acc['accuracy/top1'], 2 / 6 * 100, places=4)
# Test with multiple thrs
metric = METRICS.build(dict(type='Accuracy', thrs=(0., 0.6, None)))
metric.process(data_batch, pred)
metric.process(None, pred)
acc = metric.evaluate(6)
self.assertSetEqual(
set(acc.keys()), {
@ -49,14 +46,14 @@ class TestAccuracy(TestCase):
# Test with invalid topk
with self.assertRaisesRegex(ValueError, 'check the `val_evaluator`'):
metric = METRICS.build(dict(type='Accuracy', topk=(1, 5)))
metric.process(data_batch, pred)
metric.process(None, pred)
metric.evaluate(6)
# Test with label
for sample in pred:
del sample['pred_label']['score']
metric = METRICS.build(dict(type='Accuracy', thrs=(0., 0.6, None)))
metric.process(data_batch, pred)
metric.process(None, pred)
acc = metric.evaluate(6)
self.assertIsInstance(acc, dict)
self.assertAlmostEqual(acc['accuracy/top1'], 4 / 6 * 100, places=4)
@ -124,19 +121,16 @@ class TestSingleLabel(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
data_batch = [{
'data_sample': ClsDataSample().set_gt_label(i).to_dict()
} for i in [0, 0, 1, 2, 1, 0]]
pred = [
ClsDataSample().set_pred_score(i).set_pred_label(j).to_dict()
for i, j in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
torch.tensor([0.4, 0.5, 0.1]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
], [0, 0, 1, 2, 2, 2])
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
k).to_dict() for i, j, k in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
torch.tensor([0.4, 0.5, 0.1]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
], [0, 0, 1, 2, 2, 2], [0, 0, 1, 2, 1, 0])
]
# Test with score (use score instead of label if score exists)
@ -145,7 +139,7 @@ class TestSingleLabel(TestCase):
type='SingleLabelMetric',
thrs=0.6,
items=('precision', 'recall', 'f1-score', 'support')))
metric.process(data_batch, pred)
metric.process(None, pred)
res = metric.evaluate(6)
self.assertIsInstance(res, dict)
self.assertAlmostEqual(
@ -160,7 +154,7 @@ class TestSingleLabel(TestCase):
# Test with multiple thrs
metric = METRICS.build(
dict(type='SingleLabelMetric', thrs=(0., 0.6, None)))
metric.process(data_batch, pred)
metric.process(None, pred)
res = metric.evaluate(6)
self.assertSetEqual(
set(res.keys()), {
@ -180,7 +174,7 @@ class TestSingleLabel(TestCase):
type='SingleLabelMetric',
average='micro',
items=('precision', 'recall', 'f1-score', 'support')))
metric.process(data_batch, pred)
metric.process(None, pred)
res = metric.evaluate(6)
self.assertIsInstance(res, dict)
self.assertAlmostEqual(
@ -197,7 +191,7 @@ class TestSingleLabel(TestCase):
type='SingleLabelMetric',
average=None,
items=('precision', 'recall', 'f1-score', 'support')))
metric.process(data_batch, pred)
metric.process(None, pred)
res = metric.evaluate(6)
self.assertIsInstance(res, dict)
precision = res['single-label/precision_classwise']
@ -219,7 +213,7 @@ class TestSingleLabel(TestCase):
for sample in pred_no_score:
del sample['pred_label']['score']
metric = METRICS.build(dict(type='SingleLabelMetric', thrs=(0., 0.6)))
metric.process(data_batch, pred_no_score)
metric.process(None, pred_no_score)
res = metric.evaluate(6)
self.assertIsInstance(res, dict)
# Expected values come from sklearn
@ -231,16 +225,16 @@ class TestSingleLabel(TestCase):
for sample in pred_no_num_classes:
del sample['pred_label']['num_classes']
with self.assertRaisesRegex(ValueError, 'neither `score` nor'):
metric.process(data_batch, pred_no_num_classes)
metric.process(None, pred_no_num_classes)
# Test with empty items
metric = METRICS.build(dict(type='SingleLabelMetric', items=tuple()))
metric.process(data_batch, pred)
metric.process(None, pred)
res = metric.evaluate(6)
self.assertIsInstance(res, dict)
self.assertEqual(len(res), 0)
metric.process(data_batch, pred_no_score)
metric.process(None, pred_no_score)
res = metric.evaluate(6)
self.assertIsInstance(res, dict)
self.assertEqual(len(res), 0)

View File

@ -5,7 +5,7 @@ from unittest import TestCase
import torch
from mmcv.cnn import ConvModule
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import CSPDarkNet, CSPResNet, CSPResNeXt
from mmcls.models.backbones.cspnet import (CSPNet, DarknetBottleneck,

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import Res2Net

View File

@ -3,7 +3,7 @@ import pytest
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import ResNet, ResNetV1c, ResNetV1d
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import ResNet_CIFAR

View File

@ -8,7 +8,7 @@ from unittest import TestCase
import torch
from mmengine.runner import load_checkpoint, save_checkpoint
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import SwinTransformer
from mmcls.models.backbones.swin_transformer import SwinBlock

View File

@ -5,7 +5,7 @@ from itertools import chain
from unittest import TestCase
import torch
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch import nn
from mmcls.models.backbones import VAN

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import VGG

View File

@ -173,10 +173,10 @@ class TestImageClassifier(TestCase):
}
model: ImageClassifier = MODELS.build(cfg)
data = [{
'inputs': torch.randint(0, 256, (3, 224, 224)),
'data_sample': ClsDataSample().set_gt_label(1)
}]
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
}
optim_wrapper = MagicMock()
log_vars = model.train_step(data, optim_wrapper)
@ -190,10 +190,10 @@ class TestImageClassifier(TestCase):
}
model: ImageClassifier = MODELS.build(cfg)
data = [{
'inputs': torch.randint(0, 256, (3, 224, 224)),
'data_sample': ClsDataSample().set_gt_label(1)
}]
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
}
predictions = model.val_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
@ -205,10 +205,10 @@ class TestImageClassifier(TestCase):
}
model: ImageClassifier = MODELS.build(cfg)
data = [{
'inputs': torch.randint(0, 256, (3, 224, 224)),
'data_sample': ClsDataSample().set_gt_label(1)
}]
data = {
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
'data_samples': [ClsDataSample().set_gt_label(1)]
}
predictions = model.test_step(data)
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))

View File

@ -17,11 +17,13 @@ class TestClsDataPreprocessor(TestCase):
cfg = dict(type='ClsDataPreprocessor')
processor: ClsDataPreprocessor = MODELS.build(cfg)
data = [{
'inputs': torch.randint(0, 256, (3, 224, 224)),
'data_sample': ClsDataSample().set_gt_label(1)
}]
inputs, data_samples = processor(data)
data = {
'inputs': [torch.randint(0, 256, (3, 224, 224))],
'data_samples': [ClsDataSample().set_gt_label(1)]
}
processed_data = processor(data)
inputs = processed_data['inputs']
data_samples = processed_data['data_samples']
self.assertEqual(inputs.shape, (1, 3, 224, 224))
self.assertEqual(len(data_samples), 1)
self.assertTrue(
@ -31,22 +33,31 @@ class TestClsDataPreprocessor(TestCase):
cfg = dict(type='ClsDataPreprocessor', pad_size_divisor=16)
processor: ClsDataPreprocessor = MODELS.build(cfg)
data = [{
'inputs': torch.randint(0, 256, (3, 255, 255))
}, {
'inputs': torch.randint(0, 256, (3, 224, 224))
}]
inputs, _ = processor(data)
data = {
'inputs': [
torch.randint(0, 256, (3, 255, 255)),
torch.randint(0, 256, (3, 224, 224))
]
}
inputs = processor(data)['inputs']
self.assertEqual(inputs.shape, (2, 3, 256, 256))
data = {'inputs': torch.randint(0, 256, (2, 3, 255, 255))}
inputs = processor(data)['inputs']
self.assertEqual(inputs.shape, (2, 3, 256, 256))
def test_to_rgb(self):
cfg = dict(type='ClsDataPreprocessor', to_rgb=True)
processor: ClsDataPreprocessor = MODELS.build(cfg)
data = [{'inputs': torch.randint(0, 256, (3, 224, 224))}]
inputs, _ = processor(data)
torch.testing.assert_allclose(
data[0]['inputs'].flip(dims=(0, )).to(torch.float32), inputs[0])
data = {'inputs': [torch.randint(0, 256, (3, 224, 224))]}
inputs = processor(data)['inputs']
torch.testing.assert_allclose(data['inputs'][0].flip(0).float(),
inputs[0])
data = {'inputs': torch.randint(0, 256, (1, 3, 224, 224))}
inputs = processor(data)['inputs']
torch.testing.assert_allclose(data['inputs'].flip(1).float(), inputs)
def test_normalization(self):
cfg = dict(
@ -55,11 +66,17 @@ class TestClsDataPreprocessor(TestCase):
std=[127.5, 127.5, 127.5])
processor: ClsDataPreprocessor = MODELS.build(cfg)
data = [{'inputs': torch.randint(0, 256, (3, 224, 224))}]
inputs, data_samples = processor(data)
data = {'inputs': [torch.randint(0, 256, (3, 224, 224))]}
processed_data = processor(data)
inputs = processed_data['inputs']
self.assertTrue((inputs >= -1).all())
self.assertTrue((inputs <= 1).all())
self.assertNotIn('data_samples', processed_data)
data = {'inputs': torch.randint(0, 256, (1, 3, 224, 224))}
inputs = processor(data)['inputs']
self.assertTrue((inputs >= -1).all())
self.assertTrue((inputs <= 1).all())
self.assertIsNone(data_samples)
def test_batch_augmentation(self):
cfg = dict(
@ -70,17 +87,18 @@ class TestClsDataPreprocessor(TestCase):
])
processor: ClsDataPreprocessor = MODELS.build(cfg)
self.assertIsInstance(processor.batch_augments, RandomBatchAugment)
data = [{
'inputs': torch.randint(0, 256, (3, 224, 224)),
'data_sample': ClsDataSample().set_gt_label(1)
}]
_, data_samples = processor(data, training=True)
data = {
'inputs': [torch.randint(0, 256, (3, 224, 224))],
'data_samples': [ClsDataSample().set_gt_label(1)]
}
processed_data = processor(data, training=True)
self.assertIn('inputs', processed_data)
self.assertIn('data_samples', processed_data)
cfg['batch_augments'] = None
processor: ClsDataPreprocessor = MODELS.build(cfg)
self.assertIsNone(processor.batch_augments)
data = [{
'inputs': torch.randint(0, 256, (3, 224, 224)),
}]
_, data_samples = processor(data, training=True)
self.assertIsNone(data_samples)
data = {'inputs': [torch.randint(0, 256, (3, 224, 224))]}
processed_data = processor(data, training=True)
self.assertIn('inputs', processed_data)
self.assertNotIn('data_samples', processed_data)

View File

@ -3,7 +3,7 @@ from unittest import TestCase
import numpy as np
import torch
from mmengine.data import LabelData
from mmengine.structures import LabelData
from mmcls.structures import ClsDataSample

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmcv import Config, DictAction
from mmengine import Config, DictAction
def parse_args():

View File

@ -4,7 +4,8 @@ import fcntl
import os
from pathlib import Path
from mmcv import Config, DictAction, track_parallel_progress, track_progress
from mmengine import (Config, DictAction, track_parallel_progress,
track_progress)
from mmcls.datasets import PIPELINES, build_dataset

View File

@ -5,6 +5,7 @@ import os.path as osp
import mmengine
from mmengine.config import Config, DictAction
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmcls.utils import register_all_modules
@ -122,7 +123,7 @@ def main():
if args.out:
class SaveMetricHook(mmengine.Hook):
class SaveMetricHook(Hook):
def after_test_epoch(self, _, metrics=None):
if metrics is not None:

View File

@ -9,8 +9,11 @@ from unittest.mock import MagicMock
import matplotlib.pyplot as plt
import rich
import torch.nn as nn
from mmengine import Config, DictAction, Hook, Runner, Visualizer
from mmengine.config import Config, DictAction
from mmengine.hooks import Hook
from mmengine.model import BaseModel
from mmengine.runner import Runner
from mmengine.visualization import Visualizer
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn
from mmcls.utils import register_all_modules
@ -24,7 +27,7 @@ class SimpleModel(BaseModel):
self.data_preprocessor = nn.Identity()
self.conv = nn.Conv2d(1, 1, 1)
def forward(self, batch_inputs, data_samples, mode='tensor'):
def forward(self, inputs, data_samples, mode='tensor'):
pass
def train_step(self, data, optim_wrapper):