support export blade model for Stgcn (#299)

* support blade for stgcn and add unittest
This commit is contained in:
Cathy0908 2023-03-06 10:19:39 +08:00 committed by GitHub
parent c062b01b3f
commit 5c33d9e2f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 294 additions and 137 deletions

View File

@ -197,7 +197,8 @@ eval_pipelines = [
evaluators=[dict(type='CoCoPoseTopDownEvaluator', **evaluator_args)]) evaluators=[dict(type='CoCoPoseTopDownEvaluator', **evaluator_args)])
] ]
checkpoint_sync_export = True checkpoint_sync_export = True
export = dict(use_jit=False) export = dict(type='raw')
# export = dict(type='jit')
# export = dict( # export = dict(
# type='blade', # type='blade',
# blade_config=dict( # blade_config=dict(

View File

@ -118,3 +118,15 @@ eval_pipelines = [
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')]) log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
checkpoint_config = dict(interval=1) checkpoint_config = dict(interval=1)
export = dict(type='raw')
# export = dict(type='jit')
# export = dict(
# type='blade',
# blade_config=dict(
# enable_fp16=True,
# fp16_fallback_op_ratio=0.0,
# customize_op_black_list=[
# 'aten::select', 'aten::index', 'aten::slice', 'aten::view',
# 'aten::upsample', 'aten::clamp', 'aten::clone'
# ]))

View File

@ -14,7 +14,7 @@ from mmcv.utils import Config
from easycv.file import io from easycv.file import io
from easycv.framework.errors import NotImplementedError, ValueError from easycv.framework.errors import NotImplementedError, ValueError
from easycv.models import (DINO, MOCO, SWAV, YOLOX, BEVFormer, Classification, from easycv.models import (DINO, MOCO, SWAV, YOLOX, BEVFormer, Classification,
MoBY, TopDown, build_model) MoBY, SkeletonGCN, TopDown, build_model)
from easycv.utils.checkpoint import load_checkpoint from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.misc import encode_str_to_tensor from easycv.utils.misc import encode_str_to_tensor
@ -68,6 +68,8 @@ def export(cfg, ckpt_path, filename, model=None, **kwargs):
_export_bevformer(model, cfg, filename, **kwargs) _export_bevformer(model, cfg, filename, **kwargs)
elif isinstance(model, TopDown): elif isinstance(model, TopDown):
_export_pose_topdown(model, cfg, filename, **kwargs) _export_pose_topdown(model, cfg, filename, **kwargs)
elif isinstance(model, SkeletonGCN):
_export_stgcn(model, cfg, filename, **kwargs)
elif hasattr(cfg, 'export') and getattr(cfg.export, 'use_jit', False): elif hasattr(cfg, 'export') and getattr(cfg.export, 'use_jit', False):
export_jit_model(model, cfg, filename, **kwargs) export_jit_model(model, cfg, filename, **kwargs)
return return
@ -98,6 +100,63 @@ def _export_common(model, cfg, filename):
torch.save(checkpoint, ofile) torch.save(checkpoint, ofile)
def _export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=False):
def _trace_model():
with torch.no_grad():
if hasattr(model, 'forward_export'):
model.forward = model.forward_export
else:
model.forward = model.forward_test
trace_model = torch.jit.trace(
model,
copy.deepcopy(dummy_inputs),
strict=False,
check_trace=False)
return trace_model
export_type = cfg.export.get('type')
if export_type in ['jit', 'blade']:
if fp16:
with torch.cuda.amp.autocast():
trace_model = _trace_model()
else:
trace_model = _trace_model()
torch.jit.save(trace_model, filename + '.jit')
else:
raise NotImplementedError(f'Not support export type {export_type}!')
if export_type == 'jit':
return
blade_config = cfg.export.get('blade_config')
from easycv.toolkit.blade import blade_env_assert, blade_optimize
assert blade_env_assert()
def _get_blade_model():
blade_model = blade_optimize(
speed_test_model=model,
model=trace_model,
inputs=copy.deepcopy(dummy_inputs),
blade_config=blade_config,
static_opt=False,
min_num_nodes=None,
check_inputs=False,
fp16=fp16)
return blade_model
# optimize model with blade
if fp16:
with torch.cuda.amp.autocast():
blade_model = _get_blade_model()
else:
blade_model = _get_blade_model()
with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(blade_model, ofile)
def _export_cls(model, cfg, filename): def _export_cls(model, cfg, filename):
""" export cls (cls & metric learning)model and preprocess config """ export cls (cls & metric learning)model and preprocess config
@ -540,7 +599,7 @@ def export_jit_model(model, cfg, filename):
torch.jit.save(model_jit, ofile) torch.jit.save(model_jit, ofile)
def _export_bevformer(model, cfg, filename, fp16=False): def _export_bevformer(model, cfg, filename, fp16=False, dummy_inputs=None):
if not cfg.adapt_jit: if not cfg.adapt_jit:
raise ValueError( raise ValueError(
'"cfg.adapt_jit" must be True when export jit trace or blade model.' '"cfg.adapt_jit" must be True when export jit trace or blade model.'
@ -578,60 +637,10 @@ def _export_bevformer(model, cfg, filename, fp16=False):
} }
return img, img_metas return img, img_metas
dummy_inputs = _dummy_inputs() if dummy_inputs is None:
dummy_inputs = _dummy_inputs()
def _trace_model(): _export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)
with torch.no_grad():
model.forward = model.forward_export
trace_model = torch.jit.trace(
model, copy.deepcopy(dummy_inputs), check_trace=False)
return trace_model
export_type = cfg.export.get('type')
if export_type in ['jit', 'blade']:
if fp16:
with torch.cuda.amp.autocast():
trace_model = _trace_model()
else:
trace_model = _trace_model()
torch.jit.save(trace_model, filename + '.jit')
else:
raise NotImplementedError(f'Not support export type {export_type}!')
if export_type == 'jit':
return
blade_config = cfg.export.get('blade_config')
from easycv.toolkit.blade import blade_env_assert, blade_optimize
assert blade_env_assert()
def _get_blade_model():
blade_model = blade_optimize(
speed_test_model=model,
model=trace_model,
inputs=copy.deepcopy(dummy_inputs),
blade_config=blade_config,
static_opt=False,
min_num_nodes=None, # 50
check_inputs=False,
fp16=fp16)
return blade_model
# optimize model with blade
if fp16:
with torch.cuda.amp.autocast():
blade_model = _get_blade_model()
else:
blade_model = _get_blade_model()
# save blade code and graph
# with io.open(filename + '.blade.code.py', 'w') as ofile:
# ofile.write(blade_model.forward.code)
# with io.open(filename + '.blade.graph.txt', 'w') as ofile:
# ofile.write(blade_model.forward.graph)
with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(blade_model, ofile)
def _export_pose_topdown(model, cfg, filename, fp16=False, dummy_inputs=None): def _export_pose_topdown(model, cfg, filename, fp16=False, dummy_inputs=None):
@ -672,53 +681,26 @@ def _export_pose_topdown(model, cfg, filename, fp16=False, dummy_inputs=None):
if dummy_inputs is None: if dummy_inputs is None:
dummy_inputs = _dummy_inputs(cfg) dummy_inputs = _dummy_inputs(cfg)
def _trace_model(): _export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)
with torch.no_grad():
model.forward = model.forward_export
trace_model = torch.jit.trace(
model, copy.deepcopy(dummy_inputs), strict=False)
return trace_model
export_type = cfg.export.get('type')
if export_type in ['jit', 'blade']:
if fp16:
with torch.cuda.amp.autocast():
trace_model = _trace_model()
else:
trace_model = _trace_model()
torch.jit.save(trace_model, filename + '.jit')
else:
raise NotImplementedError(f'Not support export type {export_type}!')
if export_type == 'jit': def _export_stgcn(model, cfg, filename, fp16=False, dummy_inputs=None):
return device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = copy.deepcopy(model)
model.eval()
model.to(device)
blade_config = cfg.export.get('blade_config') if hasattr(cfg, 'export') and getattr(cfg.export, 'type', 'raw') == 'raw':
return _export_common(model, cfg, filename)
from easycv.toolkit.blade import blade_env_assert, blade_optimize def _dummy_inputs(device):
assert blade_env_assert() keypoints = torch.randn([1, 3, 300, 17, 2]).to(device)
return (keypoints, )
def _get_blade_model(): if dummy_inputs is None:
blade_model = blade_optimize( dummy_inputs = _dummy_inputs(device)
speed_test_model=model,
model=trace_model,
inputs=copy.deepcopy(dummy_inputs),
blade_config=blade_config,
static_opt=False,
min_num_nodes=None,
check_inputs=False,
fp16=fp16)
return blade_model
# optimize model with blade _export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)
if fp16:
with torch.cuda.amp.autocast():
blade_model = _get_blade_model()
else:
blade_model = _get_blade_model()
with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(blade_model, ofile)
def replace_syncbn(backbone_cfg): def replace_syncbn(backbone_cfg):

