Add UT and other code

pull/23/head
hha 2022-09-18 13:17:03 +08:00
parent 7e90343d85
commit 3d0be5c868
25 changed files with 674 additions and 0 deletions

39
mmyolo/__init__.py 100644
View File

@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import mmengine
from mmengine.utils import digit_version
import mmdet
from .version import __version__, version_info
mmcv_minimum_version = '2.0.0rc0'
mmcv_maximum_version = '2.1.0'
mmcv_version = digit_version(mmcv.__version__)
mmengine_minimum_version = '0.0.0'
mmengine_maximum_version = '0.2.0'
mmengine_version = digit_version(mmengine.__version__)
mmdet_minimum_version = '3.0.0rc0'
mmdet_maximum_version = '4.0.0'
mmdet_version = digit_version(mmdet.__version__)
assert (mmcv_version >= digit_version(mmcv_minimum_version)
and mmcv_version < digit_version(mmcv_maximum_version)), \
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
assert (mmengine_version >= digit_version(mmengine_minimum_version)
and mmengine_version < digit_version(mmengine_maximum_version)), \
f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
f'Please install mmengine>={mmengine_minimum_version}, ' \
f'<{mmengine_maximum_version}.'
assert (mmdet_version >= digit_version(mmdet_minimum_version)
and mmdet_version < digit_version(mmdet_maximum_version)), \
f'MMDetection=={mmdet.__version__} is used but incompatible. ' \
f'Please install mmdet>={mmdet_minimum_version}, ' \
f'<{mmdet_maximum_version}.'
__all__ = ['__version__', 'version_info', 'digit_version']

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backbones import * # noqa: F401,F403
from .data_preprocessors import * # noqa: F401,F403
from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .layers import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .task_modules import * # noqa: F401,F403

View File

@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from mmdet.models.layers import ExpMomentumEMA as MMDET_ExpMomentumEMA
from mmyolo.registry import MODELS
@MODELS.register_module()
class ExpMomentumEMA(MMDET_ExpMomentumEMA):
"""Exponential moving average (EMA) with exponential momentum strategy,
which is used in YOLO.
Args:
model (nn.Module): The model to be averaged.
momentum (float): The momentum used for updating ema parameter.
Ema's parameters are updated with the formula:
`averaged_param = (1-momentum) * averaged_param + momentum *
source_param`. Defaults to 0.0002.
gamma (int): Use a larger momentum early in training and gradually
annealing to a smaller value to update the ema model smoothly. The
momentum is calculated as
`(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`.
Defaults to 2000.
interval (int): Interval between two updates. Defaults to 1.
device (torch.device, optional): If provided, the averaged model will
be stored on the :attr:`device`. Defaults to None.
update_buffers (bool): if True, it will compute running averages for
both the parameters and the buffers of the model. Defaults to
False.
"""
def __init__(self,
model: nn.Module,
momentum: float = 0.0002,
gamma: int = 2000,
interval=1,
device: Optional[torch.device] = None,
update_buffers: bool = False) -> None:
super().__init__(
model=model,
momentum=momentum,
interval=interval,
device=device,
update_buffers=update_buffers)
assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
self.gamma = gamma
# Note: There is no need to re-fetch every update,
# as most models do not change their structure
# during the training process.
self.src_parameters = (
model.state_dict()
if self.update_buffers else dict(model.named_parameters()))
if not self.update_buffers:
self.src_buffers = model.buffers()
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int):
"""Compute the moving average of the parameters using the exponential
momentum strategy.
Args:
averaged_param (Tensor): The averaged parameters.
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
"""
momentum = (1 - self.momentum) * math.exp(
-float(1 + steps) / self.gamma) + self.momentum
averaged_param.lerp_(source_param, momentum)
def update_parameters(self, model: nn.Module):
if self.steps == 0:
for k, p_avg in self.avg_parameters.items():
p_avg.data.copy_(self.src_parameters[k].data)
elif self.steps % self.interval == 0:
for k, p_avg in self.avg_parameters.items():
if p_avg.dtype.is_floating_point:
self.avg_func(p_avg.data, self.src_parameters[k].data,
self.steps)
if not self.update_buffers:
# If not update the buffers,
# keep the buffers in sync with the source model.
for b_avg, b_src in zip(self.module.buffers(), self.src_buffers):
b_avg.data.copy_(b_src.data)
self.steps += 1

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .coders import YOLOv5BBoxCoder, YOLOXBBoxCoder
__all__ = ['YOLOv5BBoxCoder', 'YOLOXBBoxCoder']

View File

@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .distance_point_bbox_coder import DistancePointBBoxCoder
from .yolov5_bbox_coder import YOLOv5BBoxCoder
from .yolox_bbox_coder import YOLOXBBoxCoder
__all__ = ['YOLOv5BBoxCoder', 'YOLOXBBoxCoder', 'DistancePointBBoxCoder']

View File

@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union
import torch
from mmdet.models.task_modules.coders import \
DistancePointBBoxCoder as MMDET_DistancePointBBoxCoder
from mmdet.structures.bbox import distance2bbox
from mmyolo.registry import TASK_UTILS
@TASK_UTILS.register_module()
class DistancePointBBoxCoder(MMDET_DistancePointBBoxCoder):
"""Distance Point BBox coder.
This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
right) and decode it back to the original.
"""
def decode(
self,
points: torch.Tensor,
pred_bboxes: torch.Tensor,
stride: torch.Tensor,
max_shape: Optional[Union[Sequence[int], torch.Tensor,
Sequence[Sequence[int]]]] = None
) -> torch.Tensor:
"""Decode distance prediction to bounding box.
Args:
points (Tensor): Shape (B, N, 2) or (N, 2).
pred_bboxes (Tensor): Distance from the given point to 4
boundaries (left, top, right, bottom). Shape (B, N, 4)
or (N, 4)
stride (Tensor): Featmap stride.
max_shape (Sequence[int] or torch.Tensor or Sequence[
Sequence[int]],optional): Maximum bounds for boxes, specifies
(H, W, C) or (H, W). If priors shape is (B, N, 4), then
the max_shape should be a Sequence[Sequence[int]],
and the length of max_shape should also be B.
Default None.
Returns:
Tensor: Boxes with shape (N, 4) or (B, N, 4)
"""
assert points.size(-2) == pred_bboxes.size(-2)
assert points.size(-1) == 2
assert pred_bboxes.size(-1) == 4
if self.clip_border is False:
max_shape = None
pred_bboxes = pred_bboxes * stride[None, :, None]
return distance2bbox(points, pred_bboxes, max_shape)

View File

@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
from mmdet.models.task_modules.coders.base_bbox_coder import BaseBBoxCoder
from mmyolo.registry import TASK_UTILS
@TASK_UTILS.register_module()
class YOLOv5BBoxCoder(BaseBBoxCoder):
def encode(self, **kwargs):
pass
def decode(self, priors: torch.Tensor, pred_bboxes: torch.Tensor,
stride: Union[torch.Tensor, int]) -> torch.Tensor:
"""Apply transformation `pred_bboxes` to `decoded_bboxes`.
Args:
priors (torch.Tensor): Basic boxes or points, e.g. anchors.
pred_bboxes (torch.Tensor): Encoded boxes with shape
stride (torch.Tensor | int): Strides of bboxes.
Returns:
torch.Tensor: Decoded boxes.
"""
assert pred_bboxes.size(-1) == priors.size(-1) == 4
pred_bboxes = pred_bboxes.sigmoid()
x_center = (priors[..., 0] + priors[..., 2]) * 0.5
y_center = (priors[..., 1] + priors[..., 3]) * 0.5
w = priors[..., 2] - priors[..., 0]
h = priors[..., 3] - priors[..., 1]
# The anchor of mmdet has been offset by 0.5
x_center_pred = (pred_bboxes[..., 0] - 0.5) * 2 * stride + x_center
y_center_pred = (pred_bboxes[..., 1] - 0.5) * 2 * stride + y_center
w_pred = (pred_bboxes[..., 2] * 2)**2 * w
h_pred = (pred_bboxes[..., 3] * 2)**2 * h
decoded_bboxes = torch.stack(
(x_center_pred - w_pred / 2, y_center_pred - h_pred / 2,
x_center_pred + w_pred / 2, y_center_pred + h_pred / 2),
dim=-1)
return decoded_bboxes

View File

@ -0,0 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
from mmdet.models.task_modules.coders.base_bbox_coder import BaseBBoxCoder
from mmyolo.registry import TASK_UTILS
@TASK_UTILS.register_module()
class YOLOXBBoxCoder(BaseBBoxCoder):
def encode(self, **kwargs):
pass
def decode(self, priors: torch.Tensor, pred_bboxes: torch.Tensor,
stride: Union[torch.Tensor, int]) -> torch.Tensor:
"""Apply transformation `pred_bboxes` to `decoded_bboxes`.
Args:
priors (torch.Tensor): Basic boxes or points, e.g. anchors.
pred_bboxes (torch.Tensor): Encoded boxes with shape
stride (torch.Tensor | int): Strides of bboxes.
Returns:
torch.Tensor: Decoded boxes.
"""
stride = stride[None, :, None]
xys = (pred_bboxes[..., :2] * stride) + priors
whs = pred_bboxes[..., 2:].exp() * stride
tl_x = (xys[..., 0] - whs[..., 0] / 2)
tl_y = (xys[..., 1] - whs[..., 1] / 2)
br_x = (xys[..., 0] + whs[..., 0] / 2)
br_y = (xys[..., 1] + whs[..., 1] / 2)
decoded_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
return decoded_bboxes

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .misc import make_divisible, make_round
__all__ = ['make_divisible', 'make_round']

