[Refactor] Refactor the data flow. (#989)
* [Refactor] Refactor the data flow. * Add comments about data preprocessor. * Fix after mmengine folder structure refactoring.pull/987/head
parent
b4e39d51d6
commit
2b88df4484
|
@ -3,8 +3,7 @@ import warnings
|
|||
|
||||
import torch
|
||||
from mmengine.config import Config
|
||||
from mmengine.data import pseudo_collate
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.dataset import Compose, pseudo_collate
|
||||
from mmengine.runner import load_checkpoint
|
||||
|
||||
from mmcls.models import build_classifier
|
||||
|
|
|
@ -52,12 +52,12 @@ class PackClsInputs(BaseTransform):
|
|||
**Added Keys:**
|
||||
|
||||
- inputs (:obj:`torch.Tensor`): The forward data of models.
|
||||
- data_sample (:obj:`~mmcls.structures.ClsDataSample`): The annotation info
|
||||
of the sample.
|
||||
- data_samples (:obj:`~mmcls.structures.ClsDataSample`): The annotation
|
||||
info of the sample.
|
||||
|
||||
Args:
|
||||
meta_keys (Sequence[str]): The meta keys to be saved in the
|
||||
``metainfo`` of the packed ``data_sample``.
|
||||
``metainfo`` of the packed ``data_samples``.
|
||||
Defaults to a tuple includes keys:
|
||||
|
||||
- ``sample_idx``: The id of the image sample.
|
||||
|
@ -99,7 +99,7 @@ class PackClsInputs(BaseTransform):
|
|||
|
||||
img_meta = {k: results[k] for k in self.meta_keys if k in results}
|
||||
data_sample.set_metainfo(img_meta)
|
||||
packed_results['data_sample'] = data_sample
|
||||
packed_results['data_samples'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
|
|
|
@ -54,21 +54,22 @@ class VisualizationHook(Hook):
|
|||
|
||||
def _draw_samples(self,
|
||||
batch_idx: int,
|
||||
data_batch: Sequence[dict],
|
||||
outputs: Sequence[ClsDataSample],
|
||||
data_batch: dict,
|
||||
data_samples: Sequence[ClsDataSample],
|
||||
step: int = 0) -> None:
|
||||
"""Visualize every ``self.interval`` samples from a data batch.
|
||||
|
||||
Args:
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[dict]): Data from dataloader.
|
||||
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
|
||||
step (int): Global step value to record. Defaults to 0.
|
||||
"""
|
||||
if self.enable is False:
|
||||
return
|
||||
|
||||
batch_size = len(outputs)
|
||||
batch_size = len(data_samples)
|
||||
images = data_batch['inputs']
|
||||
start_idx = batch_size * batch_idx
|
||||
end_idx = start_idx + batch_size
|
||||
|
||||
|
@ -76,10 +77,10 @@ class VisualizationHook(Hook):
|
|||
first_sample_id = math.ceil(start_idx / self.interval) * self.interval
|
||||
|
||||
for sample_id in range(first_sample_id, end_idx, self.interval):
|
||||
image = data_batch[sample_id - start_idx]['inputs']
|
||||
image = images[sample_id - start_idx]
|
||||
image = image.permute(1, 2, 0).numpy().astype('uint8')
|
||||
|
||||
data_sample = outputs[sample_id - start_idx]
|
||||
data_sample = data_samples[sample_id - start_idx]
|
||||
if 'img_path' in data_sample:
|
||||
# osp.basename works on different platforms even file clients.
|
||||
sample_name = osp.basename(data_sample.get('img_path'))
|
||||
|
@ -99,15 +100,14 @@ class VisualizationHook(Hook):
|
|||
**self.draw_args,
|
||||
)
|
||||
|
||||
def after_val_iter(self, runner: Runner, batch_idx: int,
|
||||
data_batch: Sequence[dict],
|
||||
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
||||
outputs: Sequence[ClsDataSample]) -> None:
|
||||
"""Visualize every ``self.interval`` samples during validation.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[dict]): Data from dataloader.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model.
|
||||
"""
|
||||
if isinstance(runner.train_loop, EpochBasedTrainLoop):
|
||||
|
@ -117,15 +117,14 @@ class VisualizationHook(Hook):
|
|||
|
||||
self._draw_samples(batch_idx, data_batch, outputs, step=step)
|
||||
|
||||
def after_test_iter(self, runner: Runner, batch_idx: int,
|
||||
data_batch: Sequence[dict],
|
||||
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
||||
outputs: Sequence[ClsDataSample]) -> None:
|
||||
"""Visualize every ``self.interval`` samples during test.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the testing process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[dict]): Data from dataloader.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`DetDataSample`]): Outputs from model.
|
||||
"""
|
||||
self._draw_samples(batch_idx, data_batch, outputs, step=0)
|
||||
|
|
|
@ -3,8 +3,9 @@ from typing import List, Optional, Sequence, Union
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import LabelData, MMLogger
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
from .single_label import _precision_recall_f1_support, to_tensor
|
||||
|
@ -95,15 +96,11 @@ class MultiLabelMetric(BaseMetric):
|
|||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.structures import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> # The `data_batch` won't be used in this case, just use a fake.
|
||||
>>> data_batch = [
|
||||
... {'inputs': None, 'data_sample': ClsDataSample()}
|
||||
... for i in range(1000)]
|
||||
>>> pred = [
|
||||
... ClsDataSample().set_pred_score(torch.rand((5, ))).set_gt_score(torch.randint(2, size=(5, )))
|
||||
... for i in range(1000)]
|
||||
>>> evaluator = Evaluator(metrics=MultiLabelMetric(thrs=0.5))
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> data_sampels = [
|
||||
... ClsDataSample().set_pred_score(pred).set_gt_score(gt)
|
||||
... for pred, gt in zip(torch.rand(1000, 5), torch.randint(0, 2, (1000, 5)))]
|
||||
>>> evaluator = Evaluator(metrics=MultiLabelMetric(thr=0.5))
|
||||
>>> evaluator.process(data_sampels)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'multi-label/precision': 50.72898037055408,
|
||||
|
@ -112,26 +109,13 @@ class MultiLabelMetric(BaseMetric):
|
|||
}
|
||||
>>> # Evaluate on each class by using topk strategy
|
||||
>>> evaluator = Evaluator(metrics=MultiLabelMetric(topk=1, average=None))
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.process(data_sampels)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'multi-label/precision_top1_classwise': [48.22, 50.54, 50.99, 44.18, 52.5],
|
||||
'multi-label/recall_top1_classwise': [18.92, 19.22, 19.92, 20.0, 20.27],
|
||||
'multi-label/f1-score_top1_classwise': [27.18, 27.85, 28.65, 27.54, 29.25]
|
||||
}
|
||||
>>> # Evaluate by label data got from head
|
||||
>>> pred = [
|
||||
... ClsDataSample().set_pred_score(torch.rand((5, ))).set_pred_label(
|
||||
... torch.randint(2, size=(5, ))).set_gt_score(torch.randint(2, size=(5, )))
|
||||
... for i in range(1000)]
|
||||
>>> evaluator = Evaluator(metrics=MultiLabelMetric())
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'multi-label/precision': 20.28921606216292,
|
||||
'multi-label/recall': 38.628095855722314,
|
||||
'multi-label/f1-score': 26.603530359627918
|
||||
}
|
||||
""" # noqa: E501
|
||||
default_prefix: Optional[str] = 'multi-label'
|
||||
|
||||
|
@ -165,20 +149,20 @@ class MultiLabelMetric(BaseMetric):
|
|||
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
|
||||
"""Process one batch of data and predictions.
|
||||
def process(self, data_batch, data_samples: Sequence[dict]):
|
||||
"""Process one batch of data samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
data_batch: A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for pred in predictions:
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
pred_label = pred['pred_label']
|
||||
gt_label = pred['gt_label']
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
|
||||
result['pred_score'] = pred_label['score'].clone()
|
||||
num_classes = result['pred_score'].size()[-1]
|
||||
|
@ -459,21 +443,17 @@ class AveragePrecision(BaseMetric):
|
|||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.structures import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> # The `data_batch` won't be used in this case, just use a fake.
|
||||
>>> data_batch = [
|
||||
... {'inputs': None, 'data_sample': ClsDataSample()}
|
||||
... for i in range(4)]
|
||||
>>> pred = [
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_pred_score(i).set_gt_score(j)
|
||||
... for i, j in zip(y_pred, y_true)
|
||||
... ]
|
||||
>>> evaluator = Evaluator(metrics=AveragePrecision())
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.process(data_samples)
|
||||
>>> evaluator.evaluate(5)
|
||||
{'multi-label/mAP': 70.83333587646484}
|
||||
>>> # Evaluate on each class
|
||||
>>> evaluator = Evaluator(metrics=AveragePrecision(average=None))
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.process(data_samples)
|
||||
>>> evaluator.evaluate(5)
|
||||
{'multi-label/AP_classwise': [100., 83.33, 100., 0.]}
|
||||
"""
|
||||
|
@ -486,21 +466,21 @@ class AveragePrecision(BaseMetric):
|
|||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
self.average = average
|
||||
|
||||
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
|
||||
"""Process one batch of data and predictions.
|
||||
def process(self, data_batch, data_samples: Sequence[dict]):
|
||||
"""Process one batch of data samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
data_batch: A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
|
||||
for pred in predictions:
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
pred_label = pred['pred_label']
|
||||
gt_label = pred['gt_label']
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
|
||||
result['pred_score'] = pred_label['score']
|
||||
num_classes = result['pred_score'].size()[-1]
|
||||
|
|
|
@ -88,16 +88,12 @@ class Accuracy(BaseMetric):
|
|||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.structures import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_batch = [{
|
||||
... 'inputs': None, # In this example, the `inputs` is not used.
|
||||
... 'data_sample': ClsDataSample().set_gt_label(0)
|
||||
... } for i in range(1000)]
|
||||
>>> pred = [
|
||||
... ClsDataSample().set_pred_score(torch.rand(10))
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_gt_label(0).set_pred_score(torch.rand(10))
|
||||
... for i in range(1000)
|
||||
... ]
|
||||
>>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5)))
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.process(data_samples)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'accuracy/top1': 9.300000190734863,
|
||||
|
@ -123,22 +119,21 @@ class Accuracy(BaseMetric):
|
|||
else:
|
||||
self.thrs = tuple(thrs)
|
||||
|
||||
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
|
||||
"""Process one batch of data and predictions.
|
||||
def process(self, data_batch, data_samples: Sequence[dict]):
|
||||
"""Process one batch of data samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
data_batch: A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
|
||||
for data, pred in zip(data_batch, predictions):
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
pred_label = pred['pred_label']
|
||||
# Use gt_label in the pred dict preferentially.
|
||||
gt_label = pred.get('gt_label', data['data_sample']['gt_label'])
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
if 'score' in pred_label:
|
||||
result['pred_score'] = pred_label['score'].cpu()
|
||||
else:
|
||||
|
@ -317,45 +312,30 @@ class SingleLabelMetric(BaseMetric):
|
|||
>>> y_true = [0, 2, 1, 3]
|
||||
>>> # Output precision, recall, f1-score and support.
|
||||
>>> SingleLabelMetric.calculate(y_pred, y_true, num_classes=4)
|
||||
(tensor(62.5000, dtype=torch.float64),
|
||||
tensor(75., dtype=torch.float64),
|
||||
tensor(66.6667, dtype=torch.float64),
|
||||
tensor(4))
|
||||
(tensor(62.5000), tensor(75.), tensor(66.6667), tensor(4))
|
||||
>>> # Calculate with different thresholds.
|
||||
>>> y_score = torch.rand((1000, 10))
|
||||
>>> y_true = torch.zeros((1000, ))
|
||||
>>> SingleLabelMetric.calculate(y_score, y_true, thrs=(0., 0.9))
|
||||
[(tensor(10., dtype=torch.float64),
|
||||
tensor(1.2100, dtype=torch.float64),
|
||||
tensor(2.1588, dtype=torch.float64),
|
||||
tensor(1000)),
|
||||
(tensor(10., dtype=torch.float64),
|
||||
tensor(0.8200, dtype=torch.float64),
|
||||
tensor(1.5157, dtype=torch.float64),
|
||||
tensor(1000))]
|
||||
[(tensor(10.), tensor(0.9500), tensor(1.7352), tensor(1000)),
|
||||
(tensor(10.), tensor(0.5500), tensor(1.0427), tensor(1000))]
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.structures import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_batch = [{
|
||||
... 'inputs': None, # In this example, the `inputs` is not used.
|
||||
... 'data_sample': ClsDataSample().set_gt_label(i%5)
|
||||
... } for i in range(1000)]
|
||||
>>> pred = [
|
||||
... ClsDataSample().set_pred_score(torch.rand(5))
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_gt_label(i%5).set_pred_score(torch.rand(5))
|
||||
... for i in range(1000)
|
||||
... ]
|
||||
>>> evaluator = Evaluator(metrics=SingleLabelMetric())
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.process(data_samples)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'single-label/precision': 10.0,
|
||||
'single-label/recall': 0.96,
|
||||
'single-label/f1-score': 1.7518248175182483
|
||||
}
|
||||
{'single-label/precision': 19.650691986083984,
|
||||
'single-label/recall': 19.600000381469727,
|
||||
'single-label/f1-score': 19.619548797607422}
|
||||
>>> # Evaluate on each class
|
||||
>>> evaluator = Evaluator(metrics=SingleLabelMetric(average=None))
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.process(data_samples)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'single-label/precision_classwise': [21.1, 18.7, 17.8, 19.4, 16.1],
|
||||
|
@ -386,31 +366,28 @@ class SingleLabelMetric(BaseMetric):
|
|||
self.items = tuple(items)
|
||||
self.average = average
|
||||
|
||||
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
|
||||
"""Process one batch of data and predictions.
|
||||
def process(self, data_batch, data_samples: Sequence[dict]):
|
||||
"""Process one batch of data samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
data_batch: A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
|
||||
for data, pred in zip(data_batch, predictions):
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
pred_label = pred['pred_label']
|
||||
# Use gt_label in the pred dict preferentially.
|
||||
gt_label = pred.get('gt_label', data['data_sample']['gt_label'])
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
if 'score' in pred_label:
|
||||
result['pred_score'] = pred_label['score'].cpu()
|
||||
elif ('num_classes' in pred_label
|
||||
or 'num_classes' in data['data_sample']):
|
||||
elif ('num_classes' in pred_label):
|
||||
result['pred_label'] = pred_label['label'].cpu()
|
||||
result['num_classes'] = pred_label.get(
|
||||
'num_classes', None) or data['data_sample']['num_classes']
|
||||
result['num_classes'] = pred_label['num_classes']
|
||||
else:
|
||||
raise ValueError('The `pred_label` in predictions do not '
|
||||
raise ValueError('The `pred_label` in data_samples do not '
|
||||
'have neither `score` nor `num_classes`.')
|
||||
result['gt_label'] = gt_label['label'].cpu()
|
||||
# Save the result to `self.results`.
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmcv.cnn import build_activation_layer, build_norm_layer
|
|||
from mmcv.cnn.bricks.drop import DropPath
|
||||
from mmcv.cnn.bricks.transformer import AdaptivePadding
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.utils import trunc_normal_
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model.utils import trunc_normal_
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule, Sequential
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils.se_layer import SELayer
|
||||
|
|
|
@ -6,8 +6,8 @@ from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
|
|||
build_norm_layer)
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.utils import constant_init
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.model.weight_init import constant_init
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, build_activation_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.utils import constant_init, normal_init
|
||||
from mmengine.model.weight_init import constant_init, normal_init
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import channel_shuffle, make_divisible
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.utils import constant_init, normal_init
|
||||
from mmengine.model.weight_init import constant_init, normal_init
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import channel_shuffle
|
||||
|
|
|
@ -9,8 +9,8 @@ import torch.utils.checkpoint as cp
|
|||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.utils import trunc_normal_
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import (ShiftWindowMSA, resize_pos_embed,
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.utils import trunc_normal_
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.utils import trunc_normal_
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import to_2tuple
|
||||
|
|
|
@ -8,7 +8,8 @@ from mmcv.cnn import Conv2d, build_norm_layer
|
|||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.utils import constant_init, normal_init, trunc_normal_init
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.utils.attention import MultiheadAttention
|
||||
|
|
|
@ -5,7 +5,7 @@ from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
|||
from mmcv.cnn.bricks import DropPath
|
||||
from mmcv.cnn.bricks.transformer import PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmcv.cnn import build_norm_layer
|
|||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.utils import trunc_normal_
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import (BEiTAttention, MultiheadAttention, resize_pos_embed,
|
||||
|
|
|
@ -3,8 +3,8 @@ from abc import ABCMeta, abstractmethod
|
|||
from typing import List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from mmengine import BaseDataElement
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
|
||||
class BaseClassifier(BaseModel, metaclass=ABCMeta):
|
||||
|
@ -43,7 +43,7 @@ class BaseClassifier(BaseModel, metaclass=ABCMeta):
|
|||
|
||||
@abstractmethod
|
||||
def forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
mode: str = 'tensor'):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
@ -61,8 +61,8 @@ class BaseClassifier(BaseModel, metaclass=ABCMeta):
|
|||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
inputs (torch.Tensor): The input tensor with shape (N, C, ...)
|
||||
in general.
|
||||
data_samples (List[BaseDataElement], optional): The annotation
|
||||
data of every samples. It's required if ``mode="loss"``.
|
||||
Defaults to None.
|
||||
|
@ -78,34 +78,31 @@ class BaseClassifier(BaseModel, metaclass=ABCMeta):
|
|||
"""
|
||||
pass
|
||||
|
||||
def extract_feat(self, batch_inputs: torch.Tensor):
|
||||
def extract_feat(self, inputs: torch.Tensor):
|
||||
"""Extract features from the input tensor with shape (N, C, ...).
|
||||
|
||||
The sub-classes are recommended to implement this method to extract
|
||||
features from backbone and neck.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): A batch of inputs. The shape of it should be
|
||||
inputs (Tensor): A batch of inputs. The shape of it should be
|
||||
``(num_samples, num_channels, *img_shape)``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def extract_feats(self, multi_batch_inputs: Sequence[torch.Tensor],
|
||||
def extract_feats(self, multi_inputs: Sequence[torch.Tensor],
|
||||
**kwargs) -> list:
|
||||
"""Extract features from a sequence of input tensor.
|
||||
|
||||
Args:
|
||||
multi_batch_inputs (Sequence[torch.Tensor]): A sequence of input
|
||||
multi_inputs (Sequence[torch.Tensor]): A sequence of input
|
||||
tensor. It can be used in augmented inference.
|
||||
**kwargs: Other keyword arguments accepted by :meth:`extract_feat`.
|
||||
|
||||
Returns:
|
||||
list: Features of every input tensor.
|
||||
"""
|
||||
assert isinstance(multi_batch_inputs, Sequence), \
|
||||
assert isinstance(multi_inputs, Sequence), \
|
||||
'`extract_feats` is used for a sequence of inputs tensor. If you '\
|
||||
'want to extract on single inputs tensor, use `extract_feat`.'
|
||||
return [
|
||||
self.extract_feat(batch_input, **kwargs)
|
||||
for batch_input in multi_batch_inputs
|
||||
]
|
||||
return [self.extract_feat(inputs, **kwargs) for inputs in multi_inputs]
|
||||
|
|
|
@ -70,7 +70,7 @@ class ImageClassifier(BaseClassifier):
|
|||
self.head = MODELS.build(head)
|
||||
|
||||
def forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[ClsDataSample]] = None,
|
||||
mode: str = 'tensor'):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
@ -88,7 +88,7 @@ class ImageClassifier(BaseClassifier):
|
|||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): The input tensor with shape
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data of every samples. It's required if ``mode="loss"``.
|
||||
|
@ -104,20 +104,20 @@ class ImageClassifier(BaseClassifier):
|
|||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'tensor':
|
||||
feats = self.extract_feat(batch_inputs)
|
||||
feats = self.extract_feat(inputs)
|
||||
return self.head(feats) if self.with_head else feats
|
||||
elif mode == 'loss':
|
||||
return self.loss(batch_inputs, data_samples)
|
||||
return self.loss(inputs, data_samples)
|
||||
elif mode == 'predict':
|
||||
return self.predict(batch_inputs, data_samples)
|
||||
return self.predict(inputs, data_samples)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
def extract_feat(self, batch_inputs, stage='neck'):
|
||||
def extract_feat(self, inputs, stage='neck'):
|
||||
"""Extract features from the input tensor with shape (N, C, ...).
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): A batch of inputs. The shape of it should be
|
||||
inputs (Tensor): A batch of inputs. The shape of it should be
|
||||
``(num_samples, num_channels, *img_shape)``.
|
||||
stage (str): Which stage to output the feature. Choose from:
|
||||
|
||||
|
@ -189,7 +189,7 @@ class ImageClassifier(BaseClassifier):
|
|||
(f'Invalid output stage "{stage}", please choose from "backbone", '
|
||||
'"neck" and "pre_logits"')
|
||||
|
||||
x = self.backbone(batch_inputs)
|
||||
x = self.backbone(inputs)
|
||||
|
||||
if stage == 'backbone':
|
||||
return x
|
||||
|
@ -203,12 +203,12 @@ class ImageClassifier(BaseClassifier):
|
|||
"No head or the head doesn't implement `pre_logits` method."
|
||||
return self.head.pre_logits(x)
|
||||
|
||||
def loss(self, batch_inputs: torch.Tensor,
|
||||
def loss(self, inputs: torch.Tensor,
|
||||
data_samples: List[ClsDataSample]) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): The input tensor with shape
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
every samples.
|
||||
|
@ -216,21 +216,21 @@ class ImageClassifier(BaseClassifier):
|
|||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
feats = self.extract_feat(batch_inputs)
|
||||
feats = self.extract_feat(inputs)
|
||||
return self.head.loss(feats, data_samples)
|
||||
|
||||
def predict(self,
|
||||
batch_inputs: tuple,
|
||||
inputs: tuple,
|
||||
data_samples: Optional[List[ClsDataSample]] = None,
|
||||
**kwargs) -> List[ClsDataSample]:
|
||||
"""Predict results from the extracted features.
|
||||
|
||||
Args:
|
||||
batch_inputs (tuple): The features extracted from the backbone.
|
||||
inputs (tuple): The features extracted from the backbone.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data of every samples. Defaults to None.
|
||||
**kwargs: Other keyword arguments accepted by the ``predict``
|
||||
method of :attr:`head`.
|
||||
"""
|
||||
feats = self.extract_feat(batch_inputs)
|
||||
feats = self.extract_feat(inputs)
|
||||
return self.head.predict(feats, data_samples, **kwargs)
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from mmengine import BaseDataElement
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
|
||||
class BaseHead(BaseModule, metaclass=ABCMeta):
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.structures import ClsDataSample
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer
|
||||
from mmengine.model import Sequential
|
||||
from mmengine.model.utils import trunc_normal_
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .cls_head import ClsHead
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.utils import trunc_normal_
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .helpers import to_2tuple
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.registry import BATCH_AUGMENTS
|
||||
from mmcls.structures import ClsDataSample
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from numbers import Number
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseDataPreprocessor, stack_batch
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
|
@ -74,39 +76,71 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
else:
|
||||
self.batch_augments = None
|
||||
|
||||
def forward(self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False) -> Tuple[torch.Tensor, list]:
|
||||
def forward(self, data: dict, training: bool = False) -> dict:
|
||||
"""Perform normalization, padding, bgr2rgb conversion and batch
|
||||
augmentation based on ``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
data (dict): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, list]: Data in the same format as the model
|
||||
input.
|
||||
dict: Data in the same format as the model input.
|
||||
"""
|
||||
inputs, batch_data_samples = self.collate_data(data)
|
||||
data = self.cast_data(data)
|
||||
inputs = data['inputs']
|
||||
|
||||
# --- Pad and stack --
|
||||
batch_inputs = stack_batch(inputs, self.pad_size_divisor,
|
||||
self.pad_value)
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
# The branch if use `default_collate` as the collate_fn in the
|
||||
# dataloader.
|
||||
|
||||
# ------ To RGB ------
|
||||
if self.to_rgb and batch_inputs.size(1) == 3:
|
||||
batch_inputs = batch_inputs[:, [2, 1, 0], ...]
|
||||
# ------ To RGB ------
|
||||
if self.to_rgb and inputs.size(1) == 3:
|
||||
inputs = inputs.flip(1)
|
||||
|
||||
# -- Normalization ---
|
||||
if self._enable_normalize:
|
||||
batch_inputs = (batch_inputs - self.mean) / self.std
|
||||
# -- Normalization ---
|
||||
inputs = inputs.float()
|
||||
if self._enable_normalize:
|
||||
inputs = (inputs - self.mean) / self.std
|
||||
|
||||
# ------ Padding -----
|
||||
if self.pad_size_divisor > 1:
|
||||
h, w = inputs.shape[-2:]
|
||||
|
||||
target_h = math.ceil(
|
||||
h / self.pad_size_divisor) * self.pad_size_divisor
|
||||
target_w = math.ceil(
|
||||
w / self.pad_size_divisor) * self.pad_size_divisor
|
||||
pad_h = target_h - h
|
||||
pad_w = target_w - w
|
||||
inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant',
|
||||
self.pad_value)
|
||||
else:
|
||||
batch_inputs = batch_inputs.to(torch.float32)
|
||||
# The branch if use `pseudo_collate` as the collate_fn in the
|
||||
# dataloader.
|
||||
|
||||
processed_inputs = []
|
||||
for input_ in inputs:
|
||||
# ------ To RGB ------
|
||||
if self.to_rgb and input_.size(0) == 3:
|
||||
input_ = input_.flip(0)
|
||||
|
||||
# -- Normalization ---
|
||||
input_ = input_.float()
|
||||
if self._enable_normalize:
|
||||
input_ = (input_ - self.mean) / self.std
|
||||
|
||||
processed_inputs.append(input_)
|
||||
# Combine padding and stack
|
||||
inputs = stack_batch(processed_inputs, self.pad_size_divisor,
|
||||
self.pad_value)
|
||||
|
||||
# ----- Batch Aug ----
|
||||
if training and self.batch_augments is not None:
|
||||
batch_inputs, batch_data_samples = self.batch_augments(
|
||||
batch_inputs, batch_data_samples)
|
||||
data_samples = data['data_samples']
|
||||
inputs, data_samples = self.batch_augments(inputs, data_samples)
|
||||
data['data_samples'] = data_samples
|
||||
|
||||
return batch_inputs, batch_data_samples
|
||||
data['inputs'] = inputs
|
||||
|
||||
return data
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Sequence, Union
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import BaseDataElement, LabelData
|
||||
from mmengine.structures import BaseDataElement, LabelData
|
||||
from mmengine.utils import is_str
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils import collect_env as collect_base_env
|
||||
from mmengine.utils import get_git_hash
|
||||
from mmengine.utils.dl_utils import collect_env as collect_base_env
|
||||
|
||||
import mmcls
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@ from typing import Optional, Tuple
|
|||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine import Visualizer
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmcls.registry import VISUALIZERS
|
||||
from mmcls.structures import ClsDataSample
|
||||
|
|
|
@ -6,7 +6,7 @@ import unittest
|
|||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
from mmengine.structures import LabelData
|
||||
from PIL import Image
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
|
@ -36,10 +36,10 @@ class TestPackClsInputs(unittest.TestCase):
|
|||
results = transform(copy.deepcopy(data))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertIsInstance(results['inputs'], torch.Tensor)
|
||||
self.assertIn('data_sample', results)
|
||||
self.assertIsInstance(results['data_sample'], ClsDataSample)
|
||||
self.assertIn('flip', results['data_sample'].metainfo_keys())
|
||||
self.assertIsInstance(results['data_sample'].gt_label, LabelData)
|
||||
self.assertIn('data_samples', results)
|
||||
self.assertIsInstance(results['data_samples'], ClsDataSample)
|
||||
self.assertIn('flip', results['data_samples'].metainfo_keys())
|
||||
self.assertIsInstance(results['data_samples'].gt_label, LabelData)
|
||||
|
||||
# Test grayscale image
|
||||
data['img'] = data['img'].mean(-1)
|
||||
|
@ -53,7 +53,7 @@ class TestPackClsInputs(unittest.TestCase):
|
|||
del data['gt_label']
|
||||
with self.assertWarnsRegex(Warning, 'Cannot get "img"'):
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertNotIn('gt_label', results['data_sample'])
|
||||
self.assertNotIn('gt_label', results['data_samples'])
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(type='PackClsInputs', meta_keys=['flip', 'img_shape'])
|
||||
|
|
|
@ -23,10 +23,10 @@ class TestVisualizationHook(TestCase):
|
|||
|
||||
data_sample = ClsDataSample().set_gt_label(1).set_pred_label(2)
|
||||
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
|
||||
self.data_batch = [{
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||
'data_sample': data_sample
|
||||
}] * 10
|
||||
self.data_batch = {
|
||||
'inputs': torch.randint(0, 256, (10, 3, 224, 224)),
|
||||
'data_sample': [data_sample] * 10
|
||||
}
|
||||
|
||||
self.outputs = [data_sample] * 10
|
||||
|
||||
|
|
|
@ -137,11 +137,6 @@ class TestMultiLabel(TestCase):
|
|||
MultiLabelMetric.calculate(y_pred_binary, 5)
|
||||
|
||||
def test_evaluate(self):
|
||||
fake_data_batch = [{
|
||||
'inputs': None,
|
||||
'data_sample': ClsDataSample()
|
||||
} for _ in range(4)]
|
||||
|
||||
y_true = [[0], [1, 3], [0, 1, 2], [3]]
|
||||
y_true_binary = torch.tensor([
|
||||
[1, 0, 0, 0],
|
||||
|
@ -163,7 +158,7 @@ class TestMultiLabel(TestCase):
|
|||
|
||||
# Test with default argument
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric'))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
thr05_y_pred = np.array([
|
||||
|
@ -184,7 +179,7 @@ class TestMultiLabel(TestCase):
|
|||
|
||||
# Test with topk argument
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric', topk=1))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
top1_y_pred = np.array([
|
||||
|
@ -205,7 +200,7 @@ class TestMultiLabel(TestCase):
|
|||
|
||||
# Test with both argument
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric', thr=0.25, topk=1))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
# Expected values come from sklearn
|
||||
|
@ -228,7 +223,7 @@ class TestMultiLabel(TestCase):
|
|||
|
||||
# Test with average micro
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric', average='micro'))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
# Expected values come from sklearn
|
||||
|
@ -247,7 +242,7 @@ class TestMultiLabel(TestCase):
|
|||
|
||||
# Test with average None
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric', average=None))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
# Expected values come from sklearn
|
||||
|
@ -271,7 +266,7 @@ class TestMultiLabel(TestCase):
|
|||
]
|
||||
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric', items=['support']))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertEqual(res['multi-label/support'], 7)
|
||||
|
@ -308,11 +303,6 @@ class TestAveragePrecision(TestCase):
|
|||
[1, 0, 0, 0],
|
||||
])
|
||||
|
||||
fake_data_batch = [{
|
||||
'inputs': None,
|
||||
'data_sample': ClsDataSample()
|
||||
} for _ in range(4)]
|
||||
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
|
||||
for i, j in zip(y_pred, y_true)
|
||||
|
@ -320,14 +310,14 @@ class TestAveragePrecision(TestCase):
|
|||
|
||||
# Test with default macro avergae
|
||||
evaluator = Evaluator(dict(type='AveragePrecision'))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(5)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertAlmostEqual(res['multi-label/mAP'], 70.83333, places=4)
|
||||
|
||||
# Test with average mode None
|
||||
evaluator = Evaluator(dict(type='AveragePrecision', average=None))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(5)
|
||||
self.assertIsInstance(res, dict)
|
||||
aps = res['multi-label/AP_classwise']
|
||||
|
@ -342,7 +332,7 @@ class TestAveragePrecision(TestCase):
|
|||
for i, j in zip(y_pred, [[0, 1], [1], [2], [0]])
|
||||
]
|
||||
evaluator = Evaluator(dict(type='AveragePrecision'))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(5)
|
||||
self.assertAlmostEqual(res['multi-label/mAP'], 70.83333, places=4)
|
||||
|
||||
|
|
|
@ -14,31 +14,28 @@ class TestAccuracy(TestCase):
|
|||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
data_batch = [{
|
||||
'data_sample': ClsDataSample().set_gt_label(i).to_dict()
|
||||
} for i in [0, 0, 1, 2, 1, 0]]
|
||||
pred = [
|
||||
ClsDataSample().set_pred_score(i).set_pred_label(j).to_dict()
|
||||
for i, j in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
torch.tensor([0.4, 0.5, 0.1]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
], [0, 0, 1, 2, 2, 2])
|
||||
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
|
||||
k).to_dict() for i, j, k in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
torch.tensor([0.4, 0.5, 0.1]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
], [0, 0, 1, 2, 2, 2], [0, 0, 1, 2, 1, 0])
|
||||
]
|
||||
|
||||
# Test with score (use score instead of label if score exists)
|
||||
metric = METRICS.build(dict(type='Accuracy', thrs=0.6))
|
||||
metric.process(data_batch, pred)
|
||||
metric.process(None, pred)
|
||||
acc = metric.evaluate(6)
|
||||
self.assertIsInstance(acc, dict)
|
||||
self.assertAlmostEqual(acc['accuracy/top1'], 2 / 6 * 100, places=4)
|
||||
|
||||
# Test with multiple thrs
|
||||
metric = METRICS.build(dict(type='Accuracy', thrs=(0., 0.6, None)))
|
||||
metric.process(data_batch, pred)
|
||||
metric.process(None, pred)
|
||||
acc = metric.evaluate(6)
|
||||
self.assertSetEqual(
|
||||
set(acc.keys()), {
|
||||
|
@ -49,14 +46,14 @@ class TestAccuracy(TestCase):
|
|||
# Test with invalid topk
|
||||
with self.assertRaisesRegex(ValueError, 'check the `val_evaluator`'):
|
||||
metric = METRICS.build(dict(type='Accuracy', topk=(1, 5)))
|
||||
metric.process(data_batch, pred)
|
||||
metric.process(None, pred)
|
||||
metric.evaluate(6)
|
||||
|
||||
# Test with label
|
||||
for sample in pred:
|
||||
del sample['pred_label']['score']
|
||||
metric = METRICS.build(dict(type='Accuracy', thrs=(0., 0.6, None)))
|
||||
metric.process(data_batch, pred)
|
||||
metric.process(None, pred)
|
||||
acc = metric.evaluate(6)
|
||||
self.assertIsInstance(acc, dict)
|
||||
self.assertAlmostEqual(acc['accuracy/top1'], 4 / 6 * 100, places=4)
|
||||
|
@ -124,19 +121,16 @@ class TestSingleLabel(TestCase):
|
|||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
data_batch = [{
|
||||
'data_sample': ClsDataSample().set_gt_label(i).to_dict()
|
||||
} for i in [0, 0, 1, 2, 1, 0]]
|
||||
pred = [
|
||||
ClsDataSample().set_pred_score(i).set_pred_label(j).to_dict()
|
||||
for i, j in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
torch.tensor([0.4, 0.5, 0.1]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
], [0, 0, 1, 2, 2, 2])
|
||||
ClsDataSample().set_pred_score(i).set_pred_label(j).set_gt_label(
|
||||
k).to_dict() for i, j, k in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
torch.tensor([0.4, 0.5, 0.1]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
], [0, 0, 1, 2, 2, 2], [0, 0, 1, 2, 1, 0])
|
||||
]
|
||||
|
||||
# Test with score (use score instead of label if score exists)
|
||||
|
@ -145,7 +139,7 @@ class TestSingleLabel(TestCase):
|
|||
type='SingleLabelMetric',
|
||||
thrs=0.6,
|
||||
items=('precision', 'recall', 'f1-score', 'support')))
|
||||
metric.process(data_batch, pred)
|
||||
metric.process(None, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertAlmostEqual(
|
||||
|
@ -160,7 +154,7 @@ class TestSingleLabel(TestCase):
|
|||
# Test with multiple thrs
|
||||
metric = METRICS.build(
|
||||
dict(type='SingleLabelMetric', thrs=(0., 0.6, None)))
|
||||
metric.process(data_batch, pred)
|
||||
metric.process(None, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertSetEqual(
|
||||
set(res.keys()), {
|
||||
|
@ -180,7 +174,7 @@ class TestSingleLabel(TestCase):
|
|||
type='SingleLabelMetric',
|
||||
average='micro',
|
||||
items=('precision', 'recall', 'f1-score', 'support')))
|
||||
metric.process(data_batch, pred)
|
||||
metric.process(None, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertAlmostEqual(
|
||||
|
@ -197,7 +191,7 @@ class TestSingleLabel(TestCase):
|
|||
type='SingleLabelMetric',
|
||||
average=None,
|
||||
items=('precision', 'recall', 'f1-score', 'support')))
|
||||
metric.process(data_batch, pred)
|
||||
metric.process(None, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
precision = res['single-label/precision_classwise']
|
||||
|
@ -219,7 +213,7 @@ class TestSingleLabel(TestCase):
|
|||
for sample in pred_no_score:
|
||||
del sample['pred_label']['score']
|
||||
metric = METRICS.build(dict(type='SingleLabelMetric', thrs=(0., 0.6)))
|
||||
metric.process(data_batch, pred_no_score)
|
||||
metric.process(None, pred_no_score)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
# Expected values come from sklearn
|
||||
|
@ -231,16 +225,16 @@ class TestSingleLabel(TestCase):
|
|||
for sample in pred_no_num_classes:
|
||||
del sample['pred_label']['num_classes']
|
||||
with self.assertRaisesRegex(ValueError, 'neither `score` nor'):
|
||||
metric.process(data_batch, pred_no_num_classes)
|
||||
metric.process(None, pred_no_num_classes)
|
||||
|
||||
# Test with empty items
|
||||
metric = METRICS.build(dict(type='SingleLabelMetric', items=tuple()))
|
||||
metric.process(data_batch, pred)
|
||||
metric.process(None, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertEqual(len(res), 0)
|
||||
|
||||
metric.process(data_batch, pred_no_score)
|
||||
metric.process(None, pred_no_score)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertEqual(len(res), 0)
|
||||
|
|
|
@ -5,7 +5,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import CSPDarkNet, CSPResNet, CSPResNeXt
|
||||
from mmcls.models.backbones.cspnet import (CSPNet, DarknetBottleneck,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import Res2Net
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import ResNet, ResNetV1c, ResNetV1d
|
||||
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import ResNet_CIFAR
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
from mmengine.runner import load_checkpoint, save_checkpoint
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import SwinTransformer
|
||||
from mmcls.models.backbones.swin_transformer import SwinBlock
|
||||
|
|
|
@ -5,7 +5,7 @@ from itertools import chain
|
|||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
from torch import nn
|
||||
|
||||
from mmcls.models.backbones import VAN
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import VGG
|
||||
|
||||
|
|
|
@ -173,10 +173,10 @@ class TestImageClassifier(TestCase):
|
|||
}
|
||||
model: ImageClassifier = MODELS.build(cfg)
|
||||
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||
'data_sample': ClsDataSample().set_gt_label(1)
|
||||
}]
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
optim_wrapper = MagicMock()
|
||||
log_vars = model.train_step(data, optim_wrapper)
|
||||
|
@ -190,10 +190,10 @@ class TestImageClassifier(TestCase):
|
|||
}
|
||||
model: ImageClassifier = MODELS.build(cfg)
|
||||
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||
'data_sample': ClsDataSample().set_gt_label(1)
|
||||
}]
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
predictions = model.val_step(data)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
|
@ -205,10 +205,10 @@ class TestImageClassifier(TestCase):
|
|||
}
|
||||
model: ImageClassifier = MODELS.build(cfg)
|
||||
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||
'data_sample': ClsDataSample().set_gt_label(1)
|
||||
}]
|
||||
data = {
|
||||
'inputs': torch.randint(0, 256, (1, 3, 224, 224)),
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
}
|
||||
|
||||
predictions = model.test_step(data)
|
||||
self.assertEqual(predictions[0].pred_label.score.shape, (10, ))
|
||||
|
|
|
@ -17,11 +17,13 @@ class TestClsDataPreprocessor(TestCase):
|
|||
cfg = dict(type='ClsDataPreprocessor')
|
||||
processor: ClsDataPreprocessor = MODELS.build(cfg)
|
||||
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||
'data_sample': ClsDataSample().set_gt_label(1)
|
||||
}]
|
||||
inputs, data_samples = processor(data)
|
||||
data = {
|
||||
'inputs': [torch.randint(0, 256, (3, 224, 224))],
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
}
|
||||
processed_data = processor(data)
|
||||
inputs = processed_data['inputs']
|
||||
data_samples = processed_data['data_samples']
|
||||
self.assertEqual(inputs.shape, (1, 3, 224, 224))
|
||||
self.assertEqual(len(data_samples), 1)
|
||||
self.assertTrue(
|
||||
|
@ -31,22 +33,31 @@ class TestClsDataPreprocessor(TestCase):
|
|||
cfg = dict(type='ClsDataPreprocessor', pad_size_divisor=16)
|
||||
processor: ClsDataPreprocessor = MODELS.build(cfg)
|
||||
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 255, 255))
|
||||
}, {
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224))
|
||||
}]
|
||||
inputs, _ = processor(data)
|
||||
data = {
|
||||
'inputs': [
|
||||
torch.randint(0, 256, (3, 255, 255)),
|
||||
torch.randint(0, 256, (3, 224, 224))
|
||||
]
|
||||
}
|
||||
inputs = processor(data)['inputs']
|
||||
self.assertEqual(inputs.shape, (2, 3, 256, 256))
|
||||
|
||||
data = {'inputs': torch.randint(0, 256, (2, 3, 255, 255))}
|
||||
inputs = processor(data)['inputs']
|
||||
self.assertEqual(inputs.shape, (2, 3, 256, 256))
|
||||
|
||||
def test_to_rgb(self):
|
||||
cfg = dict(type='ClsDataPreprocessor', to_rgb=True)
|
||||
processor: ClsDataPreprocessor = MODELS.build(cfg)
|
||||
|
||||
data = [{'inputs': torch.randint(0, 256, (3, 224, 224))}]
|
||||
inputs, _ = processor(data)
|
||||
torch.testing.assert_allclose(
|
||||
data[0]['inputs'].flip(dims=(0, )).to(torch.float32), inputs[0])
|
||||
data = {'inputs': [torch.randint(0, 256, (3, 224, 224))]}
|
||||
inputs = processor(data)['inputs']
|
||||
torch.testing.assert_allclose(data['inputs'][0].flip(0).float(),
|
||||
inputs[0])
|
||||
|
||||
data = {'inputs': torch.randint(0, 256, (1, 3, 224, 224))}
|
||||
inputs = processor(data)['inputs']
|
||||
torch.testing.assert_allclose(data['inputs'].flip(1).float(), inputs)
|
||||
|
||||
def test_normalization(self):
|
||||
cfg = dict(
|
||||
|
@ -55,11 +66,17 @@ class TestClsDataPreprocessor(TestCase):
|
|||
std=[127.5, 127.5, 127.5])
|
||||
processor: ClsDataPreprocessor = MODELS.build(cfg)
|
||||
|
||||
data = [{'inputs': torch.randint(0, 256, (3, 224, 224))}]
|
||||
inputs, data_samples = processor(data)
|
||||
data = {'inputs': [torch.randint(0, 256, (3, 224, 224))]}
|
||||
processed_data = processor(data)
|
||||
inputs = processed_data['inputs']
|
||||
self.assertTrue((inputs >= -1).all())
|
||||
self.assertTrue((inputs <= 1).all())
|
||||
self.assertNotIn('data_samples', processed_data)
|
||||
|
||||
data = {'inputs': torch.randint(0, 256, (1, 3, 224, 224))}
|
||||
inputs = processor(data)['inputs']
|
||||
self.assertTrue((inputs >= -1).all())
|
||||
self.assertTrue((inputs <= 1).all())
|
||||
self.assertIsNone(data_samples)
|
||||
|
||||
def test_batch_augmentation(self):
|
||||
cfg = dict(
|
||||
|
@ -70,17 +87,18 @@ class TestClsDataPreprocessor(TestCase):
|
|||
])
|
||||
processor: ClsDataPreprocessor = MODELS.build(cfg)
|
||||
self.assertIsInstance(processor.batch_augments, RandomBatchAugment)
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||
'data_sample': ClsDataSample().set_gt_label(1)
|
||||
}]
|
||||
_, data_samples = processor(data, training=True)
|
||||
data = {
|
||||
'inputs': [torch.randint(0, 256, (3, 224, 224))],
|
||||
'data_samples': [ClsDataSample().set_gt_label(1)]
|
||||
}
|
||||
processed_data = processor(data, training=True)
|
||||
self.assertIn('inputs', processed_data)
|
||||
self.assertIn('data_samples', processed_data)
|
||||
|
||||
cfg['batch_augments'] = None
|
||||
processor: ClsDataPreprocessor = MODELS.build(cfg)
|
||||
self.assertIsNone(processor.batch_augments)
|
||||
data = [{
|
||||
'inputs': torch.randint(0, 256, (3, 224, 224)),
|
||||
}]
|
||||
_, data_samples = processor(data, training=True)
|
||||
self.assertIsNone(data_samples)
|
||||
data = {'inputs': [torch.randint(0, 256, (3, 224, 224))]}
|
||||
processed_data = processor(data, training=True)
|
||||
self.assertIn('inputs', processed_data)
|
||||
self.assertNotIn('data_samples', processed_data)
|
||||
|
|
|
@ -3,7 +3,7 @@ from unittest import TestCase
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.structures import ClsDataSample
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
|
||||
from mmcv import Config, DictAction
|
||||
from mmengine import Config, DictAction
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
|
|
@ -4,7 +4,8 @@ import fcntl
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from mmcv import Config, DictAction, track_parallel_progress, track_progress
|
||||
from mmengine import (Config, DictAction, track_parallel_progress,
|
||||
track_progress)
|
||||
|
||||
from mmcls.datasets import PIPELINES, build_dataset
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import os.path as osp
|
|||
|
||||
import mmengine
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmcls.utils import register_all_modules
|
||||
|
@ -122,7 +123,7 @@ def main():
|
|||
|
||||
if args.out:
|
||||
|
||||
class SaveMetricHook(mmengine.Hook):
|
||||
class SaveMetricHook(Hook):
|
||||
|
||||
def after_test_epoch(self, _, metrics=None):
|
||||
if metrics is not None:
|
||||
|
|
|
@ -9,8 +9,11 @@ from unittest.mock import MagicMock
|
|||
import matplotlib.pyplot as plt
|
||||
import rich
|
||||
import torch.nn as nn
|
||||
from mmengine import Config, DictAction, Hook, Runner, Visualizer
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.visualization import Visualizer
|
||||
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn
|
||||
|
||||
from mmcls.utils import register_all_modules
|
||||
|
@ -24,7 +27,7 @@ class SimpleModel(BaseModel):
|
|||
self.data_preprocessor = nn.Identity()
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, batch_inputs, data_samples, mode='tensor'):
|
||||
def forward(self, inputs, data_samples, mode='tensor'):
|
||||
pass
|
||||
|
||||
def train_step(self, data, optim_wrapper):
|
||||
|
|
Loading…
Reference in New Issue