From 3d0be5c86820faf0741912ab659ebe0d1e83cbbe Mon Sep 17 00:00:00 2001 From: hha <1286304229@qq.com> Date: Sun, 18 Sep 2022 13:17:03 +0800 Subject: [PATCH] Add UT and other code --- mmyolo/__init__.py | 39 ++++++++ mmyolo/models/__init__.py | 9 ++ .../__init__.py | 0 .../data_preprocessor.py | 0 mmyolo/models/layers/ema.py | 91 ++++++++++++++++++ mmyolo/{ => models}/losses/__init__.py | 0 mmyolo/{ => models}/losses/iou_loss.py | 0 mmyolo/models/task_modules/__init__.py | 4 + mmyolo/models/task_modules/coders/__init__.py | 6 ++ .../coders/distance_point_bbox_coder.py | 53 +++++++++++ .../task_modules/coders/yolov5_bbox_coder.py | 48 ++++++++++ .../task_modules/coders/yolox_bbox_coder.py | 38 ++++++++ mmyolo/models/utils/__init__.py | 4 + mmyolo/models/utils/misc.py | 14 +++ mmyolo/registry.py | 73 ++++++++++++++ mmyolo/version.py | 22 +++++ .../test_data_preprocessor/__init__.py | 0 .../test_data_preprocessor.py | 71 ++++++++++++++ tests/test_models/test_layers/__init__.py | 0 tests/test_models/test_layers/test_ema.py | 94 +++++++++++++++++++ .../test_layers/test_yolo_bricks.py | 23 +++++ tests/test_models/test_necks/__init__.py | 0 .../test_necks/test_yolov5_pafpn.py | 28 ++++++ .../test_necks/test_yolov6_pafpn.py | 29 ++++++ .../test_necks/test_yolox_pafpn.py | 28 ++++++ 25 files changed, 674 insertions(+) create mode 100644 mmyolo/__init__.py create mode 100644 mmyolo/models/__init__.py rename mmyolo/models/{data_preprocessor => data_preprocessors}/__init__.py (100%) rename mmyolo/models/{data_preprocessor => data_preprocessors}/data_preprocessor.py (100%) create mode 100644 mmyolo/models/layers/ema.py rename mmyolo/{ => models}/losses/__init__.py (100%) rename mmyolo/{ => models}/losses/iou_loss.py (100%) create mode 100644 mmyolo/models/task_modules/__init__.py create mode 100644 mmyolo/models/task_modules/coders/__init__.py create mode 100644 mmyolo/models/task_modules/coders/distance_point_bbox_coder.py create mode 100644 mmyolo/models/task_modules/coders/yolov5_bbox_coder.py create mode 100644 mmyolo/models/task_modules/coders/yolox_bbox_coder.py create mode 100644 mmyolo/models/utils/__init__.py create mode 100644 mmyolo/models/utils/misc.py create mode 100644 mmyolo/registry.py create mode 100644 mmyolo/version.py create mode 100644 tests/test_models/test_data_preprocessor/__init__.py create mode 100644 tests/test_models/test_data_preprocessor/test_data_preprocessor.py create mode 100644 tests/test_models/test_layers/__init__.py create mode 100644 tests/test_models/test_layers/test_ema.py create mode 100644 tests/test_models/test_layers/test_yolo_bricks.py create mode 100644 tests/test_models/test_necks/__init__.py create mode 100644 tests/test_models/test_necks/test_yolov5_pafpn.py create mode 100644 tests/test_models/test_necks/test_yolov6_pafpn.py create mode 100644 tests/test_models/test_necks/test_yolox_pafpn.py diff --git a/mmyolo/__init__.py b/mmyolo/__init__.py new file mode 100644 index 00000000..fe25528b --- /dev/null +++ b/mmyolo/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import mmengine +from mmengine.utils import digit_version + +import mmdet +from .version import __version__, version_info + +mmcv_minimum_version = '2.0.0rc0' +mmcv_maximum_version = '2.1.0' +mmcv_version = digit_version(mmcv.__version__) + +mmengine_minimum_version = '0.0.0' +mmengine_maximum_version = '0.2.0' +mmengine_version = digit_version(mmengine.__version__) + +mmdet_minimum_version = '3.0.0rc0' +mmdet_maximum_version = '4.0.0' +mmdet_version = digit_version(mmdet.__version__) + + +assert (mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version < digit_version(mmcv_maximum_version)), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.' + +assert (mmengine_version >= digit_version(mmengine_minimum_version) + and mmengine_version < digit_version(mmengine_maximum_version)), \ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_minimum_version}, ' \ + f'<{mmengine_maximum_version}.' + +assert (mmdet_version >= digit_version(mmdet_minimum_version) + and mmdet_version < digit_version(mmdet_maximum_version)), \ + f'MMDetection=={mmdet.__version__} is used but incompatible. ' \ + f'Please install mmdet>={mmdet_minimum_version}, ' \ + f'<{mmdet_maximum_version}.' + +__all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/mmyolo/models/__init__.py b/mmyolo/models/__init__.py new file mode 100644 index 00000000..b290017a --- /dev/null +++ b/mmyolo/models/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * # noqa: F401,F403 +from .data_preprocessors import * # noqa: F401,F403 +from .dense_heads import * # noqa: F401,F403 +from .detectors import * # noqa: F401,F403 +from .layers import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .task_modules import * # noqa: F401,F403 diff --git a/mmyolo/models/data_preprocessor/__init__.py b/mmyolo/models/data_preprocessors/__init__.py similarity index 100% rename from mmyolo/models/data_preprocessor/__init__.py rename to mmyolo/models/data_preprocessors/__init__.py diff --git a/mmyolo/models/data_preprocessor/data_preprocessor.py b/mmyolo/models/data_preprocessors/data_preprocessor.py similarity index 100% rename from mmyolo/models/data_preprocessor/data_preprocessor.py rename to mmyolo/models/data_preprocessors/data_preprocessor.py diff --git a/mmyolo/models/layers/ema.py b/mmyolo/models/layers/ema.py new file mode 100644 index 00000000..2f14dbc5 --- /dev/null +++ b/mmyolo/models/layers/ema.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from mmdet.models.layers import ExpMomentumEMA as MMDET_ExpMomentumEMA +from mmyolo.registry import MODELS + + +@MODELS.register_module() +class ExpMomentumEMA(MMDET_ExpMomentumEMA): + """Exponential moving average (EMA) with exponential momentum strategy, + which is used in YOLO. + + Args: + model (nn.Module): The model to be averaged. + momentum (float): The momentum used for updating ema parameter. + Ema's parameters are updated with the formula: + `averaged_param = (1-momentum) * averaged_param + momentum * + source_param`. Defaults to 0.0002. + gamma (int): Use a larger momentum early in training and gradually + annealing to a smaller value to update the ema model smoothly. The + momentum is calculated as + `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`. + Defaults to 2000. + interval (int): Interval between two updates. Defaults to 1. + device (torch.device, optional): If provided, the averaged model will + be stored on the :attr:`device`. Defaults to None. + update_buffers (bool): if True, it will compute running averages for + both the parameters and the buffers of the model. Defaults to + False. + """ + + def __init__(self, + model: nn.Module, + momentum: float = 0.0002, + gamma: int = 2000, + interval=1, + device: Optional[torch.device] = None, + update_buffers: bool = False) -> None: + super().__init__( + model=model, + momentum=momentum, + interval=interval, + device=device, + update_buffers=update_buffers) + assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' + self.gamma = gamma + + # Note: There is no need to re-fetch every update, + # as most models do not change their structure + # during the training process. + self.src_parameters = ( + model.state_dict() + if self.update_buffers else dict(model.named_parameters())) + if not self.update_buffers: + self.src_buffers = model.buffers() + + def avg_func(self, averaged_param: Tensor, source_param: Tensor, + steps: int): + """Compute the moving average of the parameters using the exponential + momentum strategy. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + """ + momentum = (1 - self.momentum) * math.exp( + -float(1 + steps) / self.gamma) + self.momentum + averaged_param.lerp_(source_param, momentum) + + def update_parameters(self, model: nn.Module): + if self.steps == 0: + for k, p_avg in self.avg_parameters.items(): + p_avg.data.copy_(self.src_parameters[k].data) + elif self.steps % self.interval == 0: + for k, p_avg in self.avg_parameters.items(): + if p_avg.dtype.is_floating_point: + self.avg_func(p_avg.data, self.src_parameters[k].data, + self.steps) + if not self.update_buffers: + # If not update the buffers, + # keep the buffers in sync with the source model. + for b_avg, b_src in zip(self.module.buffers(), self.src_buffers): + b_avg.data.copy_(b_src.data) + self.steps += 1 diff --git a/mmyolo/losses/__init__.py b/mmyolo/models/losses/__init__.py similarity index 100% rename from mmyolo/losses/__init__.py rename to mmyolo/models/losses/__init__.py diff --git a/mmyolo/losses/iou_loss.py b/mmyolo/models/losses/iou_loss.py similarity index 100% rename from mmyolo/losses/iou_loss.py rename to mmyolo/models/losses/iou_loss.py diff --git a/mmyolo/models/task_modules/__init__.py b/mmyolo/models/task_modules/__init__.py new file mode 100644 index 00000000..20cc2ad1 --- /dev/null +++ b/mmyolo/models/task_modules/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .coders import YOLOv5BBoxCoder, YOLOXBBoxCoder + +__all__ = ['YOLOv5BBoxCoder', 'YOLOXBBoxCoder'] diff --git a/mmyolo/models/task_modules/coders/__init__.py b/mmyolo/models/task_modules/coders/__init__.py new file mode 100644 index 00000000..6346387c --- /dev/null +++ b/mmyolo/models/task_modules/coders/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .distance_point_bbox_coder import DistancePointBBoxCoder +from .yolov5_bbox_coder import YOLOv5BBoxCoder +from .yolox_bbox_coder import YOLOXBBoxCoder + +__all__ = ['YOLOv5BBoxCoder', 'YOLOXBBoxCoder', 'DistancePointBBoxCoder'] diff --git a/mmyolo/models/task_modules/coders/distance_point_bbox_coder.py b/mmyolo/models/task_modules/coders/distance_point_bbox_coder.py new file mode 100644 index 00000000..44cae094 --- /dev/null +++ b/mmyolo/models/task_modules/coders/distance_point_bbox_coder.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Union + +import torch + +from mmdet.models.task_modules.coders import \ + DistancePointBBoxCoder as MMDET_DistancePointBBoxCoder +from mmdet.structures.bbox import distance2bbox +from mmyolo.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class DistancePointBBoxCoder(MMDET_DistancePointBBoxCoder): + """Distance Point BBox coder. + + This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left, + right) and decode it back to the original. + """ + + def decode( + self, + points: torch.Tensor, + pred_bboxes: torch.Tensor, + stride: torch.Tensor, + max_shape: Optional[Union[Sequence[int], torch.Tensor, + Sequence[Sequence[int]]]] = None + ) -> torch.Tensor: + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (B, N, 2) or (N, 2). + pred_bboxes (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). Shape (B, N, 4) + or (N, 4) + stride (Tensor): Featmap stride. + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If priors shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]], + and the length of max_shape should also be B. + Default None. + Returns: + Tensor: Boxes with shape (N, 4) or (B, N, 4) + """ + assert points.size(-2) == pred_bboxes.size(-2) + assert points.size(-1) == 2 + assert pred_bboxes.size(-1) == 4 + if self.clip_border is False: + max_shape = None + + pred_bboxes = pred_bboxes * stride[None, :, None] + + return distance2bbox(points, pred_bboxes, max_shape) diff --git a/mmyolo/models/task_modules/coders/yolov5_bbox_coder.py b/mmyolo/models/task_modules/coders/yolov5_bbox_coder.py new file mode 100644 index 00000000..0046df9c --- /dev/null +++ b/mmyolo/models/task_modules/coders/yolov5_bbox_coder.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch + +from mmdet.models.task_modules.coders.base_bbox_coder import BaseBBoxCoder +from mmyolo.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class YOLOv5BBoxCoder(BaseBBoxCoder): + + def encode(self, **kwargs): + pass + + def decode(self, priors: torch.Tensor, pred_bboxes: torch.Tensor, + stride: Union[torch.Tensor, int]) -> torch.Tensor: + """Apply transformation `pred_bboxes` to `decoded_bboxes`. + + Args: + priors (torch.Tensor): Basic boxes or points, e.g. anchors. + pred_bboxes (torch.Tensor): Encoded boxes with shape + stride (torch.Tensor | int): Strides of bboxes. + + Returns: + torch.Tensor: Decoded boxes. + """ + assert pred_bboxes.size(-1) == priors.size(-1) == 4 + + pred_bboxes = pred_bboxes.sigmoid() + + x_center = (priors[..., 0] + priors[..., 2]) * 0.5 + y_center = (priors[..., 1] + priors[..., 3]) * 0.5 + w = priors[..., 2] - priors[..., 0] + h = priors[..., 3] - priors[..., 1] + + # The anchor of mmdet has been offset by 0.5 + x_center_pred = (pred_bboxes[..., 0] - 0.5) * 2 * stride + x_center + y_center_pred = (pred_bboxes[..., 1] - 0.5) * 2 * stride + y_center + w_pred = (pred_bboxes[..., 2] * 2)**2 * w + h_pred = (pred_bboxes[..., 3] * 2)**2 * h + + decoded_bboxes = torch.stack( + (x_center_pred - w_pred / 2, y_center_pred - h_pred / 2, + x_center_pred + w_pred / 2, y_center_pred + h_pred / 2), + dim=-1) + + return decoded_bboxes diff --git a/mmyolo/models/task_modules/coders/yolox_bbox_coder.py b/mmyolo/models/task_modules/coders/yolox_bbox_coder.py new file mode 100644 index 00000000..620ef4d8 --- /dev/null +++ b/mmyolo/models/task_modules/coders/yolox_bbox_coder.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch + +from mmdet.models.task_modules.coders.base_bbox_coder import BaseBBoxCoder +from mmyolo.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class YOLOXBBoxCoder(BaseBBoxCoder): + + def encode(self, **kwargs): + pass + + def decode(self, priors: torch.Tensor, pred_bboxes: torch.Tensor, + stride: Union[torch.Tensor, int]) -> torch.Tensor: + """Apply transformation `pred_bboxes` to `decoded_bboxes`. + + Args: + priors (torch.Tensor): Basic boxes or points, e.g. anchors. + pred_bboxes (torch.Tensor): Encoded boxes with shape + stride (torch.Tensor | int): Strides of bboxes. + + Returns: + torch.Tensor: Decoded boxes. + """ + stride = stride[None, :, None] + xys = (pred_bboxes[..., :2] * stride) + priors + whs = pred_bboxes[..., 2:].exp() * stride + + tl_x = (xys[..., 0] - whs[..., 0] / 2) + tl_y = (xys[..., 1] - whs[..., 1] / 2) + br_x = (xys[..., 0] + whs[..., 0] / 2) + br_y = (xys[..., 1] + whs[..., 1] / 2) + + decoded_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1) + return decoded_bboxes diff --git a/mmyolo/models/utils/__init__.py b/mmyolo/models/utils/__init__.py new file mode 100644 index 00000000..89118283 --- /dev/null +++ b/mmyolo/models/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .misc import make_divisible, make_round + +__all__ = ['make_divisible', 'make_round'] diff --git a/mmyolo/models/utils/misc.py b/mmyolo/models/utils/misc.py new file mode 100644 index 00000000..6844ad37 --- /dev/null +++ b/mmyolo/models/utils/misc.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + + +def make_divisible(x: float, + widen_factor: float = 1.0, + divisor: int = 8) -> int: + """Make sure that x*widen_factor is divisible by divisor.""" + return math.ceil(x * widen_factor / divisor) * divisor + + +def make_round(x: float, deepen_factor: float = 1.0) -> int: + """Make sure that x*deepen_factor becomes an integer not less than 1.""" + return max(round(x * deepen_factor), 1) if x > 1 else x diff --git a/mmyolo/registry.py b/mmyolo/registry.py new file mode 100644 index 00000000..63967d86 --- /dev/null +++ b/mmyolo/registry.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMYOLO provides 17 registry nodes to support using modules across projects. +Each node is a child of the root registry in MMEngine. + +More details can be found at +https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS) +# manage runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS) +# manage all kinds of loops like `EpochBasedTrainLoop` +LOOPS = Registry('loop', parent=MMENGINE_LOOPS) +# manage all kinds of hooks like `CheckpointHook` +HOOKS = Registry('hook', parent=MMENGINE_HOOKS) + +# manage data-related modules +DATASETS = Registry('dataset', parent=MMENGINE_DATASETS) +DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS) +TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS) + +# manage all kinds of modules inheriting `nn.Module` +MODELS = Registry('model', parent=MMENGINE_MODELS) +# manage all kinds of model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS) +# manage all kinds of weight initialization modules like `Uniform` +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS) + +# manage all kinds of optimizers like `SGD` and `Adam` +OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS) +OPTIM_WRAPPERS = Registry('optim_wrapper', parent=MMENGINE_OPTIM_WRAPPERS) +# manage constructors that customize the optimization hyperparameters. +OPTIM_WRAPPER_CONSTRUCTORS = Registry( + 'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS) +# manage all kinds of parameter schedulers like `MultiStepLR` +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS) +# manage all kinds of metrics +METRICS = Registry('metric', parent=MMENGINE_METRICS) + +# manage task-specific modules like anchor generators and box coders +TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS) + +# manage visualizer +VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS) +# manage visualizer backend +VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS) diff --git a/mmyolo/version.py b/mmyolo/version.py new file mode 100644 index 00000000..62851b55 --- /dev/null +++ b/mmyolo/version.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +__version__ = '0.0.1' + +from typing import Tuple + +short_version = __version__ + + +def parse_version_info(version_str: str) -> Tuple: + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) diff --git a/tests/test_models/test_data_preprocessor/__init__.py b/tests/test_models/test_data_preprocessor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_models/test_data_preprocessor/test_data_preprocessor.py b/tests/test_models/test_data_preprocessor/test_data_preprocessor.py new file mode 100644 index 00000000..85cdb742 --- /dev/null +++ b/tests/test_models/test_data_preprocessor/test_data_preprocessor.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmdet.structures import DetDataSample + +from mmyolo.models.data_preprocessors import YOLOv5DetDataPreprocessor + + +class TestYOLOv5DetDataPreprocessor(TestCase): + + def test_forward(self): + processor = YOLOv5DetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1]) + + data = { + 'inputs': [torch.randint(0, 256, (3, 11, 10))], + 'data_samples': [DetDataSample()] + } + out_data = processor(data, training=False) + batch_inputs, batch_data_samples = out_data['inputs'], out_data[ + 'data_samples'] + + self.assertEqual(batch_inputs.shape, (1, 3, 11, 10)) + self.assertEqual(len(batch_data_samples), 1) + + # test channel_conversion + processor = YOLOv5DetDataPreprocessor( + mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True) + out_data = processor(data, training=False) + batch_inputs, batch_data_samples = out_data['inputs'], out_data[ + 'data_samples'] + self.assertEqual(batch_inputs.shape, (1, 3, 11, 10)) + self.assertEqual(len(batch_data_samples), 1) + + # test padding, training=False + data = { + 'inputs': [ + torch.randint(0, 256, (3, 10, 11)), + torch.randint(0, 256, (3, 9, 14)) + ] + } + processor = YOLOv5DetDataPreprocessor( + mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True) + out_data = processor(data, training=False) + batch_inputs, batch_data_samples = out_data['inputs'], out_data[ + 'data_samples'] + self.assertEqual(batch_inputs.shape, (2, 3, 10, 14)) + self.assertIsNone(batch_data_samples) + + # test training + data = { + 'inputs': torch.randint(0, 256, (2, 3, 10, 11)), + 'data_samples': torch.randint(0, 11, (18, 6)), + } + out_data = processor(data, training=True) + batch_inputs, batch_data_samples = out_data['inputs'], out_data[ + 'data_samples'] + self.assertIn('img_metas', batch_data_samples) + self.assertIn('bboxes_labels', batch_data_samples) + self.assertEqual(batch_inputs.shape, (2, 3, 10, 11)) + self.assertIsInstance(batch_data_samples['bboxes_labels'], + torch.Tensor) + self.assertIsInstance(batch_data_samples['img_metas'], list) + + data = { + 'inputs': [torch.randint(0, 256, (3, 11, 10))], + 'data_samples': [DetDataSample()] + } + # data_samples must be tensor + with self.assertRaises(AssertionError): + processor(data, training=True) diff --git a/tests/test_models/test_layers/__init__.py b/tests/test_models/test_layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_models/test_layers/test_ema.py b/tests/test_models/test_layers/test_ema.py new file mode 100644 index 00000000..b3583828 --- /dev/null +++ b/tests/test_models/test_layers/test_ema.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +import math +from unittest import TestCase + +import torch +import torch.nn as nn +from mmengine.testing import assert_allclose + +from mmyolo.models.layers import ExpMomentumEMA + + +class TestEMA(TestCase): + + def test_exp_momentum_ema(self): + model = nn.Sequential(nn.Conv2d(1, 5, kernel_size=3), nn.Linear(5, 10)) + # Test invalid gamma + with self.assertRaisesRegex(AssertionError, + 'gamma must be greater than 0'): + ExpMomentumEMA(model, gamma=-1) + + # Test EMA + model = torch.nn.Sequential( + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + momentum = 0.1 + gamma = 4 + + ema_model = ExpMomentumEMA(model, momentum=momentum, gamma=gamma) + averaged_params = [ + torch.zeros_like(param) for param in model.parameters() + ] + n_updates = 10 + for i in range(n_updates): + updated_averaged_params = [] + for p, p_avg in zip(model.parameters(), averaged_params): + p.detach().add_(torch.randn_like(p)) + if i == 0: + updated_averaged_params.append(p.clone()) + else: + m = (1 - momentum) * math.exp(-(1 + i) / gamma) + momentum + updated_averaged_params.append( + (p_avg * (1 - m) + p * m).clone()) + ema_model.update_parameters(model) + averaged_params = updated_averaged_params + + for p_target, p_ema in zip(averaged_params, ema_model.parameters()): + assert_allclose(p_target, p_ema) + + def test_exp_momentum_ema_update_buffer(self): + model = nn.Sequential( + nn.Conv2d(1, 5, kernel_size=3), nn.BatchNorm2d(5, momentum=0.3), + nn.Linear(5, 10)) + # Test invalid gamma + with self.assertRaisesRegex(AssertionError, + 'gamma must be greater than 0'): + ExpMomentumEMA(model, gamma=-1) + + # Test EMA with momentum annealing. + momentum = 0.1 + gamma = 4 + + ema_model = ExpMomentumEMA( + model, gamma=gamma, momentum=momentum, update_buffers=True) + averaged_params = [ + torch.zeros_like(param) + for param in itertools.chain(model.parameters(), model.buffers()) + if param.size() != torch.Size([]) + ] + n_updates = 10 + for i in range(n_updates): + updated_averaged_params = [] + params = [ + param for param in itertools.chain(model.parameters(), + model.buffers()) + if param.size() != torch.Size([]) + ] + for p, p_avg in zip(params, averaged_params): + p.detach().add_(torch.randn_like(p)) + if i == 0: + updated_averaged_params.append(p.clone()) + else: + m = (1 - momentum) * math.exp(-(1 + i) / gamma) + momentum + updated_averaged_params.append( + (p_avg * (1 - m) + p * m).clone()) + ema_model.update_parameters(model) + averaged_params = updated_averaged_params + + ema_params = [ + param for param in itertools.chain(ema_model.module.parameters(), + ema_model.module.buffers()) + if param.size() != torch.Size([]) + ] + for p_target, p_ema in zip(averaged_params, ema_params): + assert_allclose(p_target, p_ema) diff --git a/tests/test_models/test_layers/test_yolo_bricks.py b/tests/test_models/test_layers/test_yolo_bricks.py new file mode 100644 index 00000000..0bc35a11 --- /dev/null +++ b/tests/test_models/test_layers/test_yolo_bricks.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from unittest import TestCase + +import torch + +from mmyolo.models.layers import SPPFBottleneck +from mmyolo.utils import register_all_modules + +register_all_modules() + + +class TestSPPFBottleneck(TestCase): + + def test_forward(self): + input_tensor = torch.randn((1, 3, 20, 20)) + bottleneck = SPPFBottleneck(3, 16) + out_tensor = bottleneck(input_tensor) + self.assertEqual(out_tensor.shape, (1, 16, 20, 20)) + + bottleneck = SPPFBottleneck(3, 16, kernel_sizes=[3, 5, 7]) + out_tensor = bottleneck(input_tensor) + self.assertEqual(out_tensor.shape, (1, 16, 20, 20)) diff --git a/tests/test_models/test_necks/__init__.py b/tests/test_models/test_necks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_models/test_necks/test_yolov5_pafpn.py b/tests/test_models/test_necks/test_yolov5_pafpn.py new file mode 100644 index 00000000..339621ec --- /dev/null +++ b/tests/test_models/test_necks/test_yolov5_pafpn.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmyolo.models.necks import YOLOv5PAFPN +from mmyolo.utils import register_all_modules + +register_all_modules() + + +class TestYOLOv5PAFPN(TestCase): + + def test_forward(self): + s = 64 + in_channels = [8, 16, 32] + feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8] + out_channels = [8, 16, 32] + feats = [ + torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) + for i in range(len(in_channels)) + ] + neck = YOLOv5PAFPN(in_channels=in_channels, out_channels=out_channels) + outs = neck(feats) + assert len(outs) == len(feats) + for i in range(len(feats)): + assert outs[i].shape[1] == out_channels[i] + assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i) diff --git a/tests/test_models/test_necks/test_yolov6_pafpn.py b/tests/test_models/test_necks/test_yolov6_pafpn.py new file mode 100644 index 00000000..ae09f6ac --- /dev/null +++ b/tests/test_models/test_necks/test_yolov6_pafpn.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmyolo.models.necks import YOLOv6RepPAFPN +from mmyolo.utils import register_all_modules + +register_all_modules() + + +class TestYOLOv6RepPAFPN(TestCase): + + def test_forward(self): + s = 64 + in_channels = [8, 16, 32] + feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8] + out_channels = [8, 16, 32] + feats = [ + torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) + for i in range(len(in_channels)) + ] + neck = YOLOv6RepPAFPN( + in_channels=in_channels, out_channels=out_channels) + outs = neck(feats) + assert len(outs) == len(feats) + for i in range(len(feats)): + assert outs[i].shape[1] == out_channels[i] + assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i) diff --git a/tests/test_models/test_necks/test_yolox_pafpn.py b/tests/test_models/test_necks/test_yolox_pafpn.py new file mode 100644 index 00000000..25fe67a1 --- /dev/null +++ b/tests/test_models/test_necks/test_yolox_pafpn.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmyolo.models.necks import YOLOXPAFPN +from mmyolo.utils import register_all_modules + +register_all_modules() + + +class TestYOLOXPAFPN(TestCase): + + def test_forward(self): + s = 64 + in_channels = [8, 16, 32] + feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8] + out_channels = 24 + feats = [ + torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) + for i in range(len(in_channels)) + ] + neck = YOLOXPAFPN(in_channels=in_channels, out_channels=out_channels) + outs = neck(feats) + assert len(outs) == len(feats) + for i in range(len(feats)): + assert outs[i].shape[1] == out_channels + assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)