dinov2/distillation/models/stdc.py

273 lines
8.7 KiB
Python

import math
import torch
import torch.nn as nn
from torch.nn import init
class ConvX(nn.Module):
def __init__(self, in_planes, out_planes, kernel=3, stride=1):
super().__init__()
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel,
stride=stride,
padding=kernel // 2,
bias=False,
)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.relu(self.bn(self.conv(x)))
return out
class AddBottleneck(nn.Module):
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
super().__init__()
assert block_num > 1, print("block number should be larger than 1.")
self.conv_list = nn.ModuleList()
self.stride = stride
if stride == 2:
self.avd_layer = nn.Sequential(
nn.Conv2d(
out_planes // 2,
out_planes // 2,
kernel_size=3,
stride=2,
padding=1,
groups=out_planes // 2,
bias=False,
),
nn.BatchNorm2d(out_planes // 2),
)
self.skip = nn.Sequential(
nn.Conv2d(
in_planes,
in_planes,
kernel_size=3,
stride=2,
padding=1,
groups=in_planes,
bias=False,
),
nn.BatchNorm2d(in_planes),
nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
nn.BatchNorm2d(out_planes),
)
stride = 1
for idx in range(block_num):
if idx == 0:
self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
elif idx == 1 and block_num == 2:
self.conv_list.append(
ConvX(out_planes // 2, out_planes // 2, stride=stride)
)
elif idx == 1 and block_num > 2:
self.conv_list.append(
ConvX(out_planes // 2, out_planes // 4, stride=stride)
)
elif idx < block_num - 1:
self.conv_list.append(
ConvX(
out_planes // int(math.pow(2, idx)),
out_planes // int(math.pow(2, idx + 1)),
)
)
else:
self.conv_list.append(
ConvX(
out_planes // int(math.pow(2, idx)),
out_planes // int(math.pow(2, idx)),
)
)
def forward(self, x):
out_list = []
out = x
for idx, conv in enumerate(self.conv_list):
if idx == 0 and self.stride == 2:
out = self.avd_layer(conv(out))
else:
out = conv(out)
out_list.append(out)
if self.stride == 2:
x = self.skip(x)
return torch.cat(out_list, dim=1) + x
class CatBottleneck(nn.Module):
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
super().__init__()
assert block_num > 1, print("block number should be larger than 1.")
self.conv_list = nn.ModuleList()
self.stride = stride
if stride == 2:
self.avd_layer = nn.Sequential(
nn.Conv2d(
out_planes // 2,
out_planes // 2,
kernel_size=3,
stride=2,
padding=1,
groups=out_planes // 2,
bias=False,
),
nn.BatchNorm2d(out_planes // 2),
)
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
stride = 1
for idx in range(block_num):
if idx == 0:
self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
elif idx == 1 and block_num == 2:
self.conv_list.append(
ConvX(out_planes // 2, out_planes // 2, stride=stride)
)
elif idx == 1 and block_num > 2:
self.conv_list.append(
ConvX(out_planes // 2, out_planes // 4, stride=stride)
)
elif idx < block_num - 1:
self.conv_list.append(
ConvX(
out_planes // int(math.pow(2, idx)),
out_planes // int(math.pow(2, idx + 1)),
)
)
else:
self.conv_list.append(
ConvX(
out_planes // int(math.pow(2, idx)),
out_planes // int(math.pow(2, idx)),
)
)
def forward(self, x):
out_list = []
out1 = self.conv_list[0](x)
for idx, conv in enumerate(self.conv_list[1:]):
if idx == 0:
if self.stride == 2:
out = conv(self.avd_layer(out1))
else:
out = conv(out1)
else:
out = conv(out)
out_list.append(out)
if self.stride == 2:
out1 = self.skip(out1)
out_list.insert(0, out1)
out = torch.cat(out_list, dim=1)
return out
class STDCNet(nn.Module):
def __init__(
self,
base=64,
layers=[2, 2, 2],
block_num=4,
block_type="cat",
use_conv_last=False,
):
super().__init__()
if block_type == "cat":
block = CatBottleneck
elif block_type == "add":
block = AddBottleneck
self.use_conv_last = use_conv_last
self.features = self._make_layers(base, layers, block_num, block)
if layers != [2, 2, 2] and layers != [4, 5, 3]:
layers = [4, 5, 3]
if layers == [2, 2, 2]:
self.x2 = nn.Sequential(self.features[:1])
self.x4 = nn.Sequential(self.features[1:2])
self.x8 = nn.Sequential(self.features[2:4])
self.x16 = nn.Sequential(self.features[4:6])
self.x32 = nn.Sequential(self.features[6:])
elif layers == [4, 5, 3]:
self.x2 = nn.Sequential(self.features[:1])
self.x4 = nn.Sequential(self.features[1:2])
self.x8 = nn.Sequential(self.features[2:6])
self.x16 = nn.Sequential(self.features[6:11])
self.x32 = nn.Sequential(self.features[11:])
if self.use_conv_last:
self.conv_last = ConvX(base * 16, max(1024, base * 16), 1, 1)
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def _make_layers(self, base, layers, block_num, block):
features = []
features += [ConvX(3, base // 2, 3, 2)]
features += [ConvX(base // 2, base, 3, 2)]
for i, layer in enumerate(layers):
for j in range(layer):
if i == 0 and j == 0:
features.append(block(base, base * 4, block_num, 2))
elif j == 0:
features.append(
block(
base * int(math.pow(2, i + 1)),
base * int(math.pow(2, i + 2)),
block_num,
2,
)
)
else:
features.append(
block(
base * int(math.pow(2, i + 2)),
base * int(math.pow(2, i + 2)),
block_num,
1,
)
)
return nn.Sequential(*features)
def forward(self, x):
outs = {}
feat2 = self.x2(x)
feat4 = self.x4(feat2)
outs["res2"] = feat4
feat8 = self.x8(feat4)
outs["res3"] = feat8
feat16 = self.x16(feat8)
outs["res4"] = feat16
feat32 = self.x32(feat16)
outs["res5"] = feat32
if self.use_conv_last:
feat32 = self.conv_last(feat32)
outs["res5"] = feat32
return outs