From db37310f79cada379d40c6e157c2ae5c7f5c7720 Mon Sep 17 00:00:00 2001 From: CaoGang2018 <996389570@qq.com> Date: Sun, 31 May 2020 10:27:29 +0800 Subject: [PATCH] GMP --- util/model.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/util/model.py b/util/model.py index 71502d0..8316080 100644 --- a/util/model.py +++ b/util/model.py @@ -6,16 +6,39 @@ import torch.nn.functional as F class pool_model(nn.Module): def __init__(self): super(pool_model, self).__init__() - self.pool1 = nn.AdaptiveAvgPool2d(1) + # self.pool1 = nn.AdaptiveAvgPool2d(1) self.pool2 = nn.AdaptiveMaxPool2d(1) + - def forward(self, x): - tmp1 = self.pool1(x) + def ave_pool(self, x, cc): + b, c, h, w = x.shape + men_pool = t.zeros(b, c) + for i in range(b): + count = 0 + tmp = x[i,:, 0, 0] + # exit() + # print(tmp.shape) + for m in range(h): + for n in range(w): + if cc[i][m][n]: + tmp += x[i,:, m, n] + count += 1 + if count == 0: + men_pool[i] = tmp + else: + men_pool[i] = tmp / count + # print(count) + return men_pool + + def forward(self, x, cc): + + tmp1 = self.ave_pool(x, cc) tmp2 = self.pool2(x) - b, c, _, _ = tmp1.shape + b, c, _, _ = tmp2.shape # print(tmp1.shape) + # print(tmp2.shape) - return t.cat((tmp1, tmp2), dim=1).reshape(b, -1) + return t.cat((tmp1, tmp2.reshape(b, c)), dim=1).reshape(b, -1) # x = t.ones(2, 512, 7, 7) # model = pool_model()