mirror of https://github.com/alibaba/EasyCV.git
63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import normal_init
|
|
|
|
from .. import builder
|
|
from ..registry import HEADS
|
|
|
|
|
|
@HEADS.register_module
|
|
class LatentPredictHead(nn.Module):
|
|
'''Head for contrastive learning.
|
|
'''
|
|
|
|
def __init__(self, predictor, size_average=True):
|
|
super(LatentPredictHead, self).__init__()
|
|
self.predictor = builder.build_neck(predictor)
|
|
self.size_average = size_average
|
|
|
|
def init_weights(self, init_linear='normal'):
|
|
self.predictor.init_weights(init_linear=init_linear)
|
|
|
|
def forward(self, input, target):
|
|
'''
|
|
Args:
|
|
input (Tensor): NxC input features.
|
|
target (Tensor): NxC target features.
|
|
'''
|
|
pred = self.predictor([input])[0]
|
|
pred_norm = nn.functional.normalize(pred, dim=1)
|
|
target_norm = nn.functional.normalize(target, dim=1)
|
|
loss = -2 * (pred_norm * target_norm).sum()
|
|
if self.size_average:
|
|
loss /= input.size(0)
|
|
return dict(loss=loss)
|
|
|
|
|
|
@HEADS.register_module
|
|
class LatentClsHead(nn.Module):
|
|
'''Head for contrastive learning.
|
|
'''
|
|
|
|
def __init__(self, predictor):
|
|
super(LatentClsHead, self).__init__()
|
|
self.predictor = nn.Linear(predictor.in_channels,
|
|
predictor.num_classes)
|
|
self.criterion = nn.CrossEntropyLoss()
|
|
|
|
def init_weights(self, init_linear='normal'):
|
|
normal_init(self.predictor, std=0.01)
|
|
|
|
def forward(self, input, target):
|
|
'''
|
|
Args:
|
|
input (Tensor): NxC input features.
|
|
target (Tensor): NxC target features.
|
|
'''
|
|
pred = self.predictor(input)
|
|
with torch.no_grad():
|
|
label = torch.argmax(self.predictor(target), dim=1).detach()
|
|
loss = self.criterion(pred, label)
|
|
return dict(loss=loss)
|