fast-reid/tests/model_test.py

39 lines
1.2 KiB
Python
Raw Normal View History

2020-03-25 10:58:26 +08:00
import unittest
2020-02-10 07:38:56 +08:00
import torch
import sys
2020-03-25 10:58:26 +08:00
sys.path.append('.')
from fastreid.config import cfg
from fastreid.modeling.backbones import build_resnet_backbone
from fastreid.modeling.backbones.resnet_ibn_a import se_resnet101_ibn_a
from torch import nn
2020-02-10 07:38:56 +08:00
2020-03-25 10:58:26 +08:00
class MyTestCase(unittest.TestCase):
def test_se_resnet101(self):
cfg.MODEL.BACKBONE.NAME = 'resnet101'
cfg.MODEL.BACKBONE.DEPTH = 101
cfg.MODEL.BACKBONE.WITH_IBN = True
cfg.MODEL.BACKBONE.WITH_SE = True
cfg.MODEL.BACKBONE.PRETRAIN_PATH = '/export/home/lxy/.cache/torch/checkpoints/se_resnet101_ibn_a.pth.tar'
2020-02-10 07:38:56 +08:00
2020-03-25 10:58:26 +08:00
net1 = build_resnet_backbone(cfg)
net1.cuda()
net2 = nn.DataParallel(se_resnet101_ibn_a())
res = net2.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAIN_PATH)['state_dict'], strict=False)
net2.cuda()
x = torch.randn(10, 3, 256, 128).cuda()
y1 = net1(x)
y2 = net2(x)
assert y1.sum() == y2.sum(), 'train mode problem'
net1.eval()
net2.eval()
y1 = net1(x)
y2 = net2(x)
assert y1.sum() == y2.sum(), 'eval mode problem'
2020-02-10 07:38:56 +08:00
2020-03-25 10:58:26 +08:00
if __name__ == '__main__':
unittest.main()