mmfewshot/tests/test_classification_runtime/test_meta_test_parallel.py

63 lines
1.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import pytest
import torch
from mmfewshot.classification.models import CLASSIFIERS
from mmfewshot.classification.utils import MetaTestParallel
@pytest.mark.parametrize('classifier',
['Baseline', 'BaselinePlus', 'NegMargin'])
def test_image_classifier(classifier):
model_cfg = dict(type=classifier, backbone=dict(type='Conv4'))
imgs = torch.randn(4, 3, 84, 84)
feats = torch.randn(4, 1600)
label = torch.LongTensor([0, 1, 2, 3])
model_cfg_ = copy.deepcopy(model_cfg)
model = CLASSIFIERS.build(model_cfg_)
# test property
assert not model.with_neck
assert model.with_head
assert model.device
assert model.get_device()
model = MetaTestParallel(copy.deepcopy(model))
# test extract features
outputs = model(**{'img': imgs, 'gt_label': label, 'mode': 'extract_feat'})
assert outputs.size(0) == 4
model.before_meta_test(dict())
model.before_forward_support()
# test support step
outputs = model(**{'img': imgs, 'gt_label': label, 'mode': 'support'})
assert outputs['loss'].item() > 0
# test support step
outputs = model(**{'feats': feats, 'gt_label': label, 'mode': 'support'})
assert outputs['loss'].item() > 0
model.before_forward_query()
# test query step
outputs = model(**{'img': imgs, 'mode': 'query'})
assert isinstance(outputs, list)
assert len(outputs) == 4
assert outputs[0].shape[0] == 5
# test query step
outputs = model(**{'feats': feats, 'mode': 'query'})
assert isinstance(outputs, list)
assert len(outputs) == 4
assert outputs[0].shape[0] == 5
with pytest.raises(ValueError):
# test extract features
outputs = model(**{'img': imgs, 'gt_label': label, 'mode': 'test'})