EasyCV/tests/models/heads/test_cls_head.py
2022-04-02 20:01:06 +08:00

30 lines
734 B
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import numpy as np
import torch
from easycv.models.heads import ClsHead
class ClsHeadTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_jit_script(self):
head = ClsHead(
with_avg_pool=True, in_channels=2048, num_classes=1000).to('cuda')
head_jit = torch.jit.script(head)
batch_size = 1
a = [torch.rand(batch_size, 2048).to('cuda')]
self.assertTrue(
np.allclose(
head(a)[0].cpu().data.numpy(),
head_jit(a)[0].cpu().data.numpy()))
if __name__ == '__main__':
unittest.main()