View File

@ -339,16 +339,19 @@ class YoloXPredictor(DetectionPredictor):
nms_thresh=None, nms_thresh=None,
test_conf=None, test_conf=None,
input_processor_threads=8, input_processor_threads=8,
mode='BGR'): mode='BGR',
model_type=None):
self.max_det = max_det self.max_det = max_det
self.use_trt_efficientnms = use_trt_efficientnms self.use_trt_efficientnms = use_trt_efficientnms
self.model_type = model_type
if model_path.endswith('jit'): if self.model_type is None:
self.model_type = 'jit' if model_path.endswith('jit'):
elif model_path.endswith('blade'): self.model_type = 'jit'
self.model_type = 'blade' elif model_path.endswith('blade'):
else: self.model_type = 'blade'
self.model_type = 'raw' else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']
if self.model_type == 'blade' or self.use_trt_efficientnms: if self.model_type == 'blade' or self.use_trt_efficientnms:
import torch_blade import torch_blade

View File

@ -306,6 +306,9 @@ class PoseTopDownOutputProcessor(OutputProcessor):
return output return output
# TODO: Fix when multi people are detected in each sample,
# all the people results will be passed to the pose model,
# resulting in a dynamic batch_size, which is not supported by jit script model.
@PREDICTORS.register_module() @PREDICTORS.register_module()
class PoseTopDownPredictor(PredictorV2): class PoseTopDownPredictor(PredictorV2):
"""Pose topdown predictor. """Pose topdown predictor.
@ -336,6 +339,7 @@ class PoseTopDownPredictor(PredictorV2):
save_results=False, save_results=False,
save_path=None, save_path=None,
mode='BGR', mode='BGR',
model_type=None,
*args, *args,
**kwargs): **kwargs):
assert batch_size == 1, 'Only support batch_size=1 now!' assert batch_size == 1, 'Only support batch_size=1 now!'
@ -343,15 +347,18 @@ class PoseTopDownPredictor(PredictorV2):
self.bbox_thr = bbox_thr self.bbox_thr = bbox_thr
self.detection_predictor_config = detection_predictor_config self.detection_predictor_config = detection_predictor_config
if model_path.endswith('jit'): self.model_type = model_type
assert config_file is not None if self.model_type is None:
self.model_type = 'jit' if model_path.endswith('jit'):
elif model_path.endswith('blade'): assert config_file is not None
import torch_blade self.model_type = 'jit'
assert config_file is not None elif model_path.endswith('blade'):
self.model_type = 'blade' import torch_blade
else: assert config_file is not None
self.model_type = 'raw' self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']
super(PoseTopDownPredictor, self).__init__( super(PoseTopDownPredictor, self).__init__(
model_path, model_path,

View File

@ -6,6 +6,7 @@ import torch
from easycv.datasets.registry import PIPELINES from easycv.datasets.registry import PIPELINES
from easycv.file import io from easycv.file import io
from easycv.models.builder import build_model from easycv.models.builder import build_model
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab, from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
remove_adapt_for_mmlab) remove_adapt_for_mmlab)
from easycv.utils.registry import build_from_cfg from easycv.utils.registry import build_from_cfg
@ -262,8 +263,22 @@ class STGCNPredictor(PredictorV2):
pipelines=None, pipelines=None,
input_processor_threads=8, input_processor_threads=8,
mode='RGB', mode='RGB',
model_type=None,
*args, *args,
**kwargs): **kwargs):
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
assert config_file is not None
self.model_type = 'jit'
elif model_path.endswith('blade'):
import torch_blade
assert config_file is not None
self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']
super(STGCNPredictor, self).__init__( super(STGCNPredictor, self).__init__(
model_path, model_path,
config_file=config_file, config_file=config_file,
@ -301,6 +316,35 @@ class STGCNPredictor(PredictorV2):
self.label_map = [i.strip() for i in class_list] self.label_map = [i.strip() for i in class_list]
def _build_model(self):
if self.model_type != 'raw':
with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device)
else:
model = super()._build_model()
return model
def prepare_model(self):
"""Build model from config file by default.
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
model = self._build_model()
model.to(self.device)
model.eval()
if self.model_type == 'raw':
load_checkpoint(model, self.model_path, map_location='cpu')
return model
def model_forward(self, inputs):
if self.model_type == 'raw':
return super().model_forward(inputs)
else:
with torch.no_grad():
keypoint = inputs['keypoint'].to(self.device)
result = self.model(keypoint)
return result
def get_input_processor(self): def get_input_processor(self):
return STGCNInputProcessor( return STGCNInputProcessor(
self.cfg, self.cfg,

View File

@ -187,6 +187,9 @@ class WholeBodyKptsOutputProcessor(OutputProcessor):
return output return output
# TODO: Fix when multi people are detected in each sample,
# all the people results will be passed to the pose model,
# resulting in a dynamic batch_size, which is not supported by jit script model.
@PREDICTORS.register_module() @PREDICTORS.register_module()
class WholeBodyKeypointsPredictor(PredictorV2): class WholeBodyKeypointsPredictor(PredictorV2):
"""WholeBodyKeypointsPredictor """WholeBodyKeypointsPredictor
@ -213,17 +216,21 @@ class WholeBodyKeypointsPredictor(PredictorV2):
save_path=None, save_path=None,
bbox_thr=None, bbox_thr=None,
mode='BGR', mode='BGR',
model_type=None,
*args, *args,
**kwargs): **kwargs):
if model_path.endswith('jit'): self.model_type = model_type
assert config_file is not None if self.model_type is None:
self.model_type = 'jit' if model_path.endswith('jit'):
elif model_path.endswith('blade'): assert config_file is not None
import torch_blade self.model_type = 'jit'
assert config_file is not None elif model_path.endswith('blade'):
self.model_type = 'blade' import torch_blade
else: assert config_file is not None
self.model_type = 'raw' self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']
super(WholeBodyKeypointsPredictor, self).__init__( super(WholeBodyKeypointsPredictor, self).__init__(
model_path, model_path,

View File

@ -8,7 +8,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from tests.ut_config import (IMAGENET_LABEL_TXT, from tests.ut_config import (BASE_LOCAL_PATH, IMAGENET_LABEL_TXT,
PRETRAINED_MODEL_BEVFORMER_BASE, PRETRAINED_MODEL_BEVFORMER_BASE,
PRETRAINED_MODEL_MOCO, PRETRAINED_MODEL_RESNET50, PRETRAINED_MODEL_MOCO, PRETRAINED_MODEL_RESNET50,
PRETRAINED_MODEL_YOLOXS_EXPORT) PRETRAINED_MODEL_YOLOXS_EXPORT)
@ -165,6 +165,51 @@ class ModelExportTest(unittest.TestCase):
self.assertTrue(os.path.exists(filename + '.jit')) self.assertTrue(os.path.exists(filename + '.jit'))
def test_export_topdown_jit(self):
ckpt_path = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/pose/hrnet/pose_hrnet_epoch_210_export.pt')
easycv_dir = os.path.dirname(easycv.__file__)
if os.path.exists(os.path.join(easycv_dir, 'configs')):
config_dir = os.path.join(easycv_dir, 'configs')
else:
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
config_file = os.path.join(config_dir,
'pose/hrnet_w48_coco_256x192_udp.py')
with tempfile.TemporaryDirectory() as tmpdir:
cfg = mmcv_config_fromfile(config_file)
cfg.export.type = 'jit'
filename = os.path.join(tmpdir, 'model.pth')
export(cfg, ckpt_path, filename, fp16=False)
self.assertTrue(os.path.exists(filename + '.jit'))
def test_export_stgcn_jit(self):
ckpt_path = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/video/stgcn/stgcn_80e_ntu60_xsub.pth')
easycv_dir = os.path.dirname(easycv.__file__)
if os.path.exists(os.path.join(easycv_dir, 'configs')):
config_dir = os.path.join(easycv_dir, 'configs')
else:
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
config_file = os.path.join(
config_dir,
'video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py')
with tempfile.TemporaryDirectory() as tmpdir:
cfg = mmcv_config_fromfile(config_file)
cfg.export.type = 'jit'
filename = os.path.join(tmpdir, 'model.pth')
export(cfg, ckpt_path, filename, fp16=False)
self.assertTrue(os.path.exists(filename + '.jit'))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -7,7 +7,7 @@ import cv2
import numpy as np import numpy as np
from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_almost_equal
from PIL import Image from PIL import Image
from tests.ut_config import (POSE_DATA_SMALL_COCO_LOCAL, from tests.ut_config import (BASE_LOCAL_PATH, POSE_DATA_SMALL_COCO_LOCAL,
PRETRAINED_MODEL_POSE_HRNET_EXPORT, PRETRAINED_MODEL_POSE_HRNET_EXPORT,
PRETRAINED_MODEL_YOLOXS_EXPORT, TEST_IMAGES_DIR) PRETRAINED_MODEL_YOLOXS_EXPORT, TEST_IMAGES_DIR)
@ -20,23 +20,12 @@ class PoseTopDownPredictorTest(unittest.TestCase):
def setUp(self): def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_pose_topdown(self): def _base_test(self, predictor):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
pose_model_path = PRETRAINED_MODEL_POSE_HRNET_EXPORT
img1 = os.path.join(POSE_DATA_SMALL_COCO_LOCAL, img1 = os.path.join(POSE_DATA_SMALL_COCO_LOCAL,
'images/000000067078.jpg') 'images/000000067078.jpg')
img2 = os.path.join(TEST_IMAGES_DIR, 'crowdpose_100024.jpg') img2 = os.path.join(TEST_IMAGES_DIR, 'crowdpose_100024.jpg')
input_data_list = [img1, img2] input_data_list = [img1, img2]
predictor = PoseTopDownPredictor(
model_path=pose_model_path,
detection_predictor_config=dict(
type='YoloXPredictor',
model_path=detection_model_path,
score_thresh=0.5),
cat_id=0,
batch_size=1)
results = predictor(input_data_list) results = predictor(input_data_list)
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
@ -117,6 +106,44 @@ class PoseTopDownPredictorTest(unittest.TestCase):
cv2.imwrite(tmp_save_path, vis_result) cv2.imwrite(tmp_save_path, vis_result)
assert os.path.exists(tmp_save_path) assert os.path.exists(tmp_save_path)
def test_pose_topdown(self):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
pose_model_path = PRETRAINED_MODEL_POSE_HRNET_EXPORT
predictor = PoseTopDownPredictor(
model_path=pose_model_path,
detection_predictor_config=dict(
type='YoloXPredictor',
model_path=detection_model_path,
score_thresh=0.5),
cat_id=0,
batch_size=1)
self._base_test(predictor)
def test_pose_topdown_jit(self):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
pose_model_path = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/pose/hrnet/pose_hrnet_epoch_210_export.pth.jit')
config_file = 'configs/pose/hrnet_w48_coco_256x192_udp.py'
predictor = PoseTopDownPredictor(
model_path=pose_model_path,
config_file=config_file,
detection_predictor_config=dict(
type='YoloXPredictor',
model_path=detection_model_path,
score_thresh=0.5),
cat_id=0,
batch_size=1)
img = os.path.join(TEST_IMAGES_DIR, 'im00025.png')
input_data_list = [img, img]
results = predictor(input_data_list)
self.assertEqual(len(results), 2)
class TorchPoseTopDownPredictorWithDetectorTest(unittest.TestCase): class TorchPoseTopDownPredictorWithDetectorTest(unittest.TestCase):

View File

@ -84,6 +84,35 @@ class STGCNPredictorTest(unittest.TestCase):
self.assertIn('class_name', results) self.assertIn('class_name', results)
self.assertEqual(len(results['class_probs']), 60) self.assertEqual(len(results['class_probs']), 60)
def test_jit(self):
checkpoint = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/video/stgcn/stgcn_80e_ntu60_xsub.pth.jit')
config_file = 'configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py'
predict_op = STGCNPredictor(
model_path=checkpoint, config_file=config_file)
h, w = 480, 853
total_frames = 20
num_person = 2
inp = dict(
frame_dir='',
label=-1,
img_shape=(h, w),
original_shape=(h, w),
start_index=0,
modality='Pose',
total_frames=total_frames,
keypoint=np.random.random((num_person, total_frames, 17, 2)),
keypoint_score=np.random.random((num_person, total_frames, 17)),
)
results = predict_op([inp])[0]
self.assertIn('class', results)
self.assertIn('class_name', results)
self.assertEqual(len(results['class_probs']), 60)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()