View File

@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
def make_divisible(x: float,
widen_factor: float = 1.0,
divisor: int = 8) -> int:
"""Make sure that x*widen_factor is divisible by divisor."""
return math.ceil(x * widen_factor / divisor) * divisor
def make_round(x: float, deepen_factor: float = 1.0) -> int:
"""Make sure that x*deepen_factor becomes an integer not less than 1."""
return max(round(x * deepen_factor), 1) if x > 1 else x

73
mmyolo/registry.py 100644
View File

@ -0,0 +1,73 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""MMYOLO provides 17 registry nodes to support using modules across projects.
Each node is a child of the root registry in MMEngine.
More details can be found at
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import HOOKS as MMENGINE_HOOKS
from mmengine.registry import LOOPS as MMENGINE_LOOPS
from mmengine.registry import METRICS as MMENGINE_METRICS
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
from mmengine.registry import MODELS as MMENGINE_MODELS
from mmengine.registry import \
OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS
from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
from mmengine.registry import \
RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS
from mmengine.registry import RUNNERS as MMENGINE_RUNNERS
from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS
from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS
from mmengine.registry import \
WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS
from mmengine.registry import Registry
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
# manage runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry(
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
# manage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
# manage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
# manage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
# manage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
OPTIM_WRAPPERS = Registry('optim_wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
# manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
# manage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
# manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS)
# manage task-specific modules like anchor generators and box coders
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
# manage visualizer
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)

22
mmyolo/version.py 100644
View File

@ -0,0 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
__version__ = '0.0.1'
from typing import Tuple
short_version = __version__
def parse_version_info(version_str: str) -> Tuple:
version_info = []
for x in version_str.split('.'):
if x.isdigit():
version_info.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
version_info.append(int(patch_version[0]))
version_info.append(f'rc{patch_version[1]}')
return tuple(version_info)
version_info = parse_version_info(__version__)

View File

@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmdet.structures import DetDataSample
from mmyolo.models.data_preprocessors import YOLOv5DetDataPreprocessor
class TestYOLOv5DetDataPreprocessor(TestCase):
def test_forward(self):
processor = YOLOv5DetDataPreprocessor(mean=[0, 0, 0], std=[1, 1, 1])
data = {
'inputs': [torch.randint(0, 256, (3, 11, 10))],
'data_samples': [DetDataSample()]
}
out_data = processor(data, training=False)
batch_inputs, batch_data_samples = out_data['inputs'], out_data[
'data_samples']
self.assertEqual(batch_inputs.shape, (1, 3, 11, 10))
self.assertEqual(len(batch_data_samples), 1)
# test channel_conversion
processor = YOLOv5DetDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
out_data = processor(data, training=False)
batch_inputs, batch_data_samples = out_data['inputs'], out_data[
'data_samples']
self.assertEqual(batch_inputs.shape, (1, 3, 11, 10))
self.assertEqual(len(batch_data_samples), 1)
# test padding, training=False
data = {
'inputs': [
torch.randint(0, 256, (3, 10, 11)),
torch.randint(0, 256, (3, 9, 14))
]
}
processor = YOLOv5DetDataPreprocessor(
mean=[0., 0., 0.], std=[1., 1., 1.], bgr_to_rgb=True)
out_data = processor(data, training=False)
batch_inputs, batch_data_samples = out_data['inputs'], out_data[
'data_samples']
self.assertEqual(batch_inputs.shape, (2, 3, 10, 14))
self.assertIsNone(batch_data_samples)
# test training
data = {
'inputs': torch.randint(0, 256, (2, 3, 10, 11)),
'data_samples': torch.randint(0, 11, (18, 6)),
}
out_data = processor(data, training=True)
batch_inputs, batch_data_samples = out_data['inputs'], out_data[
'data_samples']
self.assertIn('img_metas', batch_data_samples)
self.assertIn('bboxes_labels', batch_data_samples)
self.assertEqual(batch_inputs.shape, (2, 3, 10, 11))
self.assertIsInstance(batch_data_samples['bboxes_labels'],
torch.Tensor)
self.assertIsInstance(batch_data_samples['img_metas'], list)
data = {
'inputs': [torch.randint(0, 256, (3, 11, 10))],
'data_samples': [DetDataSample()]
}
# data_samples must be tensor
with self.assertRaises(AssertionError):
processor(data, training=True)

