mmselfsup/openselfsup/models/deepcluster.py

130 lines
4.3 KiB
Python

import numpy as np
import torch
import torch.nn as nn
from openselfsup.utils import print_log
from . import builder
from .registry import MODELS
from .utils import Sobel
@MODELS.register_module
class DeepCluster(nn.Module):
"""DeepCluster.
Implementation of "Deep Clustering for Unsupervised Learning
of Visual Features (https://arxiv.org/abs/1807.05520)".
Args:
backbone (nn.Module): Module of backbone ConvNet.
with_sobel (bool): Whether to apply a Sobel filter on images. Default: False.
neck (nn.Module): Module of deep features to compact feature vectors.
head (nn.Module): Module of loss functions.
pretrained (str, optional): Path to pre-trained weights. Default: None.
"""
def __init__(self,
backbone,
with_sobel=False,
neck=None,
head=None,
pretrained=None):
super(DeepCluster, self).__init__()
self.with_sobel = with_sobel
if with_sobel:
self.sobel_layer = Sobel()
self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck)
if head is not None:
self.head = builder.build_head(head)
self.init_weights(pretrained=pretrained)
# reweight
self.num_classes = head.num_classes
self.loss_weight = torch.ones((self.num_classes, ),
dtype=torch.float32).cuda()
self.loss_weight /= self.loss_weight.sum()
def init_weights(self, pretrained=None):
"""Initialize the weights of model.
Args:
pretrained (str, optional): Path to pre-trained weights.
Default: None.
"""
if pretrained is not None:
print_log('load model from: {}'.format(pretrained), logger='root')
self.backbone.init_weights(pretrained=pretrained)
self.neck.init_weights(init_linear='kaiming')
self.head.init_weights(init_linear='normal')
def forward_backbone(self, img):
"""Forward backbone.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
Returns:
tuple[Tensor]: backbone outputs.
"""
if self.with_sobel:
img = self.sobel_layer(img)
x = self.backbone(img)
return x
def forward_train(self, img, pseudo_label, **kwargs):
"""Forward computation during training.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
pseudo_label (Tensor): Label assignments.
kwargs: Any keyword arguments to be used to forward.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
x = self.forward_backbone(img)
assert len(x) == 1
feature = self.neck(x)
outs = self.head(feature)
loss_inputs = (outs, pseudo_label)
losses = self.head.loss(*loss_inputs)
return losses
def forward_test(self, img, **kwargs):
x = self.forward_backbone(img) # tuple
outs = self.head(x)
keys = ['head{}'.format(i) for i in range(len(outs))]
out_tensors = [out.cpu() for out in outs] # NxC
return dict(zip(keys, out_tensors))
def forward(self, img, mode='train', **kwargs):
if mode == 'train':
return self.forward_train(img, **kwargs)
elif mode == 'test':
return self.forward_test(img, **kwargs)
elif mode == 'extract':
return self.forward_backbone(img)
else:
raise Exception("No such mode: {}".format(mode))
def set_reweight(self, labels, reweight_pow=0.5):
"""Loss re-weighting.
Re-weighting the loss according to the number of samples in each class.
Args:
labels (numpy.ndarray): Label assignments.
reweight_pow (float): The power of re-weighting. Default: 0.5.
"""
hist = np.bincount(
labels, minlength=self.num_classes).astype(np.float32)
inv_hist = (1. / (hist + 1e-10))**reweight_pow
weight = inv_hist / inv_hist.sum()
self.loss_weight.copy_(torch.from_numpy(weight))
self.head.criterion = nn.CrossEntropyLoss(weight=self.loss_weight)