[Improvement] Add docstring and update RTMDet graph (#317)

* add docstring

* Update yolov5_coco.py

* Update rtmdet_description.md

* Update batch_yolov7_assigner.py
pull/278/merge
Range King 2022-11-28 17:13:20 +08:00 committed by GitHub
parent 55da55eb09
commit 2354302c94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 17 additions and 6 deletions

View File

@ -5,7 +5,7 @@
高性能,低延时的单阶段目标检测器
<div align=center>
<img alt="RTMDet_structure_v1.2" src="https://user-images.githubusercontent.com/27466624/200001002-008ac696-e74d-4da1-9c6d-07149e2ad752.jpg"/>
<img alt="RTMDet_structure_v1.3" src="https://user-images.githubusercontent.com/27466624/204126145-cb4ff4f1-fb16-455e-96b5-17620081023a.jpg"/>
</div>
以上结构图由 RangeKing@github 绘制。

View File

@ -7,6 +7,9 @@ from ..registry import DATASETS, TASK_UTILS
class BatchShapePolicyDataset(BaseDetDataset):
"""Dataset with the batch shape policy that makes paddings with least
pixels during batch inference process, which does not require the image
scales of all batches to be the same throughout validation."""
def __init__(self,
*args,
@ -17,7 +20,7 @@ class BatchShapePolicyDataset(BaseDetDataset):
def full_init(self):
"""rewrite full_init() to be compatible with serialize_data in
BatchShapesPolicy."""
BatchShapePolicy."""
if self._fully_initialized:
return
# load data information

View File

@ -18,13 +18,14 @@ class MMYOLO(MMCodebase):
@classmethod
def register_deploy_modules(cls):
"""register all rewriters for mmcls."""
"""register all rewriters for mmdet."""
import mmdeploy.codebase.mmdet.models # noqa: F401
import mmdeploy.codebase.mmdet.ops # noqa: F401
import mmdeploy.codebase.mmdet.structures # noqa: F401
@classmethod
def register_all_modules(cls):
"""register all modules."""
from mmdet.utils.setup_env import \
register_all_modules as register_all_modules_mmdet

View File

@ -17,4 +17,5 @@ class SwitchToDeployHook(Hook):
"""
def before_test_epoch(self, runner: Runner):
"""Switch to deploy mode before testing."""
switch_to_deploy(runner.model)

View File

@ -146,8 +146,8 @@ class YOLOv5CSPDarknet(BaseBackbone):
return stage
def init_weights(self):
"""Initialize the parameters."""
if self.init_cfg is None:
"""Initialize the parameters."""
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
# In order to be consistent with the source code,

View File

@ -48,6 +48,7 @@ class ChannelAttention(BaseModule):
self.sigmoid = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
avgpool_out = self.fc(self.avg_pool(x))
maxpool_out = self.fc(self.max_pool(x))
out = self.sigmoid(avgpool_out + maxpool_out)
@ -74,6 +75,7 @@ class SpatialAttention(BaseModule):
act_cfg=dict(type='Sigmoid'))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
@ -111,6 +113,7 @@ class CBAM(BaseModule):
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
out = self.channel_attention(x) * x
out = self.spatial_attention(out) * out
return out

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -6,6 +8,7 @@ from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_overlaps
def _cat_multi_level_tensor_in_place(*multi_level_tensor, place_hold_var):
"""concat multi-level tensor in place."""
for level_tensor in multi_level_tensor:
for i, var in enumerate(level_tensor):
if len(var) > 0:
@ -28,8 +31,8 @@ class BatchYOLOv7Assigner(nn.Module):
def __init__(self,
num_classes: int,
num_base_priors,
featmap_strides,
num_base_priors: int,
featmap_strides: Sequence[int],
prior_match_thr: float = 4.0,
candidate_topk: int = 10,
iou_weight: float = 3.0,