mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add device arg for patch embed resize, fix #2024
This commit is contained in:
parent
cd8d9d9ff3
commit
df7ae11eb2
@ -196,7 +196,7 @@ def resample_patch_embed(
|
|||||||
return np.stack(mat).T
|
return np.stack(mat).T
|
||||||
|
|
||||||
resize_mat = get_resize_mat(old_size, new_size)
|
resize_mat = get_resize_mat(old_size, new_size)
|
||||||
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
|
resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device)
|
||||||
|
|
||||||
def resample_kernel(kernel):
|
def resample_kernel(kernel):
|
||||||
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user