mmselfsup/openselfsup/models/byol.py

85 lines
3.0 KiB
Python

import torch
import torch.nn as nn
from openselfsup.utils import print_log
from . import builder
from .registry import MODELS
@MODELS.register_module
class BYOL(nn.Module):
'''BYOL unofficial implementation. Paper: https://arxiv.org/abs/2006.07733
'''
def __init__(self,
backbone,
neck=None,
head=None,
pretrained=None,
base_momentum=0.996,
**kwargs):
super(BYOL, self).__init__()
self.online_net = nn.Sequential(
builder.build_backbone(backbone), builder.build_neck(neck))
self.target_net = nn.Sequential(
builder.build_backbone(backbone), builder.build_neck(neck))
self.backbone = self.online_net[0]
for param in self.target_net.parameters():
param.requires_grad = False
self.head = builder.build_head(head)
self.init_weights(pretrained=pretrained)
self.base_momentum = base_momentum
self.momentum = base_momentum
def init_weights(self, pretrained=None):
if pretrained is not None:
print_log('load model from: {}'.format(pretrained), logger='root')
self.online_net[0].init_weights(pretrained=pretrained) # backbone
self.online_net[1].init_weights(init_linear='kaiming') # projection
for param_ol, param_tgt in zip(self.online_net.parameters(),
self.target_net.parameters()):
param_tgt.data.copy_(param_ol.data)
# init the predictor in the head
self.head.init_weights()
@torch.no_grad()
def _momentum_update(self):
"""
Momentum update of the target network.
"""
for param_ol, param_tgt in zip(self.online_net.parameters(),
self.target_net.parameters()):
param_tgt.data = param_tgt.data * self.momentum + \
param_ol.data * (1. - self.momentum)
def forward_train(self, img, **kwargs):
assert img.dim() == 5, \
"Input must have 5 dims, got: {}".format(img.dim())
img_v1 = img[:, 0, ...].contiguous()
img_v2 = img[:, 1, ...].contiguous()
img_cat1 = torch.cat([img_v1, img_v2], dim=0)
img_cat2 = torch.cat([img_v2, img_v1], dim=0)
# compute query features
proj_online = self.online_net(img_cat1)[0]
with torch.no_grad():
proj_target = self.target_net(img_cat2)[0].clone().detach()
losses = self.head(proj_online, proj_target)
self._momentum_update()
return losses
def forward_test(self, img, **kwargs):
pass
def forward(self, img, mode='train', **kwargs):
if mode == 'train':
return self.forward_train(img, **kwargs)
elif mode == 'test':
return self.forward_test(img, **kwargs)
elif mode == 'extract':
return self.backbone(img)
else:
raise Exception("No such mode: {}".format(mode))