From 9002daf3743967c6b415b7a3e5b407c777b00bd1 Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Sun, 18 Sep 2022 13:04:51 +0800 Subject: [PATCH] [Feature] Add yolo detector (#17) --- mmyolo/models/detectors/__init__.py | 4 ++ mmyolo/models/detectors/yolo_detector.py | 53 ++++++++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 mmyolo/models/detectors/__init__.py create mode 100644 mmyolo/models/detectors/yolo_detector.py diff --git a/mmyolo/models/detectors/__init__.py b/mmyolo/models/detectors/__init__.py new file mode 100644 index 00000000..74fb1c6c --- /dev/null +++ b/mmyolo/models/detectors/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .yolo_detector import YOLODetector + +__all__ = ['YOLODetector'] diff --git a/mmyolo/models/detectors/yolo_detector.py b/mmyolo/models/detectors/yolo_detector.py new file mode 100644 index 00000000..8f4182cc --- /dev/null +++ b/mmyolo/models/detectors/yolo_detector.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dist import get_world_size +from mmengine.logging import print_log + +from mmdet.models.detectors.single_stage import SingleStageDetector +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from mmyolo.registry import MODELS + + +@MODELS.register_module() +class YOLODetector(SingleStageDetector): + r"""Implementation of YOLO Series + + Args: + backbone (:obj:`ConfigDict` or dict): The backbone config. + neck (:obj:`ConfigDict` or dict): The neck config. + bbox_head (:obj:`ConfigDict` or dict): The bbox head config. + train_cfg (:obj:`ConfigDict` or dict, optional): The training config + of YOLOX. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): The testing config + of YOLOX. Defaults to None. + data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of + :class:`DetDataPreprocessor` to process the input data. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + use_syncbn (bool): whether to use SyncBatchNorm. Defaults to True. + """ + + def __init__(self, + backbone: ConfigType, + neck: ConfigType, + bbox_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None, + use_syncbn: bool = True) -> None: + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + # TODO: Waiting for mmengine support + if use_syncbn and get_world_size() > 1: + torch.nn.SyncBatchNorm.convert_sync_batchnorm(self) + print_log('Using SyncBatchNorm()', 'current')