mirror of https://github.com/hero-y/BHRL
57 lines
2.1 KiB
Python
57 lines
2.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmcv.cnn import xavier_init
|
|
import mmcv
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import cv2
|
|
|
|
class IHR(nn.Module):
|
|
|
|
def __init__(self,
|
|
metric_module):
|
|
super(IHR, self).__init__()
|
|
|
|
in_channel = metric_module['metric_module_in_channel']
|
|
out_channel = metric_module['metric_module_out_channel']
|
|
|
|
self.key_t = nn.Conv2d(in_channel, in_channel//8, kernel_size=(3,3), padding=(1,1), stride=1)
|
|
self.key_q = nn.Conv2d(in_channel, in_channel//8, kernel_size=(3,3), padding=(1,1), stride=1)
|
|
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.metric = nn.Conv2d(in_channel*2, out_channel, 1, 1)
|
|
self.conv1 = nn.Conv2d(in_channel, in_channel//2, 1, 1)
|
|
|
|
xavier_init(self.key_t, distribution='uniform')
|
|
xavier_init(self.key_q, distribution='uniform')
|
|
xavier_init(self.metric, distribution='uniform')
|
|
xavier_init(self.conv1, distribution='uniform')
|
|
|
|
|
|
def forward(self, x, ref):
|
|
|
|
bs = ref.shape[0]
|
|
ref_h, ref_w = ref.shape[2:]
|
|
x_h, x_w = x.shape[2:]
|
|
|
|
contrastive_feat = (self.avg(ref) - x).abs()
|
|
contrastive_feat = self.conv1(contrastive_feat)
|
|
|
|
salient_feat = F.conv2d(x.reshape(1,-1,x.shape[2],x.shape[3]), self.avg(ref).reshape(1,-1,self.avg(ref).shape[2],self.avg(ref).shape[3]).permute(1,0,2,3), groups=x.shape[0]*x.shape[1])
|
|
salient_feat = salient_feat.reshape(x.shape[0],x.shape[1],x.shape[2],x.shape[3])
|
|
salient_feat = self.conv1(salient_feat)
|
|
|
|
key_q=self.key_q(ref)
|
|
key_t=self.key_t(x)
|
|
p = torch.matmul(key_t.view(bs,32,-1).permute(0,2,1),key_q.view(bs,32,-1))
|
|
p = F.softmax(p,dim=1)
|
|
val_t_out = torch.matmul(ref.view(bs,256,-1),p).view(bs,256,ref_h,ref_w)
|
|
attention_feat = (val_t_out - x).abs()
|
|
attention_feat = self.conv1(attention_feat)
|
|
|
|
contrastive_and_salient_feat= self.conv1(torch.cat([salient_feat,contrastive_feat],1))
|
|
output = torch.cat([contrastive_and_salient_feat, attention_feat, x], 1)
|
|
output = self.metric(output)
|
|
|
|
return output
|