mirror of https://github.com/JDAI-CV/DCL.git
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
|
from torch import nn
|
||
|
import torch
|
||
|
from torchvision import models, transforms, datasets
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
class resnet_swap_2loss_add(nn.Module):
|
||
|
def __init__(self, num_classes):
|
||
|
super(resnet_swap_2loss_add,self).__init__()
|
||
|
resnet50 = models.resnet50(pretrained=True)
|
||
|
self.stage1_img = nn.Sequential(*list(resnet50.children())[:5])
|
||
|
self.stage2_img = nn.Sequential(*list(resnet50.children())[5:6])
|
||
|
self.stage3_img = nn.Sequential(*list(resnet50.children())[6:7])
|
||
|
self.stage4_img = nn.Sequential(*list(resnet50.children())[7])
|
||
|
|
||
|
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
|
||
|
self.classifier = nn.Linear(2048, num_classes)
|
||
|
self.classifier_swap = nn.Linear(2048, 2*num_classes)
|
||
|
# self.classifier_swap = nn.Linear(2048, 2)
|
||
|
self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=False)
|
||
|
self.avgpool2 = nn.AvgPool2d(2,stride=2)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x2 = self.stage1_img(x)
|
||
|
x3 = self.stage2_img(x2)
|
||
|
x4 = self.stage3_img(x3)
|
||
|
x5 = self.stage4_img(x4)
|
||
|
|
||
|
x = x5
|
||
|
mask = self.Convmask(x)
|
||
|
mask = self.avgpool2(mask)
|
||
|
mask = F.tanh(mask)
|
||
|
mask = mask.view(mask.size(0),-1)
|
||
|
|
||
|
x = self.avgpool(x)
|
||
|
x = x.view(x.size(0), -1)
|
||
|
out = []
|
||
|
out.append(self.classifier(x))
|
||
|
out.append(self.classifier_swap(x))
|
||
|
out.append(mask)
|
||
|
|
||
|
return out
|