fix caffe export bug

Summary: put tensor to gpu device
pull/240/head
liaoxingyu 2020-08-20 16:28:52 +08:00
parent f4305b0964
commit 07872d4cdb
1 changed files with 1 additions and 1 deletions

View File

@ -69,7 +69,7 @@ if __name__ == '__main__':
model.eval()
logger.info(model)
inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(torch.device(cfg.MODEL.DEVICE))
PathManager.mkdirs(args.output)
pytorch_to_caffe.trans_net(model, inputs, args.name)
pytorch_to_caffe.save_prototxt(f"{args.output}/{args.name}.prototxt")