[Refactor] refactor byol

pull/352/head
fangyixiao.vendor 2022-07-06 07:50:01 +00:00 committed by fangyixiao18
parent 29c6b26ee0
commit 1e16016b27
12 changed files with 197 additions and 197 deletions

View File

@ -5,7 +5,11 @@ data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
view_pipeline1 = [
dict(type='RandomResizedCrop', size=224, interpolation='bicubic'),
dict(
type='RandomResizedCrop',
size=224,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomApply',
@ -22,12 +26,16 @@ view_pipeline1 = [
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.299)),
channel_weights=(0.114, 0.587, 0.2989)),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1.),
dict(type='RandomSolarize', prob=0.),
]
view_pipeline2 = [
dict(type='RandomResizedCrop', size=224, interpolation='bicubic'),
dict(
type='RandomResizedCrop',
size=224,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomApply',
@ -44,7 +52,7 @@ view_pipeline2 = [
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.299)),
channel_weights=(0.114, 0.587, 0.2989)),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.1),
dict(type='RandomSolarize', prob=0.2)
]

View File

@ -2,6 +2,10 @@
model = dict(
type='BYOL',
base_momentum=0.99,
data_preprocessor=dict(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
bgr_to_rgb=True),
backbone=dict(
type='ResNet',
depth=50,
@ -27,5 +31,6 @@ model = dict(
num_layers=2,
with_bias=True,
with_last_bn=False,
with_avg_pool=False)),
loss=dict(type='CosineSimilarityLoss'))
with_avg_pool=False),
loss=dict(type='CosineSimilarityLoss')),
)

View File

@ -5,31 +5,21 @@ _base_ = [
'../_base_/default_runtime.py',
]
# dataset summary
data = dict(samples_per_gpu=256, workers_per_gpu=8)
# additional hooks
# interval for accumulate gradient, total 16*256*1(interval)=4096
update_interval = 1
custom_hooks = [
dict(type='BYOLHook', end_momentum=1., update_interval=update_interval)
]
train_dataloader = dict(batch_size=256)
# optimizer
optimizer = dict(
type='LARS',
lr=4.8,
momentum=0.9,
weight_decay=1e-6,
paramwise_options={
'(bn|gn)(\\d+)?.(weight|bias)':
dict(weight_decay=0., lars_exclude=True),
'bias': dict(weight_decay=0., lars_exclude=True)
})
optimizer_config = dict(update_interval=update_interval)
optimizer = dict(type='LARS', lr=4.8, momentum=0.9, weight_decay=1e-6)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=optimizer,
paramwise_cfg=dict(
custom_keys={
'bn': dict(decay_mult=0, lars_exclude=True),
'bias': dict(decay_mult=0, lars_exclude=True),
# bn layer in ResNet block downsample module
'downsample.1': dict(decay_mult=0, lars_exclude=True),
}),
)
# runtime settings
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
# it will remove the oldest one to keep the number of total ckpts as 3
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
default_hooks = dict(checkpoint=dict(max_keep_ckpts=3))

View File

@ -2,6 +2,19 @@ _base_ = 'byol_resnet50_8xb256-fp16-accum2-coslr-200e_in1k.py'
# optimizer
optimizer = dict(lr=7.2)
optim_wrapper = dict(optimizer=optimizer)
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(type='CosineAnnealingLR', T_max=90, by_epoch=True, begin=10, end=100)
]
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=100)
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100)

View File

@ -6,33 +6,23 @@ _base_ = [
]
# dataset summary
data = dict(samples_per_gpu=256)
# additional hooks
# interval for accumulate gradient, total 8*256*2(interval)=4096
update_interval = 2
custom_hooks = [
dict(type='BYOLHook', end_momentum=1., update_interval=update_interval)
]
train_dataloader = dict(batch_size=256)
# optimizer
optimizer = dict(
type='LARS',
lr=4.8,
momentum=0.9,
weight_decay=1e-6,
paramwise_options={
'(bn|gn)(\\d+)?.(weight|bias)':
dict(weight_decay=0., lars_exclude=True),
'bias': dict(weight_decay=0., lars_exclude=True)
})
optimizer_config = dict(update_interval=update_interval)
# fp16
fp16 = dict(loss_scale=512.)
optimizer = dict(type='LARS', lr=4.8, momentum=0.9, weight_decay=1e-6)
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale=512.,
optimizer=optimizer,
accumulative_iters=2,
paramwise_cfg=dict(
custom_keys={
'bn': dict(decay_mult=0, lars_exclude=True),
'bias': dict(decay_mult=0, lars_exclude=True),
# bn layer in ResNet block downsample module
'downsample.1': dict(decay_mult=0, lars_exclude=True),
}),
)
# runtime settings
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
# it will remove the oldest one to keep the number of total ckpts as 3
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
default_hooks = dict(checkpoint=dict(max_keep_ckpts=3))

View File

@ -1,4 +1,17 @@
_base_ = 'byol_resnet50_8xb256-fp16-accum2-coslr-200e_in1k.py'
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR', T_max=290, by_epoch=True, begin=10, end=300)
]
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=300)
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)

