EasyCV/tests/models/backbones/test_resnext.py

27 lines
687 B
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from easycv.models.backbones import ResNeXt
class ResNeXtTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_resnext(self):
model = ResNeXt(
depth=50,
groups=32,
base_width=4,
out_indices=[4],
norm_cfg=dict(type='BN'))
model.init_weights()
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
self.assertEqual(len(feat), 1)
self.assertEqual(feat[0].shape, torch.Size([2, 2048, 7, 7]))