[Refactor] add UT codes
parent
f11be17d58
commit
8f9ff74736
Binary file not shown.
After Width: | Height: | Size: 35 KiB |
|
@ -0,0 +1,2 @@
|
|||
color.jpg 0
|
||||
gray.jpg 1
|
Binary file not shown.
After Width: | Height: | Size: 38 KiB |
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from mmselfsup.datasets import DATASOURCES
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dataset_name',
|
||||
['CIFAR10', 'CIFAR100', 'ImageNet', 'ImageList'])
|
||||
def test_data_sources_override_default(dataset_name):
|
||||
dataset_class = DATASOURCES.get(dataset_name)
|
||||
load_annotations_f = dataset_class.load_annotations
|
||||
dataset_class.load_annotations = MagicMock()
|
||||
|
||||
original_classes = dataset_class.CLASSES
|
||||
|
||||
# Test setting classes as a tuple
|
||||
dataset = dataset_class(data_prefix='', classes=('bus', 'car'))
|
||||
assert dataset.CLASSES == ('bus', 'car')
|
||||
|
||||
# Test setting classes as a list
|
||||
dataset = dataset_class(data_prefix='', classes=['bus', 'car'])
|
||||
assert dataset.CLASSES == ['bus', 'car']
|
||||
|
||||
# Test setting classes through a file
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
with open(tmp_file.name, 'w') as f:
|
||||
f.write('bus\ncar\n')
|
||||
dataset = dataset_class(data_prefix='', classes=tmp_file.name)
|
||||
tmp_file.close()
|
||||
|
||||
assert dataset.CLASSES == ['bus', 'car']
|
||||
|
||||
# Test overriding not a subset
|
||||
dataset = dataset_class(data_prefix='', classes=['foo'])
|
||||
assert dataset.CLASSES == ['foo']
|
||||
|
||||
# Test default behavior
|
||||
dataset = dataset_class(data_prefix='')
|
||||
assert dataset.data_prefix == ''
|
||||
assert dataset.ann_file is None
|
||||
assert dataset.CLASSES == original_classes
|
||||
|
||||
dataset_class.load_annotations = load_annotations_f
|
|
@ -0,0 +1,19 @@
|
|||
import os.path as osp
|
||||
|
||||
import pytest
|
||||
|
||||
from mmselfsup.datasets.data_sources import ImageList
|
||||
|
||||
|
||||
def test_image_list():
|
||||
data_source = dict(
|
||||
data_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||
ann_file=osp.join(osp.dirname(__file__), '../../data/data_list.txt'),
|
||||
)
|
||||
|
||||
dataset = ImageList(**data_source)
|
||||
assert len(dataset) == 2
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dataset = ImageList(
|
||||
data_prefix=osp.join(osp.dirname(__file__), '../../data'), )
|
|
@ -0,0 +1,18 @@
|
|||
import os.path as osp
|
||||
|
||||
import pytest
|
||||
|
||||
from mmselfsup.datasets.data_sources import ImageNet
|
||||
|
||||
|
||||
def test_imagenet():
|
||||
data_source = dict(data_prefix=osp.join(osp.dirname(__file__), '../../'))
|
||||
|
||||
dataset = ImageNet(**data_source)
|
||||
assert len(dataset) == 2
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
dataset = ImageNet(ann_file=1, **data_source)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
dataset = ImageNet(data_prefix=osp.join(osp.dirname(__file__)))
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from mmselfsup.datasets import BaseDataset, ConcatDataset, RepeatDataset
|
||||
|
||||
|
||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
||||
def construct_toy_dataset():
|
||||
BaseDataset.CLASSES = ('foo', 'bar')
|
||||
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
||||
data = dict(
|
||||
data_source=dict(
|
||||
type='ImageNet',
|
||||
data_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||
ann_file=osp.join(
|
||||
osp.dirname(__file__), '../../data/data_list.txt'),
|
||||
),
|
||||
pipeline=[])
|
||||
dataset = BaseDataset(**data)
|
||||
dataset.data_infos = MagicMock()
|
||||
return dataset
|
||||
|
||||
|
||||
def test_concat_dataset():
|
||||
dataset_a = construct_toy_dataset()
|
||||
dataset_b = construct_toy_dataset()
|
||||
|
||||
concat_dataset = ConcatDataset([dataset_a, dataset_b])
|
||||
assert concat_dataset[0] == 0
|
||||
assert concat_dataset[3] == 1
|
||||
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
|
||||
|
||||
|
||||
def test_repeat_dataset():
|
||||
dataset = construct_toy_dataset()
|
||||
|
||||
repeat_dataset = RepeatDataset(dataset, 10)
|
||||
assert repeat_dataset[5] == 1
|
||||
assert repeat_dataset[10] == 0
|
||||
assert len(repeat_dataset) == 10 * len(dataset)
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import pytest
|
||||
|
||||
from mmselfsup.datasets import DeepClusterDataset
|
||||
|
||||
# dataset settings
|
||||
data_source = 'ImageNet'
|
||||
dataset_type = 'DeepClusterDataset'
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [dict(type='RandomResizedCrop', size=4)]
|
||||
# prefetch
|
||||
prefetch = False
|
||||
if not prefetch:
|
||||
train_pipeline.extend(
|
||||
[dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg)])
|
||||
|
||||
|
||||
def test_deepcluster_dataset():
|
||||
data = dict(
|
||||
data_source=dict(
|
||||
type=data_source,
|
||||
data_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||
ann_file=osp.join(
|
||||
osp.dirname(__file__), '../../data/data_list.txt'),
|
||||
),
|
||||
pipeline=train_pipeline,
|
||||
prefetch=prefetch)
|
||||
dataset = DeepClusterDataset(**data)
|
||||
x = dataset[0]
|
||||
assert x['img'].size() == (3, 4, 4)
|
||||
assert x['pseudo_label'] == -1
|
||||
assert x['idx'] == 0
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dataset.assign_labels([1])
|
||||
|
||||
dataset.assign_labels([1, 0])
|
||||
assert dataset.clustering_labels[0] == 1
|
||||
assert dataset.clustering_labels[1] == 0
|
||||
|
||||
x = dataset[0]
|
||||
assert x['pseudo_label'] == 1
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import pytest
|
||||
|
||||
from mmselfsup.datasets import MultiViewDataset
|
||||
|
||||
# dataset settings
|
||||
data_source = 'ImageNet'
|
||||
dataset_type = 'MultiViewDataset'
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [dict(type='RandomResizedCrop', size=4)]
|
||||
# prefetch
|
||||
prefetch = False
|
||||
if not prefetch:
|
||||
train_pipeline.extend(
|
||||
[dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg)])
|
||||
|
||||
|
||||
def test_multi_views_dataste():
|
||||
data = dict(
|
||||
data_source=dict(
|
||||
type=data_source,
|
||||
data_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||
ann_file=osp.join(
|
||||
osp.dirname(__file__), '../../data/data_list.txt'),
|
||||
),
|
||||
num_views=[2],
|
||||
pipelines=[train_pipeline, train_pipeline],
|
||||
prefetch=prefetch)
|
||||
with pytest.raises(AssertionError):
|
||||
dataset = MultiViewDataset(**data)
|
||||
|
||||
# test dataset
|
||||
data = dict(
|
||||
data_source=dict(
|
||||
type=data_source,
|
||||
data_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||
ann_file=osp.join(
|
||||
osp.dirname(__file__), '../../data/data_list.txt'),
|
||||
),
|
||||
num_views=[2, 6],
|
||||
pipelines=[train_pipeline, train_pipeline],
|
||||
prefetch=prefetch)
|
||||
dataset = MultiViewDataset(**data)
|
||||
x = dataset[0]
|
||||
assert isinstance(x['img'], list)
|
||||
assert len(x['img']) == 8
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmselfsup.datasets import RelativeLocDataset
|
||||
|
||||
# dataset settings
|
||||
data_source = 'ImageNet'
|
||||
dataset_type = 'RelativeLocDataset'
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [dict(type='RandomResizedCrop', size=224)]
|
||||
# prefetch
|
||||
format_pipeline = [
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]
|
||||
|
||||
|
||||
def test_relative_loc_dataset():
|
||||
# prefetch False
|
||||
data = dict(
|
||||
data_source=dict(
|
||||
type=data_source,
|
||||
data_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||
ann_file=osp.join(
|
||||
osp.dirname(__file__), '../../data/data_list.txt'),
|
||||
),
|
||||
pipeline=train_pipeline,
|
||||
format_pipeline=format_pipeline)
|
||||
dataset = RelativeLocDataset(**data)
|
||||
x = dataset[0]
|
||||
split_per_side = 3
|
||||
patch_jitter = 21
|
||||
h_grid = 224 // split_per_side
|
||||
w_grid = 224 // split_per_side
|
||||
h_patch = h_grid - patch_jitter
|
||||
w_patch = w_grid - patch_jitter
|
||||
assert x['img'].size() == (8, 6, h_patch, w_patch)
|
||||
assert (x['patch_label'].numpy() == np.array([0, 1, 2, 3, 4, 5, 6,
|
||||
7])).all()
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmselfsup.datasets import RotationPredDataset
|
||||
|
||||
# dataset settings
|
||||
data_source = 'ImageNet'
|
||||
dataset_type = 'RotationPredDataset'
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [dict(type='RandomResizedCrop', size=4)]
|
||||
# prefetch
|
||||
prefetch = False
|
||||
if not prefetch:
|
||||
train_pipeline.extend(
|
||||
[dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg)])
|
||||
|
||||
|
||||
def test_rotation_pred_dataset():
|
||||
# prefetch False
|
||||
data = dict(
|
||||
data_source=dict(
|
||||
type=data_source,
|
||||
data_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||
ann_file=osp.join(
|
||||
osp.dirname(__file__), '../../data/data_list.txt'),
|
||||
),
|
||||
pipeline=train_pipeline,
|
||||
prefetch=prefetch)
|
||||
dataset = RotationPredDataset(**data)
|
||||
x = dataset[0]
|
||||
assert x['img'].size() == (4, 3, 4, 4)
|
||||
assert (x['rot_label'].numpy() == np.array([0, 1, 2, 3])).all()
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmselfsup.datasets import SingleViewDataset
|
||||
|
||||
# dataset settings
|
||||
data_source = 'ImageNet'
|
||||
dataset_type = 'MultiViewDataset'
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [dict(type='RandomResizedCrop', size=4)]
|
||||
# prefetch
|
||||
prefetch = False
|
||||
if not prefetch:
|
||||
train_pipeline.extend(
|
||||
[dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg)])
|
||||
|
||||
|
||||
def test_one_view_dataset():
|
||||
data = dict(
|
||||
data_source=dict(
|
||||
type=data_source,
|
||||
data_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||
ann_file=osp.join(
|
||||
osp.dirname(__file__), '../../data/data_list.txt'),
|
||||
),
|
||||
pipeline=train_pipeline,
|
||||
prefetch=prefetch)
|
||||
dataset = SingleViewDataset(**data)
|
||||
fake_results = {'test': np.array([[0.7, 0, 0.3], [0.5, 0.3, 0.2]])}
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
eval_res = dataset.evaluate({'test': np.array([[0.7, 0, 0.3]])},
|
||||
topk=(1))
|
||||
|
||||
eval_res = dataset.evaluate(fake_results, topk=(1, 2))
|
||||
assert eval_res['test_top1'] == 1 * 100.0 / 2
|
||||
assert eval_res['test_top2'] == 2 * 100.0 / 2
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.utils import build_from_cfg
|
||||
from PIL import Image
|
||||
|
||||
from mmselfsup.datasets.builder import PIPELINES
|
||||
|
||||
|
||||
def test_random_applied_trans():
|
||||
img = Image.open(osp.join(osp.dirname(__file__), '../data/color.jpg'))
|
||||
|
||||
# p=0.5
|
||||
transform = dict(
|
||||
type='RandomAppliedTrans', transforms=[dict(type='Solarization')])
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
res = module(img)
|
||||
assert img.size == res.size
|
||||
|
||||
transform = dict(
|
||||
type='RandomAppliedTrans',
|
||||
transforms=[dict(type='Solarization')],
|
||||
p=0.)
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
res = module(img)
|
||||
assert img.size == res.size
|
||||
|
||||
# p=1.
|
||||
transform = dict(
|
||||
type='RandomAppliedTrans',
|
||||
transforms=[dict(type='Solarization')],
|
||||
p=1.)
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
res = module(img)
|
||||
assert img.size == res.size
|
||||
|
||||
|
||||
def test_lighting():
|
||||
transform = dict(type='Lighting')
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
img = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/color.jpg')))
|
||||
with pytest.raises(AssertionError):
|
||||
res = module(img)
|
||||
|
||||
img = torch.from_numpy(img).float().permute(2, 1, 0)
|
||||
res = module(img)
|
||||
|
||||
assert img.size() == res.size()
|
||||
|
||||
|
||||
def test_gaussianblur():
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='GaussianBlur', sigma_min=0.1, sigma_max=1.0, p=-1)
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
img = Image.open(osp.join(osp.dirname(__file__), '../data/color.jpg'))
|
||||
|
||||
# p=0.5
|
||||
transform = dict(type='GaussianBlur', sigma_min=0.1, sigma_max=1.0)
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
res = module(img)
|
||||
|
||||
transform = dict(type='GaussianBlur', sigma_min=0.1, sigma_max=1.0, p=0.)
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
res = module(img)
|
||||
|
||||
transform = dict(type='GaussianBlur', sigma_min=0.1, sigma_max=1.0, p=1.)
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
res = module(img)
|
||||
|
||||
assert img.size == res.size
|
||||
|
||||
|
||||
def test_solarization():
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Solarization', p=-1)
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
img = Image.open(osp.join(osp.dirname(__file__), '../data/color.jpg'))
|
||||
|
||||
# p=0.5
|
||||
transform = dict(type='Solarization')
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
res = module(img)
|
||||
|
||||
transform = dict(type='Solarization', p=0.)
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
res = module(img)
|
||||
|
||||
transform = dict(type='Solarization', p=1.)
|
||||
module = build_from_cfg(transform, PIPELINES)
|
||||
res = module(img)
|
||||
|
||||
assert img.size == res.size
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from mmselfsup.datasets.utils import check_integrity, rm_suffix, to_numpy
|
||||
|
||||
|
||||
def test_to_numpy():
|
||||
pil_img = Image.open(osp.join(osp.dirname(__file__), '../data/color.jpg'))
|
||||
np_img = to_numpy(pil_img)
|
||||
assert type(np_img) == np.ndarray
|
||||
if np_img.ndim < 3:
|
||||
assert np_img.shape[0] == 1
|
||||
elif np_img.ndim == 3:
|
||||
assert np_img.shape[0] == 3
|
||||
|
||||
|
||||
def test_dataset_utils():
|
||||
# test rm_suffix
|
||||
assert rm_suffix('a.jpg') == 'a'
|
||||
assert rm_suffix('a.bak.jpg') == 'a.bak'
|
||||
assert rm_suffix('a.bak.jpg', suffix='.jpg') == 'a.bak'
|
||||
assert rm_suffix('a.bak.jpg', suffix='.bak.jpg') == 'a'
|
||||
|
||||
# test check_integrity
|
||||
rand_file = ''.join(random.sample(string.ascii_letters, 10))
|
||||
assert not check_integrity(rand_file, md5=None)
|
||||
assert not check_integrity(rand_file, md5=2333)
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
assert check_integrity(tmp_file.name, md5=None)
|
||||
assert not check_integrity(tmp_file.name, md5=2333)
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.utils import Accuracy
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
pred = torch.Tensor([[0.2, 0.3, 0.5], [0.25, 0.15, 0.6], [0.9, 0.05, 0.05],
|
||||
[0.8, 0.1, 0.1], [0.55, 0.15, 0.3]])
|
||||
target = torch.zeros(5)
|
||||
|
||||
acc = Accuracy((1, 2))
|
||||
res = acc.forward(pred, target)
|
||||
assert res[0].item() == 60.
|
||||
assert res[1].item() == 80.
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import BYOL
|
||||
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(
|
||||
type='NonLinearNeck',
|
||||
in_channels=2048,
|
||||
hid_channels=4,
|
||||
out_channels=4,
|
||||
with_bias=True,
|
||||
with_last_bn=False,
|
||||
with_avg_pool=True,
|
||||
norm_cfg=dict(type='BN1d'))
|
||||
head = dict(
|
||||
type='LatentPredictHead',
|
||||
predictor=dict(
|
||||
type='NonLinearNeck',
|
||||
in_channels=4,
|
||||
hid_channels=4,
|
||||
out_channels=4,
|
||||
with_bias=True,
|
||||
with_last_bn=False,
|
||||
with_avg_pool=False,
|
||||
norm_cfg=dict(type='BN1d')))
|
||||
|
||||
|
||||
def test_byol():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = BYOL(backbone=backbone, neck=None, head=head)
|
||||
with pytest.raises(AssertionError):
|
||||
alg = BYOL(backbone=backbone, neck=neck, head=None)
|
||||
|
||||
alg = BYOL(backbone=backbone, neck=neck, head=head)
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_backbone_out = alg.extract_feat(fake_input)
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
||||
with pytest.raises(AssertionError):
|
||||
fake_out = alg.forward_train(fake_input)
|
||||
|
||||
fake_input = [
|
||||
torch.randn((16, 3, 224, 224)),
|
||||
torch.randn((16, 3, 224, 224))
|
||||
]
|
||||
fake_out = alg.forward_train(fake_input)
|
||||
assert fake_out['loss'].item() > -4
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import Classification
|
||||
|
||||
with_sobel = True,
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=2,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'),
|
||||
frozen_stages=4)
|
||||
head = dict(
|
||||
type='ClsHead', with_avg_pool=True, in_channels=2048, num_classes=4)
|
||||
|
||||
|
||||
def test_classification():
|
||||
alg = Classification(backbone=backbone, with_sobel=with_sobel, head=head)
|
||||
assert hasattr(alg, 'sobel_layer')
|
||||
assert hasattr(alg, 'head')
|
||||
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_labels = torch.ones(16, dtype=torch.long)
|
||||
fake_backbone_out = alg.extract_feat(fake_input)
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
||||
fake_out = alg.forward_train(fake_input, fake_labels)
|
||||
assert fake_out['loss'].item() > 0
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import DeepCluster
|
||||
|
||||
num_classes = 5
|
||||
with_sobel = True,
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=2,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(type='AvgPool2dNeck')
|
||||
head = dict(
|
||||
type='ClsHead',
|
||||
with_avg_pool=False, # already has avgpool in the neck
|
||||
in_channels=2048,
|
||||
num_classes=num_classes)
|
||||
|
||||
|
||||
def test_deepcluster():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = DeepCluster(
|
||||
backbone=backbone, with_sobel=with_sobel, neck=neck, head=None)
|
||||
alg = DeepCluster(
|
||||
backbone=backbone, with_sobel=with_sobel, neck=neck, head=head)
|
||||
assert alg.num_classes == num_classes
|
||||
assert hasattr(alg, 'sobel_layer')
|
||||
assert hasattr(alg, 'neck')
|
||||
assert hasattr(alg, 'head')
|
||||
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_labels = torch.ones(16, dtype=torch.long)
|
||||
fake_backbone_out = alg.extract_feat(fake_input)
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
||||
fake_out = alg.forward_train(fake_input, fake_labels)
|
||||
assert fake_out['loss'].item() > 0
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import DenseCL
|
||||
|
||||
queue_len = 65536
|
||||
feat_dim = 128
|
||||
momentum = 0.999
|
||||
loss_lambda = 0.5
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(
|
||||
type='DenseCLNeck',
|
||||
in_channels=2048,
|
||||
hid_channels=4,
|
||||
out_channels=4,
|
||||
num_grid=None)
|
||||
head = dict(type='ContrastiveHead', temperature=0.2)
|
||||
|
||||
|
||||
def test_densecl():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = DenseCL(backbone=backbone, neck=None, head=head)
|
||||
with pytest.raises(AssertionError):
|
||||
alg = DenseCL(backbone=backbone, neck=neck, head=None)
|
||||
|
||||
alg = DenseCL(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
queue_len=queue_len,
|
||||
feat_dim=feat_dim,
|
||||
momentum=momentum,
|
||||
loss_lambda=loss_lambda)
|
||||
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
|
||||
assert alg.queue2.size() == torch.Size([feat_dim, queue_len])
|
||||
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_backbone_out = alg.extract_feat(fake_input)
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
||||
with pytest.raises(AssertionError):
|
||||
fake_backbone_out = alg.forward_train(fake_input)
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import MoCo
|
||||
|
||||
queue_len = 8
|
||||
feat_dim = 4
|
||||
momentum = 0.999
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(
|
||||
type='MoCoV2Neck',
|
||||
in_channels=2048,
|
||||
hid_channels=4,
|
||||
out_channels=4,
|
||||
with_avg_pool=True)
|
||||
head = dict(type='ContrastiveHead', temperature=0.2)
|
||||
|
||||
|
||||
def test_moco():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MoCo(backbone=backbone, neck=None, head=head)
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MoCo(backbone=backbone, neck=neck, head=None)
|
||||
|
||||
alg = MoCo(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
queue_len=queue_len,
|
||||
feat_dim=feat_dim,
|
||||
momentum=momentum)
|
||||
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
|
||||
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_backbone_out = alg.extract_feat(fake_input)
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
||||
with pytest.raises(AssertionError):
|
||||
fake_backbone_out = alg.forward_train(fake_input)
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import NPID
|
||||
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(
|
||||
type='LinearNeck', in_channels=2048, out_channels=4, with_avg_pool=True)
|
||||
head = dict(type='ContrastiveHead', temperature=0.07)
|
||||
memory_bank = dict(type='SimpleMemory', length=8, feat_dim=4, momentum=0.5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='CUDA is not available.')
|
||||
def test_npid():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = NPID(backbone=backbone, neck=neck, head=head, memory_bank=None)
|
||||
with pytest.raises(AssertionError):
|
||||
alg = NPID(
|
||||
backbone=backbone, neck=neck, head=None, memory_bank=memory_bank)
|
||||
|
||||
alg = NPID(
|
||||
backbone=backbone, neck=neck, head=head, memory_bank=memory_bank)
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_backbone_out = alg.extract_feat(fake_input)
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import ODC
|
||||
|
||||
num_classes = 5
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(
|
||||
type='ODCNeck',
|
||||
in_channels=2048,
|
||||
hid_channels=4,
|
||||
out_channels=4,
|
||||
with_avg_pool=True)
|
||||
head = dict(
|
||||
type='ClsHead',
|
||||
with_avg_pool=False,
|
||||
in_channels=4,
|
||||
num_classes=num_classes)
|
||||
memory_bank = dict(
|
||||
type='ODCMemory',
|
||||
length=8,
|
||||
feat_dim=4,
|
||||
momentum=0.5,
|
||||
num_classes=num_classes,
|
||||
min_cluster=2,
|
||||
debug=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='CUDA is not available.')
|
||||
def test_odc():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = ODC(backbone=backbone, neck=neck, head=head, memory_bank=None)
|
||||
with pytest.raises(AssertionError):
|
||||
alg = ODC(
|
||||
backbone=backbone, neck=neck, head=None, memory_bank=memory_bank)
|
||||
|
||||
alg = ODC(backbone=backbone, neck=neck, head=head, memory_bank=memory_bank)
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_backbone_out = alg.extract_feat(fake_input)
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import RelativeLoc
|
||||
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(
|
||||
type='RelativeLocNeck',
|
||||
in_channels=2048,
|
||||
out_channels=4,
|
||||
with_avg_pool=True)
|
||||
head = dict(type='ClsHead', with_avg_pool=False, in_channels=4, num_classes=8)
|
||||
|
||||
|
||||
def test_relative_loc():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = RelativeLoc(backbone=backbone, neck=None, head=head)
|
||||
with pytest.raises(AssertionError):
|
||||
alg = RelativeLoc(backbone=backbone, neck=neck, head=None)
|
||||
|
||||
alg = RelativeLoc(backbone=backbone, neck=neck, head=head)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
fake_input = torch.randn((2, 8, 6, 224, 224))
|
||||
patch_labels = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
alg.forward(fake_input, patch_labels)
|
||||
|
||||
# train
|
||||
fake_input = torch.randn((2, 8, 6, 224, 224))
|
||||
patch_labels = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]])
|
||||
fake_out = alg.forward(fake_input, patch_labels)
|
||||
assert fake_out['loss'].item() > 0
|
||||
|
||||
# test
|
||||
fake_input = torch.randn((2, 8, 6, 224, 224))
|
||||
patch_labels = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]])
|
||||
fake_out = alg.forward(fake_input, patch_labels, mode='test')
|
||||
assert 'head4' in fake_out
|
||||
|
||||
# extract
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_backbone_out = alg.forward(fake_input, mode='extract')
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import RotationPred
|
||||
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
head = dict(
|
||||
type='ClsHead', with_avg_pool=True, in_channels=2048, num_classes=4)
|
||||
|
||||
|
||||
def test_rotation_pred():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = RotationPred(backbone=backbone, head=None)
|
||||
|
||||
alg = RotationPred(backbone=backbone, head=head)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
fake_input = torch.randn((2, 4, 3, 224, 224))
|
||||
rotation_labels = torch.LongTensor([0, 1, 2, 3])
|
||||
alg.forward(fake_input, rotation_labels)
|
||||
|
||||
# train
|
||||
fake_input = torch.randn((2, 4, 3, 224, 224))
|
||||
rotation_labels = torch.LongTensor([[0, 1, 2, 3], [0, 1, 2, 3]])
|
||||
fake_out = alg.forward(fake_input, rotation_labels)
|
||||
assert fake_out['loss'].item() > 0
|
||||
|
||||
# test
|
||||
fake_input = torch.randn((2, 4, 3, 224, 224))
|
||||
rotation_labels = torch.LongTensor([[0, 1, 2, 3], [0, 1, 2, 3]])
|
||||
fake_out = alg.forward(fake_input, rotation_labels, mode='test')
|
||||
assert 'head4' in fake_out
|
||||
|
||||
# extract
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_backbone_out = alg.forward(fake_input, mode='extract')
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import SimCLR
|
||||
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(
|
||||
type='NonLinearNeck', # SimCLR non-linear neck
|
||||
in_channels=2048,
|
||||
hid_channels=4,
|
||||
out_channels=4,
|
||||
num_layers=2,
|
||||
with_avg_pool=True)
|
||||
head = dict(type='ContrastiveHead', temperature=0.1)
|
||||
|
||||
|
||||
def test_simclr():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = SimCLR(backbone=backbone, neck=None, head=head)
|
||||
with pytest.raises(AssertionError):
|
||||
alg = SimCLR(backbone=backbone, neck=neck, head=None)
|
||||
|
||||
alg = SimCLR(backbone=backbone, neck=neck, head=head)
|
||||
with pytest.raises(AssertionError):
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
alg.forward_train(fake_input)
|
||||
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_backbone_out = alg.extract_feat(fake_input)
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import SimSiam
|
||||
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'),
|
||||
zero_init_residual=True)
|
||||
neck = dict(
|
||||
type='NonLinearNeck',
|
||||
in_channels=2048,
|
||||
hid_channels=4,
|
||||
out_channels=4,
|
||||
num_layers=3,
|
||||
with_last_bn_affine=False,
|
||||
with_avg_pool=True,
|
||||
norm_cfg=dict(type='BN1d'))
|
||||
head = dict(
|
||||
type='LatentPredictHead',
|
||||
predictor=dict(
|
||||
type='NonLinearNeck',
|
||||
in_channels=4,
|
||||
hid_channels=4,
|
||||
out_channels=4,
|
||||
with_avg_pool=False,
|
||||
with_last_bn=False,
|
||||
with_last_bias=True,
|
||||
norm_cfg=dict(type='BN1d')))
|
||||
|
||||
|
||||
def test_simsiam():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = SimSiam(backbone=backbone, neck=neck, head=None)
|
||||
|
||||
alg = SimSiam(backbone=backbone, neck=neck, head=head)
|
||||
with pytest.raises(AssertionError):
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
alg.forward_train(fake_input)
|
||||
|
||||
fake_input = [
|
||||
torch.randn((16, 3, 224, 224)),
|
||||
torch.randn((16, 3, 224, 224))
|
||||
]
|
||||
fake_out = alg.forward(fake_input)
|
||||
assert fake_out['loss'].item() > -1
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.algorithms import SwAV
|
||||
|
||||
nmb_crops = [2, 6]
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'),
|
||||
zero_init_residual=True)
|
||||
neck = dict(
|
||||
type='SwAVNeck',
|
||||
in_channels=2048,
|
||||
hid_channels=4,
|
||||
out_channels=4,
|
||||
norm_cfg=dict(type='BN1d'),
|
||||
with_avg_pool=True)
|
||||
head = dict(
|
||||
type='SwAVHead',
|
||||
feat_dim=4, # equal to neck['out_channels']
|
||||
epsilon=0.05,
|
||||
temperature=0.1,
|
||||
num_crops=nmb_crops)
|
||||
|
||||
|
||||
def test_swav():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = SwAV(backbone=backbone, neck=neck, head=None)
|
||||
with pytest.raises(AssertionError):
|
||||
alg = SwAV(backbone=backbone, neck=None, head=head)
|
||||
|
||||
alg = SwAV(backbone=backbone, neck=neck, head=head)
|
||||
fake_input = torch.randn((16, 3, 224, 224))
|
||||
fake_backbone_out = alg.extract_feat(fake_input)
|
||||
assert fake_backbone_out[0].size() == torch.Size([16, 2048, 7, 7])
|
||||
|
||||
fake_input = [
|
||||
torch.randn((16, 3, 224, 224)),
|
||||
torch.randn((16, 3, 224, 224)),
|
||||
torch.randn((16, 3, 96, 96)),
|
||||
torch.randn((16, 3, 96, 96)),
|
||||
torch.randn((16, 3, 96, 96)),
|
||||
torch.randn((16, 3, 96, 96)),
|
||||
torch.randn((16, 3, 96, 96)),
|
||||
torch.randn((16, 3, 96, 96)),
|
||||
]
|
||||
fake_out = alg.forward_train(fake_input)
|
||||
assert fake_out['loss'].item() > 0
|
|
@ -0,0 +1,245 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmselfsup.models.backbones import ResNet
|
||||
from mmselfsup.models.backbones.resnet import BasicBlock, Bottleneck
|
||||
|
||||
|
||||
def is_block(modules):
|
||||
"""Check if is ResNet building block."""
|
||||
if isinstance(modules, (BasicBlock, Bottleneck)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def all_zeros(modules):
|
||||
"""Check if the weight(and bias) is all zero."""
|
||||
weight_zero = torch.equal(modules.weight.data,
|
||||
torch.zeros_like(modules.weight.data))
|
||||
if hasattr(modules, 'bias'):
|
||||
bias_zero = torch.equal(modules.bias.data,
|
||||
torch.zeros_like(modules.bias.data))
|
||||
else:
|
||||
bias_zero = True
|
||||
|
||||
return weight_zero and bias_zero
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_basic_block():
|
||||
# BasicBlock with stride 1, out_channels == in_channels
|
||||
block = BasicBlock(64, 64)
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 64
|
||||
assert block.conv1.kernel_size == (3, 3)
|
||||
assert block.conv1.stride == (1, 1)
|
||||
assert block.conv2.in_channels == 64
|
||||
assert block.conv2.out_channels == 64
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
# BasicBlock with stride 1 and downsample
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.BatchNorm2d(128))
|
||||
block = BasicBlock(64, 128, downsample=downsample)
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 128
|
||||
assert block.conv1.kernel_size == (3, 3)
|
||||
assert block.conv1.stride == (1, 1)
|
||||
assert block.conv2.in_channels == 128
|
||||
assert block.conv2.out_channels == 128
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 128, 56, 56])
|
||||
|
||||
# BasicBlock with stride 2 and downsample
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
|
||||
nn.BatchNorm2d(128))
|
||||
block = BasicBlock(64, 128, stride=2, downsample=downsample)
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 128
|
||||
assert block.conv1.kernel_size == (3, 3)
|
||||
assert block.conv1.stride == (2, 2)
|
||||
assert block.conv2.in_channels == 128
|
||||
assert block.conv2.out_channels == 128
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 128, 28, 28])
|
||||
|
||||
|
||||
def test_bottleneck():
|
||||
# Test Bottleneck style
|
||||
block = Bottleneck(64, 64, stride=2, style='pytorch')
|
||||
assert block.conv1.stride == (1, 1)
|
||||
assert block.conv2.stride == (2, 2)
|
||||
block = Bottleneck(64, 64, stride=2, style='caffe')
|
||||
assert block.conv1.stride == (2, 2)
|
||||
assert block.conv2.stride == (1, 1)
|
||||
|
||||
# Bottleneck with stride 1
|
||||
block = Bottleneck(64, 16, style='pytorch')
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 16
|
||||
assert block.conv1.kernel_size == (1, 1)
|
||||
assert block.conv2.in_channels == 16
|
||||
assert block.conv2.out_channels == 16
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
assert block.conv3.in_channels == 16
|
||||
assert block.conv3.out_channels == 64
|
||||
assert block.conv3.kernel_size == (1, 1)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == (1, 64, 56, 56)
|
||||
|
||||
# Bottleneck with stride 1 and downsample
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(64, 256, kernel_size=1), nn.BatchNorm2d(256))
|
||||
block = Bottleneck(64, 64, style='pytorch', downsample=downsample)
|
||||
assert block.conv1.in_channels == 64
|
||||
assert block.conv1.out_channels == 64
|
||||
assert block.conv1.kernel_size == (1, 1)
|
||||
assert block.conv2.in_channels == 64
|
||||
assert block.conv2.out_channels == 64
|
||||
assert block.conv2.kernel_size == (3, 3)
|
||||
assert block.conv3.in_channels == 64
|
||||
assert block.conv3.out_channels == 256
|
||||
assert block.conv3.kernel_size == (1, 1)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == (1, 256, 56, 56)
|
||||
|
||||
# Bottleneck with stride 2 and downsample
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(64, 256, kernel_size=1, stride=2), nn.BatchNorm2d(256))
|
||||
block = Bottleneck(
|
||||
64, 64, stride=2, style='pytorch', downsample=downsample)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == (1, 256, 28, 28)
|
||||
|
||||
# Test Bottleneck with checkpointing
|
||||
block = Bottleneck(64, 16, with_cp=True)
|
||||
block.train()
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 64, 56, 56, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_resnet():
|
||||
"""Test resnet backbone."""
|
||||
# Test ResNet50 norm_eval=True
|
||||
model = ResNet(50, norm_eval=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test ResNet50 with torchvision pretrained weight
|
||||
model = ResNet(depth=50, norm_eval=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test ResNet50 with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = ResNet(50, frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert model.norm1.training is False
|
||||
for layer in [model.conv1, model.norm1]:
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'layer{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test ResNet18 forward
|
||||
model = ResNet(18, out_indices=(0, 1, 2, 3, 4))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 64, 56, 56)
|
||||
assert feat[2].shape == (1, 128, 28, 28)
|
||||
assert feat[3].shape == (1, 256, 14, 14)
|
||||
assert feat[4].shape == (1, 512, 7, 7)
|
||||
|
||||
# Test ResNet50 with BatchNorm forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3, 4))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 256, 56, 56)
|
||||
assert feat[2].shape == (1, 512, 28, 28)
|
||||
assert feat[3].shape == (1, 1024, 14, 14)
|
||||
assert feat[4].shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50 with layers 3 (top feature maps) out forward
|
||||
model = ResNet(50, out_indices=(4, ))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[0].shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50 with checkpoint forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3, 4), with_cp=True)
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 256, 56, 56)
|
||||
assert feat[2].shape == (1, 512, 28, 28)
|
||||
assert feat[3].shape == (1, 1024, 14, 14)
|
||||
assert feat[4].shape == (1, 2048, 7, 7)
|
||||
|
||||
# zero initialization of residual blocks
|
||||
model = ResNet(50, zero_init_residual=True)
|
||||
model.init_weights()
|
||||
for m in model.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
assert all_zeros(m.norm3)
|
||||
elif isinstance(m, BasicBlock):
|
||||
assert all_zeros(m.norm2)
|
||||
|
||||
# non-zero initialization of residual blocks
|
||||
model = ResNet(50, zero_init_residual=False)
|
||||
model.init_weights()
|
||||
for m in model.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
assert not all_zeros(m.norm3)
|
||||
elif isinstance(m, BasicBlock):
|
||||
assert not all_zeros(m.norm2)
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.backbones import ResNeXt
|
||||
from mmselfsup.models.backbones.resnext import Bottleneck as BottleneckX
|
||||
|
||||
|
||||
def test_bottleneck():
|
||||
with pytest.raises(AssertionError):
|
||||
# Style must be in ['pytorch', 'caffe']
|
||||
BottleneckX(64, 64, groups=32, width_per_group=4, style='tensorflow')
|
||||
|
||||
# Test ResNeXt Bottleneck structure
|
||||
block = BottleneckX(
|
||||
64, 64, stride=2, groups=32, width_per_group=4, style='pytorch')
|
||||
assert block.conv2.stride == (2, 2)
|
||||
assert block.conv2.groups == 32
|
||||
assert block.conv2.out_channels == 128
|
||||
|
||||
# Test ResNeXt Bottleneck forward
|
||||
block = BottleneckX(64, 16, stride=1, groups=32, width_per_group=4)
|
||||
x = torch.randn(1, 64, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 64, 56, 56])
|
||||
|
||||
|
||||
def test_resnext():
|
||||
with pytest.raises(KeyError):
|
||||
# ResNeXt depth should be in [50, 101, 152]
|
||||
ResNeXt(depth=18)
|
||||
|
||||
# Test ResNeXt with group 32, width_per_group 4
|
||||
model = ResNeXt(
|
||||
depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3, 4))
|
||||
for m in model.modules():
|
||||
if isinstance(m, BottleneckX):
|
||||
assert m.conv2.groups == 32
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size([1, 64, 112, 112])
|
||||
assert feat[1].shape == torch.Size([1, 256, 56, 56])
|
||||
assert feat[2].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[3].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[4].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
# Test ResNeXt with group 32, width_per_group 4 and layers 3 out forward
|
||||
model = ResNeXt(depth=50, groups=32, width_per_group=4, out_indices=(4, ))
|
||||
for m in model.modules():
|
||||
if isinstance(m, BottleneckX):
|
||||
assert m.conv2.groups == 32
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[0].shape == torch.Size([1, 2048, 7, 7])
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.heads import (ClsHead, ContrastiveHead, LatentClsHead,
|
||||
LatentPredictHead, MultiClsHead, SwAVHead)
|
||||
|
||||
|
||||
def test_cls_head():
|
||||
# test ClsHead
|
||||
head = ClsHead()
|
||||
fake_cls_score = [torch.rand(4, 3)]
|
||||
fake_gt_label = torch.randint(0, 2, (4, ))
|
||||
|
||||
loss = head.loss(fake_cls_score, fake_gt_label)
|
||||
assert loss['loss'].item() > 0
|
||||
|
||||
|
||||
def test_contrastive_head():
|
||||
head = ContrastiveHead()
|
||||
fake_pos = torch.rand(32, 1) # N, 1
|
||||
fake_neg = torch.rand(32, 100) # N, k
|
||||
|
||||
loss = head.forward(fake_pos, fake_neg)
|
||||
assert loss['loss'].item() > 0
|
||||
|
||||
|
||||
def test_latent_predict_head():
|
||||
predictor = dict(
|
||||
type='NonLinearNeck',
|
||||
in_channels=64,
|
||||
hid_channels=128,
|
||||
out_channels=64,
|
||||
with_bias=True,
|
||||
with_last_bn=True,
|
||||
with_avg_pool=False,
|
||||
norm_cfg=dict(type='BN1d'))
|
||||
head = LatentPredictHead(predictor=predictor)
|
||||
fake_input = torch.rand(32, 64) # N, C
|
||||
fake_traget = torch.rand(32, 64) # N, C
|
||||
|
||||
loss = head.forward(fake_input, fake_traget)
|
||||
assert loss['loss'].item() > -1
|
||||
|
||||
|
||||
def test_latent_cls_head():
|
||||
head = LatentClsHead(64, 10)
|
||||
fake_input = torch.rand(32, 64) # N, C
|
||||
fake_traget = torch.rand(32, 64) # N, C
|
||||
|
||||
loss = head.forward(fake_input, fake_traget)
|
||||
assert loss['loss'].item() > 0
|
||||
|
||||
|
||||
def test_multi_cls_head():
|
||||
head = MultiClsHead(in_indices=(0, 1))
|
||||
fake_input = [torch.rand(8, 64, 5, 5), torch.rand(8, 256, 14, 14)]
|
||||
out = head.forward(fake_input)
|
||||
assert isinstance(out, list)
|
||||
|
||||
fake_cls_score = [torch.rand(4, 3)]
|
||||
fake_gt_label = torch.randint(0, 2, (4, ))
|
||||
|
||||
loss = head.loss(fake_cls_score, fake_gt_label)
|
||||
print(loss.keys())
|
||||
for k in loss.keys():
|
||||
if 'loss' in k:
|
||||
assert loss[k].item() > 0
|
||||
|
||||
|
||||
def test_swav_head():
|
||||
head = SwAVHead(feat_dim=128, num_crops=[2, 6])
|
||||
fake_input = torch.rand(32, 128) # N, C
|
||||
|
||||
loss = head.forward(fake_input)
|
||||
assert loss['loss'].item() > 0
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.necks import AvgPool2dNeck
|
||||
|
||||
|
||||
def test_avgpool2d_neck():
|
||||
fake_in = [torch.randn((2, 3, 8, 8))]
|
||||
|
||||
# test default
|
||||
neck = AvgPool2dNeck()
|
||||
fake_out = neck(fake_in)
|
||||
assert fake_out[0].shape == (2, 3, 1, 1)
|
||||
|
||||
# test custom
|
||||
neck = AvgPool2dNeck(2)
|
||||
fake_out = neck(fake_in)
|
||||
assert fake_out[0].shape == (2, 3, 2, 2)
|
||||
|
||||
# test custom
|
||||
neck = AvgPool2dNeck((1, 2))
|
||||
fake_out = neck(fake_in)
|
||||
assert fake_out[0].shape == (2, 3, 1, 2)
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.models.necks import DenseCLNeck
|
||||
|
||||
|
||||
def test_densecl_neck():
|
||||
neck = DenseCLNeck(16, 32, 16)
|
||||
assert isinstance(neck.mlp, nn.Sequential)
|
||||
assert isinstance(neck.mlp2, nn.Sequential)
|
||||
assert neck.mlp[0].in_features == 16
|
||||
assert neck.mlp[2].in_features == 32
|
||||
assert neck.mlp[2].out_features == 16
|
||||
assert neck.mlp2[0].in_channels == 16
|
||||
assert neck.mlp2[2].in_channels == 32
|
||||
assert neck.mlp2[2].out_channels == 16
|
||||
|
||||
# test neck when num_grid is None
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
assert fake_out[1].shape == torch.Size([32, 16, 25])
|
||||
assert fake_out[2].shape == torch.Size([32, 16])
|
||||
|
||||
# test neck when num_grid is not None
|
||||
neck = DenseCLNeck(16, 32, 16, num_grid=3)
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
assert fake_out[1].shape == torch.Size([32, 16, 9])
|
||||
assert fake_out[2].shape == torch.Size([32, 16])
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.models.necks import LinearNeck
|
||||
|
||||
|
||||
def test_linear_neck():
|
||||
neck = LinearNeck(16, 32, with_avg_pool=True)
|
||||
assert isinstance(neck.avgpool, nn.Module)
|
||||
assert neck.fc.in_features == 16
|
||||
assert neck.fc.out_features == 32
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 32])
|
||||
|
||||
# test neck without avgpool
|
||||
neck = LinearNeck(16, 32, with_avg_pool=False)
|
||||
fake_in = torch.rand((32, 16))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 32])
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.models.necks import MoCoV2Neck
|
||||
|
||||
|
||||
def test_mocov2_neck():
|
||||
neck = MoCoV2Neck(16, 32, 16)
|
||||
assert isinstance(neck.mlp, nn.Sequential)
|
||||
assert neck.mlp[0].in_features == 16
|
||||
assert neck.mlp[2].in_features == 32
|
||||
assert neck.mlp[2].out_features == 16
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
|
||||
# test neck without avgpool
|
||||
neck = MoCoV2Neck(16, 32, 16, with_avg_pool=False)
|
||||
fake_in = torch.rand((32, 16))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.necks import NonLinearNeck
|
||||
|
||||
|
||||
def test_nonlinear_neck():
|
||||
# test neck arch
|
||||
neck = NonLinearNeck(16, 32, 16, norm_cfg=dict(type='BN1d'))
|
||||
assert neck.fc0.in_features == 16
|
||||
assert neck.fc0.out_features == 32
|
||||
assert neck.bn0.num_features == 32
|
||||
fc = getattr(neck, neck.fc_names[-1])
|
||||
assert fc.out_features == 16
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
|
||||
# test neck without avgpool
|
||||
neck = NonLinearNeck(
|
||||
16, 32, 16, with_avg_pool=False, norm_cfg=dict(type='BN1d'))
|
||||
fake_in = torch.rand((32, 16))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.necks import ODCNeck
|
||||
|
||||
|
||||
def test_odc_neck():
|
||||
neck = ODCNeck(16, 32, 16, norm_cfg=dict(type='BN1d'))
|
||||
assert neck.fc0.in_features == 16
|
||||
assert neck.fc0.out_features == 32
|
||||
assert neck.bn0.num_features == 32
|
||||
assert neck.fc1.in_features == 32
|
||||
assert neck.fc1.out_features == 16
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
|
||||
# test neck without avgpool
|
||||
neck = ODCNeck(16, 32, 16, with_avg_pool=False, norm_cfg=dict(type='BN1d'))
|
||||
fake_in = torch.rand((32, 16))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.necks import RelativeLocNeck
|
||||
|
||||
|
||||
def test_relative_loc_neck():
|
||||
neck = RelativeLocNeck(16, 32)
|
||||
assert neck.fc.in_features == 32
|
||||
assert neck.fc.out_features == 32
|
||||
assert neck.bn.num_features == 32
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = torch.rand((32, 32, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 32])
|
||||
|
||||
# test neck without avgpool
|
||||
neck = RelativeLocNeck(16, 32, with_avg_pool=False)
|
||||
fake_in = torch.rand((32, 32))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 32])
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.models.necks import SwAVNeck
|
||||
|
||||
|
||||
def test_swav_neck():
|
||||
neck = SwAVNeck(16, 32, 16, norm_cfg=dict(type='BN1d'))
|
||||
assert isinstance(neck.projection_neck, (nn.Module, nn.Sequential))
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = [[torch.rand((32, 16, 5, 5))], [torch.rand((32, 16, 5, 5))],
|
||||
[torch.rand((32, 16, 3, 3))]]
|
||||
fake_out = neck.forward(fake_in)
|
||||
assert fake_out[0].shape == torch.Size([32 * len(fake_in), 16])
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.utils import MultiPooling
|
||||
|
||||
|
||||
def test_multi_pooling():
|
||||
# adaptive
|
||||
layer = MultiPooling(pool_type='adaptive', in_indices=(0, 1, 2))
|
||||
fake_in = [
|
||||
torch.rand((1, 32, 112, 112)),
|
||||
torch.rand((1, 64, 56, 56)),
|
||||
torch.rand((1, 128, 28, 28)),
|
||||
]
|
||||
res = layer.forward(fake_in)
|
||||
assert res[0].shape == (1, 32, 12, 12)
|
||||
assert res[1].shape == (1, 64, 6, 6)
|
||||
assert res[2].shape == (1, 128, 4, 4)
|
||||
|
||||
# specified
|
||||
layer = MultiPooling(pool_type='specified', in_indices=(0, 1, 2))
|
||||
fake_in = [
|
||||
torch.rand((1, 32, 112, 112)),
|
||||
torch.rand((1, 64, 56, 56)),
|
||||
torch.rand((1, 128, 28, 28)),
|
||||
]
|
||||
res = layer.forward(fake_in)
|
||||
assert res[0].shape == (1, 32, 12, 12)
|
||||
assert res[1].shape == (1, 64, 6, 6)
|
||||
assert res[2].shape == (1, 128, 4, 4)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
layer = MultiPooling(pool_type='other')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
layer = MultiPooling(backbone='resnet101')
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.models.utils import MultiPrototypes
|
||||
|
||||
|
||||
def test_multi_prototypes():
|
||||
with pytest.raises(AssertionError):
|
||||
layer = MultiPrototypes(output_dim=16, num_prototypes=2)
|
||||
|
||||
layer = MultiPrototypes(output_dim=16, num_prototypes=[3, 4, 5])
|
||||
assert isinstance(getattr(layer, 'prototypes0'), nn.Module)
|
||||
assert isinstance(getattr(layer, 'prototypes1'), nn.Module)
|
||||
assert isinstance(getattr(layer, 'prototypes2'), nn.Module)
|
||||
|
||||
fake_in = torch.rand((32, 16))
|
||||
res = layer.forward(fake_in)
|
||||
assert len(res) == 3
|
||||
assert res[0].shape == (32, 3)
|
||||
assert res[1].shape == (32, 4)
|
||||
assert res[2].shape == (32, 5)
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.utils import Sobel
|
||||
|
||||
|
||||
def test_sobel():
|
||||
sobel_layer = Sobel()
|
||||
fake_input = torch.rand((1, 3, 224, 224))
|
||||
fake_res = sobel_layer(fake_input)
|
||||
assert fake_res.shape == (1, 2, 224, 224)
|
||||
|
||||
for p in sobel_layer.sobel.parameters():
|
||||
assert p.requires_grad is False
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmselfsup.models.utils import ExtractProcess
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, img, test_mode=False, **kwargs):
|
||||
return [
|
||||
torch.rand((1, 32, 112, 112)),
|
||||
torch.rand((1, 64, 56, 56)),
|
||||
torch.rand((1, 128, 28, 28)),
|
||||
]
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
|
||||
|
||||
def test_extract_process():
|
||||
with pytest.raises(AssertionError):
|
||||
process = ExtractProcess(
|
||||
pool_type='specified', backbone='resnet50', layer_indices=(-1, ))
|
||||
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
model = MMDataParallel(ExampleModel())
|
||||
|
||||
process = ExtractProcess(
|
||||
pool_type='specified', backbone='resnet50', layer_indices=(0, 1, 2))
|
||||
|
||||
results = process.extract(model, data_loader)
|
||||
assert 'feat1' in results
|
||||
assert 'feat2' in results
|
||||
assert 'feat3' in results
|
||||
assert results['feat1'].shape == (1, 32 * 12 * 12)
|
||||
assert results['feat2'].shape == (1, 64 * 6 * 6)
|
||||
assert results['feat3'].shape == (1, 128 * 4 * 4)
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import build_runner
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmselfsup.core.optimizer import build_optimizer
|
||||
from mmselfsup.utils import Extractor
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
self.neck = nn.Identity()
|
||||
|
||||
def forward(self, img, test_mode=False, **kwargs):
|
||||
return img
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
|
||||
|
||||
def test_extractor():
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
|
||||
runner_cfg = dict(type='EpochBasedRunner', max_epochs=2)
|
||||
optim_cfg = dict(type='SGD', lr=0.05, momentum=0.9, weight_decay=0.0005)
|
||||
extractor = Extractor(
|
||||
test_dataset, 1, 0, dist_mode=False, persistent_workers=False)
|
||||
|
||||
# test extractor
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = MMDataParallel(ExampleModel())
|
||||
optimizer = build_optimizer(model, optim_cfg)
|
||||
runner = build_runner(
|
||||
runner_cfg,
|
||||
default_args=dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger()))
|
||||
features = extractor(runner)
|
||||
assert features.shape == (1, 1)
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import build_runner, obj_from_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmselfsup.core.hooks import BYOLHook
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.online_net = nn.Conv2d(3, 3, 3)
|
||||
self.target_net = nn.Conv2d(3, 3, 3)
|
||||
self.base_momentum = 0.96
|
||||
self.momentum = self.base_momentum
|
||||
|
||||
def forward(self, img, img_metas, test_mode=False, **kwargs):
|
||||
return img
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
|
||||
@torch.no_grad()
|
||||
def _momentum_update(self):
|
||||
"""Momentum update of the target network."""
|
||||
for param_ol, param_tgt in zip(self.online_net.parameters(),
|
||||
self.target_net.parameters()):
|
||||
param_tgt.data = param_tgt.data * self.momentum + \
|
||||
param_ol.data * (1. - self.momentum)
|
||||
|
||||
@torch.no_grad()
|
||||
def momentum_update(self):
|
||||
self._momentum_update()
|
||||
|
||||
|
||||
def test_byol_hook():
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
|
||||
runner_cfg = dict(type='EpochBasedRunner', max_epochs=2)
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
|
||||
# test BYOLHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = MMDataParallel(ExampleModel())
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
byol_hook = BYOLHook()
|
||||
runner = build_runner(
|
||||
runner_cfg,
|
||||
default_args=dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger()))
|
||||
runner.register_hook(byol_hook)
|
||||
runner.run([data_loader], [('train', 1)])
|
||||
assert runner.model.module.momentum == 0.98
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import build_runner
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmselfsup.core.hooks import DeepClusterHook
|
||||
from mmselfsup.core.optimizer import build_optimizer
|
||||
from mmselfsup.models.algorithms import DeepCluster
|
||||
|
||||
num_classes = 10
|
||||
with_sobel = True,
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=2,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(type='AvgPool2dNeck')
|
||||
head = dict(
|
||||
type='ClsHead',
|
||||
with_avg_pool=False, # already has avgpool in the neck
|
||||
in_channels=2048,
|
||||
num_classes=num_classes)
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.randn((3, 224, 224)), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
||||
|
||||
def test_deepcluster_hook():
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
|
||||
alg = DeepCluster(
|
||||
backbone=backbone, with_sobel=with_sobel, neck=neck, head=head)
|
||||
extractor = dict(
|
||||
dataset=test_dataset,
|
||||
imgs_per_gpu=1,
|
||||
workers_per_gpu=0,
|
||||
persistent_workers=False)
|
||||
|
||||
runner_cfg = dict(type='EpochBasedRunner', max_epochs=3)
|
||||
optim_cfg = dict(type='SGD', lr=0.05, momentum=0.9, weight_decay=0.0005)
|
||||
lr_config = dict(policy='CosineAnnealing', min_lr=0.)
|
||||
|
||||
# test DeepClusterHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = MMDataParallel(alg)
|
||||
optimizer = build_optimizer(model, optim_cfg)
|
||||
deepcluster_hook = DeepClusterHook(
|
||||
extractor=extractor,
|
||||
clustering=dict(type='Kmeans', k=num_classes, pca_dim=16),
|
||||
unif_sampling=True,
|
||||
reweight=False,
|
||||
reweight_pow=0.5,
|
||||
initial=True,
|
||||
interval=1,
|
||||
dist_mode=False)
|
||||
runner = build_runner(
|
||||
runner_cfg,
|
||||
default_args=dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger()))
|
||||
runner.register_training_hooks(lr_config)
|
||||
runner.register_hook(deepcluster_hook)
|
||||
assert deepcluster_hook.clustering_type == 'Kmeans'
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import build_runner, obj_from_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmselfsup.core.hooks import DenseCLHook
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.loss_lambda = 0.5
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, img, img_metas, test_mode=False, **kwargs):
|
||||
return img
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
|
||||
|
||||
def test_densecl_hook():
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
|
||||
runner_cfg = dict(type='EpochBasedRunner', max_epochs=2)
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
|
||||
# test DenseCLHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = MMDataParallel(ExampleModel())
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
|
||||
densecl_hook = DenseCLHook(start_iters=1)
|
||||
runner = build_runner(
|
||||
runner_cfg,
|
||||
default_args=dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger()))
|
||||
runner.register_hook(densecl_hook)
|
||||
runner.run([data_loader], [('train', 1)])
|
||||
cur_iter = runner.iter
|
||||
if cur_iter >= 1:
|
||||
assert runner.model.module.loss_lambda == 0.5
|
||||
else:
|
||||
assert runner.model.module.loss_lambda == 0.
|
|
@ -0,0 +1,125 @@
|
|||
import logging
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import build_runner, obj_from_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmselfsup.core.hooks import DistOptimizerHook, GradAccumFp16OptimizerHook
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1.]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.linear = nn.Linear(1, 1)
|
||||
self.prototypes_test = nn.Linear(1, 1)
|
||||
|
||||
def forward(self, img, img_metas, test_mode=False, **kwargs):
|
||||
out = self.linear(img)
|
||||
out = self.prototypes_test(out)
|
||||
return out
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss, num_samples=len(data_batch))
|
||||
|
||||
|
||||
def test_optimizer_hook():
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
|
||||
runner_cfg = dict(type='EpochBasedRunner', max_epochs=5)
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optim_hook_cfg = dict(
|
||||
grad_clip=dict(max_norm=10), frozen_layers_cfg=dict(prototypes=5005))
|
||||
|
||||
optimizer_hook = DistOptimizerHook(**optim_hook_cfg)
|
||||
|
||||
# test DistOptimizerHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = MMDataParallel(ExampleModel())
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
|
||||
runner = build_runner(
|
||||
runner_cfg,
|
||||
default_args=dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger()))
|
||||
runner.register_training_hooks(optimizer_hook)
|
||||
|
||||
prototypes_start = []
|
||||
for name, p in runner.model.module.named_parameters():
|
||||
if 'prototypes_test' in name:
|
||||
prototypes_start.append(p)
|
||||
|
||||
# run training
|
||||
runner.run([data_loader], [('train', 1)])
|
||||
|
||||
prototypes_end = []
|
||||
for name, p in runner.model.module.named_parameters():
|
||||
if 'prototypes_test' in name:
|
||||
prototypes_end.append(p)
|
||||
|
||||
assert len(prototypes_start) == len(prototypes_end)
|
||||
for i in range(len(prototypes_start)):
|
||||
p_start = prototypes_start[i]
|
||||
p_end = prototypes_end[i]
|
||||
assert p_start == p_end
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='CUDA is not available.')
|
||||
def test_fp16optimizer_hook():
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
|
||||
runner_cfg = dict(type='EpochBasedRunner', max_epochs=5)
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optim_hook_cfg = dict(
|
||||
grad_clip=dict(max_norm=10),
|
||||
loss_scale=16.,
|
||||
frozen_layers_cfg=dict(prototypes=5005))
|
||||
|
||||
optimizer_hook = GradAccumFp16OptimizerHook(**optim_hook_cfg)
|
||||
|
||||
# test GradAccumFp16OptimizerHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = MMDataParallel(ExampleModel())
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
|
||||
runner = build_runner(
|
||||
runner_cfg,
|
||||
default_args=dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger(),
|
||||
meta=dict()))
|
||||
runner.register_training_hooks(optimizer_hook)
|
||||
# run training
|
||||
runner.run([data_loader], [('train', 1)])
|
||||
assert runner.meta['fp16']['loss_scaler']['scale'] == 16.
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import build_runner
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmselfsup.core.hooks import SimSiamHook
|
||||
from mmselfsup.core.optimizer import build_optimizer
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
self.predictor = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, img, img_metas, test_mode=False, **kwargs):
|
||||
return img
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
|
||||
|
||||
def test_simsiam_hook():
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
runner_cfg = dict(type='EpochBasedRunner', max_epochs=2)
|
||||
optim_cfg = dict(
|
||||
type='SGD',
|
||||
lr=0.05,
|
||||
momentum=0.9,
|
||||
weight_decay=0.0005,
|
||||
paramwise_options={'predictor': dict(fix_lr=True)})
|
||||
lr_config = dict(policy='CosineAnnealing', min_lr=0.)
|
||||
|
||||
# test SimSiamHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = MMDataParallel(ExampleModel())
|
||||
optimizer = build_optimizer(model, optim_cfg)
|
||||
simsiam_hook = SimSiamHook(True, 0.05)
|
||||
runner = build_runner(
|
||||
runner_cfg,
|
||||
default_args=dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger()))
|
||||
runner.register_training_hooks(lr_config)
|
||||
runner.register_hook(simsiam_hook)
|
||||
runner.run([data_loader], [('train', 1)])
|
||||
|
||||
for param_group in runner.optimizer.param_groups:
|
||||
if 'fix_lr' in param_group and param_group['fix_lr']:
|
||||
assert param_group['lr'] == 0.05
|
||||
else:
|
||||
assert param_group['lr'] != 0.05
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import build_runner, obj_from_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmselfsup.core.hooks import SwAVHook
|
||||
from mmselfsup.models.heads import SwAVHead
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1.]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.linear = nn.Linear(1, 1)
|
||||
self.prototypes_test = nn.Linear(1, 1)
|
||||
self.head = SwAVHead(feat_dim=2, num_crops=[2, 6], num_prototypes=3)
|
||||
|
||||
def forward(self, img, img_metas, test_mode=False, **kwargs):
|
||||
out = self.linear(img)
|
||||
out = self.prototypes_test(out)
|
||||
return out
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
|
||||
|
||||
def test_swav_hook():
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
|
||||
runner_cfg = dict(type='EpochBasedRunner', max_epochs=2)
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
|
||||
# test SwAVHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = MMDataParallel(ExampleModel())
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
|
||||
swav_hook = SwAVHook(
|
||||
batch_size=1,
|
||||
epoch_queue_starts=15,
|
||||
crops_for_assign=[0, 1],
|
||||
feat_dim=128,
|
||||
queue_length=300)
|
||||
runner = build_runner(
|
||||
runner_cfg,
|
||||
default_args=dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger()))
|
||||
runner.register_hook(swav_hook)
|
||||
runner.run([data_loader], [('train', 1)])
|
||||
assert swav_hook.queue_length == 300
|
||||
assert runner.model.module.head.use_queue is False
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.utils import AliasMethod
|
||||
|
||||
|
||||
def test_alias_multinomial():
|
||||
example_in = torch.Tensor([1, 2, 3, 4])
|
||||
example_alias_method = AliasMethod(example_in)
|
||||
assert (example_alias_method.prob.numpy() <= 1).all()
|
||||
assert len(example_in) == len(example_alias_method.alias)
|
||||
|
||||
# test assertion if N is smaller than 0
|
||||
with pytest.raises(AssertionError):
|
||||
example_alias_method.draw(-1)
|
||||
with pytest.raises(AssertionError):
|
||||
example_alias_method.draw(0)
|
||||
|
||||
example_res = example_alias_method.draw(5)
|
||||
assert len(example_res) == 5
|
|
@ -0,0 +1,28 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.utils.clustering import PIC, Kmeans
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='CUDA is not available.')
|
||||
def test_kmeans():
|
||||
fake_input = np.random.rand(10, 8).astype(np.float32)
|
||||
pca_dim = 2
|
||||
|
||||
kmeans = Kmeans(2, pca_dim)
|
||||
loss = kmeans.cluster(fake_input)
|
||||
assert loss is not None
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
loss = kmeans.cluster(np.random.rand(10, 8))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='CUDA is not available.')
|
||||
def test_pic():
|
||||
fake_input = np.random.rand(1000, 16).astype(np.float32)
|
||||
pic = PIC(pca_dim=8)
|
||||
res = pic.cluster(fake_input)
|
||||
assert res == 0
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.utils.misc import tensor2imgs
|
||||
|
||||
|
||||
def test_tensor2imgs():
|
||||
with pytest.raises(AssertionError):
|
||||
tensor2imgs(torch.rand((3, 16, 16)))
|
||||
fake_tensor = torch.rand((3, 3, 16, 16))
|
||||
fake_imgs = tensor2imgs(fake_tensor)
|
||||
assert len(fake_imgs) == 3
|
||||
assert fake_imgs[0].shape == (16, 16, 3)
|
|
@ -0,0 +1,44 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmselfsup.utils.test_helper import single_gpu_test
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, img, mode='test', **kwargs):
|
||||
return dict(img=img)
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
|
||||
|
||||
def test_test_helper():
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
model = ExampleModel()
|
||||
|
||||
res = single_gpu_test(model, data_loader)
|
||||
assert res['img'] == np.array([[1]])
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmselfsup import digit_version
|
||||
|
||||
|
||||
def test_digit_version():
|
||||
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
|
||||
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
|
||||
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
|
||||
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
|
||||
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
|
||||
assert digit_version('1.0') == digit_version('1.0.0')
|
||||
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
|
||||
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
|
||||
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
|
||||
assert digit_version('1.0.0a') < digit_version('1.0.0b')
|
||||
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
|
||||
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
|
||||
assert digit_version('1.0.0') < digit_version('1.0.0post')
|
||||
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
|
||||
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
|
||||
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)
|
Loading…
Reference in New Issue