2021-11-01 16:33:21 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
import copy
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from mmfewshot.classification.models import CLASSIFIERS
|
|
|
|
|
|
|
|
|
2021-11-12 23:28:00 +08:00
|
|
|
@pytest.mark.parametrize('classifier', ['MAML'])
|
2021-11-01 16:33:21 +08:00
|
|
|
def test_image_classifier(classifier):
|
|
|
|
model_cfg = dict(
|
|
|
|
type=classifier,
|
|
|
|
backbone=(dict(type='Conv4')),
|
|
|
|
head=dict(type='LinearHead', num_classes=5, in_channels=1600))
|
|
|
|
|
|
|
|
imgs_a = torch.randn(4, 3, 84, 84)
|
|
|
|
imgs_b = torch.randn(4, 3, 84, 84)
|
|
|
|
label = torch.LongTensor([0, 1, 2, 3])
|
|
|
|
|
|
|
|
model_cfg_ = copy.deepcopy(model_cfg)
|
|
|
|
model = CLASSIFIERS.build(model_cfg_)
|
|
|
|
|
|
|
|
# test property
|
|
|
|
assert not model.with_neck
|
|
|
|
assert model.with_head
|
|
|
|
|
|
|
|
assert model.device
|
|
|
|
assert model.get_device()
|
|
|
|
|
|
|
|
# test train_step
|
|
|
|
outputs = model.train_step(
|
|
|
|
{
|
|
|
|
'support_data': {
|
|
|
|
'img': imgs_a,
|
|
|
|
'gt_label': label,
|
|
|
|
'mode': 'train',
|
|
|
|
'img_metas': [_ for _ in range(4)]
|
|
|
|
},
|
|
|
|
'query_data': {
|
|
|
|
'img': imgs_b,
|
|
|
|
'gt_label': label,
|
|
|
|
'mode': 'train',
|
|
|
|
'img_metas': [_ for _ in range(4)]
|
|
|
|
}
|
|
|
|
}, None)
|
|
|
|
assert outputs['loss'].item() > 0
|
|
|
|
assert outputs['num_samples'] == 4
|
|
|
|
for weight in model.parameters():
|
|
|
|
assert weight.fast is None
|
|
|
|
|
2021-11-12 23:28:00 +08:00
|
|
|
# test first order
|
|
|
|
model_cfg = dict(
|
|
|
|
type=classifier,
|
|
|
|
first_order=True,
|
|
|
|
backbone=(dict(type='Conv4')),
|
|
|
|
head=dict(type='LinearHead', num_classes=5, in_channels=1600))
|
|
|
|
|
|
|
|
imgs_a = torch.randn(4, 3, 84, 84)
|
|
|
|
imgs_b = torch.randn(4, 3, 84, 84)
|
|
|
|
label = torch.LongTensor([0, 1, 2, 3])
|
|
|
|
|
|
|
|
model_cfg_ = copy.deepcopy(model_cfg)
|
|
|
|
model = CLASSIFIERS.build(model_cfg_)
|
|
|
|
|
|
|
|
# test property
|
|
|
|
assert not model.with_neck
|
|
|
|
assert model.with_head
|
|
|
|
|
|
|
|
assert model.device
|
|
|
|
assert model.get_device()
|
|
|
|
|
|
|
|
# test train_step
|
|
|
|
outputs = model.train_step(
|
|
|
|
{
|
|
|
|
'support_data': {
|
|
|
|
'img': imgs_a,
|
|
|
|
'gt_label': label,
|
|
|
|
'mode': 'train',
|
|
|
|
'img_metas': [_ for _ in range(4)]
|
|
|
|
},
|
|
|
|
'query_data': {
|
|
|
|
'img': imgs_b,
|
|
|
|
'gt_label': label,
|
|
|
|
'mode': 'train',
|
|
|
|
'img_metas': [_ for _ in range(4)]
|
|
|
|
}
|
|
|
|
}, None)
|
|
|
|
assert outputs['loss'].item() > 0
|
|
|
|
assert outputs['num_samples'] == 4
|
|
|
|
for weight in model.parameters():
|
|
|
|
assert weight.fast is None
|
|
|
|
|
|
|
|
# test val_step
|
|
|
|
outputs = model.val_step(
|
|
|
|
{
|
|
|
|
'support_data': {
|
|
|
|
'img': imgs_a,
|
|
|
|
'gt_label': label,
|
|
|
|
'mode': 'train',
|
|
|
|
'img_metas': [_ for _ in range(4)]
|
|
|
|
},
|
|
|
|
'query_data': {
|
|
|
|
'img': imgs_b,
|
|
|
|
'gt_label': label,
|
|
|
|
'mode': 'train',
|
|
|
|
'img_metas': [_ for _ in range(4)]
|
|
|
|
}
|
|
|
|
}, None)
|
|
|
|
assert outputs['loss'].item() > 0
|
|
|
|
assert outputs['num_samples'] == 4
|
|
|
|
for weight in model.parameters():
|
|
|
|
assert weight.fast is None
|
|
|
|
|
|
|
|
model.before_meta_test(dict(support={'num_inner_steps': 2}))
|
2021-11-01 16:33:21 +08:00
|
|
|
model.before_forward_support()
|
|
|
|
|
|
|
|
# test support step
|
|
|
|
model(**{'img': imgs_a, 'gt_label': label, 'mode': 'support'})
|
|
|
|
for weight in model.parameters():
|
|
|
|
assert weight.fast is not None
|
|
|
|
|
|
|
|
model.before_forward_query()
|
|
|
|
# test query step
|
|
|
|
outputs = model(**{'img': imgs_b, 'gt_label': label, 'mode': 'query'})
|
|
|
|
assert isinstance(outputs, list)
|
|
|
|
assert len(outputs) == 4
|
|
|
|
assert outputs[0].shape[0] == 5
|