pull/150/head
Lingxiao He 2020-06-10 17:43:56 +08:00 committed by GitHub
parent 27c48c8f02
commit fa0728a0c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -35,8 +35,9 @@ class OcclusionUnit(nn.Module):
SpatialFeatAll = SpatialFeatAll.transpose(1, 2) # shape: [n, c, m]
y = self.mask_layer(SpatialFeatAll)
mask_weight = torch.sigmoid(y[:, :, 0])
mask_score = F.normalize(mask_weight[:, :48], p=1, dim=1)
feat_dim = SpaFeat1.size(2) * SpaFeat1.size(3)
mask_score = F.normalize(mask_weight[:, :feat_dim], p=1, dim=1)
mask_weight_norm = F.normalize(mask_weight, p=1, dim=1)
mask_score = mask_score.unsqueeze(1)