BHRL/mmdet/models/plugins/ihr.py

57 lines
2.1 KiB
Python

import torch
import torch.nn as nn
from mmcv.cnn import xavier_init
import mmcv
import torch.nn.functional as F
import numpy as np
import cv2
class IHR(nn.Module):
def __init__(self,
metric_module):
super(IHR, self).__init__()
in_channel = metric_module['metric_module_in_channel']
out_channel = metric_module['metric_module_out_channel']
self.key_t = nn.Conv2d(in_channel, in_channel//8, kernel_size=(3,3), padding=(1,1), stride=1)
self.key_q = nn.Conv2d(in_channel, in_channel//8, kernel_size=(3,3), padding=(1,1), stride=1)
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.metric = nn.Conv2d(in_channel*2, out_channel, 1, 1)
self.conv1 = nn.Conv2d(in_channel, in_channel//2, 1, 1)
xavier_init(self.key_t, distribution='uniform')
xavier_init(self.key_q, distribution='uniform')
xavier_init(self.metric, distribution='uniform')
xavier_init(self.conv1, distribution='uniform')
def forward(self, x, ref):
bs = ref.shape[0]
ref_h, ref_w = ref.shape[2:]
x_h, x_w = x.shape[2:]
contrastive_feat = (self.avg(ref) - x).abs()
contrastive_feat = self.conv1(contrastive_feat)
salient_feat = F.conv2d(x.reshape(1,-1,x.shape[2],x.shape[3]), self.avg(ref).reshape(1,-1,self.avg(ref).shape[2],self.avg(ref).shape[3]).permute(1,0,2,3), groups=x.shape[0]*x.shape[1])
salient_feat = salient_feat.reshape(x.shape[0],x.shape[1],x.shape[2],x.shape[3])
salient_feat = self.conv1(salient_feat)
key_q=self.key_q(ref)
key_t=self.key_t(x)
p = torch.matmul(key_t.view(bs,32,-1).permute(0,2,1),key_q.view(bs,32,-1))
p = F.softmax(p,dim=1)
val_t_out = torch.matmul(ref.view(bs,256,-1),p).view(bs,256,ref_h,ref_w)
attention_feat = (val_t_out - x).abs()
attention_feat = self.conv1(attention_feat)
contrastive_and_salient_feat= self.conv1(torch.cat([salient_feat,contrastive_feat],1))
output = torch.cat([contrastive_and_salient_feat, attention_feat, x], 1)
output = self.metric(output)
return output