87 lines
3.4 KiB
Python
87 lines
3.4 KiB
Python
from __future__ import absolute_import
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
import torchvision
|
|
|
|
__all__ = ['DaRe']
|
|
|
|
class DaRe(nn.Module):
|
|
def __init__(self, num_classes=0, loss={'xent'}, w_init=0.1, **kwargs):
|
|
super(DaRe, self).__init__()
|
|
self.loss = loss
|
|
resnet50 = torchvision.models.resnet50(pretrained=True)
|
|
base = nn.Sequential(*list(resnet50.children())[:-2])
|
|
|
|
# construct four convolutional blocks
|
|
self.conv_block1 = nn.Sequential(
|
|
base[0], base[1], base[2], base[3],
|
|
base[4][0], base[4][1],
|
|
base[4][2].conv1, base[4][2].bn1, base[4][2].conv2, base[4][2].bn2,
|
|
)
|
|
self.linears1 = nn.Sequential(nn.Linear(64, 1204), nn.BatchNorm1d(1204), nn.ReLU(), nn.Linear(1204, 128))
|
|
|
|
self.conv_block2 = nn.Sequential(
|
|
base[4][2].conv3, base[4][2].bn3, base[4][2].relu,
|
|
base[5][0], base[5][1], base[5][2],
|
|
base[5][3].conv1, base[5][3].bn1, base[5][3].conv2, base[5][3].bn2,
|
|
)
|
|
self.linears2 = nn.Sequential(nn.Linear(128, 1204), nn.BatchNorm1d(1204), nn.ReLU(), nn.Linear(1204, 128))
|
|
|
|
self.conv_block3 = nn.Sequential(
|
|
base[5][3].conv3, base[5][3].bn3, base[5][3].relu,
|
|
base[6][0], base[6][1], base[6][2], base[6][3], base[6][4],
|
|
base[6][5].conv1, base[6][5].bn1, base[6][5].conv2, base[6][5].bn2,
|
|
)
|
|
self.linears3 = nn.Sequential(nn.Linear(256, 1204), nn.BatchNorm1d(1204), nn.ReLU(), nn.Linear(1204, 128))
|
|
|
|
self.conv_block4 = nn.Sequential(
|
|
base[6][5].conv3, base[6][5].bn3, base[6][5].relu,
|
|
base[7][0], base[7][1],
|
|
base[7][2].conv1, base[7][2].bn1, base[7][2].conv2, base[7][2].bn2,
|
|
)
|
|
self.linears4 = nn.Sequential(nn.Linear(512, 1204), nn.BatchNorm1d(1204), nn.ReLU(), nn.Linear(1204, 128))
|
|
|
|
# fusion weights for four stages
|
|
self.w1 = nn.Parameter(torch.ones(1) * w_init)
|
|
self.w2 = nn.Parameter(torch.ones(1) * w_init)
|
|
self.w3 = nn.Parameter(torch.ones(1) * w_init)
|
|
self.w4 = nn.Parameter(torch.ones(1) * w_init)
|
|
|
|
self.classifier = nn.Linear(128, num_classes)
|
|
self.feat_dim = 128 # feature dimension
|
|
|
|
def forward(self, x):
|
|
x1 = self.conv_block1(x)
|
|
x1_feat = F.avg_pool2d(x1, x1.size()[2:]).view(x1.size(0), x1.size(1))
|
|
x1_feat = self.linears1(x1_feat)
|
|
|
|
x2 = self.conv_block2(x1)
|
|
x2_feat = F.avg_pool2d(x2, x2.size()[2:]).view(x2.size(0), x2.size(1))
|
|
x2_feat = self.linears2(x2_feat)
|
|
|
|
x3 = self.conv_block3(x2)
|
|
x3_feat = F.avg_pool2d(x3, x3.size()[2:]).view(x3.size(0), x3.size(1))
|
|
x3_feat = self.linears3(x3_feat)
|
|
|
|
x4 = self.conv_block4(x3)
|
|
x4_feat = F.avg_pool2d(x4, x4.size()[2:]).view(x4.size(0), x4.size(1))
|
|
x4_feat = self.linears4(x4_feat)
|
|
|
|
fusion_feat = x1_feat * self.w1 + x2_feat * self.w2 + x3_feat * self.w3 + x4_feat * self.w4
|
|
|
|
if not self.training:
|
|
return fusion_feat
|
|
|
|
prelogits = self.classifier(fusion_feat)
|
|
|
|
if self.loss == {'xent'}:
|
|
return prelogits
|
|
elif self.loss == {'xent', 'htri'}:
|
|
return prelogits, (fusion_feat, x1_feat, x2_feat, x3_feat, x4_feat)
|
|
elif self.loss == {'cent'}:
|
|
return prelogits, fusion_feat
|
|
else:
|
|
raise KeyError("Unsupported loss: {}".format(self.loss))
|