mirror of https://github.com/open-mmlab/mmyolo.git
[Improvement] Add docstring and update RTMDet graph (#317)
* add docstring * Update yolov5_coco.py * Update rtmdet_description.md * Update batch_yolov7_assigner.pypull/278/merge
parent
55da55eb09
commit
2354302c94
|
@ -5,7 +5,7 @@
|
|||
高性能,低延时的单阶段目标检测器
|
||||
|
||||
<div align=center>
|
||||
<img alt="RTMDet_structure_v1.2" src="https://user-images.githubusercontent.com/27466624/200001002-008ac696-e74d-4da1-9c6d-07149e2ad752.jpg"/>
|
||||
<img alt="RTMDet_structure_v1.3" src="https://user-images.githubusercontent.com/27466624/204126145-cb4ff4f1-fb16-455e-96b5-17620081023a.jpg"/>
|
||||
</div>
|
||||
|
||||
以上结构图由 RangeKing@github 绘制。
|
||||
|
|
|
@ -7,6 +7,9 @@ from ..registry import DATASETS, TASK_UTILS
|
|||
|
||||
|
||||
class BatchShapePolicyDataset(BaseDetDataset):
|
||||
"""Dataset with the batch shape policy that makes paddings with least
|
||||
pixels during batch inference process, which does not require the image
|
||||
scales of all batches to be the same throughout validation."""
|
||||
|
||||
def __init__(self,
|
||||
*args,
|
||||
|
@ -17,7 +20,7 @@ class BatchShapePolicyDataset(BaseDetDataset):
|
|||
|
||||
def full_init(self):
|
||||
"""rewrite full_init() to be compatible with serialize_data in
|
||||
BatchShapesPolicy."""
|
||||
BatchShapePolicy."""
|
||||
if self._fully_initialized:
|
||||
return
|
||||
# load data information
|
||||
|
|
|
@ -18,13 +18,14 @@ class MMYOLO(MMCodebase):
|
|||
|
||||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
"""register all rewriters for mmcls."""
|
||||
"""register all rewriters for mmdet."""
|
||||
import mmdeploy.codebase.mmdet.models # noqa: F401
|
||||
import mmdeploy.codebase.mmdet.ops # noqa: F401
|
||||
import mmdeploy.codebase.mmdet.structures # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
"""register all modules."""
|
||||
from mmdet.utils.setup_env import \
|
||||
register_all_modules as register_all_modules_mmdet
|
||||
|
||||
|
|
|
@ -17,4 +17,5 @@ class SwitchToDeployHook(Hook):
|
|||
"""
|
||||
|
||||
def before_test_epoch(self, runner: Runner):
|
||||
"""Switch to deploy mode before testing."""
|
||||
switch_to_deploy(runner.model)
|
||||
|
|
|
@ -146,8 +146,8 @@ class YOLOv5CSPDarknet(BaseBackbone):
|
|||
return stage
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the parameters."""
|
||||
if self.init_cfg is None:
|
||||
"""Initialize the parameters."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
# In order to be consistent with the source code,
|
||||
|
|
|
@ -48,6 +48,7 @@ class ChannelAttention(BaseModule):
|
|||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function."""
|
||||
avgpool_out = self.fc(self.avg_pool(x))
|
||||
maxpool_out = self.fc(self.max_pool(x))
|
||||
out = self.sigmoid(avgpool_out + maxpool_out)
|
||||
|
@ -74,6 +75,7 @@ class SpatialAttention(BaseModule):
|
|||
act_cfg=dict(type='Sigmoid'))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function."""
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
out = torch.cat([avg_out, max_out], dim=1)
|
||||
|
@ -111,6 +113,7 @@ class CBAM(BaseModule):
|
|||
self.spatial_attention = SpatialAttention(kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function."""
|
||||
out = self.channel_attention(x) * x
|
||||
out = self.spatial_attention(out) * out
|
||||
return out
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -6,6 +8,7 @@ from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_overlaps
|
|||
|
||||
|
||||
def _cat_multi_level_tensor_in_place(*multi_level_tensor, place_hold_var):
|
||||
"""concat multi-level tensor in place."""
|
||||
for level_tensor in multi_level_tensor:
|
||||
for i, var in enumerate(level_tensor):
|
||||
if len(var) > 0:
|
||||
|
@ -28,8 +31,8 @@ class BatchYOLOv7Assigner(nn.Module):
|
|||
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
num_base_priors,
|
||||
featmap_strides,
|
||||
num_base_priors: int,
|
||||
featmap_strides: Sequence[int],
|
||||
prior_match_thr: float = 4.0,
|
||||
candidate_topk: int = 10,
|
||||
iou_weight: float = 3.0,
|
||||
|
|
Loading…
Reference in New Issue