EasyCV/tests/models/pose/test_top_down.py
2022-04-02 20:01:06 +08:00

120 lines
3.7 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import numpy as np
import torch
from easycv.models import TopDown
def gen_fake_data(input_shape, num_joints=17):
(N, C, H, W) = input_shape
target = np.zeros([N, num_joints, H // 4, W // 4], dtype=np.float32)
target_weight = np.ones([N, num_joints, 1], dtype=np.float32)
img_metas = [{
'img_shape': (H, W, C),
'center': np.array([W / 2, H / 2]),
'scale': np.array([0.5, 0.5]),
'flip_pairs': [],
'image_file': 'demo.jpg',
'image_id': 1,
'inference_channel': np.arange(num_joints)
} for _ in range(N)]
fake_inputs = {
'target': torch.FloatTensor(target),
'target_weight': torch.FloatTensor(target_weight),
'img_metas': img_metas
}
fake_inputs['imgs'] = torch.rand(input_shape).requires_grad_(True)
return fake_inputs
class TopDownTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def _test_topdown(self, input_cfg, model_cfg):
detector = TopDown(model_cfg['backbone'], None,
model_cfg['keypoint_head'], model_cfg['train_cfg'],
model_cfg['test_cfg'], model_cfg['pretrained'])
detector.init_weights()
imgs = input_cfg.pop('imgs')
target = input_cfg.pop('target')
target_weight = input_cfg.pop('target_weight')
img_metas = input_cfg.pop('img_metas')
# Test forward train
train_output = detector.forward(
img=imgs,
target=target,
target_weight=target_weight,
img_metas=img_metas)
# Test forward test
with torch.no_grad():
test_output = detector.forward(
imgs, mode='test', img_metas=img_metas)
return train_output, test_output
def test_litehrnet30(self):
model_cfg = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='LiteHRNet',
in_channels=3,
extra=dict(
stem=dict(
stem_channels=32, out_channels=32, expand_ratio=1),
num_stages=3,
stages_spec=dict(
num_modules=(3, 8, 3),
num_branches=(2, 3, 4),
num_blocks=(2, 2, 2),
module_type=('LITE', 'LITE', 'LITE'),
with_fuse=(True, True, True),
reduce_ratios=(8, 8, 8),
num_channels=(
(40, 80),
(40, 80, 160),
(40, 80, 160, 320),
)),
with_head=True,
)),
keypoint_head=dict(
type='TopdownHeatmapSimpleHead',
in_channels=40,
out_channels=17,
num_deconv_layers=0,
extra=dict(final_conv_kernel=1, ),
loss_keypoint=dict(
type='JointsMSELoss', use_target_weight=True)),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process='default',
shift_heatmap=True,
modulate_kernel=11))
fake_inputs = gen_fake_data(input_shape=(1, 3, 256, 256))
train_output, test_output = self._test_topdown(fake_inputs, model_cfg)
self.assertTrue(isinstance(train_output, dict))
self.assertIn('mse_loss', train_output)
self.assertTrue(isinstance(test_output, dict))
self.assertIn('preds', test_output)
if __name__ == '__main__':
unittest.main()