[Refactor] refactor byol
parent
29c6b26ee0
commit
1e16016b27
|
@ -5,7 +5,11 @@ data_root = 'data/imagenet/'
|
|||
file_client_args = dict(backend='disk')
|
||||
|
||||
view_pipeline1 = [
|
||||
dict(type='RandomResizedCrop', size=224, interpolation='bicubic'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
|
@ -22,12 +26,16 @@ view_pipeline1 = [
|
|||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.299)),
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1.),
|
||||
dict(type='RandomSolarize', prob=0.),
|
||||
]
|
||||
view_pipeline2 = [
|
||||
dict(type='RandomResizedCrop', size=224, interpolation='bicubic'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
|
@ -44,7 +52,7 @@ view_pipeline2 = [
|
|||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.299)),
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.1),
|
||||
dict(type='RandomSolarize', prob=0.2)
|
||||
]
|
||||
|
|
|
@ -2,6 +2,10 @@
|
|||
model = dict(
|
||||
type='BYOL',
|
||||
base_momentum=0.99,
|
||||
data_preprocessor=dict(
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -27,5 +31,6 @@ model = dict(
|
|||
num_layers=2,
|
||||
with_bias=True,
|
||||
with_last_bn=False,
|
||||
with_avg_pool=False)),
|
||||
loss=dict(type='CosineSimilarityLoss'))
|
||||
with_avg_pool=False),
|
||||
loss=dict(type='CosineSimilarityLoss')),
|
||||
)
|
||||
|
|
|
@ -5,31 +5,21 @@ _base_ = [
|
|||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# dataset summary
|
||||
data = dict(samples_per_gpu=256, workers_per_gpu=8)
|
||||
|
||||
# additional hooks
|
||||
# interval for accumulate gradient, total 16*256*1(interval)=4096
|
||||
update_interval = 1
|
||||
custom_hooks = [
|
||||
dict(type='BYOLHook', end_momentum=1., update_interval=update_interval)
|
||||
]
|
||||
train_dataloader = dict(batch_size=256)
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='LARS',
|
||||
lr=4.8,
|
||||
momentum=0.9,
|
||||
weight_decay=1e-6,
|
||||
paramwise_options={
|
||||
'(bn|gn)(\\d+)?.(weight|bias)':
|
||||
dict(weight_decay=0., lars_exclude=True),
|
||||
'bias': dict(weight_decay=0., lars_exclude=True)
|
||||
})
|
||||
optimizer_config = dict(update_interval=update_interval)
|
||||
optimizer = dict(type='LARS', lr=4.8, momentum=0.9, weight_decay=1e-6)
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'bn': dict(decay_mult=0, lars_exclude=True),
|
||||
'bias': dict(decay_mult=0, lars_exclude=True),
|
||||
# bn layer in ResNet block downsample module
|
||||
'downsample.1': dict(decay_mult=0, lars_exclude=True),
|
||||
}),
|
||||
)
|
||||
|
||||
# runtime settings
|
||||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
|
||||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
|
||||
# it will remove the oldest one to keep the number of total ckpts as 3
|
||||
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
|
||||
default_hooks = dict(checkpoint=dict(max_keep_ckpts=3))
|
||||
|
|
|
@ -2,6 +2,19 @@ _base_ = 'byol_resnet50_8xb256-fp16-accum2-coslr-200e_in1k.py'
|
|||
|
||||
# optimizer
|
||||
optimizer = dict(lr=7.2)
|
||||
optim_wrapper = dict(optimizer=optimizer)
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=10,
|
||||
convert_to_iter_based=True),
|
||||
dict(type='CosineAnnealingLR', T_max=90, by_epoch=True, begin=10, end=100)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100)
|
||||
|
|
|
@ -6,33 +6,23 @@ _base_ = [
|
|||
]
|
||||
|
||||
# dataset summary
|
||||
data = dict(samples_per_gpu=256)
|
||||
|
||||
# additional hooks
|
||||
# interval for accumulate gradient, total 8*256*2(interval)=4096
|
||||
update_interval = 2
|
||||
custom_hooks = [
|
||||
dict(type='BYOLHook', end_momentum=1., update_interval=update_interval)
|
||||
]
|
||||
train_dataloader = dict(batch_size=256)
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='LARS',
|
||||
lr=4.8,
|
||||
momentum=0.9,
|
||||
weight_decay=1e-6,
|
||||
paramwise_options={
|
||||
'(bn|gn)(\\d+)?.(weight|bias)':
|
||||
dict(weight_decay=0., lars_exclude=True),
|
||||
'bias': dict(weight_decay=0., lars_exclude=True)
|
||||
})
|
||||
optimizer_config = dict(update_interval=update_interval)
|
||||
|
||||
# fp16
|
||||
fp16 = dict(loss_scale=512.)
|
||||
optimizer = dict(type='LARS', lr=4.8, momentum=0.9, weight_decay=1e-6)
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale=512.,
|
||||
optimizer=optimizer,
|
||||
accumulative_iters=2,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'bn': dict(decay_mult=0, lars_exclude=True),
|
||||
'bias': dict(decay_mult=0, lars_exclude=True),
|
||||
# bn layer in ResNet block downsample module
|
||||
'downsample.1': dict(decay_mult=0, lars_exclude=True),
|
||||
}),
|
||||
)
|
||||
|
||||
# runtime settings
|
||||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
|
||||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
|
||||
# it will remove the oldest one to keep the number of total ckpts as 3
|
||||
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
|
||||
default_hooks = dict(checkpoint=dict(max_keep_ckpts=3))
|
||||
|
|
|
@ -1,4 +1,17 @@
|
|||
_base_ = 'byol_resnet50_8xb256-fp16-accum2-coslr-200e_in1k.py'
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=10,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR', T_max=290, by_epoch=True, begin=10, end=300)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
|
||||
|
|
|
@ -2,6 +2,19 @@ _base_ = 'byol_resnet50_8xb32-accum16-coslr-200e_in1k.py'
|
|||
|
||||
# optimizer
|
||||
optimizer = dict(lr=7.2)
|
||||
optim_wrapper = dict(optimizer=optimizer)
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=10,
|
||||
convert_to_iter_based=True),
|
||||
dict(type='CosineAnnealingLR', T_max=90, by_epoch=True, begin=10, end=100)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100)
|
||||
|
|
|
@ -5,28 +5,20 @@ _base_ = [
|
|||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# additional hooks
|
||||
# interval for accumulate gradient, total 8*32*16(interval)=4096
|
||||
update_interval = 16
|
||||
custom_hooks = [
|
||||
dict(type='BYOLHook', end_momentum=1., update_interval=update_interval)
|
||||
]
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='LARS',
|
||||
lr=4.8,
|
||||
momentum=0.9,
|
||||
weight_decay=1e-6,
|
||||
paramwise_options={
|
||||
'(bn|gn)(\\d+)?.(weight|bias)':
|
||||
dict(weight_decay=0., lars_exclude=True),
|
||||
'bias': dict(weight_decay=0., lars_exclude=True)
|
||||
})
|
||||
optimizer_config = dict(update_interval=update_interval)
|
||||
optimizer = dict(type='LARS', lr=4.8, momentum=0.9, weight_decay=1e-6)
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
accumulative_iters=16,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'bn': dict(decay_mult=0, lars_exclude=True),
|
||||
'bias': dict(decay_mult=0, lars_exclude=True),
|
||||
# bn layer in ResNet block downsample module
|
||||
'downsample.1': dict(decay_mult=0, lars_exclude=True),
|
||||
}),
|
||||
)
|
||||
|
||||
# runtime settings
|
||||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
|
||||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
|
||||
# it will remove the oldest one to keep the number of total ckpts as 3
|
||||
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
|
||||
default_hooks = dict(checkpoint=dict(max_keep_ckpts=3))
|
||||
|
|
|
@ -1,4 +1,17 @@
|
|||
_base_ = 'byol_resnet50_8xb32-accum16-coslr-200e_in1k.py'
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=10,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR', T_max=290, by_epoch=True, begin=10, end=300)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
@ -9,24 +11,21 @@ from mmselfsup.registry import OPTIMIZERS
|
|||
class LARS(Optimizer):
|
||||
"""Implements layer-wise adaptive rate scaling for SGD.
|
||||
|
||||
Args:
|
||||
params (iterable): Iterable of parameters to optimize or dicts defining
|
||||
parameter groups.
|
||||
lr (float): Base learning rate.
|
||||
momentum (float, optional): Momentum factor. Defaults to 0 ('m')
|
||||
weight_decay (float, optional): Weight decay (L2 penalty).
|
||||
Defaults to 0. ('beta')
|
||||
dampening (float, optional): Dampening for momentum. Defaults to 0.
|
||||
eta (float, optional): LARS coefficient. Defaults to 0.001.
|
||||
nesterov (bool, optional): Enables Nesterov momentum.
|
||||
Defaults to False.
|
||||
eps (float, optional): A small number to avoid dviding zero.
|
||||
Defaults to 1e-8.
|
||||
|
||||
Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
|
||||
`Large Batch Training of Convolutional Networks:
|
||||
<https://arxiv.org/abs/1708.03888>`_.
|
||||
|
||||
Args:
|
||||
params (Iterable): Iterable of parameters to optimize or dicts defining
|
||||
parameter groups.
|
||||
lr (float): Base learning rate.
|
||||
momentum (float): Momentum factor. Defaults to 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Defaults to 0.
|
||||
dampening (float): Dampening for momentum. Defaults to 0.
|
||||
eta (float): LARS coefficient. Defaults to 0.001.
|
||||
nesterov (bool): Enables Nesterov momentum. Defaults to False.
|
||||
eps (float): A small number to avoid dviding zero. Defaults to 1e-8.
|
||||
|
||||
Example:
|
||||
>>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9,
|
||||
>>> weight_decay=1e-4, eta=1e-3)
|
||||
|
@ -36,14 +35,14 @@ class LARS(Optimizer):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=float,
|
||||
momentum=0,
|
||||
weight_decay=0,
|
||||
dampening=0,
|
||||
eta=0.001,
|
||||
nesterov=False,
|
||||
eps=1e-8):
|
||||
params: Iterable,
|
||||
lr: float,
|
||||
momentum: float = 0,
|
||||
weight_decay: float = 0,
|
||||
dampening: float = 0,
|
||||
eta: float = 0.001,
|
||||
nesterov: bool = False,
|
||||
eps: float = 1e-8) -> None:
|
||||
if not isinstance(lr, float) and lr < 0.0:
|
||||
raise ValueError(f'Invalid learning rate: {lr}')
|
||||
if momentum < 0.0:
|
||||
|
@ -65,15 +64,15 @@ class LARS(Optimizer):
|
|||
'Nesterov momentum requires a momentum and zero dampening')
|
||||
|
||||
self.eps = eps
|
||||
super(LARS, self).__init__(params, defaults)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(LARS, self).__setstate__(state)
|
||||
def __setstate__(self, state) -> None:
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('nesterov', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
def step(self, closure=None) -> torch.Tensor:
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -5,13 +5,12 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
|
||||
build_neck)
|
||||
from mmselfsup.registry import MODELS
|
||||
from ..utils import CosineEMA
|
||||
from .base import BaseModel
|
||||
|
||||
|
||||
@ALGORITHMS.register_module()
|
||||
@MODELS.register_module()
|
||||
class BYOL(BaseModel):
|
||||
"""BYOL.
|
||||
|
||||
|
@ -20,95 +19,83 @@ class BYOL(BaseModel):
|
|||
The momentum adjustment is in `core/hooks/byol_hook.py`.
|
||||
|
||||
Args:
|
||||
backbone (Dict, optional): Config dict for module of backbone.
|
||||
neck (Dict, optional): Config dict for module of deep features
|
||||
to compact feature vectors. Defaults to None.
|
||||
head (Dict, optional): Config dict for module of head functions.
|
||||
Defaults to None.
|
||||
loss (Dict, optional): Config dict for module of loss functions.
|
||||
Defaults to None.
|
||||
backbone (dict): Config dict for module of backbone.
|
||||
neck (dict): Config dict for module of deep features
|
||||
to compact feature vectors.
|
||||
head (dict): Config dict for module of head functions.
|
||||
base_momentum (float): The base momentum coefficient for the target
|
||||
network. Defaults to 0.996.
|
||||
preprocess_cfg (Dict, optional): Config dict to preprocess images.
|
||||
pretrained (str, optional): The pretrained checkpoint path, support
|
||||
local path and remote path. Defaults to None.
|
||||
data_preprocessor (dict, optional): Config dict to preprocess images.
|
||||
Defaults to None.
|
||||
init_cfg (Dict or List[Dict], optional): Config dict for weight
|
||||
init_cfg (dict or List[dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: Optional[Dict] = None,
|
||||
neck: Optional[Dict] = None,
|
||||
head: Optional[Dict] = None,
|
||||
loss: Optional[Dict] = None,
|
||||
backbone: dict,
|
||||
neck: dict,
|
||||
head: dict,
|
||||
base_momentum: float = 0.996,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
assert backbone is not None
|
||||
assert neck is not None
|
||||
self.online_net = nn.Sequential(
|
||||
build_backbone(backbone), build_neck(neck))
|
||||
self.backbone = self.online_net[0]
|
||||
self.neck = self.online_net[1]
|
||||
assert head is not None
|
||||
self.head = build_head(head)
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
pretrained: Optional[str] = None,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
pretrained=pretrained,
|
||||
data_preprocessor=data_preprocessor,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
# create momentum model
|
||||
self.target_net = CosineEMA(self.online_net, momentum=base_momentum)
|
||||
for param_tgt in self.target_net.module.parameters():
|
||||
param_tgt.requires_grad = False
|
||||
self.target_net = CosineEMA(
|
||||
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
|
||||
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
def extract_feat(self, batch_inputs: List[torch.Tensor],
|
||||
**kwargs) -> Tuple[torch.Tensor]:
|
||||
"""Function to extract features from backbone.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
x = self.backbone(inputs[0])
|
||||
x = self.backbone(batch_inputs[0])
|
||||
return x
|
||||
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
def loss(self, batch_inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""Forward computation during training.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
assert isinstance(inputs, list)
|
||||
img_v1 = inputs[0]
|
||||
img_v2 = inputs[1]
|
||||
assert isinstance(batch_inputs, list)
|
||||
img_v1 = batch_inputs[0]
|
||||
img_v2 = batch_inputs[1]
|
||||
# compute online features
|
||||
proj_online_v1 = self.online_net(img_v1)[0]
|
||||
proj_online_v2 = self.online_net(img_v2)[0]
|
||||
proj_online_v1 = self.neck(self.backbone(img_v1))[0]
|
||||
proj_online_v2 = self.neck(self.backbone(img_v2))[0]
|
||||
# compute target features
|
||||
with torch.no_grad():
|
||||
# update the target net
|
||||
self.target_net.update_parameters(self.online_net)
|
||||
self.target_net.update_parameters(
|
||||
nn.Sequential(self.backbone, self.neck))
|
||||
|
||||
proj_target_v1 = self.target_net(img_v1)[0]
|
||||
proj_target_v2 = self.target_net(img_v2)[0]
|
||||
|
||||
pred_1, target_1 = self.head(proj_online_v1, proj_target_v2)
|
||||
pred_2, target_2 = self.head(proj_online_v2, proj_target_v1)
|
||||
|
||||
loss_1 = self.loss(pred_1, target_1)
|
||||
loss_2 = self.loss(pred_2, target_2)
|
||||
loss_1 = self.head(proj_online_v1, proj_target_v2)
|
||||
loss_2 = self.head(proj_online_v2, proj_target_v1)
|
||||
|
||||
losses = dict(loss=2. * (loss_1 + loss_2))
|
||||
return losses
|
||||
|
|
|
@ -26,6 +26,7 @@ neck = dict(
|
|||
norm_cfg=dict(type='BN1d'))
|
||||
head = dict(
|
||||
type='LatentPredictHead',
|
||||
loss=dict(type='CosineSimilarityLoss'),
|
||||
predictor=dict(
|
||||
type='NonLinearNeck',
|
||||
in_channels=2,
|
||||
|
@ -35,43 +36,20 @@ head = dict(
|
|||
with_last_bn=False,
|
||||
with_avg_pool=False,
|
||||
norm_cfg=dict(type='BN1d')))
|
||||
loss = dict(type='CosineSimilarityLoss')
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_byol():
|
||||
preprocess_cfg = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
}
|
||||
with pytest.raises(AssertionError):
|
||||
alg = BYOL(
|
||||
backbone=backbone,
|
||||
neck=None,
|
||||
head=head,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = BYOL(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=None,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = BYOL(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=None,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
data_preprocessor = dict(
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
bgr_to_rgb=True)
|
||||
|
||||
alg = BYOL(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
data_preprocessor=copy.deepcopy(data_preprocessor))
|
||||
|
||||
fake_data = [{
|
||||
'inputs': [torch.randn((3, 224, 224)),
|
||||
|
@ -79,12 +57,11 @@ def test_byol():
|
|||
'data_sample':
|
||||
SelfSupDataSample()
|
||||
} for _ in range(2)]
|
||||
fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||
|
||||
fake_outputs = alg(fake_data, return_loss=True)
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
||||
assert fake_outputs['loss'].item() > -4
|
||||
fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
|
||||
assert isinstance(fake_loss['loss'].item(), float)
|
||||
assert fake_loss['loss'].item() > -4
|
||||
|
||||
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
|
||||
fake_feat = alg.extract_feat(
|
||||
inputs=fake_inputs, data_samples=fake_data_samples)
|
||||
assert list(fake_feat[0].shape) == [2, 512, 7, 7]
|
||||
fake_feats = alg(fake_inputs, fake_data_samples, mode='tensor')
|
||||
assert list(fake_feats[0].shape) == [2, 512, 7, 7]
|
||||
|
|
Loading…
Reference in New Issue