# 教程 5:如何增加新模块 ## 开发新组件 我们基本上将模型组件分为 3 种类型。 - 主干网络:通常是一个特征提取网络,例如 ResNet、MobileNet - 颈部:用于连接主干网络和头部的组件,例如 GlobalAveragePooling - 头部:用于执行特定任务的组件,例如分类和回归 ### 添加新的主干网络 这里,我们以 ResNet_CIFAR 为例,展示了如何开发一个新的主干网络组件。 ResNet_CIFAR 针对 CIFAR 32x32 的图像输入,将 ResNet 中 `kernel_size=7, stride=2` 的设置替换为 `kernel_size=3, stride=1`,并移除了 stem 层之后的 `MaxPooling`,以避免传递过小的特征图到残差块中。 它继承自 `ResNet` 并只修改了 stem 层。 1. 创建一个新文件 `mmcls/models/backbones/resnet_cifar.py`。 ```python import torch.nn as nn from ..builder import BACKBONES from .resnet import ResNet @BACKBONES.register_module() class ResNet_CIFAR(ResNet): """ResNet backbone for CIFAR. (对这个主干网络的简短描述) Args: depth(int): Network depth, from {18, 34, 50, 101, 152}. ... (参数文档) """ def __init__(self, depth, deep_stem=False, **kwargs): # 调用基类 ResNet 的初始化函数 super(ResNet_CIFAR, self).__init__(depth, deep_stem=deep_stem **kwargs) # 其他特殊的初始化流程 assert not self.deep_stem, 'ResNet_CIFAR do not support deep_stem' def _make_stem_layer(self, in_channels, base_channels): # 重载基类的方法,以实现对网络结构的修改 self.conv1 = build_conv_layer( self.conv_cfg, in_channels, base_channels, kernel_size=3, stride=1, padding=1, bias=False) self.norm1_name, norm1 = build_norm_layer( self.norm_cfg, base_channels, postfix=1) self.add_module(self.norm1_name, norm1) self.relu = nn.ReLU(inplace=True) def forward(self, x): # 需要返回一个元组 pass # 此处省略了网络的前向实现 def init_weights(self, pretrained=None): pass # 如果有必要的话,重载基类 ResNet 的参数初始化函数 def train(self, mode=True): pass # 如果有必要的话,重载基类 ResNet 的训练状态函数 ``` 2. 在 `mmcls/models/backbones/__init__.py` 中导入新模块 ```python ... from .resnet_cifar import ResNet_CIFAR __all__ = [ ..., 'ResNet_CIFAR' ] ``` 3. 在配置文件中使用新的主干网络 ```python model = dict( ... backbone=dict( type='ResNet_CIFAR', depth=18, other_arg=xxx), ... ``` ### 添加新的颈部组件 这里我们以 `GlobalAveragePooling` 为例。这是一个非常简单的颈部组件,没有任何参数。 要添加新的颈部组件,我们主要需要实现 `forward` 函数,该函数对主干网络的输出进行 一些操作并将结果传递到头部。 1. 创建一个新文件 `mmcls/models/necks/gap.py` ```python import torch.nn as nn from ..builder import NECKS @NECKS.register_module() class GlobalAveragePooling(nn.Module): def __init__(self): self.gap = nn.AdaptiveAvgPool2d((1, 1)) def forward(self, inputs): # 简单起见,我们默认输入是一个张量 outs = self.gap(inputs) outs = outs.view(inputs.size(0), -1) return outs ``` 2. 在 `mmcls/models/necks/__init__.py` 中导入新模块 ```python ... from .gap import GlobalAveragePooling __all__ = [ ..., 'GlobalAveragePooling' ] ``` 3. 修改配置文件以使用新的颈部组件 ```python model = dict( neck=dict(type='GlobalAveragePooling'), ) ``` ### 添加新的头部组件 在此,我们以 `LinearClsHead` 为例,说明如何开发新的头部组件。 要添加一个新的头部组件,基本上我们需要实现 `forward_train` 函数,它接受来自颈部 或主干网络的特征图作为输入,并基于真实标签计算。 1. 创建一个文件 `mmcls/models/heads/linear_head.py`. ```python from ..builder import HEADS from .cls_head import ClsHead @HEADS.register_module() class LinearClsHead(ClsHead): def __init__(self, num_classes, in_channels, loss=dict(type='CrossEntropyLoss', loss_weight=1.0), topk=(1, )): super(LinearClsHead, self).__init__(loss=loss, topk=topk) self.in_channels = in_channels self.num_classes = num_classes if self.num_classes <= 0: raise ValueError( f'num_classes={num_classes} must be a positive integer') self._init_layers() def _init_layers(self): self.fc = nn.Linear(self.in_channels, self.num_classes) def init_weights(self): normal_init(self.fc, mean=0, std=0.01, bias=0) def forward_train(self, x, gt_label): cls_score = self.fc(x) losses = self.loss(cls_score, gt_label) return losses ``` 2. 在 `mmcls/models/heads/__init__.py` 中导入这个模块 ```python ... from .linear_head import LinearClsHead __all__ = [ ..., 'LinearClsHead' ] ``` 3. 修改配置文件以使用新的头部组件。 连同 `GlobalAveragePooling` 颈部组件,完整的模型配置如下: ```python model = dict( type='ImageClassifier', backbone=dict( type='ResNet', depth=50, num_stages=4, out_indices=(3, ), style='pytorch'), neck=dict(type='GlobalAveragePooling'), head=dict( type='LinearClsHead', num_classes=1000, in_channels=2048, loss=dict(type='CrossEntropyLoss', loss_weight=1.0), topk=(1, 5), )) ``` ### 添加新的损失函数 要添加新的损失函数,我们主要需要在损失函数模块中 `forward` 函数。另外,利用装饰器 `weighted_loss` 可以方便的实现对每个元素的损失进行加权平均。 假设我们要模拟从另一个分类模型生成的概率分布,需要添加 `L1loss` 来实现该目的。 1. 创建一个新文件 `mmcls/models/losses/l1_loss.py` ```python import torch import torch.nn as nn from ..builder import LOSSES from .utils import weighted_loss @weighted_loss def l1_loss(pred, target): assert pred.size() == target.size() and target.numel() > 0 loss = torch.abs(pred - target) return loss @LOSSES.register_module() class L1Loss(nn.Module): def __init__(self, reduction='mean', loss_weight=1.0): super(L1Loss, self).__init__() self.reduction = reduction self.loss_weight = loss_weight def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss = self.loss_weight * l1_loss( pred, target, weight, reduction=reduction, avg_factor=avg_factor) return loss ``` 2. 在文件 `mmcls/models/losses/__init__.py` 中导入这个模块 ```python ... from .l1_loss import L1Loss, l1_loss __all__ = [ ..., 'L1Loss', 'l1_loss' ] ``` 3. 修改配置文件中的 `loss` 字段以使用新的损失函数 ```python loss=dict(type='L1Loss', loss_weight=1.0)) ```