[Refactor]: refactor swav algorithm
parent
5f778aa552
commit
dfa4d180df
|
@ -1,4 +1,5 @@
|
||||||
# dataset settings
|
# dataset settings
|
||||||
|
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
|
||||||
dataset_type = 'mmcls.ImageNet'
|
dataset_type = 'mmcls.ImageNet'
|
||||||
data_root = 'data/imagenet/'
|
data_root = 'data/imagenet/'
|
||||||
file_client_args = dict(backend='disk')
|
file_client_args = dict(backend='disk')
|
||||||
|
@ -6,7 +7,9 @@ file_client_args = dict(backend='disk')
|
||||||
num_crops = [2, 6]
|
num_crops = [2, 6]
|
||||||
color_distort_strength = 1.0
|
color_distort_strength = 1.0
|
||||||
view_pipeline1 = [
|
view_pipeline1 = [
|
||||||
dict(type='RandomResizedCrop', size=224, scale=(0.14, 1.)),
|
dict(
|
||||||
|
type='RandomResizedCrop', size=224, scale=(0.14, 1.),
|
||||||
|
backend='pillow'),
|
||||||
dict(
|
dict(
|
||||||
type='RandomApply',
|
type='RandomApply',
|
||||||
transforms=[
|
transforms=[
|
||||||
|
@ -18,12 +21,20 @@ view_pipeline1 = [
|
||||||
hue=0.2 * color_distort_strength)
|
hue=0.2 * color_distort_strength)
|
||||||
],
|
],
|
||||||
prob=0.8),
|
prob=0.8),
|
||||||
dict(type='RandomGrayscale', prob=0.2, keep_channels=True),
|
dict(
|
||||||
|
type='RandomGrayscale',
|
||||||
|
prob=0.2,
|
||||||
|
keep_channels=True,
|
||||||
|
channel_weights=(0.114, 0.587, 0.2989)),
|
||||||
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
|
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
|
||||||
dict(type='RandomFlip', prob=0.5),
|
dict(type='RandomFlip', prob=0.5),
|
||||||
]
|
]
|
||||||
view_pipeline2 = [
|
view_pipeline2 = [
|
||||||
dict(type='RandomResizedCrop', size=96, scale=(0.05, 0.14)),
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
size=96,
|
||||||
|
scale=(0.05, 0.14),
|
||||||
|
backend='pillow'),
|
||||||
dict(
|
dict(
|
||||||
type='RandomApply',
|
type='RandomApply',
|
||||||
transforms=[
|
transforms=[
|
||||||
|
@ -35,7 +46,11 @@ view_pipeline2 = [
|
||||||
hue=0.2 * color_distort_strength)
|
hue=0.2 * color_distort_strength)
|
||||||
],
|
],
|
||||||
prob=0.8),
|
prob=0.8),
|
||||||
dict(type='RandomGrayscale', prob=0.2, keep_channels=True),
|
dict(
|
||||||
|
type='RandomGrayscale',
|
||||||
|
prob=0.2,
|
||||||
|
keep_channels=True,
|
||||||
|
channel_weights=(0.114, 0.587, 0.2989)),
|
||||||
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
|
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
|
||||||
dict(type='RandomFlip', prob=0.5),
|
dict(type='RandomFlip', prob=0.5),
|
||||||
]
|
]
|
||||||
|
@ -51,7 +66,8 @@ train_pipeline = [
|
||||||
|
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers=4,
|
num_workers=8,
|
||||||
|
drop_last=True,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
# model settings
|
# model settings
|
||||||
model = dict(
|
model = dict(
|
||||||
type='SwAV',
|
type='SwAV',
|
||||||
|
data_preprocessor=dict(
|
||||||
|
mean=(123.675, 116.28, 103.53),
|
||||||
|
std=(58.395, 57.12, 57.375),
|
||||||
|
bgr_to_rgb=True),
|
||||||
backbone=dict(
|
backbone=dict(
|
||||||
type='ResNet',
|
type='ResNet',
|
||||||
depth=50,
|
depth=50,
|
||||||
|
@ -14,10 +18,12 @@ model = dict(
|
||||||
hid_channels=2048,
|
hid_channels=2048,
|
||||||
out_channels=128,
|
out_channels=128,
|
||||||
with_avg_pool=True),
|
with_avg_pool=True),
|
||||||
loss=dict(
|
head=dict(
|
||||||
type='SwAVLoss',
|
type='SwAVHead',
|
||||||
feat_dim=128, # equal to neck['out_channels']
|
loss=dict(
|
||||||
epsilon=0.05,
|
type='SwAVLoss',
|
||||||
temperature=0.1,
|
feat_dim=128, # equal to neck['out_channels']
|
||||||
num_crops=[2, 6],
|
epsilon=0.05,
|
||||||
))
|
temperature=0.1,
|
||||||
|
num_crops=[2, 6],
|
||||||
|
)))
|
||||||
|
|
|
@ -6,7 +6,7 @@ _base_ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
# model settings
|
# model settings
|
||||||
model = dict(head=dict(num_crops={{_base_.num_crops}}))
|
model = dict(head=dict(loss=dict(num_crops={{_base_.num_crops}})))
|
||||||
|
|
||||||
# additional hooks
|
# additional hooks
|
||||||
custom_hooks = [
|
custom_hooks = [
|
||||||
|
@ -17,7 +17,8 @@ custom_hooks = [
|
||||||
epoch_queue_starts=15,
|
epoch_queue_starts=15,
|
||||||
crops_for_assign=[0, 1],
|
crops_for_assign=[0, 1],
|
||||||
feat_dim=128,
|
feat_dim=128,
|
||||||
queue_length=3840)
|
queue_length=3840,
|
||||||
|
frozen_layers_cfg=dict(prototypes=5005))
|
||||||
]
|
]
|
||||||
|
|
||||||
# dataset summary
|
# dataset summary
|
||||||
|
@ -25,7 +26,7 @@ data = dict(num_views={{_base_.num_crops}})
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
optimizer = dict(type='LARS', lr=0.6)
|
optimizer = dict(type='LARS', lr=0.6)
|
||||||
optimizer_config = dict(frozen_layers_cfg=dict(prototypes=5005))
|
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
|
||||||
|
|
||||||
# learning policy
|
# learning policy
|
||||||
param_scheduler = [
|
param_scheduler = [
|
||||||
|
@ -35,11 +36,14 @@ param_scheduler = [
|
||||||
eta_min=6e-4,
|
eta_min=6e-4,
|
||||||
by_epoch=True,
|
by_epoch=True,
|
||||||
begin=0,
|
begin=0,
|
||||||
end=200)
|
end=200,
|
||||||
|
convert_to_iter_based=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
# runtime settings
|
# runtime settings
|
||||||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
|
default_hooks = dict(
|
||||||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
|
logger=dict(type='LoggerHook', interval=50),
|
||||||
# it will remove the oldest one to keep the number of total ckpts as 3
|
# only keeps the latest 3 checkpoints
|
||||||
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
|
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
|
||||||
|
|
||||||
|
find_unused_parameters = True
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Sequence
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmengine.hooks import Hook
|
from mmengine.hooks import Hook
|
||||||
|
from mmengine.logging import MMLogger
|
||||||
|
|
||||||
from mmselfsup.registry import HOOKS
|
from mmselfsup.registry import HOOKS
|
||||||
|
|
||||||
|
@ -64,7 +65,10 @@ class SwAVHook(Hook):
|
||||||
# build the queue
|
# build the queue
|
||||||
if osp.isfile(self.queue_path):
|
if osp.isfile(self.queue_path):
|
||||||
self.queue = torch.load(self.queue_path)['queue']
|
self.queue = torch.load(self.queue_path)['queue']
|
||||||
runner.model.module.head.queue = self.queue
|
runner.model.module.head.loss.queue = self.queue
|
||||||
|
MMLogger.get_current_instance().info(
|
||||||
|
f'Load queue from file: {self.queue_path}')
|
||||||
|
|
||||||
# the queue needs to be divisible by the batch size
|
# the queue needs to be divisible by the batch size
|
||||||
self.queue_length -= self.queue_length % self.batch_size
|
self.queue_length -= self.queue_length % self.batch_size
|
||||||
|
|
||||||
|
@ -96,11 +100,11 @@ class SwAVHook(Hook):
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# set the boolean type of use_the_queue
|
# set the boolean type of use_the_queue
|
||||||
runner.model.module.head.queue = self.queue
|
runner.model.module.head.loss.queue = self.queue
|
||||||
runner.model.module.head.use_queue = False
|
runner.model.module.head.loss.use_queue = False
|
||||||
|
|
||||||
def after_train_epoch(self, runner) -> None:
|
def after_train_epoch(self, runner) -> None:
|
||||||
self.queue = runner.model.module.head.queue
|
self.queue = runner.model.module.head.loss.queue
|
||||||
|
|
||||||
if self.queue is not None and self.every_n_epochs(
|
if self.queue is not None and self.every_n_epochs(
|
||||||
runner, self.interval):
|
runner, self.interval):
|
||||||
|
|
|
@ -1,92 +1,62 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmselfsup.core import SelfSupDataSample
|
from mmselfsup.core import SelfSupDataSample
|
||||||
from ..builder import ALGORITHMS, build_backbone, build_loss, build_neck
|
from mmselfsup.registry import MODELS
|
||||||
from .base import BaseModel
|
from .base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@ALGORITHMS.register_module()
|
@MODELS.register_module()
|
||||||
class SwAV(BaseModel):
|
class SwAV(BaseModel):
|
||||||
"""SwAV.
|
"""SwAV.
|
||||||
|
|
||||||
Implementation of `Unsupervised Learning of Visual Features by Contrasting
|
Implementation of `Unsupervised Learning of Visual Features by Contrasting
|
||||||
Cluster Assignments <https://arxiv.org/abs/2006.09882>`_.
|
Cluster Assignments <https://arxiv.org/abs/2006.09882>`_. The queue is
|
||||||
The queue is built in `core/hooks/swav_hook.py`.
|
built in `core/hooks/swav_hook.py`.
|
||||||
|
|
||||||
Args:
|
|
||||||
backbone (Dict, optional): Config dict for module of backbone.
|
|
||||||
neck (Dict, optional): Config dict for module of deep features
|
|
||||||
to compact
|
|
||||||
feature vectors. Defaults to None.
|
|
||||||
loss (Dict, optional): Config dict for module of loss functions.
|
|
||||||
Defaults to None.
|
|
||||||
preprocess_cfg (Dict, optional): Config dict to preprocess images.
|
|
||||||
Defaults to None.
|
|
||||||
init_cfg (Dict or List[Dict], optional): Config dict for weight
|
|
||||||
initialization. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def extract_feat(self, batch_inputs: List[torch.Tensor],
|
||||||
backbone: Optional[Dict] = None,
|
|
||||||
neck: Optional[Dict] = None,
|
|
||||||
loss: Optional[Dict] = None,
|
|
||||||
preprocess_cfg: Optional[Dict] = None,
|
|
||||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
|
|
||||||
**kwargs) -> None:
|
|
||||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
|
||||||
assert backbone is not None
|
|
||||||
self.backbone = build_backbone(backbone)
|
|
||||||
assert neck is not None
|
|
||||||
self.neck = build_neck(neck)
|
|
||||||
assert loss is not None
|
|
||||||
self.loss = build_loss(loss)
|
|
||||||
|
|
||||||
def extract_feat(self, inputs: List[torch.Tensor],
|
|
||||||
data_samples: List[SelfSupDataSample],
|
|
||||||
**kwargs) -> Tuple[torch.Tensor]:
|
**kwargs) -> Tuple[torch.Tensor]:
|
||||||
"""Function to extract features from backbone.
|
"""Function to extract features from backbone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (List[torch.Tensor]): The input images.
|
batch_inputs (List[torch.Tensor]): The input images.
|
||||||
data_samples (List[SelfSupDataSample]): All elements required
|
|
||||||
during the forward function.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.Tensor]: backbone outputs.
|
Tuple[torch.Tensor]: backbone outputs.
|
||||||
"""
|
"""
|
||||||
x = self.backbone(inputs[0])
|
x = self.backbone(batch_inputs[0])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_train(self, inputs: List[torch.Tensor],
|
def loss(self, batch_inputs: List[torch.Tensor],
|
||||||
data_samples: List[SelfSupDataSample],
|
data_samples: List[SelfSupDataSample],
|
||||||
**kwargs) -> Dict[str, torch.Tensor]:
|
**kwargs) -> Dict[str, torch.Tensor]:
|
||||||
"""Forward computation during training.
|
"""Forward computation during training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (List[torch.Tensor]): The input images.
|
batch_inputs (List[torch.Tensor]): The input images.
|
||||||
data_samples (List[SelfSupDataSample]): All elements required
|
data_samples (List[SelfSupDataSample]): All elements required
|
||||||
during the forward function.
|
during the forward function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||||
"""
|
"""
|
||||||
assert isinstance(inputs, list)
|
assert isinstance(batch_inputs, list)
|
||||||
# multi-res forward passes
|
# multi-res forward passes
|
||||||
idx_crops = torch.cumsum(
|
idx_crops = torch.cumsum(
|
||||||
torch.unique_consecutive(
|
torch.unique_consecutive(
|
||||||
torch.tensor([input.shape[-1] for input in inputs]),
|
torch.tensor([input.shape[-1] for input in batch_inputs]),
|
||||||
return_counts=True)[1], 0)
|
return_counts=True)[1], 0)
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
output = []
|
output = []
|
||||||
for end_idx in idx_crops:
|
for end_idx in idx_crops:
|
||||||
_out = self.backbone(torch.cat(inputs[start_idx:end_idx]))
|
_out = self.backbone(torch.cat(batch_inputs[start_idx:end_idx]))
|
||||||
output.append(_out)
|
output.append(_out)
|
||||||
start_idx = end_idx
|
start_idx = end_idx
|
||||||
output = self.neck(output)[0]
|
output = self.neck(output)[0]
|
||||||
|
|
||||||
loss = self.loss(output)
|
loss = self.head(output)
|
||||||
losses = dict(loss=loss)
|
losses = dict(loss=loss)
|
||||||
return losses
|
return losses
|
||||||
|
|
|
@ -7,9 +7,11 @@ from .mae_head import MAEFinetuneHead, MAELinprobeHead, MAEPretrainHead
|
||||||
from .mocov3_head import MoCoV3Head
|
from .mocov3_head import MoCoV3Head
|
||||||
from .multi_cls_head import MultiClsHead
|
from .multi_cls_head import MultiClsHead
|
||||||
from .simmim_head import SimMIMHead
|
from .simmim_head import SimMIMHead
|
||||||
|
from .swav_head import SwAVHead
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ContrastiveHead', 'ClsHead', 'LatentPredictHead',
|
'ContrastiveHead', 'ClsHead', 'LatentPredictHead',
|
||||||
'LatentCrossCorrelationHead', 'MultiClsHead', 'MAEFinetuneHead',
|
'LatentCrossCorrelationHead', 'MultiClsHead', 'MAEFinetuneHead',
|
||||||
'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'MAELinprobeHead'
|
'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead',
|
||||||
|
'MAELinprobeHead', 'SwAVHead'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
|
from mmengine.model import BaseModule
|
||||||
|
|
||||||
|
from ..builder import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class SwAVHead(BaseModule):
|
||||||
|
"""Head for SwAV.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (dict): Config dict for module of loss functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loss: dict) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.loss = MODELS.build(loss)
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Forward function of SwAV head.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor): NxC input features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The SwAV loss.
|
||||||
|
"""
|
||||||
|
loss = self.loss(pred)
|
||||||
|
|
||||||
|
return loss
|
|
@ -8,11 +8,11 @@ import torch.nn as nn
|
||||||
from mmcv.runner import BaseModule
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
from mmselfsup.utils import distributed_sinkhorn
|
from mmselfsup.utils import distributed_sinkhorn
|
||||||
from ..builder import LOSSES
|
from ..builder import MODELS
|
||||||
from ..utils import MultiPrototypes
|
from ..utils import MultiPrototypes
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module()
|
@MODELS.register_module()
|
||||||
class SwAVLoss(BaseModule):
|
class SwAVLoss(BaseModule):
|
||||||
"""The Loss for SwAV.
|
"""The Loss for SwAV.
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ from torch.utils.data import Dataset
|
||||||
from mmselfsup.core.data_structures import SelfSupDataSample
|
from mmselfsup.core.data_structures import SelfSupDataSample
|
||||||
from mmselfsup.core.hooks import SwAVHook
|
from mmselfsup.core.hooks import SwAVHook
|
||||||
from mmselfsup.models.algorithms import BaseModel
|
from mmselfsup.models.algorithms import BaseModel
|
||||||
from mmselfsup.models.losses import SwAVLoss
|
from mmselfsup.models.heads import SwAVHead
|
||||||
from mmselfsup.registry import MODELS
|
from mmselfsup.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,7 +53,12 @@ class ToyModel(BaseModel):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(backbone=dict(type='SwAVDummyLayer'))
|
super().__init__(backbone=dict(type='SwAVDummyLayer'))
|
||||||
self.prototypes_test = nn.Linear(1, 1)
|
self.prototypes_test = nn.Linear(1, 1)
|
||||||
self.head = SwAVLoss(feat_dim=2, num_crops=[2, 6], num_prototypes=3)
|
self.head = SwAVHead(
|
||||||
|
loss=dict(
|
||||||
|
type='SwAVLoss',
|
||||||
|
feat_dim=2,
|
||||||
|
num_crops=[2, 6],
|
||||||
|
num_prototypes=3))
|
||||||
|
|
||||||
def loss(self, batch_inputs, data_samples):
|
def loss(self, batch_inputs, data_samples):
|
||||||
labels = []
|
labels = []
|
||||||
|
@ -119,4 +124,4 @@ class TestSwAVHook(TestCase):
|
||||||
if isinstance(hook, SwAVHook):
|
if isinstance(hook, SwAVHook):
|
||||||
assert hook.queue_length == 300
|
assert hook.queue_length == 300
|
||||||
|
|
||||||
assert runner.model.module.head.use_queue is False
|
assert runner.model.module.head.loss.use_queue is False
|
||||||
|
|
|
@ -24,44 +24,29 @@ neck = dict(
|
||||||
out_channels=2,
|
out_channels=2,
|
||||||
norm_cfg=dict(type='BN1d'),
|
norm_cfg=dict(type='BN1d'),
|
||||||
with_avg_pool=True)
|
with_avg_pool=True)
|
||||||
loss = dict(
|
head = dict(
|
||||||
type='SwAVLoss',
|
type='SwAVHead',
|
||||||
feat_dim=2, # equal to neck['out_channels']
|
loss=dict(
|
||||||
epsilon=0.05,
|
type='SwAVLoss',
|
||||||
temperature=0.1,
|
feat_dim=2, # equal to neck['out_channels']
|
||||||
num_crops=nmb_crops)
|
epsilon=0.05,
|
||||||
|
temperature=0.1,
|
||||||
|
num_crops=nmb_crops))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||||
def test_swav():
|
def test_swav():
|
||||||
preprocess_cfg = {
|
data_preprocessor = {
|
||||||
'mean': [0.5, 0.5, 0.5],
|
'mean': (123.675, 116.28, 103.53),
|
||||||
'std': [0.5, 0.5, 0.5],
|
'std': (58.395, 57.12, 57.375),
|
||||||
'to_rgb': True
|
'bgr_to_rgb': True
|
||||||
}
|
}
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
alg = SwAV(
|
|
||||||
backbone=backbone,
|
|
||||||
neck=neck,
|
|
||||||
loss=None,
|
|
||||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
alg = SwAV(
|
|
||||||
backbone=backbone,
|
|
||||||
neck=None,
|
|
||||||
loss=loss,
|
|
||||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
alg = SwAV(
|
|
||||||
backbone=None,
|
|
||||||
neck=neck,
|
|
||||||
loss=loss,
|
|
||||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
|
||||||
alg = SwAV(
|
alg = SwAV(
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
neck=neck,
|
neck=neck,
|
||||||
loss=loss,
|
head=head,
|
||||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
data_preprocessor=copy.deepcopy(data_preprocessor))
|
||||||
|
|
||||||
fake_data = [{
|
fake_data = [{
|
||||||
'inputs': [
|
'inputs': [
|
||||||
|
@ -78,10 +63,9 @@ def test_swav():
|
||||||
SelfSupDataSample()
|
SelfSupDataSample()
|
||||||
} for _ in range(2)]
|
} for _ in range(2)]
|
||||||
|
|
||||||
fake_outputs = alg(fake_data, return_loss=True)
|
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||||
|
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
||||||
assert isinstance(fake_outputs['loss'].item(), float)
|
assert isinstance(fake_outputs['loss'].item(), float)
|
||||||
|
|
||||||
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
|
fake_feat = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
|
||||||
fake_feat = alg.extract_feat(
|
|
||||||
inputs=fake_inputs, data_samples=fake_data_samples)
|
|
||||||
assert list(fake_feat[0].shape) == [2, 512, 7, 7]
|
assert list(fake_feat[0].shape) == [2, 512, 7, 7]
|
||||||
|
|
Loading…
Reference in New Issue