mirror of
https://github.com/CaoGang2018/SCDA_pytorch.git
synced 2025-06-03 14:59:31 +08:00
GMP
This commit is contained in:
parent
b38eccc3b2
commit
db37310f79
@ -6,16 +6,39 @@ import torch.nn.functional as F
|
|||||||
class pool_model(nn.Module):
|
class pool_model(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(pool_model, self).__init__()
|
super(pool_model, self).__init__()
|
||||||
self.pool1 = nn.AdaptiveAvgPool2d(1)
|
# self.pool1 = nn.AdaptiveAvgPool2d(1)
|
||||||
self.pool2 = nn.AdaptiveMaxPool2d(1)
|
self.pool2 = nn.AdaptiveMaxPool2d(1)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def ave_pool(self, x, cc):
|
||||||
tmp1 = self.pool1(x)
|
b, c, h, w = x.shape
|
||||||
|
men_pool = t.zeros(b, c)
|
||||||
|
for i in range(b):
|
||||||
|
count = 0
|
||||||
|
tmp = x[i,:, 0, 0]
|
||||||
|
# exit()
|
||||||
|
# print(tmp.shape)
|
||||||
|
for m in range(h):
|
||||||
|
for n in range(w):
|
||||||
|
if cc[i][m][n]:
|
||||||
|
tmp += x[i,:, m, n]
|
||||||
|
count += 1
|
||||||
|
if count == 0:
|
||||||
|
men_pool[i] = tmp
|
||||||
|
else:
|
||||||
|
men_pool[i] = tmp / count
|
||||||
|
# print(count)
|
||||||
|
return men_pool
|
||||||
|
|
||||||
|
def forward(self, x, cc):
|
||||||
|
|
||||||
|
tmp1 = self.ave_pool(x, cc)
|
||||||
tmp2 = self.pool2(x)
|
tmp2 = self.pool2(x)
|
||||||
b, c, _, _ = tmp1.shape
|
b, c, _, _ = tmp2.shape
|
||||||
# print(tmp1.shape)
|
# print(tmp1.shape)
|
||||||
|
# print(tmp2.shape)
|
||||||
|
|
||||||
return t.cat((tmp1, tmp2), dim=1).reshape(b, -1)
|
return t.cat((tmp1, tmp2.reshape(b, c)), dim=1).reshape(b, -1)
|
||||||
|
|
||||||
# x = t.ones(2, 512, 7, 7)
|
# x = t.ones(2, 512, 7, 7)
|
||||||
# model = pool_model()
|
# model = pool_model()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user