Linyiqi 72129d989a
Fix 2nd round comments (#49)
* fix base class name & sampler length & comments

* add test code
2021-11-12 23:28:00 +08:00

86 lines
2.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmfewshot.classification.models.backbones.resnet12 import (BasicBlock,
ResNet12)
def test_basicblock():
downsample = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(64),
)
# BasicBlock w/o downsample
block = BasicBlock(64, 64, 2, None, 0, False, 1)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 64, 28, 28)
# BasicBlock w/ downsample
block = BasicBlock(3, 64, 2, downsample, 0, False, 1)
x = torch.randn(1, 3, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 64, 28, 28)
# 3 BasicBlock w/ downsample and dropout
block = BasicBlock(3, 64, 2, downsample, 0.1, False, 1)
x = torch.randn(1, 3, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 64, 28, 28)
# 3 BasicBlock w/ downsample and drop block
block = BasicBlock(3, 64, 2, downsample, 0.1, True, 5)
x = torch.randn(1, 3, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 64, 28, 28)
def test_resnet():
"""Test resnet backbone."""
# Test ResNet12 forward
model = ResNet12()
model.init_weights()
model.train()
img = torch.randn(1, 3, 224, 224)
feat = model(img)
assert feat.shape == (1, 640)
# Test ResNet12 w/o avg pool
model = ResNet12(with_avgpool=False)
model.init_weights()
model.train()
img = torch.randn(1, 3, 224, 224)
feat = model(img)
assert feat.shape == (1, 125440)
# Test ResNet12 w/o flatten
model = ResNet12(flatten=False)
model.init_weights()
model.train()
img = torch.randn(1, 3, 224, 224)
feat = model(img)
assert feat.shape == (1, 640, 1, 1)
# Test ResNet12 w/o flatten
model = ResNet12(flatten=False)
model.init_weights()
model.train()
img = torch.randn(1, 3, 224, 224)
feat = model(img)
assert feat.shape == (1, 640, 1, 1)
def test_drop_block():
# Test ResNet12 w/o flatten
model = ResNet12(drop_rate=0.5)
model.init_weights()
model.train()
img = torch.randn(1, 3, 224, 224)
feat = model(img)
assert feat.shape == (1, 640)