Better vmap compat across recent torch versions
parent
130458988a
commit
7c846d9970
|
@ -84,6 +84,14 @@ def resample_patch_embed(
|
|||
Resized patch embedding kernel.
|
||||
"""
|
||||
import numpy as np
|
||||
try:
|
||||
import functorch
|
||||
vmap = functorch.vmap
|
||||
except ImportError:
|
||||
if hasattr(torch, 'vmap'):
|
||||
vmap = torch.vmap
|
||||
else:
|
||||
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing."
|
||||
|
||||
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
||||
assert len(new_size) == 2, "New shape should only be hw"
|
||||
|
@ -115,7 +123,7 @@ def resample_patch_embed(
|
|||
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
||||
return resampled_kernel.reshape(new_size)
|
||||
|
||||
v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1)
|
||||
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
|
||||
return v_resample_kernel(patch_embed)
|
||||
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '0.8.2dev0'
|
||||
__version__ = '0.8.3dev0'
|
||||
|
|
Loading…
Reference in New Issue