fast-reid/tests/model_test.py

30 lines
768 B
Python
Raw Normal View History

2019-08-20 09:36:47 +08:00
import sys
import unittest
import torch
from torch import nn
import sys
sys.path.append('.')
2019-08-23 07:49:03 +08:00
from modeling import *
2019-08-20 09:36:47 +08:00
from config import cfg
class MyTestCase(unittest.TestCase):
def test_model(self):
2019-08-23 07:49:03 +08:00
cfg.MODEL.WITH_IBN = True
cfg.MODEL.PRETRAIN_PATH = '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
net = build_model(cfg, 100)
y = net(torch.randn(2, 3, 256, 128))
from ipdb import set_trace; set_trace()
# net1 = ResNet.from_name('resnet50', 1, True)
# for i in net1.named_parameters():
# print(i[0])
# net2 = resnet50_ibn_a(1)
2019-08-20 09:36:47 +08:00
# print('*'*10)
# for i in net2.named_parameters():
2019-08-23 07:49:03 +08:00
# print(i[0])
2019-08-20 09:36:47 +08:00
if __name__ == '__main__':
unittest.main()