mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
support export blade model for Stgcn (#299)
* support blade for stgcn and add unittest
This commit is contained in:
parent
c062b01b3f
commit
5c33d9e2f9
@ -197,7 +197,8 @@ eval_pipelines = [
|
||||
evaluators=[dict(type='CoCoPoseTopDownEvaluator', **evaluator_args)])
|
||||
]
|
||||
checkpoint_sync_export = True
|
||||
export = dict(use_jit=False)
|
||||
export = dict(type='raw')
|
||||
# export = dict(type='jit')
|
||||
# export = dict(
|
||||
# type='blade',
|
||||
# blade_config=dict(
|
||||
|
@ -118,3 +118,15 @@ eval_pipelines = [
|
||||
|
||||
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
|
||||
checkpoint_config = dict(interval=1)
|
||||
|
||||
export = dict(type='raw')
|
||||
# export = dict(type='jit')
|
||||
# export = dict(
|
||||
# type='blade',
|
||||
# blade_config=dict(
|
||||
# enable_fp16=True,
|
||||
# fp16_fallback_op_ratio=0.0,
|
||||
# customize_op_black_list=[
|
||||
# 'aten::select', 'aten::index', 'aten::slice', 'aten::view',
|
||||
# 'aten::upsample', 'aten::clamp', 'aten::clone'
|
||||
# ]))
|
||||
|
@ -14,7 +14,7 @@ from mmcv.utils import Config
|
||||
from easycv.file import io
|
||||
from easycv.framework.errors import NotImplementedError, ValueError
|
||||
from easycv.models import (DINO, MOCO, SWAV, YOLOX, BEVFormer, Classification,
|
||||
MoBY, TopDown, build_model)
|
||||
MoBY, SkeletonGCN, TopDown, build_model)
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.misc import encode_str_to_tensor
|
||||
|
||||
@ -68,6 +68,8 @@ def export(cfg, ckpt_path, filename, model=None, **kwargs):
|
||||
_export_bevformer(model, cfg, filename, **kwargs)
|
||||
elif isinstance(model, TopDown):
|
||||
_export_pose_topdown(model, cfg, filename, **kwargs)
|
||||
elif isinstance(model, SkeletonGCN):
|
||||
_export_stgcn(model, cfg, filename, **kwargs)
|
||||
elif hasattr(cfg, 'export') and getattr(cfg.export, 'use_jit', False):
|
||||
export_jit_model(model, cfg, filename, **kwargs)
|
||||
return
|
||||
@ -98,6 +100,63 @@ def _export_common(model, cfg, filename):
|
||||
torch.save(checkpoint, ofile)
|
||||
|
||||
|
||||
def _export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=False):
|
||||
|
||||
def _trace_model():
|
||||
with torch.no_grad():
|
||||
if hasattr(model, 'forward_export'):
|
||||
model.forward = model.forward_export
|
||||
else:
|
||||
model.forward = model.forward_test
|
||||
trace_model = torch.jit.trace(
|
||||
model,
|
||||
copy.deepcopy(dummy_inputs),
|
||||
strict=False,
|
||||
check_trace=False)
|
||||
return trace_model
|
||||
|
||||
export_type = cfg.export.get('type')
|
||||
if export_type in ['jit', 'blade']:
|
||||
if fp16:
|
||||
with torch.cuda.amp.autocast():
|
||||
trace_model = _trace_model()
|
||||
else:
|
||||
trace_model = _trace_model()
|
||||
torch.jit.save(trace_model, filename + '.jit')
|
||||
else:
|
||||
raise NotImplementedError(f'Not support export type {export_type}!')
|
||||
|
||||
if export_type == 'jit':
|
||||
return
|
||||
|
||||
blade_config = cfg.export.get('blade_config')
|
||||
|
||||
from easycv.toolkit.blade import blade_env_assert, blade_optimize
|
||||
assert blade_env_assert()
|
||||
|
||||
def _get_blade_model():
|
||||
blade_model = blade_optimize(
|
||||
speed_test_model=model,
|
||||
model=trace_model,
|
||||
inputs=copy.deepcopy(dummy_inputs),
|
||||
blade_config=blade_config,
|
||||
static_opt=False,
|
||||
min_num_nodes=None,
|
||||
check_inputs=False,
|
||||
fp16=fp16)
|
||||
return blade_model
|
||||
|
||||
# optimize model with blade
|
||||
if fp16:
|
||||
with torch.cuda.amp.autocast():
|
||||
blade_model = _get_blade_model()
|
||||
else:
|
||||
blade_model = _get_blade_model()
|
||||
|
||||
with io.open(filename + '.blade', 'wb') as ofile:
|
||||
torch.jit.save(blade_model, ofile)
|
||||
|
||||
|
||||
def _export_cls(model, cfg, filename):
|
||||
""" export cls (cls & metric learning)model and preprocess config
|
||||
|
||||
@ -540,7 +599,7 @@ def export_jit_model(model, cfg, filename):
|
||||
torch.jit.save(model_jit, ofile)
|
||||
|
||||
|
||||
def _export_bevformer(model, cfg, filename, fp16=False):
|
||||
def _export_bevformer(model, cfg, filename, fp16=False, dummy_inputs=None):
|
||||
if not cfg.adapt_jit:
|
||||
raise ValueError(
|
||||
'"cfg.adapt_jit" must be True when export jit trace or blade model.'
|
||||
@ -578,60 +637,10 @@ def _export_bevformer(model, cfg, filename, fp16=False):
|
||||
}
|
||||
return img, img_metas
|
||||
|
||||
dummy_inputs = _dummy_inputs()
|
||||
if dummy_inputs is None:
|
||||
dummy_inputs = _dummy_inputs()
|
||||
|
||||
def _trace_model():
|
||||
with torch.no_grad():
|
||||
model.forward = model.forward_export
|
||||
trace_model = torch.jit.trace(
|
||||
model, copy.deepcopy(dummy_inputs), check_trace=False)
|
||||
return trace_model
|
||||
|
||||
export_type = cfg.export.get('type')
|
||||
if export_type in ['jit', 'blade']:
|
||||
if fp16:
|
||||
with torch.cuda.amp.autocast():
|
||||
trace_model = _trace_model()
|
||||
else:
|
||||
trace_model = _trace_model()
|
||||
torch.jit.save(trace_model, filename + '.jit')
|
||||
else:
|
||||
raise NotImplementedError(f'Not support export type {export_type}!')
|
||||
|
||||
if export_type == 'jit':
|
||||
return
|
||||
|
||||
blade_config = cfg.export.get('blade_config')
|
||||
|
||||
from easycv.toolkit.blade import blade_env_assert, blade_optimize
|
||||
assert blade_env_assert()
|
||||
|
||||
def _get_blade_model():
|
||||
blade_model = blade_optimize(
|
||||
speed_test_model=model,
|
||||
model=trace_model,
|
||||
inputs=copy.deepcopy(dummy_inputs),
|
||||
blade_config=blade_config,
|
||||
static_opt=False,
|
||||
min_num_nodes=None, # 50
|
||||
check_inputs=False,
|
||||
fp16=fp16)
|
||||
return blade_model
|
||||
|
||||
# optimize model with blade
|
||||
if fp16:
|
||||
with torch.cuda.amp.autocast():
|
||||
blade_model = _get_blade_model()
|
||||
else:
|
||||
blade_model = _get_blade_model()
|
||||
|
||||
# save blade code and graph
|
||||
# with io.open(filename + '.blade.code.py', 'w') as ofile:
|
||||
# ofile.write(blade_model.forward.code)
|
||||
# with io.open(filename + '.blade.graph.txt', 'w') as ofile:
|
||||
# ofile.write(blade_model.forward.graph)
|
||||
with io.open(filename + '.blade', 'wb') as ofile:
|
||||
torch.jit.save(blade_model, ofile)
|
||||
_export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)
|
||||
|
||||
|
||||
def _export_pose_topdown(model, cfg, filename, fp16=False, dummy_inputs=None):
|
||||
@ -672,53 +681,26 @@ def _export_pose_topdown(model, cfg, filename, fp16=False, dummy_inputs=None):
|
||||
if dummy_inputs is None:
|
||||
dummy_inputs = _dummy_inputs(cfg)
|
||||
|
||||
def _trace_model():
|
||||
with torch.no_grad():
|
||||
model.forward = model.forward_export
|
||||
trace_model = torch.jit.trace(
|
||||
model, copy.deepcopy(dummy_inputs), strict=False)
|
||||
return trace_model
|
||||
_export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)
|
||||
|
||||
export_type = cfg.export.get('type')
|
||||
if export_type in ['jit', 'blade']:
|
||||
if fp16:
|
||||
with torch.cuda.amp.autocast():
|
||||
trace_model = _trace_model()
|
||||
else:
|
||||
trace_model = _trace_model()
|
||||
torch.jit.save(trace_model, filename + '.jit')
|
||||
else:
|
||||
raise NotImplementedError(f'Not support export type {export_type}!')
|
||||
|
||||
if export_type == 'jit':
|
||||
return
|
||||
def _export_stgcn(model, cfg, filename, fp16=False, dummy_inputs=None):
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = copy.deepcopy(model)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
blade_config = cfg.export.get('blade_config')
|
||||
if hasattr(cfg, 'export') and getattr(cfg.export, 'type', 'raw') == 'raw':
|
||||
return _export_common(model, cfg, filename)
|
||||
|
||||
from easycv.toolkit.blade import blade_env_assert, blade_optimize
|
||||
assert blade_env_assert()
|
||||
def _dummy_inputs(device):
|
||||
keypoints = torch.randn([1, 3, 300, 17, 2]).to(device)
|
||||
return (keypoints, )
|
||||
|
||||
def _get_blade_model():
|
||||
blade_model = blade_optimize(
|
||||
speed_test_model=model,
|
||||
model=trace_model,
|
||||
inputs=copy.deepcopy(dummy_inputs),
|
||||
blade_config=blade_config,
|
||||
static_opt=False,
|
||||
min_num_nodes=None,
|
||||
check_inputs=False,
|
||||
fp16=fp16)
|
||||
return blade_model
|
||||
if dummy_inputs is None:
|
||||
dummy_inputs = _dummy_inputs(device)
|
||||
|
||||
# optimize model with blade
|
||||
if fp16:
|
||||
with torch.cuda.amp.autocast():
|
||||
blade_model = _get_blade_model()
|
||||
else:
|
||||
blade_model = _get_blade_model()
|
||||
|
||||
with io.open(filename + '.blade', 'wb') as ofile:
|
||||
torch.jit.save(blade_model, ofile)
|
||||
_export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)
|
||||
|
||||
|
||||
def replace_syncbn(backbone_cfg):
|
||||
|
@ -339,16 +339,19 @@ class YoloXPredictor(DetectionPredictor):
|
||||
nms_thresh=None,
|
||||
test_conf=None,
|
||||
input_processor_threads=8,
|
||||
mode='BGR'):
|
||||
mode='BGR',
|
||||
model_type=None):
|
||||
self.max_det = max_det
|
||||
self.use_trt_efficientnms = use_trt_efficientnms
|
||||
|
||||
if model_path.endswith('jit'):
|
||||
self.model_type = 'jit'
|
||||
elif model_path.endswith('blade'):
|
||||
self.model_type = 'blade'
|
||||
else:
|
||||
self.model_type = 'raw'
|
||||
self.model_type = model_type
|
||||
if self.model_type is None:
|
||||
if model_path.endswith('jit'):
|
||||
self.model_type = 'jit'
|
||||
elif model_path.endswith('blade'):
|
||||
self.model_type = 'blade'
|
||||
else:
|
||||
self.model_type = 'raw'
|
||||
assert self.model_type in ['raw', 'jit', 'blade']
|
||||
|
||||
if self.model_type == 'blade' or self.use_trt_efficientnms:
|
||||
import torch_blade
|
||||
|
@ -306,6 +306,9 @@ class PoseTopDownOutputProcessor(OutputProcessor):
|
||||
return output
|
||||
|
||||
|
||||
# TODO: Fix when multi people are detected in each sample,
|
||||
# all the people results will be passed to the pose model,
|
||||
# resulting in a dynamic batch_size, which is not supported by jit script model.
|
||||
@PREDICTORS.register_module()
|
||||
class PoseTopDownPredictor(PredictorV2):
|
||||
"""Pose topdown predictor.
|
||||
@ -336,6 +339,7 @@ class PoseTopDownPredictor(PredictorV2):
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
mode='BGR',
|
||||
model_type=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
assert batch_size == 1, 'Only support batch_size=1 now!'
|
||||
@ -343,15 +347,18 @@ class PoseTopDownPredictor(PredictorV2):
|
||||
self.bbox_thr = bbox_thr
|
||||
self.detection_predictor_config = detection_predictor_config
|
||||
|
||||
if model_path.endswith('jit'):
|
||||
assert config_file is not None
|
||||
self.model_type = 'jit'
|
||||
elif model_path.endswith('blade'):
|
||||
import torch_blade
|
||||
assert config_file is not None
|
||||
self.model_type = 'blade'
|
||||
else:
|
||||
self.model_type = 'raw'
|
||||
self.model_type = model_type
|
||||
if self.model_type is None:
|
||||
if model_path.endswith('jit'):
|
||||
assert config_file is not None
|
||||
self.model_type = 'jit'
|
||||
elif model_path.endswith('blade'):
|
||||
import torch_blade
|
||||
assert config_file is not None
|
||||
self.model_type = 'blade'
|
||||
else:
|
||||
self.model_type = 'raw'
|
||||
assert self.model_type in ['raw', 'jit', 'blade']
|
||||
|
||||
super(PoseTopDownPredictor, self).__init__(
|
||||
model_path,
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
from easycv.datasets.registry import PIPELINES
|
||||
from easycv.file import io
|
||||
from easycv.models.builder import build_model
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
|
||||
remove_adapt_for_mmlab)
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
@ -262,8 +263,22 @@ class STGCNPredictor(PredictorV2):
|
||||
pipelines=None,
|
||||
input_processor_threads=8,
|
||||
mode='RGB',
|
||||
model_type=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
self.model_type = model_type
|
||||
if self.model_type is None:
|
||||
if model_path.endswith('jit'):
|
||||
assert config_file is not None
|
||||
self.model_type = 'jit'
|
||||
elif model_path.endswith('blade'):
|
||||
import torch_blade
|
||||
assert config_file is not None
|
||||
self.model_type = 'blade'
|
||||
else:
|
||||
self.model_type = 'raw'
|
||||
assert self.model_type in ['raw', 'jit', 'blade']
|
||||
|
||||
super(STGCNPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file=config_file,
|
||||
@ -301,6 +316,35 @@ class STGCNPredictor(PredictorV2):
|
||||
|
||||
self.label_map = [i.strip() for i in class_list]
|
||||
|
||||
def _build_model(self):
|
||||
if self.model_type != 'raw':
|
||||
with io.open(self.model_path, 'rb') as infile:
|
||||
model = torch.jit.load(infile, self.device)
|
||||
else:
|
||||
model = super()._build_model()
|
||||
return model
|
||||
|
||||
def prepare_model(self):
|
||||
"""Build model from config file by default.
|
||||
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
|
||||
"""
|
||||
model = self._build_model()
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
if self.model_type == 'raw':
|
||||
load_checkpoint(model, self.model_path, map_location='cpu')
|
||||
return model
|
||||
|
||||
def model_forward(self, inputs):
|
||||
if self.model_type == 'raw':
|
||||
return super().model_forward(inputs)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
keypoint = inputs['keypoint'].to(self.device)
|
||||
result = self.model(keypoint)
|
||||
|
||||
return result
|
||||
|
||||
def get_input_processor(self):
|
||||
return STGCNInputProcessor(
|
||||
self.cfg,
|
||||
|
@ -187,6 +187,9 @@ class WholeBodyKptsOutputProcessor(OutputProcessor):
|
||||
return output
|
||||
|
||||
|
||||
# TODO: Fix when multi people are detected in each sample,
|
||||
# all the people results will be passed to the pose model,
|
||||
# resulting in a dynamic batch_size, which is not supported by jit script model.
|
||||
@PREDICTORS.register_module()
|
||||
class WholeBodyKeypointsPredictor(PredictorV2):
|
||||
"""WholeBodyKeypointsPredictor
|
||||
@ -213,17 +216,21 @@ class WholeBodyKeypointsPredictor(PredictorV2):
|
||||
save_path=None,
|
||||
bbox_thr=None,
|
||||
mode='BGR',
|
||||
model_type=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
if model_path.endswith('jit'):
|
||||
assert config_file is not None
|
||||
self.model_type = 'jit'
|
||||
elif model_path.endswith('blade'):
|
||||
import torch_blade
|
||||
assert config_file is not None
|
||||
self.model_type = 'blade'
|
||||
else:
|
||||
self.model_type = 'raw'
|
||||
self.model_type = model_type
|
||||
if self.model_type is None:
|
||||
if model_path.endswith('jit'):
|
||||
assert config_file is not None
|
||||
self.model_type = 'jit'
|
||||
elif model_path.endswith('blade'):
|
||||
import torch_blade
|
||||
assert config_file is not None
|
||||
self.model_type = 'blade'
|
||||
else:
|
||||
self.model_type = 'raw'
|
||||
assert self.model_type in ['raw', 'jit', 'blade']
|
||||
|
||||
super(WholeBodyKeypointsPredictor, self).__init__(
|
||||
model_path,
|
||||
|
@ -8,7 +8,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tests.ut_config import (IMAGENET_LABEL_TXT,
|
||||
from tests.ut_config import (BASE_LOCAL_PATH, IMAGENET_LABEL_TXT,
|
||||
PRETRAINED_MODEL_BEVFORMER_BASE,
|
||||
PRETRAINED_MODEL_MOCO, PRETRAINED_MODEL_RESNET50,
|
||||
PRETRAINED_MODEL_YOLOXS_EXPORT)
|
||||
@ -165,6 +165,51 @@ class ModelExportTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(os.path.exists(filename + '.jit'))
|
||||
|
||||
def test_export_topdown_jit(self):
|
||||
ckpt_path = os.path.join(
|
||||
BASE_LOCAL_PATH,
|
||||
'pretrained_models/pose/hrnet/pose_hrnet_epoch_210_export.pt')
|
||||
|
||||
easycv_dir = os.path.dirname(easycv.__file__)
|
||||
if os.path.exists(os.path.join(easycv_dir, 'configs')):
|
||||
config_dir = os.path.join(easycv_dir, 'configs')
|
||||
else:
|
||||
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
|
||||
config_file = os.path.join(config_dir,
|
||||
'pose/hrnet_w48_coco_256x192_udp.py')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = mmcv_config_fromfile(config_file)
|
||||
cfg.export.type = 'jit'
|
||||
|
||||
filename = os.path.join(tmpdir, 'model.pth')
|
||||
export(cfg, ckpt_path, filename, fp16=False)
|
||||
|
||||
self.assertTrue(os.path.exists(filename + '.jit'))
|
||||
|
||||
def test_export_stgcn_jit(self):
|
||||
ckpt_path = os.path.join(
|
||||
BASE_LOCAL_PATH,
|
||||
'pretrained_models/video/stgcn/stgcn_80e_ntu60_xsub.pth')
|
||||
|
||||
easycv_dir = os.path.dirname(easycv.__file__)
|
||||
if os.path.exists(os.path.join(easycv_dir, 'configs')):
|
||||
config_dir = os.path.join(easycv_dir, 'configs')
|
||||
else:
|
||||
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
|
||||
config_file = os.path.join(
|
||||
config_dir,
|
||||
'video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = mmcv_config_fromfile(config_file)
|
||||
cfg.export.type = 'jit'
|
||||
|
||||
filename = os.path.join(tmpdir, 'model.pth')
|
||||
export(cfg, ckpt_path, filename, fp16=False)
|
||||
|
||||
self.assertTrue(os.path.exists(filename + '.jit'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -7,7 +7,7 @@ import cv2
|
||||
import numpy as np
|
||||
from numpy.testing import assert_array_almost_equal
|
||||
from PIL import Image
|
||||
from tests.ut_config import (POSE_DATA_SMALL_COCO_LOCAL,
|
||||
from tests.ut_config import (BASE_LOCAL_PATH, POSE_DATA_SMALL_COCO_LOCAL,
|
||||
PRETRAINED_MODEL_POSE_HRNET_EXPORT,
|
||||
PRETRAINED_MODEL_YOLOXS_EXPORT, TEST_IMAGES_DIR)
|
||||
|
||||
@ -20,23 +20,12 @@ class PoseTopDownPredictorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_pose_topdown(self):
|
||||
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
||||
pose_model_path = PRETRAINED_MODEL_POSE_HRNET_EXPORT
|
||||
def _base_test(self, predictor):
|
||||
img1 = os.path.join(POSE_DATA_SMALL_COCO_LOCAL,
|
||||
'images/000000067078.jpg')
|
||||
img2 = os.path.join(TEST_IMAGES_DIR, 'crowdpose_100024.jpg')
|
||||
input_data_list = [img1, img2]
|
||||
|
||||
predictor = PoseTopDownPredictor(
|
||||
model_path=pose_model_path,
|
||||
detection_predictor_config=dict(
|
||||
type='YoloXPredictor',
|
||||
model_path=detection_model_path,
|
||||
score_thresh=0.5),
|
||||
cat_id=0,
|
||||
batch_size=1)
|
||||
|
||||
results = predictor(input_data_list)
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
@ -117,6 +106,44 @@ class PoseTopDownPredictorTest(unittest.TestCase):
|
||||
cv2.imwrite(tmp_save_path, vis_result)
|
||||
assert os.path.exists(tmp_save_path)
|
||||
|
||||
def test_pose_topdown(self):
|
||||
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
||||
pose_model_path = PRETRAINED_MODEL_POSE_HRNET_EXPORT
|
||||
|
||||
predictor = PoseTopDownPredictor(
|
||||
model_path=pose_model_path,
|
||||
detection_predictor_config=dict(
|
||||
type='YoloXPredictor',
|
||||
model_path=detection_model_path,
|
||||
score_thresh=0.5),
|
||||
cat_id=0,
|
||||
batch_size=1)
|
||||
|
||||
self._base_test(predictor)
|
||||
|
||||
def test_pose_topdown_jit(self):
|
||||
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
||||
pose_model_path = os.path.join(
|
||||
BASE_LOCAL_PATH,
|
||||
'pretrained_models/pose/hrnet/pose_hrnet_epoch_210_export.pth.jit')
|
||||
|
||||
config_file = 'configs/pose/hrnet_w48_coco_256x192_udp.py'
|
||||
|
||||
predictor = PoseTopDownPredictor(
|
||||
model_path=pose_model_path,
|
||||
config_file=config_file,
|
||||
detection_predictor_config=dict(
|
||||
type='YoloXPredictor',
|
||||
model_path=detection_model_path,
|
||||
score_thresh=0.5),
|
||||
cat_id=0,
|
||||
batch_size=1)
|
||||
|
||||
img = os.path.join(TEST_IMAGES_DIR, 'im00025.png')
|
||||
input_data_list = [img, img]
|
||||
results = predictor(input_data_list)
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
|
||||
class TorchPoseTopDownPredictorWithDetectorTest(unittest.TestCase):
|
||||
|
||||
|
@ -84,6 +84,35 @@ class STGCNPredictorTest(unittest.TestCase):
|
||||
self.assertIn('class_name', results)
|
||||
self.assertEqual(len(results['class_probs']), 60)
|
||||
|
||||
def test_jit(self):
|
||||
checkpoint = os.path.join(
|
||||
BASE_LOCAL_PATH,
|
||||
'pretrained_models/video/stgcn/stgcn_80e_ntu60_xsub.pth.jit')
|
||||
|
||||
config_file = 'configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py'
|
||||
predict_op = STGCNPredictor(
|
||||
model_path=checkpoint, config_file=config_file)
|
||||
|
||||
h, w = 480, 853
|
||||
total_frames = 20
|
||||
num_person = 2
|
||||
inp = dict(
|
||||
frame_dir='',
|
||||
label=-1,
|
||||
img_shape=(h, w),
|
||||
original_shape=(h, w),
|
||||
start_index=0,
|
||||
modality='Pose',
|
||||
total_frames=total_frames,
|
||||
keypoint=np.random.random((num_person, total_frames, 17, 2)),
|
||||
keypoint_score=np.random.random((num_person, total_frames, 17)),
|
||||
)
|
||||
|
||||
results = predict_op([inp])[0]
|
||||
self.assertIn('class', results)
|
||||
self.assertIn('class_name', results)
|
||||
self.assertEqual(len(results['class_probs']), 60)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user