add profile for pred (#476)

pull/478/head
littletomatodonkey 2020-12-15 14:32:07 +08:00 committed by GitHub
parent d860430b55
commit 29b305d228
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 1 deletions

View File

@ -113,7 +113,8 @@ class VGGNet(nn.Layer):
x = self._conv_block_4(x)
x = self._conv_block_5(x)
x = paddle.reshape(x, [-1, x.shape[1]*x.shape[2]*x.shape[3]])
_, c, h, w = list(x.shape)
x = paddle.reshape(x, [-1, c * h * w])
x = self._fc1(x)
x = F.relu(x)
x = self._drop(x)

View File

@ -41,6 +41,7 @@ def parse_args():
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--enable_profile", type=str2bool, default=False)
parser.add_argument("--enable_benchmark", type=str2bool, default=False)
parser.add_argument("--top_k", type=int, default=1)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
@ -81,6 +82,8 @@ def create_paddle_predictor(args):
config.enable_mkldnn()
config.set_cpu_math_library_num_threads(args.cpu_num_threads)
if args.enable_profile:
config.enable_profile()
config.disable_glog_info()
config.switch_ir_optim(args.ir_optim) # default true
if args.use_tensorrt: