diff --git a/tools/deploy/caffe_export.py b/tools/deploy/caffe_export.py index db89612..e6651b7 100644 --- a/tools/deploy/caffe_export.py +++ b/tools/deploy/caffe_export.py @@ -69,7 +69,7 @@ if __name__ == '__main__': model.eval() logger.info(model) - inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]) + inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(torch.device(cfg.MODEL.DEVICE)) PathManager.mkdirs(args.output) pytorch_to_caffe.trans_net(model, inputs, args.name) pytorch_to_caffe.save_prototxt(f"{args.output}/{args.name}.prototxt")