diff --git a/tools/keypoint.ipynb b/tools/keypoint.ipynb index 3812733..4f5e936 100644 --- a/tools/keypoint.ipynb +++ b/tools/keypoint.ipynb @@ -25,10 +25,12 @@ "outputs": [], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", - "weigths = torch.load('yolov7-w6-pose.pt')\n", + "weigths = torch.load('yolov7-w6-pose.pt', map_location=device)\n", "model = weigths['model']\n", - "model = model.half().to(device)\n", - "_ = model.eval()" + "_ = model.float().eval()\n", + "\n", + "if torch.cuda.is_available():\n", + " model.half().to(device)" ] }, { @@ -43,9 +45,9 @@ "image_ = image.copy()\n", "image = transforms.ToTensor()(image)\n", "image = torch.tensor(np.array([image.numpy()]))\n", - "image = image.to(device)\n", - "image = image.half()\n", "\n", + "if torch.cuda.is_available():\n", + " image = image.half().to(device) \n", "output, _ = model(image)" ] }, @@ -118,7 +120,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.9.12" } }, "nbformat": 4,