View File

@ -2,6 +2,19 @@ _base_ = 'byol_resnet50_8xb32-accum16-coslr-200e_in1k.py'
# optimizer
optimizer = dict(lr=7.2)
optim_wrapper = dict(optimizer=optimizer)
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(type='CosineAnnealingLR', T_max=90, by_epoch=True, begin=10, end=100)
]
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=100)
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100)

View File

@ -5,28 +5,20 @@ _base_ = [
'../_base_/default_runtime.py',
]
# additional hooks
# interval for accumulate gradient, total 8*32*16(interval)=4096
update_interval = 16
custom_hooks = [
dict(type='BYOLHook', end_momentum=1., update_interval=update_interval)
]
# optimizer
optimizer = dict(
type='LARS',
lr=4.8,
momentum=0.9,
weight_decay=1e-6,
paramwise_options={
'(bn|gn)(\\d+)?.(weight|bias)':
dict(weight_decay=0., lars_exclude=True),
'bias': dict(weight_decay=0., lars_exclude=True)
})
optimizer_config = dict(update_interval=update_interval)
optimizer = dict(type='LARS', lr=4.8, momentum=0.9, weight_decay=1e-6)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=optimizer,
accumulative_iters=16,
paramwise_cfg=dict(
custom_keys={
'bn': dict(decay_mult=0, lars_exclude=True),
'bias': dict(decay_mult=0, lars_exclude=True),
# bn layer in ResNet block downsample module
'downsample.1': dict(decay_mult=0, lars_exclude=True),
}),
)
# runtime settings
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
# it will remove the oldest one to keep the number of total ckpts as 3
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
default_hooks = dict(checkpoint=dict(max_keep_ckpts=3))

View File

