mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add device arg for patch embed resize, fix #2024
This commit is contained in:
parent
cd8d9d9ff3
commit
df7ae11eb2
@ -196,7 +196,7 @@ def resample_patch_embed(
|
||||
return np.stack(mat).T
|
||||
|
||||
resize_mat = get_resize_mat(old_size, new_size)
|
||||
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
|
||||
resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device)
|
||||
|
||||
def resample_kernel(kernel):
|
||||
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user