View File

@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import math
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.testing import assert_allclose
from mmyolo.models.layers import ExpMomentumEMA
class TestEMA(TestCase):
def test_exp_momentum_ema(self):
model = nn.Sequential(nn.Conv2d(1, 5, kernel_size=3), nn.Linear(5, 10))
# Test invalid gamma
with self.assertRaisesRegex(AssertionError,
'gamma must be greater than 0'):
ExpMomentumEMA(model, gamma=-1)
# Test EMA
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
momentum = 0.1
gamma = 4
ema_model = ExpMomentumEMA(model, momentum=momentum, gamma=gamma)
averaged_params = [
torch.zeros_like(param) for param in model.parameters()
]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
for p, p_avg in zip(model.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
m = (1 - momentum) * math.exp(-(1 + i) / gamma) + momentum
updated_averaged_params.append(
(p_avg * (1 - m) + p * m).clone())
ema_model.update_parameters(model)
averaged_params = updated_averaged_params
for p_target, p_ema in zip(averaged_params, ema_model.parameters()):
assert_allclose(p_target, p_ema)
def test_exp_momentum_ema_update_buffer(self):
model = nn.Sequential(
nn.Conv2d(1, 5, kernel_size=3), nn.BatchNorm2d(5, momentum=0.3),
nn.Linear(5, 10))
# Test invalid gamma
with self.assertRaisesRegex(AssertionError,
'gamma must be greater than 0'):
ExpMomentumEMA(model, gamma=-1)
# Test EMA with momentum annealing.
momentum = 0.1
gamma = 4
ema_model = ExpMomentumEMA(
model, gamma=gamma, momentum=momentum, update_buffers=True)
averaged_params = [
torch.zeros_like(param)
for param in itertools.chain(model.parameters(), model.buffers())
if param.size() != torch.Size([])
]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
params = [
param for param in itertools.chain(model.parameters(),
model.buffers())
if param.size() != torch.Size([])
]
for p, p_avg in zip(params, averaged_params):
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
m = (1 - momentum) * math.exp(-(1 + i) / gamma) + momentum
updated_averaged_params.append(
(p_avg * (1 - m) + p * m).clone())
ema_model.update_parameters(model)
averaged_params = updated_averaged_params
ema_params = [
param for param in itertools.chain(ema_model.module.parameters(),
ema_model.module.buffers())
if param.size() != torch.Size([])
]
for p_target, p_ema in zip(averaged_params, ema_params):
assert_allclose(p_target, p_ema)

View File

@ -0,0 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmyolo.models.layers import SPPFBottleneck
from mmyolo.utils import register_all_modules
register_all_modules()
class TestSPPFBottleneck(TestCase):
def test_forward(self):
input_tensor = torch.randn((1, 3, 20, 20))
bottleneck = SPPFBottleneck(3, 16)
out_tensor = bottleneck(input_tensor)
self.assertEqual(out_tensor.shape, (1, 16, 20, 20))
bottleneck = SPPFBottleneck(3, 16, kernel_sizes=[3, 5, 7])
out_tensor = bottleneck(input_tensor)
self.assertEqual(out_tensor.shape, (1, 16, 20, 20))

View File

@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmyolo.models.necks import YOLOv5PAFPN
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOv5PAFPN(TestCase):
def test_forward(self):
s = 64
in_channels = [8, 16, 32]
feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8]
out_channels = [8, 16, 32]
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
neck = YOLOv5PAFPN(in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)
assert len(outs) == len(feats)
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels[i]
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)

View File

@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmyolo.models.necks import YOLOv6RepPAFPN
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOv6RepPAFPN(TestCase):
def test_forward(self):
s = 64
in_channels = [8, 16, 32]
feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8]
out_channels = [8, 16, 32]
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
neck = YOLOv6RepPAFPN(
in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)
assert len(outs) == len(feats)
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels[i]
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)

View File

@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmyolo.models.necks import YOLOXPAFPN
from mmyolo.utils import register_all_modules
register_all_modules()
class TestYOLOXPAFPN(TestCase):
def test_forward(self):
s = 64
in_channels = [8, 16, 32]
feat_sizes = [s // 2**i for i in range(4)] # [32, 16, 8]
out_channels = 24
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
neck = YOLOXPAFPN(in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)
assert len(outs) == len(feats)
for i in range(len(feats)):
assert outs[i].shape[1] == out_channels
assert outs[i].shape[2] == outs[i].shape[3] == s // (2**i)