mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Add yolo detector (#17)
parent
a7a4f16d69
commit
9002daf374
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .yolo_detector import YOLODetector
|
||||
|
||||
__all__ = ['YOLODetector']
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine.dist import get_world_size
|
||||
from mmengine.logging import print_log
|
||||
|
||||
from mmdet.models.detectors.single_stage import SingleStageDetector
|
||||
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
|
||||
from mmyolo.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class YOLODetector(SingleStageDetector):
|
||||
r"""Implementation of YOLO Series
|
||||
|
||||
Args:
|
||||
backbone (:obj:`ConfigDict` or dict): The backbone config.
|
||||
neck (:obj:`ConfigDict` or dict): The neck config.
|
||||
bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
|
||||
train_cfg (:obj:`ConfigDict` or dict, optional): The training config
|
||||
of YOLOX. Defaults to None.
|
||||
test_cfg (:obj:`ConfigDict` or dict, optional): The testing config
|
||||
of YOLOX. Defaults to None.
|
||||
data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
|
||||
:class:`DetDataPreprocessor` to process the input data.
|
||||
Defaults to None.
|
||||
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
||||
list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
use_syncbn (bool): whether to use SyncBatchNorm. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: ConfigType,
|
||||
neck: ConfigType,
|
||||
bbox_head: ConfigType,
|
||||
train_cfg: OptConfigType = None,
|
||||
test_cfg: OptConfigType = None,
|
||||
data_preprocessor: OptConfigType = None,
|
||||
init_cfg: OptMultiConfig = None,
|
||||
use_syncbn: bool = True) -> None:
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
bbox_head=bbox_head,
|
||||
train_cfg=train_cfg,
|
||||
test_cfg=test_cfg,
|
||||
data_preprocessor=data_preprocessor,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
# TODO: Waiting for mmengine support
|
||||
if use_syncbn and get_world_size() > 1:
|
||||
torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
|
||||
print_log('Using SyncBatchNorm()', 'current')
|
Loading…
Reference in New Issue