mirror of
https://github.com/open-mmlab/mmfewshot.git
synced 2025-06-03 14:49:43 +08:00
* fix init * fix test api fix test api bug * add metarcnn fsdetview config * add pr * add metatestparallel comments * add test code and fix typos * add test code of model frozen * update test det forward code * update pr * update doc str
62 lines
1.6 KiB
Python
62 lines
1.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmfewshot.classification.models.backbones.conv4 import (Conv4,
|
|
Conv4NoPool,
|
|
ConvBlock,
|
|
ConvNet)
|
|
|
|
|
|
def test_convblock():
|
|
|
|
# ConvBlock w/o pool
|
|
block = ConvBlock(3, 64, False)
|
|
x = torch.randn(1, 3, 56, 56)
|
|
x_out = block(x)
|
|
assert x_out.shape == (1, 64, 56, 56)
|
|
|
|
# ConvBlock w/ pool
|
|
block = ConvBlock(3, 64, True)
|
|
x = torch.randn(1, 3, 56, 56)
|
|
x_out = block(x)
|
|
assert x_out.shape == (1, 64, 28, 28)
|
|
|
|
|
|
def test_convnet():
|
|
"""Test resnet backbone."""
|
|
# Test ConvNet w/ flatten forward
|
|
model = ConvNet(4, (0, 1, 2, 3), (0, 1, 2, 3), True)
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
img = torch.randn(1, 3, 84, 84)
|
|
feat = model(img)
|
|
assert feat.shape == (1, 1600)
|
|
|
|
# Test ConvNet w/o flatten forward
|
|
model = ConvNet(4, (0, 1, 2, 3), (0, 1, 2, 3), False)
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
img = torch.randn(1, 3, 84, 84)
|
|
feat = model(img)
|
|
assert feat.shape == (1, 64, 5, 5)
|
|
|
|
# Test Conv4 forward
|
|
model = Conv4()
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
img = torch.randn(1, 3, 84, 84)
|
|
feat = model(img)
|
|
assert feat.shape == (1, 1600)
|
|
|
|
# Test Conv4NoPool forward
|
|
model = Conv4NoPool()
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
img = torch.randn(1, 3, 84, 84)
|
|
feat = model(img)
|
|
assert feat.shape == (1, 64, 19, 19)
|