SCDA_pytorch/util/model.py
CaoGang2018 db37310f79 GMP
2020-05-31 10:27:29 +08:00

61 lines
1.7 KiB
Python

import torch as t
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
class pool_model(nn.Module):
def __init__(self):
super(pool_model, self).__init__()
# self.pool1 = nn.AdaptiveAvgPool2d(1)
self.pool2 = nn.AdaptiveMaxPool2d(1)
def ave_pool(self, x, cc):
b, c, h, w = x.shape
men_pool = t.zeros(b, c)
for i in range(b):
count = 0
tmp = x[i,:, 0, 0]
# exit()
# print(tmp.shape)
for m in range(h):
for n in range(w):
if cc[i][m][n]:
tmp += x[i,:, m, n]
count += 1
if count == 0:
men_pool[i] = tmp
else:
men_pool[i] = tmp / count
# print(count)
return men_pool
def forward(self, x, cc):
tmp1 = self.ave_pool(x, cc)
tmp2 = self.pool2(x)
b, c, _, _ = tmp2.shape
# print(tmp1.shape)
# print(tmp2.shape)
return t.cat((tmp1, tmp2.reshape(b, c)), dim=1).reshape(b, -1)
# x = t.ones(2, 512, 7, 7)
# model = pool_model()
# print(model(x).shape)
# class Net(nn.modules):
# def __init__(self):
# super(Net, self).__init__()
# self.basenet = models.vgg16(pretrained=True).features
# # print(self.basenet)
# def forward(self, x):
# x = self.basenet(x)
# return x
# net1 = models.vgg16(pretrained=False).features[:-3]
# net2 = models.vgg16(pretrained=False).features
# x = t.ones(2, 3, 224, 224)
# print(net1(x).shape)
# x1 = F.interpolate(net2(x), size=net1(x).shape[-2:], mode='nearest')
# print(x1.shape)