82 lines
2.5 KiB
Python
Raw Normal View History

[Feature] Support LeViT backbone. (#1238) * 网络搭建完成、能正常推理 * 网络搭建完成、能正常推理 * 网络搭建完成、能正常推理 * 添加了模型转换未验证,配置文件 但有无法运行 * 模型转换、结构验证完成,可以推理出正确答案 * 推理精度与原论文一致 已完成转化 * 三个方法改为class 暂存 * 完成推理精度对齐 误差0.04 * 暂时使用的levit2mmcls * 训练跑通,训练相关参数未对齐 * '训练相关参数对齐'参数' * '修复训练时验证导致模型结构改变无法复原问题' * '修复训练时验证导致模型结构改变无法复原问题' * '添加mixup和labelsmooth' * '配置文件补齐' * 添加模型转换 * 添加meta文件 * 添加meta文件 * 删除demo.py测试文件 * 添加模型README文件 * docs文件回滚 * model-index删除末行空格 * 更新模型metafile * 更新metafile * 更新metafile * 更新README和metafile * 更新模型README * 更新模型metafile * Delete the model class and get_LeViT_model methods in the mmcls.models.backone.levit file * Change the class name to Google Code Style * use arch to provide default architectures * use nn.Conv2d * mmcv.cnn.fuse_conv_bn * modify some details * remove down_ops from the architectures. * remove init_weight function * Modify ambiguous variable names * Change the drop_path in config to drop_path_rate * Add unit test * remove train function * add unit test * modify nn.norm1d to build_norm_layer * update metafile and readme * Update configs and LeViT implementations. * Update README. * Add docstring and update unit tests. * Revert irrelative modification. * Fix unit tests * minor fix Co-authored-by: mzr1996 <mzr1996@163.com>
2023-01-17 17:43:42 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmpretrain.models.heads import ClsHead
from mmpretrain.registry import MODELS
[Feature] Support LeViT backbone. (#1238) * 网络搭建完成、能正常推理 * 网络搭建完成、能正常推理 * 网络搭建完成、能正常推理 * 添加了模型转换未验证,配置文件 但有无法运行 * 模型转换、结构验证完成,可以推理出正确答案 * 推理精度与原论文一致 已完成转化 * 三个方法改为class 暂存 * 完成推理精度对齐 误差0.04 * 暂时使用的levit2mmcls * 训练跑通,训练相关参数未对齐 * '训练相关参数对齐'参数' * '修复训练时验证导致模型结构改变无法复原问题' * '修复训练时验证导致模型结构改变无法复原问题' * '添加mixup和labelsmooth' * '配置文件补齐' * 添加模型转换 * 添加meta文件 * 添加meta文件 * 删除demo.py测试文件 * 添加模型README文件 * docs文件回滚 * model-index删除末行空格 * 更新模型metafile * 更新metafile * 更新metafile * 更新README和metafile * 更新模型README * 更新模型metafile * Delete the model class and get_LeViT_model methods in the mmcls.models.backone.levit file * Change the class name to Google Code Style * use arch to provide default architectures * use nn.Conv2d * mmcv.cnn.fuse_conv_bn * modify some details * remove down_ops from the architectures. * remove init_weight function * Modify ambiguous variable names * Change the drop_path in config to drop_path_rate * Add unit test * remove train function * add unit test * modify nn.norm1d to build_norm_layer * update metafile and readme * Update configs and LeViT implementations. * Update README. * Add docstring and update unit tests. * Revert irrelative modification. * Fix unit tests * minor fix Co-authored-by: mzr1996 <mzr1996@163.com>
2023-01-17 17:43:42 +08:00
from ..utils import build_norm_layer
class BatchNormLinear(BaseModule):
def __init__(self, in_channels, out_channels, norm_cfg=dict(type='BN1d')):
super(BatchNormLinear, self).__init__()
self.bn = build_norm_layer(norm_cfg, in_channels)
self.linear = nn.Linear(in_channels, out_channels)
@torch.no_grad()
def fuse(self):
w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5
b = self.bn.bias - self.bn.running_mean * \
self.bn.weight / (self.bn.running_var + self.bn.eps) ** 0.5
w = self.linear.weight * w[None, :]
b = (self.linear.weight @ b[:, None]).view(-1) + self.linear.bias
self.linear.weight.data.copy_(w)
self.linear.bias.data.copy_(b)
return self.linear
def forward(self, x):
x = self.bn(x)
x = self.linear(x)
return x
def fuse_parameters(module):
for child_name, child in module.named_children():
if hasattr(child, 'fuse'):
setattr(module, child_name, child.fuse())
else:
fuse_parameters(child)
@MODELS.register_module()
class LeViTClsHead(ClsHead):
def __init__(self,
num_classes=1000,
distillation=True,
in_channels=None,
deploy=False,
**kwargs):
super(LeViTClsHead, self).__init__(**kwargs)
self.num_classes = num_classes
self.distillation = distillation
self.deploy = deploy
self.head = BatchNormLinear(in_channels, num_classes)
if distillation:
self.head_dist = BatchNormLinear(in_channels, num_classes)
if self.deploy:
self.switch_to_deploy(self)
def switch_to_deploy(self):
if self.deploy:
return
fuse_parameters(self)
self.deploy = True
def forward(self, x):
x = self.pre_logits(x)
if self.distillation:
x = self.head(x), self.head_dist(x) # 2 16 384 -> 2 1000
if not self.training:
x = (x[0] + x[1]) / 2
else:
raise NotImplementedError("MMPretrain doesn't support "
[Feature] Support LeViT backbone. (#1238) * 网络搭建完成、能正常推理 * 网络搭建完成、能正常推理 * 网络搭建完成、能正常推理 * 添加了模型转换未验证,配置文件 但有无法运行 * 模型转换、结构验证完成,可以推理出正确答案 * 推理精度与原论文一致 已完成转化 * 三个方法改为class 暂存 * 完成推理精度对齐 误差0.04 * 暂时使用的levit2mmcls * 训练跑通,训练相关参数未对齐 * '训练相关参数对齐'参数' * '修复训练时验证导致模型结构改变无法复原问题' * '修复训练时验证导致模型结构改变无法复原问题' * '添加mixup和labelsmooth' * '配置文件补齐' * 添加模型转换 * 添加meta文件 * 添加meta文件 * 删除demo.py测试文件 * 添加模型README文件 * docs文件回滚 * model-index删除末行空格 * 更新模型metafile * 更新metafile * 更新metafile * 更新README和metafile * 更新模型README * 更新模型metafile * Delete the model class and get_LeViT_model methods in the mmcls.models.backone.levit file * Change the class name to Google Code Style * use arch to provide default architectures * use nn.Conv2d * mmcv.cnn.fuse_conv_bn * modify some details * remove down_ops from the architectures. * remove init_weight function * Modify ambiguous variable names * Change the drop_path in config to drop_path_rate * Add unit test * remove train function * add unit test * modify nn.norm1d to build_norm_layer * update metafile and readme * Update configs and LeViT implementations. * Update README. * Add docstring and update unit tests. * Revert irrelative modification. * Fix unit tests * minor fix Co-authored-by: mzr1996 <mzr1996@163.com>
2023-01-17 17:43:42 +08:00
'training in distillation mode.')
else:
x = self.head(x)
return x