EasyCV/tests/models/backbones/test_benchmark_mlp.py

26 lines
586 B
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from easycv.models.backbones import BenchMarkMLP
class BenchMarkMLPTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_benchmark_mlp(self):
model = BenchMarkMLP(feature_num=512)
model.init_weights()
model.train()
feas = torch.randn(2, 512)
output = model(feas)
self.assertEqual(output[0].shape, torch.Size([2, 512]))
if __name__ == '__main__':
unittest.main()