diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index ec8986d3..59708285 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -196,7 +196,7 @@ def resample_patch_embed( return np.stack(mat).T resize_mat = get_resize_mat(old_size, new_size) - resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T)) + resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device) def resample_kernel(kernel): resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)