xiaohangzhan 64e27eb9b7 fix BYOL
2020-08-04 13:33:16 +08:00

86 lines
3.1 KiB
Python

import torch
import torch.nn as nn
from openselfsup.utils import print_log
from . import builder
from .registry import MODELS
@MODELS.register_module
class BYOL(nn.Module):
'''BYOL unofficial implementation. Paper: https://arxiv.org/abs/2006.07733
'''
def __init__(self,
backbone,
neck=None,
head=None,
pretrained=None,
base_momentum=0.996,
**kwargs):
super(BYOL, self).__init__()
self.online_net = nn.Sequential(
builder.build_backbone(backbone), builder.build_neck(neck))
self.target_net = nn.Sequential(
builder.build_backbone(backbone), builder.build_neck(neck))
self.backbone = self.online_net[0]
for param in self.target_net.parameters():
param.requires_grad = False
self.head = builder.build_head(head)
self.init_weights(pretrained=pretrained)
self.base_momentum = base_momentum
self.momentum = base_momentum
def init_weights(self, pretrained=None):
if pretrained is not None:
print_log('load model from: {}'.format(pretrained), logger='root')
self.online_net[0].init_weights(pretrained=pretrained) # backbone
self.online_net[1].init_weights(init_linear='kaiming') # projection
for param_ol, param_tgt in zip(self.online_net.parameters(),
self.target_net.parameters()):
param_tgt.data.copy_(param_ol.data)
# init the predictor in the head
self.head.init_weights()
@torch.no_grad()
def _momentum_update(self):
"""
Momentum update of the target network.
"""
for param_ol, param_tgt in zip(self.online_net.parameters(),
self.target_net.parameters()):
param_tgt.data = param_tgt.data * self.momentum + \
param_ol.data * (1. - self.momentum)
def forward_train(self, img, **kwargs):
assert img.dim() == 5, \
"Input must have 5 dims, got: {}".format(img.dim())
img_v1 = img[:, 0, ...].contiguous()
img_v2 = img[:, 1, ...].contiguous()
# compute query features
proj_online_v1 = self.online_net(img_v1)[0]
proj_online_v2 = self.online_net(img_v2)[0]
with torch.no_grad():
proj_target_v1 = self.target_net(img_v1)[0].clone().detach()
proj_target_v2 = self.target_net(img_v2)[0].clone().detach()
loss = self.head(proj_online_v1, proj_target_v2)['loss'] + \
self.head(proj_online_v2, proj_target_v1)['loss']
self._momentum_update()
return dict(loss=loss)
def forward_test(self, img, **kwargs):
pass
def forward(self, img, mode='train', **kwargs):
if mode == 'train':
return self.forward_train(img, **kwargs)
elif mode == 'test':
return self.forward_test(img, **kwargs)
elif mode == 'extract':
return self.backbone(img)
else:
raise Exception("No such mode: {}".format(mode))