63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmfewshot.classification.models import CLASSIFIERS
|
|
from mmfewshot.classification.utils import MetaTestParallel
|
|
|
|
|
|
@pytest.mark.parametrize('classifier',
|
|
['Baseline', 'BaselinePlus', 'NegMargin'])
|
|
def test_image_classifier(classifier):
|
|
model_cfg = dict(type=classifier, backbone=dict(type='Conv4'))
|
|
|
|
imgs = torch.randn(4, 3, 84, 84)
|
|
feats = torch.randn(4, 1600)
|
|
label = torch.LongTensor([0, 1, 2, 3])
|
|
|
|
model_cfg_ = copy.deepcopy(model_cfg)
|
|
model = CLASSIFIERS.build(model_cfg_)
|
|
|
|
# test property
|
|
assert not model.with_neck
|
|
assert model.with_head
|
|
|
|
assert model.device
|
|
assert model.get_device()
|
|
|
|
model = MetaTestParallel(copy.deepcopy(model))
|
|
|
|
# test extract features
|
|
outputs = model(**{'img': imgs, 'gt_label': label, 'mode': 'extract_feat'})
|
|
assert outputs.size(0) == 4
|
|
|
|
model.before_meta_test(dict())
|
|
model.before_forward_support()
|
|
|
|
# test support step
|
|
outputs = model(**{'img': imgs, 'gt_label': label, 'mode': 'support'})
|
|
assert outputs['loss'].item() > 0
|
|
|
|
# test support step
|
|
outputs = model(**{'feats': feats, 'gt_label': label, 'mode': 'support'})
|
|
assert outputs['loss'].item() > 0
|
|
|
|
model.before_forward_query()
|
|
# test query step
|
|
outputs = model(**{'img': imgs, 'mode': 'query'})
|
|
assert isinstance(outputs, list)
|
|
assert len(outputs) == 4
|
|
assert outputs[0].shape[0] == 5
|
|
|
|
# test query step
|
|
outputs = model(**{'feats': feats, 'mode': 'query'})
|
|
assert isinstance(outputs, list)
|
|
assert len(outputs) == 4
|
|
assert outputs[0].shape[0] == 5
|
|
|
|
with pytest.raises(ValueError):
|
|
# test extract features
|
|
outputs = model(**{'img': imgs, 'gt_label': label, 'mode': 'test'})
|