support export blade model for Stgcn (#299)

* support blade for stgcn and add unittest
This commit is contained in:
Cathy0908 2023-03-06 10:19:39 +08:00 committed by GitHub
parent c062b01b3f
commit 5c33d9e2f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 294 additions and 137 deletions

View File

@ -197,7 +197,8 @@ eval_pipelines = [
evaluators=[dict(type='CoCoPoseTopDownEvaluator', **evaluator_args)])
]
checkpoint_sync_export = True
export = dict(use_jit=False)
export = dict(type='raw')
# export = dict(type='jit')
# export = dict(
# type='blade',
# blade_config=dict(

View File

@ -118,3 +118,15 @@ eval_pipelines = [
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
checkpoint_config = dict(interval=1)
export = dict(type='raw')
# export = dict(type='jit')
# export = dict(
# type='blade',
# blade_config=dict(
# enable_fp16=True,
# fp16_fallback_op_ratio=0.0,
# customize_op_black_list=[
# 'aten::select', 'aten::index', 'aten::slice', 'aten::view',
# 'aten::upsample', 'aten::clamp', 'aten::clone'
# ]))

View File

@ -14,7 +14,7 @@ from mmcv.utils import Config
from easycv.file import io
from easycv.framework.errors import NotImplementedError, ValueError
from easycv.models import (DINO, MOCO, SWAV, YOLOX, BEVFormer, Classification,
MoBY, TopDown, build_model)
MoBY, SkeletonGCN, TopDown, build_model)
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.misc import encode_str_to_tensor
@ -68,6 +68,8 @@ def export(cfg, ckpt_path, filename, model=None, **kwargs):
_export_bevformer(model, cfg, filename, **kwargs)
elif isinstance(model, TopDown):
_export_pose_topdown(model, cfg, filename, **kwargs)
elif isinstance(model, SkeletonGCN):
_export_stgcn(model, cfg, filename, **kwargs)
elif hasattr(cfg, 'export') and getattr(cfg.export, 'use_jit', False):
export_jit_model(model, cfg, filename, **kwargs)
return
@ -98,6 +100,63 @@ def _export_common(model, cfg, filename):
torch.save(checkpoint, ofile)
def _export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=False):
def _trace_model():
with torch.no_grad():
if hasattr(model, 'forward_export'):
model.forward = model.forward_export
else:
model.forward = model.forward_test
trace_model = torch.jit.trace(
model,
copy.deepcopy(dummy_inputs),
strict=False,
check_trace=False)
return trace_model
export_type = cfg.export.get('type')
if export_type in ['jit', 'blade']:
if fp16:
with torch.cuda.amp.autocast():
trace_model = _trace_model()
else:
trace_model = _trace_model()
torch.jit.save(trace_model, filename + '.jit')
else:
raise NotImplementedError(f'Not support export type {export_type}!')
if export_type == 'jit':
return
blade_config = cfg.export.get('blade_config')
from easycv.toolkit.blade import blade_env_assert, blade_optimize
assert blade_env_assert()
def _get_blade_model():
blade_model = blade_optimize(
speed_test_model=model,
model=trace_model,
inputs=copy.deepcopy(dummy_inputs),
blade_config=blade_config,
static_opt=False,
min_num_nodes=None,
check_inputs=False,
fp16=fp16)
return blade_model
# optimize model with blade
if fp16:
with torch.cuda.amp.autocast():
blade_model = _get_blade_model()
else:
blade_model = _get_blade_model()
with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(blade_model, ofile)
def _export_cls(model, cfg, filename):
""" export cls (cls & metric learning)model and preprocess config
@ -540,7 +599,7 @@ def export_jit_model(model, cfg, filename):
torch.jit.save(model_jit, ofile)
def _export_bevformer(model, cfg, filename, fp16=False):
def _export_bevformer(model, cfg, filename, fp16=False, dummy_inputs=None):
if not cfg.adapt_jit:
raise ValueError(
'"cfg.adapt_jit" must be True when export jit trace or blade model.'
@ -578,60 +637,10 @@ def _export_bevformer(model, cfg, filename, fp16=False):
}
return img, img_metas
dummy_inputs = _dummy_inputs()
if dummy_inputs is None:
dummy_inputs = _dummy_inputs()
def _trace_model():
with torch.no_grad():
model.forward = model.forward_export
trace_model = torch.jit.trace(
model, copy.deepcopy(dummy_inputs), check_trace=False)
return trace_model
export_type = cfg.export.get('type')
if export_type in ['jit', 'blade']:
if fp16:
with torch.cuda.amp.autocast():
trace_model = _trace_model()
else:
trace_model = _trace_model()
torch.jit.save(trace_model, filename + '.jit')
else:
raise NotImplementedError(f'Not support export type {export_type}!')
if export_type == 'jit':
return
blade_config = cfg.export.get('blade_config')
from easycv.toolkit.blade import blade_env_assert, blade_optimize
assert blade_env_assert()
def _get_blade_model():
blade_model = blade_optimize(
speed_test_model=model,
model=trace_model,
inputs=copy.deepcopy(dummy_inputs),
blade_config=blade_config,
static_opt=False,
min_num_nodes=None, # 50
check_inputs=False,
fp16=fp16)
return blade_model
# optimize model with blade
if fp16:
with torch.cuda.amp.autocast():
blade_model = _get_blade_model()
else:
blade_model = _get_blade_model()
# save blade code and graph
# with io.open(filename + '.blade.code.py', 'w') as ofile:
# ofile.write(blade_model.forward.code)
# with io.open(filename + '.blade.graph.txt', 'w') as ofile:
# ofile.write(blade_model.forward.graph)
with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(blade_model, ofile)
_export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)
def _export_pose_topdown(model, cfg, filename, fp16=False, dummy_inputs=None):
@ -672,53 +681,26 @@ def _export_pose_topdown(model, cfg, filename, fp16=False, dummy_inputs=None):
if dummy_inputs is None:
dummy_inputs = _dummy_inputs(cfg)
def _trace_model():
with torch.no_grad():
model.forward = model.forward_export
trace_model = torch.jit.trace(
model, copy.deepcopy(dummy_inputs), strict=False)
return trace_model
_export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)
export_type = cfg.export.get('type')
if export_type in ['jit', 'blade']:
if fp16:
with torch.cuda.amp.autocast():
trace_model = _trace_model()
else:
trace_model = _trace_model()
torch.jit.save(trace_model, filename + '.jit')
else:
raise NotImplementedError(f'Not support export type {export_type}!')
if export_type == 'jit':
return
def _export_stgcn(model, cfg, filename, fp16=False, dummy_inputs=None):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = copy.deepcopy(model)
model.eval()
model.to(device)
blade_config = cfg.export.get('blade_config')
if hasattr(cfg, 'export') and getattr(cfg.export, 'type', 'raw') == 'raw':
return _export_common(model, cfg, filename)
from easycv.toolkit.blade import blade_env_assert, blade_optimize
assert blade_env_assert()
def _dummy_inputs(device):
keypoints = torch.randn([1, 3, 300, 17, 2]).to(device)
return (keypoints, )
def _get_blade_model():
blade_model = blade_optimize(
speed_test_model=model,
model=trace_model,
inputs=copy.deepcopy(dummy_inputs),
blade_config=blade_config,
static_opt=False,
min_num_nodes=None,
check_inputs=False,
fp16=fp16)
return blade_model
if dummy_inputs is None:
dummy_inputs = _dummy_inputs(device)
# optimize model with blade
if fp16:
with torch.cuda.amp.autocast():
blade_model = _get_blade_model()
else:
blade_model = _get_blade_model()
with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(blade_model, ofile)
_export_jit_and_blade(model, cfg, filename, dummy_inputs, fp16=fp16)
def replace_syncbn(backbone_cfg):

View File

@ -339,16 +339,19 @@ class YoloXPredictor(DetectionPredictor):
nms_thresh=None,
test_conf=None,
input_processor_threads=8,
mode='BGR'):
mode='BGR',
model_type=None):
self.max_det = max_det
self.use_trt_efficientnms = use_trt_efficientnms
if model_path.endswith('jit'):
self.model_type = 'jit'
elif model_path.endswith('blade'):
self.model_type = 'blade'
else:
self.model_type = 'raw'
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
self.model_type = 'jit'
elif model_path.endswith('blade'):
self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']
if self.model_type == 'blade' or self.use_trt_efficientnms:
import torch_blade

View File

@ -306,6 +306,9 @@ class PoseTopDownOutputProcessor(OutputProcessor):
return output
# TODO: Fix when multi people are detected in each sample,
# all the people results will be passed to the pose model,
# resulting in a dynamic batch_size, which is not supported by jit script model.
@PREDICTORS.register_module()
class PoseTopDownPredictor(PredictorV2):
"""Pose topdown predictor.
@ -336,6 +339,7 @@ class PoseTopDownPredictor(PredictorV2):
save_results=False,
save_path=None,
mode='BGR',
model_type=None,
*args,
**kwargs):
assert batch_size == 1, 'Only support batch_size=1 now!'
@ -343,15 +347,18 @@ class PoseTopDownPredictor(PredictorV2):
self.bbox_thr = bbox_thr
self.detection_predictor_config = detection_predictor_config
if model_path.endswith('jit'):
assert config_file is not None
self.model_type = 'jit'
elif model_path.endswith('blade'):
import torch_blade
assert config_file is not None
self.model_type = 'blade'
else:
self.model_type = 'raw'
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
assert config_file is not None
self.model_type = 'jit'
elif model_path.endswith('blade'):
import torch_blade
assert config_file is not None
self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']
super(PoseTopDownPredictor, self).__init__(
model_path,

View File

@ -6,6 +6,7 @@ import torch
from easycv.datasets.registry import PIPELINES
from easycv.file import io
from easycv.models.builder import build_model
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
remove_adapt_for_mmlab)
from easycv.utils.registry import build_from_cfg
@ -262,8 +263,22 @@ class STGCNPredictor(PredictorV2):
pipelines=None,
input_processor_threads=8,
mode='RGB',
model_type=None,
*args,
**kwargs):
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
assert config_file is not None
self.model_type = 'jit'
elif model_path.endswith('blade'):
import torch_blade
assert config_file is not None
self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']
super(STGCNPredictor, self).__init__(
model_path,
config_file=config_file,
@ -301,6 +316,35 @@ class STGCNPredictor(PredictorV2):
self.label_map = [i.strip() for i in class_list]
def _build_model(self):
if self.model_type != 'raw':
with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device)
else:
model = super()._build_model()
return model
def prepare_model(self):
"""Build model from config file by default.
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
model = self._build_model()
model.to(self.device)
model.eval()
if self.model_type == 'raw':
load_checkpoint(model, self.model_path, map_location='cpu')
return model
def model_forward(self, inputs):
if self.model_type == 'raw':
return super().model_forward(inputs)
else:
with torch.no_grad():
keypoint = inputs['keypoint'].to(self.device)
result = self.model(keypoint)
return result
def get_input_processor(self):
return STGCNInputProcessor(
self.cfg,

View File

@ -187,6 +187,9 @@ class WholeBodyKptsOutputProcessor(OutputProcessor):
return output
# TODO: Fix when multi people are detected in each sample,
# all the people results will be passed to the pose model,
# resulting in a dynamic batch_size, which is not supported by jit script model.
@PREDICTORS.register_module()
class WholeBodyKeypointsPredictor(PredictorV2):
"""WholeBodyKeypointsPredictor
@ -213,17 +216,21 @@ class WholeBodyKeypointsPredictor(PredictorV2):
save_path=None,
bbox_thr=None,
mode='BGR',
model_type=None,
*args,
**kwargs):
if model_path.endswith('jit'):
assert config_file is not None
self.model_type = 'jit'
elif model_path.endswith('blade'):
import torch_blade
assert config_file is not None
self.model_type = 'blade'
else:
self.model_type = 'raw'
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
assert config_file is not None
self.model_type = 'jit'
elif model_path.endswith('blade'):
import torch_blade
assert config_file is not None
self.model_type = 'blade'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']
super(WholeBodyKeypointsPredictor, self).__init__(
model_path,

View File

@ -8,7 +8,7 @@ import unittest
import numpy as np
import torch
from tests.ut_config import (IMAGENET_LABEL_TXT,
from tests.ut_config import (BASE_LOCAL_PATH, IMAGENET_LABEL_TXT,
PRETRAINED_MODEL_BEVFORMER_BASE,
PRETRAINED_MODEL_MOCO, PRETRAINED_MODEL_RESNET50,
PRETRAINED_MODEL_YOLOXS_EXPORT)
@ -165,6 +165,51 @@ class ModelExportTest(unittest.TestCase):
self.assertTrue(os.path.exists(filename + '.jit'))
def test_export_topdown_jit(self):
ckpt_path = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/pose/hrnet/pose_hrnet_epoch_210_export.pt')
easycv_dir = os.path.dirname(easycv.__file__)
if os.path.exists(os.path.join(easycv_dir, 'configs')):
config_dir = os.path.join(easycv_dir, 'configs')
else:
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
config_file = os.path.join(config_dir,
'pose/hrnet_w48_coco_256x192_udp.py')
with tempfile.TemporaryDirectory() as tmpdir:
cfg = mmcv_config_fromfile(config_file)
cfg.export.type = 'jit'
filename = os.path.join(tmpdir, 'model.pth')
export(cfg, ckpt_path, filename, fp16=False)
self.assertTrue(os.path.exists(filename + '.jit'))
def test_export_stgcn_jit(self):
ckpt_path = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/video/stgcn/stgcn_80e_ntu60_xsub.pth')
easycv_dir = os.path.dirname(easycv.__file__)
if os.path.exists(os.path.join(easycv_dir, 'configs')):
config_dir = os.path.join(easycv_dir, 'configs')
else:
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
config_file = os.path.join(
config_dir,
'video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py')
with tempfile.TemporaryDirectory() as tmpdir:
cfg = mmcv_config_fromfile(config_file)
cfg.export.type = 'jit'
filename = os.path.join(tmpdir, 'model.pth')
export(cfg, ckpt_path, filename, fp16=False)
self.assertTrue(os.path.exists(filename + '.jit'))
if __name__ == '__main__':
unittest.main()

View File

@ -7,7 +7,7 @@ import cv2
import numpy as np
from numpy.testing import assert_array_almost_equal
from PIL import Image
from tests.ut_config import (POSE_DATA_SMALL_COCO_LOCAL,
from tests.ut_config import (BASE_LOCAL_PATH, POSE_DATA_SMALL_COCO_LOCAL,
PRETRAINED_MODEL_POSE_HRNET_EXPORT,
PRETRAINED_MODEL_YOLOXS_EXPORT, TEST_IMAGES_DIR)
@ -20,23 +20,12 @@ class PoseTopDownPredictorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_pose_topdown(self):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
pose_model_path = PRETRAINED_MODEL_POSE_HRNET_EXPORT
def _base_test(self, predictor):
img1 = os.path.join(POSE_DATA_SMALL_COCO_LOCAL,
'images/000000067078.jpg')
img2 = os.path.join(TEST_IMAGES_DIR, 'crowdpose_100024.jpg')
input_data_list = [img1, img2]
predictor = PoseTopDownPredictor(
model_path=pose_model_path,
detection_predictor_config=dict(
type='YoloXPredictor',
model_path=detection_model_path,
score_thresh=0.5),
cat_id=0,
batch_size=1)
results = predictor(input_data_list)
self.assertEqual(len(results), 2)
@ -117,6 +106,44 @@ class PoseTopDownPredictorTest(unittest.TestCase):
cv2.imwrite(tmp_save_path, vis_result)
assert os.path.exists(tmp_save_path)
def test_pose_topdown(self):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
pose_model_path = PRETRAINED_MODEL_POSE_HRNET_EXPORT
predictor = PoseTopDownPredictor(
model_path=pose_model_path,
detection_predictor_config=dict(
type='YoloXPredictor',
model_path=detection_model_path,
score_thresh=0.5),
cat_id=0,
batch_size=1)
self._base_test(predictor)
def test_pose_topdown_jit(self):
detection_model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
pose_model_path = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/pose/hrnet/pose_hrnet_epoch_210_export.pth.jit')
config_file = 'configs/pose/hrnet_w48_coco_256x192_udp.py'
predictor = PoseTopDownPredictor(
model_path=pose_model_path,
config_file=config_file,
detection_predictor_config=dict(
type='YoloXPredictor',
model_path=detection_model_path,
score_thresh=0.5),
cat_id=0,
batch_size=1)
img = os.path.join(TEST_IMAGES_DIR, 'im00025.png')
input_data_list = [img, img]
results = predictor(input_data_list)
self.assertEqual(len(results), 2)
class TorchPoseTopDownPredictorWithDetectorTest(unittest.TestCase):

View File

@ -84,6 +84,35 @@ class STGCNPredictorTest(unittest.TestCase):
self.assertIn('class_name', results)
self.assertEqual(len(results['class_probs']), 60)
def test_jit(self):
checkpoint = os.path.join(
BASE_LOCAL_PATH,
'pretrained_models/video/stgcn/stgcn_80e_ntu60_xsub.pth.jit')
config_file = 'configs/video_recognition/stgcn/stgcn_80e_ntu60_xsub_keypoint.py'
predict_op = STGCNPredictor(
model_path=checkpoint, config_file=config_file)
h, w = 480, 853
total_frames = 20
num_person = 2
inp = dict(
frame_dir='',
label=-1,
img_shape=(h, w),
original_shape=(h, w),
start_index=0,
modality='Pose',
total_frames=total_frames,
keypoint=np.random.random((num_person, total_frames, 17, 2)),
keypoint_score=np.random.random((num_person, total_frames, 17)),
)
results = predict_op([inp])[0]
self.assertIn('class', results)
self.assertIn('class_name', results)
self.assertEqual(len(results['class_probs']), 60)
if __name__ == '__main__':
unittest.main()