85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from openselfsup.utils import print_log
|
|
|
|
from . import builder
|
|
from .registry import MODELS
|
|
|
|
|
|
@MODELS.register_module
|
|
class BYOL(nn.Module):
|
|
'''BYOL unofficial implementation. Paper: https://arxiv.org/abs/2006.07733
|
|
'''
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
neck=None,
|
|
head=None,
|
|
pretrained=None,
|
|
base_momentum=0.996,
|
|
**kwargs):
|
|
super(BYOL, self).__init__()
|
|
self.online_net = nn.Sequential(
|
|
builder.build_backbone(backbone), builder.build_neck(neck))
|
|
self.target_net = nn.Sequential(
|
|
builder.build_backbone(backbone), builder.build_neck(neck))
|
|
self.backbone = self.online_net[0]
|
|
for param in self.target_net.parameters():
|
|
param.requires_grad = False
|
|
self.head = builder.build_head(head)
|
|
self.init_weights(pretrained=pretrained)
|
|
|
|
self.base_momentum = base_momentum
|
|
self.momentum = base_momentum
|
|
|
|
def init_weights(self, pretrained=None):
|
|
if pretrained is not None:
|
|
print_log('load model from: {}'.format(pretrained), logger='root')
|
|
self.online_net[0].init_weights(pretrained=pretrained) # backbone
|
|
self.online_net[1].init_weights(init_linear='kaiming') # projection
|
|
for param_ol, param_tgt in zip(self.online_net.parameters(),
|
|
self.target_net.parameters()):
|
|
param_tgt.data.copy_(param_ol.data)
|
|
# init the predictor in the head
|
|
self.head.init_weights()
|
|
|
|
@torch.no_grad()
|
|
def _momentum_update(self):
|
|
"""
|
|
Momentum update of the target network.
|
|
"""
|
|
for param_ol, param_tgt in zip(self.online_net.parameters(),
|
|
self.target_net.parameters()):
|
|
param_tgt.data = param_tgt.data * self.momentum + \
|
|
param_ol.data * (1. - self.momentum)
|
|
|
|
def forward_train(self, img, **kwargs):
|
|
assert img.dim() == 5, \
|
|
"Input must have 5 dims, got: {}".format(img.dim())
|
|
img_v1 = img[:, 0, ...].contiguous()
|
|
img_v2 = img[:, 1, ...].contiguous()
|
|
img_cat1 = torch.cat([img_v1, img_v2], dim=0)
|
|
img_cat2 = torch.cat([img_v2, img_v1], dim=0)
|
|
# compute query features
|
|
proj_online = self.online_net(img_cat1)[0]
|
|
with torch.no_grad():
|
|
proj_target = self.target_net(img_cat2)[0].clone().detach()
|
|
|
|
losses = self.head(proj_online, proj_target)
|
|
self._momentum_update()
|
|
return losses
|
|
|
|
def forward_test(self, img, **kwargs):
|
|
pass
|
|
|
|
def forward(self, img, mode='train', **kwargs):
|
|
if mode == 'train':
|
|
return self.forward_train(img, **kwargs)
|
|
elif mode == 'test':
|
|
return self.forward_test(img, **kwargs)
|
|
elif mode == 'extract':
|
|
return self.backbone(img)
|
|
else:
|
|
raise Exception("No such mode: {}".format(mode))
|