104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from openselfsup.utils import print_log
|
|
from . import builder
|
|
from .registry import MODELS
|
|
from .utils import Sobel
|
|
|
|
|
|
@MODELS.register_module
|
|
class ODC(nn.Module):
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
with_sobel=False,
|
|
neck=None,
|
|
head=None,
|
|
memory_bank=None,
|
|
pretrained=None):
|
|
super(ODC, self).__init__()
|
|
self.with_sobel = with_sobel
|
|
if with_sobel:
|
|
self.sobel_layer = Sobel()
|
|
self.backbone = builder.build_backbone(backbone)
|
|
self.neck = builder.build_neck(neck)
|
|
if head is not None:
|
|
self.head = builder.build_head(head)
|
|
if memory_bank is not None:
|
|
self.memory_bank = builder.build_memory(memory_bank)
|
|
self.init_weights(pretrained=pretrained)
|
|
|
|
# set reweight tensors
|
|
self.num_classes = head.num_classes
|
|
self.loss_weight = torch.ones((self.num_classes, ),
|
|
dtype=torch.float32).cuda()
|
|
self.loss_weight /= self.loss_weight.sum()
|
|
|
|
def init_weights(self, pretrained=None):
|
|
if pretrained is not None:
|
|
print_log('load model from: {}'.format(pretrained), logger='root')
|
|
self.backbone.init_weights(pretrained=pretrained)
|
|
self.neck.init_weights(init_linear='kaiming')
|
|
self.head.init_weights(init_linear='normal')
|
|
|
|
def forward_backbone(self, img):
|
|
"""Forward backbone
|
|
|
|
Returns:
|
|
x (tuple): backbone outputs
|
|
"""
|
|
if self.with_sobel:
|
|
img = self.sobel_layer(img)
|
|
x = self.backbone(img)
|
|
return x
|
|
|
|
def forward_train(self, img, idx, **kwargs):
|
|
# forward & backward
|
|
x = self.forward_backbone(img)
|
|
feature = self.neck(x)
|
|
outs = self.head(feature)
|
|
if self.memory_bank.label_bank.is_cuda:
|
|
loss_inputs = (outs, self.memory_bank.label_bank[idx])
|
|
else:
|
|
loss_inputs = (outs, self.memory_bank.label_bank[idx.cpu()].cuda())
|
|
losses = self.head.loss(*loss_inputs)
|
|
|
|
# update samples memory
|
|
change_ratio = self.memory_bank.update_samples_memory(
|
|
idx, feature[0].detach())
|
|
losses['change_ratio'] = change_ratio
|
|
|
|
return losses
|
|
|
|
def forward_test(self, img, **kwargs):
|
|
x = self.forward_backbone(img) # tuple
|
|
outs = self.head(x)
|
|
keys = ['head{}'.format(i) for i in range(len(outs))]
|
|
out_tensors = [out.cpu() for out in outs] # NxC
|
|
return dict(zip(keys, out_tensors))
|
|
|
|
def forward(self, img, mode='train', **kwargs):
|
|
if mode == 'train':
|
|
return self.forward_train(img, **kwargs)
|
|
elif mode == 'test':
|
|
return self.forward_test(img, **kwargs)
|
|
elif mode == 'extract':
|
|
return self.forward_backbone(img)
|
|
else:
|
|
raise Exception("No such mode: {}".format(mode))
|
|
|
|
def set_reweight(self, labels=None, reweight_pow=0.5):
|
|
if labels is None:
|
|
if self.memory_bank.label_bank.is_cuda:
|
|
labels = self.memory_bank.label_bank.cpu().numpy()
|
|
else:
|
|
labels = self.memory_bank.label_bank.numpy()
|
|
hist = np.bincount(
|
|
labels, minlength=self.num_classes).astype(np.float32)
|
|
inv_hist = (1. / (hist + 1e-5))**reweight_pow
|
|
weight = inv_hist / inv_hist.sum()
|
|
self.loss_weight.copy_(torch.from_numpy(weight))
|
|
self.head.criterion = nn.CrossEntropyLoss(weight=self.loss_weight)
|