[Feature] Add yolo detector (#17)

pull/23/head
HinGwenWoong 2022-09-18 13:04:51 +08:00 committed by GitHub
parent a7a4f16d69
commit 9002daf374
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 0 deletions

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .yolo_detector import YOLODetector
__all__ = ['YOLODetector']

View File

@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.dist import get_world_size
from mmengine.logging import print_log
from mmdet.models.detectors.single_stage import SingleStageDetector
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from mmyolo.registry import MODELS
@MODELS.register_module()
class YOLODetector(SingleStageDetector):
r"""Implementation of YOLO Series
Args:
backbone (:obj:`ConfigDict` or dict): The backbone config.
neck (:obj:`ConfigDict` or dict): The neck config.
bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
train_cfg (:obj:`ConfigDict` or dict, optional): The training config
of YOLOX. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): The testing config
of YOLOX. Defaults to None.
data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
:class:`DetDataPreprocessor` to process the input data.
Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
use_syncbn (bool): whether to use SyncBatchNorm. Defaults to True.
"""
def __init__(self,
backbone: ConfigType,
neck: ConfigType,
bbox_head: ConfigType,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None,
use_syncbn: bool = True) -> None:
super().__init__(
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
# TODO Waiting for mmengine support
if use_syncbn and get_world_size() > 1:
torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
print_log('Using SyncBatchNorm()', 'current')