mirror of https://github.com/open-mmlab/mmyolo.git
[Fix] training errors of yolox nano (#285)
* Update yolox_pafpn.py * Update csp_darknet.pypull/280/head^2
parent
ab4e7c5158
commit
c043181149
|
@ -3,7 +3,7 @@ from typing import List, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmdet.models.backbones.csp_darknet import CSPLayer, Focus
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
|
||||
|
@ -178,6 +178,8 @@ class YOLOXCSPDarknet(BaseBackbone):
|
|||
Defaults to (2, 3, 4).
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
||||
mode). -1 means not freezing any parameters. Defaults to -1.
|
||||
use_depthwise (bool): Whether to use depthwise separable convolution.
|
||||
Defaults to False.
|
||||
spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP
|
||||
layers. Defaults to (5, 9, 13).
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
|
@ -218,12 +220,14 @@ class YOLOXCSPDarknet(BaseBackbone):
|
|||
input_channels: int = 3,
|
||||
out_indices: Tuple[int] = (2, 3, 4),
|
||||
frozen_stages: int = -1,
|
||||
use_depthwise: bool = False,
|
||||
spp_kernal_sizes: Tuple[int] = (5, 9, 13),
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
norm_eval: bool = False,
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.use_depthwise = use_depthwise
|
||||
self.spp_kernal_sizes = spp_kernal_sizes
|
||||
super().__init__(self.arch_settings[arch], deepen_factor, widen_factor,
|
||||
input_channels, out_indices, frozen_stages, plugins,
|
||||
|
@ -251,7 +255,9 @@ class YOLOXCSPDarknet(BaseBackbone):
|
|||
out_channels = make_divisible(out_channels, self.widen_factor)
|
||||
num_blocks = make_round(num_blocks, self.deepen_factor)
|
||||
stage = []
|
||||
conv_layer = ConvModule(
|
||||
conv = DepthwiseSeparableConvModule \
|
||||
if self.use_depthwise else ConvModule
|
||||
conv_layer = conv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from typing import List
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmdet.models.backbones.csp_darknet import CSPLayer
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
|
||||
|
@ -22,6 +22,8 @@ class YOLOXPAFPN(BaseYOLONeck):
|
|||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Defaults to 1.
|
||||
use_depthwise (bool): Whether to use depthwise separable convolution.
|
||||
Defaults to False.
|
||||
freeze_all(bool): Whether to freeze the model. Defaults to False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
||||
|
@ -37,12 +39,14 @@ class YOLOXPAFPN(BaseYOLONeck):
|
|||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
num_csp_blocks: int = 3,
|
||||
use_depthwise: bool = False,
|
||||
freeze_all: bool = False,
|
||||
norm_cfg: ConfigType = dict(
|
||||
type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
||||
init_cfg: OptMultiConfig = None):
|
||||
self.num_csp_blocks = round(num_csp_blocks * deepen_factor)
|
||||
self.use_depthwise = use_depthwise
|
||||
|
||||
super().__init__(
|
||||
in_channels=[
|
||||
|
@ -123,7 +127,9 @@ class YOLOXPAFPN(BaseYOLONeck):
|
|||
Returns:
|
||||
nn.Module: The downsample layer.
|
||||
"""
|
||||
return ConvModule(
|
||||
conv = DepthwiseSeparableConvModule \
|
||||
if self.use_depthwise else ConvModule
|
||||
return conv(
|
||||
self.in_channels[idx],
|
||||
self.in_channels[idx],
|
||||
kernel_size=3,
|
||||
|
|
Loading…
Reference in New Issue