mmfewshot/tests/test_classification_models/test_classification_model_utils.py
Linyiqi 37cd8a1f7d
Fix docs and add test code (#45)
* fix init

* fix test api

fix test api bug

* add metarcnn fsdetview config

* add pr

* add metatestparallel comments

* add test code and fix typos

* add test code of model frozen

* update test det forward code

* update pr

* update doc str
2021-11-01 16:33:21 +08:00

20 lines
577 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmfewshot.classification.models.backbones import Conv4
from mmfewshot.classification.models.utils import convert_maml_module
def test_maml_module():
model = Conv4()
maml_model = convert_maml_module(model)
image = torch.randn(1, 3, 32, 32)
for weight in maml_model.parameters():
assert weight.fast is None
feat = maml_model(image)
for weight in maml_model.parameters():
weight.fast = weight
maml_feat = maml_model(image)
assert torch.allclose(feat, maml_feat)