EasyCV/tests/models/video_recognition/test_stdcn.py

55 lines
1.6 KiB
Python
Raw Normal View History

2023-03-02 19:13:10 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from easycv.models.builder import build_model
class STGCNTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def _get_model(self):
model_cfg = dict(
type='SkeletonGCN',
backbone=dict(
type='STGCN',
in_channels=3,
edge_importance_weighting=True,
graph_cfg=dict(layout='coco', strategy='spatial')),
cls_head=dict(
type='STGCNHead',
num_classes=60,
in_channels=256,
loss_cls=dict(type='CrossEntropyLoss')),
train_cfg=None,
test_cfg=None)
model = build_model(model_cfg)
return model
def test_train(self):
model = self._get_model()
model.train()
batch_size = 2
keypoints = torch.randn([batch_size, 3, 300, 17, 2])
label = torch.randint(0, 60, (batch_size, ))
output = model(keypoint=keypoints, label=label)
self.assertIn('loss_cls', output)
self.assertIn('top1_acc', output)
self.assertIn('top5_acc', output)
def test_infer(self):
model = self._get_model()
model.eval()
with torch.no_grad():
keypoints = torch.randn([1, 3, 300, 17, 2])
output = model(keypoint=keypoints, mode='test')
self.assertEqual(output['prob'].shape, (1, 60))
if __name__ == '__main__':
unittest.main()