mirror of https://github.com/alibaba/EasyCV.git
55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
import unittest
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from easycv.models.builder import build_model
|
||
|
|
||
|
|
||
|
class STGCNTest(unittest.TestCase):
|
||
|
|
||
|
def setUp(self):
|
||
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||
|
|
||
|
def _get_model(self):
|
||
|
model_cfg = dict(
|
||
|
type='SkeletonGCN',
|
||
|
backbone=dict(
|
||
|
type='STGCN',
|
||
|
in_channels=3,
|
||
|
edge_importance_weighting=True,
|
||
|
graph_cfg=dict(layout='coco', strategy='spatial')),
|
||
|
cls_head=dict(
|
||
|
type='STGCNHead',
|
||
|
num_classes=60,
|
||
|
in_channels=256,
|
||
|
loss_cls=dict(type='CrossEntropyLoss')),
|
||
|
train_cfg=None,
|
||
|
test_cfg=None)
|
||
|
model = build_model(model_cfg)
|
||
|
return model
|
||
|
|
||
|
def test_train(self):
|
||
|
model = self._get_model()
|
||
|
model.train()
|
||
|
batch_size = 2
|
||
|
keypoints = torch.randn([batch_size, 3, 300, 17, 2])
|
||
|
label = torch.randint(0, 60, (batch_size, ))
|
||
|
output = model(keypoint=keypoints, label=label)
|
||
|
self.assertIn('loss_cls', output)
|
||
|
self.assertIn('top1_acc', output)
|
||
|
self.assertIn('top5_acc', output)
|
||
|
|
||
|
def test_infer(self):
|
||
|
model = self._get_model()
|
||
|
model.eval()
|
||
|
|
||
|
with torch.no_grad():
|
||
|
keypoints = torch.randn([1, 3, 300, 17, 2])
|
||
|
output = model(keypoint=keypoints, mode='test')
|
||
|
self.assertEqual(output['prob'].shape, (1, 60))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|