mirror of https://github.com/JDAI-CV/DCL.git
81 lines
3.0 KiB
Python
81 lines
3.0 KiB
Python
import numpy as np
|
|
from torch import nn
|
|
import torch
|
|
from torchvision import models, transforms, datasets
|
|
import torch.nn.functional as F
|
|
import pretrainedmodels
|
|
|
|
from config import pretrained_model
|
|
|
|
import pdb
|
|
|
|
class MainModel(nn.Module):
|
|
def __init__(self, config):
|
|
super(MainModel, self).__init__()
|
|
self.use_dcl = config.use_dcl
|
|
self.num_classes = config.numcls
|
|
self.backbone_arch = config.backbone
|
|
self.use_Asoftmax = config.use_Asoftmax
|
|
print(self.backbone_arch)
|
|
|
|
if self.backbone_arch in dir(models):
|
|
self.model = getattr(models, self.backbone_arch)()
|
|
if self.backbone_arch in pretrained_model:
|
|
self.model.load_state_dict(torch.load(pretrained_model[self.backbone_arch]))
|
|
else:
|
|
if self.backbone_arch in pretrained_model:
|
|
self.model = pretrainedmodels.__dict__[self.backbone_arch](num_classes=1000, pretrained=None)
|
|
else:
|
|
self.model = pretrainedmodels.__dict__[self.backbone_arch](num_classes=1000)
|
|
|
|
if self.backbone_arch == 'resnet50' or self.backbone_arch == 'se_resnet50':
|
|
self.model = nn.Sequential(*list(self.model.children())[:-2])
|
|
if self.backbone_arch == 'senet154':
|
|
self.model = nn.Sequential(*list(self.model.children())[:-3])
|
|
if self.backbone_arch == 'se_resnext101_32x4d':
|
|
self.model = nn.Sequential(*list(self.model.children())[:-2])
|
|
if self.backbone_arch == 'se_resnet101':
|
|
self.model = nn.Sequential(*list(self.model.children())[:-2])
|
|
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
|
|
self.classifier = nn.Linear(2048, self.num_classes, bias=False)
|
|
|
|
if self.use_dcl:
|
|
if config.cls_2:
|
|
self.classifier_swap = nn.Linear(2048, 2, bias=False)
|
|
if config.cls_2xmul:
|
|
self.classifier_swap = nn.Linear(2048, 2*self.num_classes, bias=False)
|
|
self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=True)
|
|
self.avgpool2 = nn.AvgPool2d(2, stride=2)
|
|
|
|
if self.use_Asoftmax:
|
|
self.Aclassifier = AngleLinear(2048, self.num_classes, bias=False)
|
|
|
|
def forward(self, x, last_cont=None):
|
|
x = self.model(x)
|
|
if self.use_dcl:
|
|
mask = self.Convmask(x)
|
|
mask = self.avgpool2(mask)
|
|
mask = torch.tanh(mask)
|
|
mask = mask.view(mask.size(0), -1)
|
|
|
|
x = self.avgpool(x)
|
|
x = x.view(x.size(0), -1)
|
|
out = []
|
|
out.append(self.classifier(x))
|
|
|
|
if self.use_dcl:
|
|
out.append(self.classifier_swap(x))
|
|
out.append(mask)
|
|
|
|
if self.use_Asoftmax:
|
|
if last_cont is None:
|
|
x_size = x.size(0)
|
|
out.append(self.Aclassifier(x[0:x_size:2]))
|
|
else:
|
|
last_x = self.model(last_cont)
|
|
last_x = self.avgpool(last_x)
|
|
last_x = last_x.view(last_x.size(0), -1)
|
|
out.append(self.Aclassifier(last_x))
|
|
|
|
return out
|