Add device arg for patch embed resize, fix #2024

This commit is contained in:
Ross Wightman 2023-12-04 11:42:13 -08:00
parent cd8d9d9ff3
commit df7ae11eb2

View File

@ -196,7 +196,7 @@ def resample_patch_embed(
return np.stack(mat).T
resize_mat = get_resize_mat(old_size, new_size)
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device)
def resample_kernel(kernel):
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)