Update onnx.py

Fix break in computation graph due to typecast to python `int` instead of torch type.
pull/58/head
William Woof 2023-04-06 18:13:24 +01:00 committed by GitHub
parent f2557f7780
commit e9f2d58094
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -81,8 +81,8 @@ class SamOnnxModel(nn.Module):
align_corners=False,
)
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
masks = masks[..., : prepadded_size[0], : prepadded_size[1]]
orig_im_size = orig_im_size.to(torch.int64)
h, w = orig_im_size[0], orig_im_size[1]