@ -1,4 +1,17 @@
_base_ = 'byol_resnet50_8xb32-accum16-coslr-200e_in1k.py'
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR', T_max=290, by_epoch=True, begin=10, end=300)
]
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=300)
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Iterable
import torch
from torch.optim.optimizer import Optimizer
@ -9,24 +11,21 @@ from mmselfsup.registry import OPTIMIZERS
class LARS(Optimizer):
"""Implements layer-wise adaptive rate scaling for SGD.
Args:
params (iterable): Iterable of parameters to optimize or dicts defining
parameter groups.
lr (float): Base learning rate.
momentum (float, optional): Momentum factor. Defaults to 0 ('m')
weight_decay (float, optional): Weight decay (L2 penalty).
Defaults to 0. ('beta')
dampening (float, optional): Dampening for momentum. Defaults to 0.
eta (float, optional): LARS coefficient. Defaults to 0.001.
nesterov (bool, optional): Enables Nesterov momentum.
Defaults to False.
eps (float, optional): A small number to avoid dviding zero.
Defaults to 1e-8.
Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
`Large Batch Training of Convolutional Networks:
<https://arxiv.org/abs/1708.03888>`_.
Args:
params (Iterable): Iterable of parameters to optimize or dicts defining
parameter groups.
lr (float): Base learning rate.
momentum (float): Momentum factor. Defaults to 0.
weight_decay (float): Weight decay (L2 penalty). Defaults to 0.
dampening (float): Dampening for momentum. Defaults to 0.
eta (float): LARS coefficient. Defaults to 0.001.
nesterov (bool): Enables Nesterov momentum. Defaults to False.
eps (float): A small number to avoid dviding zero. Defaults to 1e-8.
Example:
>>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9,
>>> weight_decay=1e-4, eta=1e-3)
@ -36,14 +35,14 @@ class LARS(Optimizer):
"""
def __init__(self,
params,
lr=float,
momentum=0,
weight_decay=0,
dampening=0,
eta=0.001,
nesterov=False,
eps=1e-8):
params: Iterable,
lr: float,
momentum: float = 0,
weight_decay: float = 0,
dampening: float = 0,
eta: float = 0.001,
nesterov: bool = False,
eps: float = 1e-8) -> None:
if not isinstance(lr, float) and lr < 0.0:
raise ValueError(f'Invalid learning rate: {lr}')
if momentum < 0.0:
@ -65,15 +64,15 @@ class LARS(Optimizer):
'Nesterov momentum requires a momentum and zero dampening')
self.eps = eps
super(LARS, self).__init__(params, defaults)
super().__init__(params, defaults)
def __setstate__(self, state):
super(LARS, self).__setstate__(state)
def __setstate__(self, state) -> None:
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
@torch.no_grad()
def step(self, closure=None):
def step(self, closure=None) -> torch.Tensor:
"""Performs a single optimization step.
Args:

View File

@ -5,13 +5,12 @@ import torch
import torch.nn as nn
from mmselfsup.core import SelfSupDataSample
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
build_neck)
from mmselfsup.registry import MODELS
from ..utils import CosineEMA
from .base import BaseModel
@ALGORITHMS.register_module()
@MODELS.register_module()
class BYOL(BaseModel):
"""BYOL.
@ -20,95 +19,83 @@ class BYOL(BaseModel):
The momentum adjustment is in `core/hooks/byol_hook.py`.
Args:
backbone (Dict, optional): Config dict for module of backbone.
neck (Dict, optional): Config dict for module of deep features
to compact feature vectors. Defaults to None.
head (Dict, optional): Config dict for module of head functions.
Defaults to None.
loss (Dict, optional): Config dict for module of loss functions.
Defaults to None.
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features
to compact feature vectors.
head (dict): Config dict for module of head functions.
base_momentum (float): The base momentum coefficient for the target
network. Defaults to 0.996.
preprocess_cfg (Dict, optional): Config dict to preprocess images.
pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
data_preprocessor (dict, optional): Config dict to preprocess images.
Defaults to None.
init_cfg (Dict or List[Dict], optional): Config dict for weight
init_cfg (dict or List[dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
backbone: Optional[Dict] = None,
neck: Optional[Dict] = None,
head: Optional[Dict] = None,
loss: Optional[Dict] = None,
backbone: dict,
neck: dict,
head: dict,
base_momentum: float = 0.996,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
**kwargs) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
assert neck is not None
self.online_net = nn.Sequential(
build_backbone(backbone), build_neck(neck))
self.backbone = self.online_net[0]
self.neck = self.online_net[1]
assert head is not None
self.head = build_head(head)
assert loss is not None
self.loss = build_loss(loss)
pretrained: Optional[str] = None,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
# create momentum model
self.target_net = CosineEMA(self.online_net, momentum=base_momentum)
for param_tgt in self.target_net.module.parameters():
param_tgt.requires_grad = False
self.target_net = CosineEMA(
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
def extract_feat(self, batch_inputs: List[torch.Tensor],
**kwargs) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Tuple[torch.Tensor]: backbone outputs.
"""
x = self.backbone(inputs[0])
x = self.backbone(batch_inputs[0])
return x
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
def loss(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""Forward computation during training.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
assert isinstance(inputs, list)
img_v1 = inputs[0]
img_v2 = inputs[1]
assert isinstance(batch_inputs, list)
img_v1 = batch_inputs[0]
img_v2 = batch_inputs[1]
# compute online features
proj_online_v1 = self.online_net(img_v1)[0]
proj_online_v2 = self.online_net(img_v2)[0]
proj_online_v1 = self.neck(self.backbone(img_v1))[0]
proj_online_v2 = self.neck(self.backbone(img_v2))[0]
# compute target features
with torch.no_grad():
# update the target net
self.target_net.update_parameters(self.online_net)
self.target_net.update_parameters(
nn.Sequential(self.backbone, self.neck))
proj_target_v1 = self.target_net(img_v1)[0]
proj_target_v2 = self.target_net(img_v2)[0]
pred_1, target_1 = self.head(proj_online_v1, proj_target_v2)
pred_2, target_2 = self.head(proj_online_v2, proj_target_v1)
loss_1 = self.loss(pred_1, target_1)
loss_2 = self.loss(pred_2, target_2)
loss_1 = self.head(proj_online_v1, proj_target_v2)
loss_2 = self.head(proj_online_v2, proj_target_v1)
losses = dict(loss=2. * (loss_1 + loss_2))
return losses

View File

@ -26,6 +26,7 @@ neck = dict(
norm_cfg=dict(type='BN1d'))
head = dict(
type='LatentPredictHead',
loss=dict(type='CosineSimilarityLoss'),
predictor=dict(
type='NonLinearNeck',
in_channels=2,
@ -35,43 +36,20 @@ head = dict(
with_last_bn=False,
with_avg_pool=False,
norm_cfg=dict(type='BN1d')))
loss = dict(type='CosineSimilarityLoss')
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_byol():
preprocess_cfg = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
}
with pytest.raises(AssertionError):
alg = BYOL(
backbone=backbone,
neck=None,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = BYOL(
backbone=backbone,
neck=neck,
head=None,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = BYOL(
backbone=backbone,
neck=neck,
head=head,
loss=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
data_preprocessor = dict(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
bgr_to_rgb=True)
alg = BYOL(
backbone=backbone,
neck=neck,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
data_preprocessor=copy.deepcopy(data_preprocessor))
fake_data = [{
'inputs': [torch.randn((3, 224, 224)),
@ -79,12 +57,11 @@ def test_byol():
'data_sample':
SelfSupDataSample()
} for _ in range(2)]
fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
fake_outputs = alg(fake_data, return_loss=True)
assert isinstance(fake_outputs['loss'].item(), float)
assert fake_outputs['loss'].item() > -4
fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_loss['loss'].item(), float)
assert fake_loss['loss'].item() > -4
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
fake_feat = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
assert list(fake_feat[0].shape) == [2, 512, 7, 7]
fake_feats = alg(fake_inputs, fake_data_samples, mode='tensor')
assert list(fake_feats[0].shape) == [2, 512, 7, 7]