mmselfsup/openselfsup/models/heads/latent_pred_head.py

59 lines
1.7 KiB
Python

import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from ..registry import HEADS
from .. import builder
@HEADS.register_module
class LatentPredictHead(nn.Module):
'''Head for contrastive learning.
'''
def __init__(self, predictor):
super(LatentPredictHead, self).__init__()
self.predictor = builder.build_neck(predictor)
def init_weights(self, init_linear='normal'):
self.predictor.init_weights(init_linear=init_linear)
def forward(self, input, target):
'''
Args:
input (Tensor): NxC input features.
target (Tensor): NxC target features.
'''
N = input.size(0)
pred = self.predictor([input])[0]
pred_norm = nn.functional.normalize(pred, dim=1)
target_norm = nn.functional.normalize(target, dim=1)
loss = 2 - 2 * (pred_norm * target_norm).sum() / N
return dict(loss=loss)
@HEADS.register_module
class LatentClsHead(nn.Module):
'''Head for contrastive learning.
'''
def __init__(self, predictor):
super(LatentClsHead, self).__init__()
self.predictor = nn.Linear(predictor.in_channels,
predictor.num_classes)
self.criterion = nn.CrossEntropyLoss()
def init_weights(self, init_linear='normal'):
normal_init(self.predictor, std=0.01)
def forward(self, input, target):
'''
Args:
input (Tensor): NxC input features.
target (Tensor): NxC target features.
'''
pred = self.predictor(input)
with torch.no_grad():
label = torch.argmax(self.predictor(target), dim=1).detach()
loss = self.criterion(pred, label)
return dict(loss=loss)