mmselfsup/openselfsup/models/heads/contrastive_head.py
2020-09-02 18:49:39 +08:00

39 lines
1.0 KiB
Python

import torch
import torch.nn as nn
from ..registry import HEADS
@HEADS.register_module
class ContrastiveHead(nn.Module):
"""Head for contrastive learning.
Args:
temperature (float): The temperature hyper-parameter that
controls the concentration level of the distribution.
Default: 0.1.
"""
def __init__(self, temperature=0.1):
super(ContrastiveHead, self).__init__()
self.criterion = nn.CrossEntropyLoss()
self.temperature = temperature
def forward(self, pos, neg):
"""Forward head.
Args:
pos (Tensor): Nx1 positive similarity.
neg (Tensor): Nxk negative similarity.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
N = pos.size(0)
logits = torch.cat((pos, neg), dim=1)
logits /= self.temperature
labels = torch.zeros((N, ), dtype=torch.long).cuda()
losses = dict()
losses['loss'] = self.criterion(logits, labels)
return losses