From 1e16016b2709ac4db13116d0152890f2e0acbbf4 Mon Sep 17 00:00:00 2001 From: "fangyixiao.vendor" Date: Wed, 6 Jul 2022 07:50:01 +0000 Subject: [PATCH] [Refactor] refactor byol --- .../selfsup/_base_/datasets/imagenet_byol.py | 16 +++- configs/selfsup/_base_/models/byol.py | 9 +- .../byol_resnet50_16xb256-coslr-200e_in1k.py | 38 +++----- ...et50_8xb256-fp16-accum2-coslr-100e_in1k.py | 15 ++- ...et50_8xb256-fp16-accum2-coslr-200e_in1k.py | 42 ++++----- ...et50_8xb256-fp16-accum2-coslr-300e_in1k.py | 15 ++- ..._resnet50_8xb32-accum16-coslr-100e_in1k.py | 15 ++- ..._resnet50_8xb32-accum16-coslr-200e_in1k.py | 36 +++---- ..._resnet50_8xb32-accum16-coslr-300e_in1k.py | 15 ++- mmselfsup/core/optimizer/optimizers.py | 51 +++++----- mmselfsup/models/algorithms/byol.py | 93 ++++++++----------- .../test_models/test_algorithms/test_byol.py | 49 +++------- 12 files changed, 197 insertions(+), 197 deletions(-) diff --git a/configs/selfsup/_base_/datasets/imagenet_byol.py b/configs/selfsup/_base_/datasets/imagenet_byol.py index 1d7227e8..18d7ec05 100644 --- a/configs/selfsup/_base_/datasets/imagenet_byol.py +++ b/configs/selfsup/_base_/datasets/imagenet_byol.py @@ -5,7 +5,11 @@ data_root = 'data/imagenet/' file_client_args = dict(backend='disk') view_pipeline1 = [ - dict(type='RandomResizedCrop', size=224, interpolation='bicubic'), + dict( + type='RandomResizedCrop', + size=224, + interpolation='bicubic', + backend='pillow'), dict(type='RandomFlip', prob=0.5), dict( type='RandomApply', @@ -22,12 +26,16 @@ view_pipeline1 = [ type='RandomGrayscale', prob=0.2, keep_channels=True, - channel_weights=(0.114, 0.587, 0.299)), + channel_weights=(0.114, 0.587, 0.2989)), dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1.), dict(type='RandomSolarize', prob=0.), ] view_pipeline2 = [ - dict(type='RandomResizedCrop', size=224, interpolation='bicubic'), + dict( + type='RandomResizedCrop', + size=224, + interpolation='bicubic', + backend='pillow'), dict(type='RandomFlip', prob=0.5), dict( type='RandomApply', @@ -44,7 +52,7 @@ view_pipeline2 = [ type='RandomGrayscale', prob=0.2, keep_channels=True, - channel_weights=(0.114, 0.587, 0.299)), + channel_weights=(0.114, 0.587, 0.2989)), dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.1), dict(type='RandomSolarize', prob=0.2) ] diff --git a/configs/selfsup/_base_/models/byol.py b/configs/selfsup/_base_/models/byol.py index 5a4e846b..67770ef3 100644 --- a/configs/selfsup/_base_/models/byol.py +++ b/configs/selfsup/_base_/models/byol.py @@ -2,6 +2,10 @@ model = dict( type='BYOL', base_momentum=0.99, + data_preprocessor=dict( + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + bgr_to_rgb=True), backbone=dict( type='ResNet', depth=50, @@ -27,5 +31,6 @@ model = dict( num_layers=2, with_bias=True, with_last_bn=False, - with_avg_pool=False)), - loss=dict(type='CosineSimilarityLoss')) + with_avg_pool=False), + loss=dict(type='CosineSimilarityLoss')), +) diff --git a/configs/selfsup/byol/byol_resnet50_16xb256-coslr-200e_in1k.py b/configs/selfsup/byol/byol_resnet50_16xb256-coslr-200e_in1k.py index df3b475e..1ef3f127 100644 --- a/configs/selfsup/byol/byol_resnet50_16xb256-coslr-200e_in1k.py +++ b/configs/selfsup/byol/byol_resnet50_16xb256-coslr-200e_in1k.py @@ -5,31 +5,21 @@ _base_ = [ '../_base_/default_runtime.py', ] -# dataset summary -data = dict(samples_per_gpu=256, workers_per_gpu=8) - -# additional hooks -# interval for accumulate gradient, total 16*256*1(interval)=4096 -update_interval = 1 -custom_hooks = [ - dict(type='BYOLHook', end_momentum=1., update_interval=update_interval) -] +train_dataloader = dict(batch_size=256) # optimizer -optimizer = dict( - type='LARS', - lr=4.8, - momentum=0.9, - weight_decay=1e-6, - paramwise_options={ - '(bn|gn)(\\d+)?.(weight|bias)': - dict(weight_decay=0., lars_exclude=True), - 'bias': dict(weight_decay=0., lars_exclude=True) - }) -optimizer_config = dict(update_interval=update_interval) +optimizer = dict(type='LARS', lr=4.8, momentum=0.9, weight_decay=1e-6) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=optimizer, + paramwise_cfg=dict( + custom_keys={ + 'bn': dict(decay_mult=0, lars_exclude=True), + 'bias': dict(decay_mult=0, lars_exclude=True), + # bn layer in ResNet block downsample module + 'downsample.1': dict(decay_mult=0, lars_exclude=True), + }), +) # runtime settings -# the max_keep_ckpts controls the max number of ckpt file in your work_dirs -# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt -# it will remove the oldest one to keep the number of total ckpts as 3 -checkpoint_config = dict(interval=10, max_keep_ckpts=3) +default_hooks = dict(checkpoint=dict(max_keep_ckpts=3)) diff --git a/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-100e_in1k.py b/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-100e_in1k.py index 27ccf438..9ed9ae8b 100644 --- a/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-100e_in1k.py +++ b/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-100e_in1k.py @@ -2,6 +2,19 @@ _base_ = 'byol_resnet50_8xb256-fp16-accum2-coslr-200e_in1k.py' # optimizer optimizer = dict(lr=7.2) +optim_wrapper = dict(optimizer=optimizer) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict(type='CosineAnnealingLR', T_max=90, by_epoch=True, begin=10, end=100) +] # runtime settings -runner = dict(type='EpochBasedRunner', max_epochs=100) +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100) diff --git a/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-200e_in1k.py b/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-200e_in1k.py index b3a6a1a8..a93d3148 100644 --- a/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-200e_in1k.py +++ b/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-200e_in1k.py @@ -6,33 +6,23 @@ _base_ = [ ] # dataset summary -data = dict(samples_per_gpu=256) - -# additional hooks -# interval for accumulate gradient, total 8*256*2(interval)=4096 -update_interval = 2 -custom_hooks = [ - dict(type='BYOLHook', end_momentum=1., update_interval=update_interval) -] +train_dataloader = dict(batch_size=256) # optimizer -optimizer = dict( - type='LARS', - lr=4.8, - momentum=0.9, - weight_decay=1e-6, - paramwise_options={ - '(bn|gn)(\\d+)?.(weight|bias)': - dict(weight_decay=0., lars_exclude=True), - 'bias': dict(weight_decay=0., lars_exclude=True) - }) -optimizer_config = dict(update_interval=update_interval) - -# fp16 -fp16 = dict(loss_scale=512.) +optimizer = dict(type='LARS', lr=4.8, momentum=0.9, weight_decay=1e-6) +optim_wrapper = dict( + type='AmpOptimWrapper', + loss_scale=512., + optimizer=optimizer, + accumulative_iters=2, + paramwise_cfg=dict( + custom_keys={ + 'bn': dict(decay_mult=0, lars_exclude=True), + 'bias': dict(decay_mult=0, lars_exclude=True), + # bn layer in ResNet block downsample module + 'downsample.1': dict(decay_mult=0, lars_exclude=True), + }), +) # runtime settings -# the max_keep_ckpts controls the max number of ckpt file in your work_dirs -# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt -# it will remove the oldest one to keep the number of total ckpts as 3 -checkpoint_config = dict(interval=10, max_keep_ckpts=3) +default_hooks = dict(checkpoint=dict(max_keep_ckpts=3)) diff --git a/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-300e_in1k.py b/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-300e_in1k.py index ea166f58..77c61d20 100644 --- a/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-300e_in1k.py +++ b/configs/selfsup/byol/byol_resnet50_8xb256-fp16-accum2-coslr-300e_in1k.py @@ -1,4 +1,17 @@ _base_ = 'byol_resnet50_8xb256-fp16-accum2-coslr-200e_in1k.py' +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', T_max=290, by_epoch=True, begin=10, end=300) +] + # runtime settings -runner = dict(type='EpochBasedRunner', max_epochs=300) +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300) diff --git a/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-100e_in1k.py b/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-100e_in1k.py index 2b742ad4..b7271a8d 100644 --- a/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-100e_in1k.py +++ b/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-100e_in1k.py @@ -2,6 +2,19 @@ _base_ = 'byol_resnet50_8xb32-accum16-coslr-200e_in1k.py' # optimizer optimizer = dict(lr=7.2) +optim_wrapper = dict(optimizer=optimizer) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict(type='CosineAnnealingLR', T_max=90, by_epoch=True, begin=10, end=100) +] # runtime settings -runner = dict(type='EpochBasedRunner', max_epochs=100) +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100) diff --git a/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-200e_in1k.py b/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-200e_in1k.py index ef57e741..63296e89 100644 --- a/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-200e_in1k.py +++ b/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-200e_in1k.py @@ -5,28 +5,20 @@ _base_ = [ '../_base_/default_runtime.py', ] -# additional hooks -# interval for accumulate gradient, total 8*32*16(interval)=4096 -update_interval = 16 -custom_hooks = [ - dict(type='BYOLHook', end_momentum=1., update_interval=update_interval) -] - # optimizer -optimizer = dict( - type='LARS', - lr=4.8, - momentum=0.9, - weight_decay=1e-6, - paramwise_options={ - '(bn|gn)(\\d+)?.(weight|bias)': - dict(weight_decay=0., lars_exclude=True), - 'bias': dict(weight_decay=0., lars_exclude=True) - }) -optimizer_config = dict(update_interval=update_interval) +optimizer = dict(type='LARS', lr=4.8, momentum=0.9, weight_decay=1e-6) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=optimizer, + accumulative_iters=16, + paramwise_cfg=dict( + custom_keys={ + 'bn': dict(decay_mult=0, lars_exclude=True), + 'bias': dict(decay_mult=0, lars_exclude=True), + # bn layer in ResNet block downsample module + 'downsample.1': dict(decay_mult=0, lars_exclude=True), + }), +) # runtime settings -# the max_keep_ckpts controls the max number of ckpt file in your work_dirs -# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt -# it will remove the oldest one to keep the number of total ckpts as 3 -checkpoint_config = dict(interval=10, max_keep_ckpts=3) +default_hooks = dict(checkpoint=dict(max_keep_ckpts=3)) diff --git a/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-300e_in1k.py b/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-300e_in1k.py index 39556b16..8f3aeabc 100644 --- a/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-300e_in1k.py +++ b/configs/selfsup/byol/byol_resnet50_8xb32-accum16-coslr-300e_in1k.py @@ -1,4 +1,17 @@ _base_ = 'byol_resnet50_8xb32-accum16-coslr-200e_in1k.py' +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', T_max=290, by_epoch=True, begin=10, end=300) +] + # runtime settings -runner = dict(type='EpochBasedRunner', max_epochs=300) +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300) diff --git a/mmselfsup/core/optimizer/optimizers.py b/mmselfsup/core/optimizer/optimizers.py index a46b4c4f..522d7373 100644 --- a/mmselfsup/core/optimizer/optimizers.py +++ b/mmselfsup/core/optimizer/optimizers.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterable + import torch from torch.optim.optimizer import Optimizer @@ -9,24 +11,21 @@ from mmselfsup.registry import OPTIMIZERS class LARS(Optimizer): """Implements layer-wise adaptive rate scaling for SGD. - Args: - params (iterable): Iterable of parameters to optimize or dicts defining - parameter groups. - lr (float): Base learning rate. - momentum (float, optional): Momentum factor. Defaults to 0 ('m') - weight_decay (float, optional): Weight decay (L2 penalty). - Defaults to 0. ('beta') - dampening (float, optional): Dampening for momentum. Defaults to 0. - eta (float, optional): LARS coefficient. Defaults to 0.001. - nesterov (bool, optional): Enables Nesterov momentum. - Defaults to False. - eps (float, optional): A small number to avoid dviding zero. - Defaults to 1e-8. - Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. `Large Batch Training of Convolutional Networks: `_. + Args: + params (Iterable): Iterable of parameters to optimize or dicts defining + parameter groups. + lr (float): Base learning rate. + momentum (float): Momentum factor. Defaults to 0. + weight_decay (float): Weight decay (L2 penalty). Defaults to 0. + dampening (float): Dampening for momentum. Defaults to 0. + eta (float): LARS coefficient. Defaults to 0.001. + nesterov (bool): Enables Nesterov momentum. Defaults to False. + eps (float): A small number to avoid dviding zero. Defaults to 1e-8. + Example: >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, >>> weight_decay=1e-4, eta=1e-3) @@ -36,14 +35,14 @@ class LARS(Optimizer): """ def __init__(self, - params, - lr=float, - momentum=0, - weight_decay=0, - dampening=0, - eta=0.001, - nesterov=False, - eps=1e-8): + params: Iterable, + lr: float, + momentum: float = 0, + weight_decay: float = 0, + dampening: float = 0, + eta: float = 0.001, + nesterov: bool = False, + eps: float = 1e-8) -> None: if not isinstance(lr, float) and lr < 0.0: raise ValueError(f'Invalid learning rate: {lr}') if momentum < 0.0: @@ -65,15 +64,15 @@ class LARS(Optimizer): 'Nesterov momentum requires a momentum and zero dampening') self.eps = eps - super(LARS, self).__init__(params, defaults) + super().__init__(params, defaults) - def __setstate__(self, state): - super(LARS, self).__setstate__(state) + def __setstate__(self, state) -> None: + super().__setstate__(state) for group in self.param_groups: group.setdefault('nesterov', False) @torch.no_grad() - def step(self, closure=None): + def step(self, closure=None) -> torch.Tensor: """Performs a single optimization step. Args: diff --git a/mmselfsup/models/algorithms/byol.py b/mmselfsup/models/algorithms/byol.py index 609ff797..d01d9acf 100644 --- a/mmselfsup/models/algorithms/byol.py +++ b/mmselfsup/models/algorithms/byol.py @@ -5,13 +5,12 @@ import torch import torch.nn as nn from mmselfsup.core import SelfSupDataSample -from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss, - build_neck) +from mmselfsup.registry import MODELS from ..utils import CosineEMA from .base import BaseModel -@ALGORITHMS.register_module() +@MODELS.register_module() class BYOL(BaseModel): """BYOL. @@ -20,95 +19,83 @@ class BYOL(BaseModel): The momentum adjustment is in `core/hooks/byol_hook.py`. Args: - backbone (Dict, optional): Config dict for module of backbone. - neck (Dict, optional): Config dict for module of deep features - to compact feature vectors. Defaults to None. - head (Dict, optional): Config dict for module of head functions. - Defaults to None. - loss (Dict, optional): Config dict for module of loss functions. - Defaults to None. + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features + to compact feature vectors. + head (dict): Config dict for module of head functions. base_momentum (float): The base momentum coefficient for the target network. Defaults to 0.996. - preprocess_cfg (Dict, optional): Config dict to preprocess images. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (dict, optional): Config dict to preprocess images. Defaults to None. - init_cfg (Dict or List[Dict], optional): Config dict for weight + init_cfg (dict or List[dict], optional): Config dict for weight initialization. Defaults to None. """ def __init__(self, - backbone: Optional[Dict] = None, - neck: Optional[Dict] = None, - head: Optional[Dict] = None, - loss: Optional[Dict] = None, + backbone: dict, + neck: dict, + head: dict, base_momentum: float = 0.996, - preprocess_cfg: Optional[Dict] = None, - init_cfg: Optional[Union[Dict, List[Dict]]] = None, - **kwargs) -> None: - super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg) - assert backbone is not None - assert neck is not None - self.online_net = nn.Sequential( - build_backbone(backbone), build_neck(neck)) - self.backbone = self.online_net[0] - self.neck = self.online_net[1] - assert head is not None - self.head = build_head(head) - assert loss is not None - self.loss = build_loss(loss) + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) # create momentum model - self.target_net = CosineEMA(self.online_net, momentum=base_momentum) - for param_tgt in self.target_net.module.parameters(): - param_tgt.requires_grad = False + self.target_net = CosineEMA( + nn.Sequential(self.backbone, self.neck), momentum=base_momentum) - def extract_feat(self, inputs: List[torch.Tensor], - data_samples: List[SelfSupDataSample], + def extract_feat(self, batch_inputs: List[torch.Tensor], **kwargs) -> Tuple[torch.Tensor]: """Function to extract features from backbone. Args: inputs (List[torch.Tensor]): The input images. - data_samples (List[SelfSupDataSample]): All elements required - during the forward function. Returns: Tuple[torch.Tensor]: backbone outputs. """ - x = self.backbone(inputs[0]) + x = self.backbone(batch_inputs[0]) return x - def forward_train(self, inputs: List[torch.Tensor], - data_samples: List[SelfSupDataSample], - **kwargs) -> Dict[str, torch.Tensor]: + def loss(self, batch_inputs: List[torch.Tensor], + data_samples: List[SelfSupDataSample], + **kwargs) -> Dict[str, torch.Tensor]: """Forward computation during training. Args: - inputs (List[torch.Tensor]): The input images. + batch_inputs (List[torch.Tensor]): The input images. data_samples (List[SelfSupDataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ - assert isinstance(inputs, list) - img_v1 = inputs[0] - img_v2 = inputs[1] + assert isinstance(batch_inputs, list) + img_v1 = batch_inputs[0] + img_v2 = batch_inputs[1] # compute online features - proj_online_v1 = self.online_net(img_v1)[0] - proj_online_v2 = self.online_net(img_v2)[0] + proj_online_v1 = self.neck(self.backbone(img_v1))[0] + proj_online_v2 = self.neck(self.backbone(img_v2))[0] # compute target features with torch.no_grad(): # update the target net - self.target_net.update_parameters(self.online_net) + self.target_net.update_parameters( + nn.Sequential(self.backbone, self.neck)) proj_target_v1 = self.target_net(img_v1)[0] proj_target_v2 = self.target_net(img_v2)[0] - pred_1, target_1 = self.head(proj_online_v1, proj_target_v2) - pred_2, target_2 = self.head(proj_online_v2, proj_target_v1) - - loss_1 = self.loss(pred_1, target_1) - loss_2 = self.loss(pred_2, target_2) + loss_1 = self.head(proj_online_v1, proj_target_v2) + loss_2 = self.head(proj_online_v2, proj_target_v1) losses = dict(loss=2. * (loss_1 + loss_2)) return losses diff --git a/tests/test_models/test_algorithms/test_byol.py b/tests/test_models/test_algorithms/test_byol.py index ae4ed0ea..a287b8d3 100644 --- a/tests/test_models/test_algorithms/test_byol.py +++ b/tests/test_models/test_algorithms/test_byol.py @@ -26,6 +26,7 @@ neck = dict( norm_cfg=dict(type='BN1d')) head = dict( type='LatentPredictHead', + loss=dict(type='CosineSimilarityLoss'), predictor=dict( type='NonLinearNeck', in_channels=2, @@ -35,43 +36,20 @@ head = dict( with_last_bn=False, with_avg_pool=False, norm_cfg=dict(type='BN1d'))) -loss = dict(type='CosineSimilarityLoss') @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') def test_byol(): - preprocess_cfg = { - 'mean': [0.5, 0.5, 0.5], - 'std': [0.5, 0.5, 0.5], - 'to_rgb': True - } - with pytest.raises(AssertionError): - alg = BYOL( - backbone=backbone, - neck=None, - head=head, - loss=loss, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) - with pytest.raises(AssertionError): - alg = BYOL( - backbone=backbone, - neck=neck, - head=None, - loss=loss, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) - with pytest.raises(AssertionError): - alg = BYOL( - backbone=backbone, - neck=neck, - head=head, - loss=None, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) + data_preprocessor = dict( + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + bgr_to_rgb=True) + alg = BYOL( backbone=backbone, neck=neck, head=head, - loss=loss, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) + data_preprocessor=copy.deepcopy(data_preprocessor)) fake_data = [{ 'inputs': [torch.randn((3, 224, 224)), @@ -79,12 +57,11 @@ def test_byol(): 'data_sample': SelfSupDataSample() } for _ in range(2)] + fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data) - fake_outputs = alg(fake_data, return_loss=True) - assert isinstance(fake_outputs['loss'].item(), float) - assert fake_outputs['loss'].item() > -4 + fake_loss = alg(fake_inputs, fake_data_samples, mode='loss') + assert isinstance(fake_loss['loss'].item(), float) + assert fake_loss['loss'].item() > -4 - fake_inputs, fake_data_samples = alg.preprocss_data(fake_data) - fake_feat = alg.extract_feat( - inputs=fake_inputs, data_samples=fake_data_samples) - assert list(fake_feat[0].shape) == [2, 512, 7, 7] + fake_feats = alg(fake_inputs, fake_data_samples, mode='tensor') + assert list(fake_feats[0].shape) == [2, 512, 7, 7]