Small tweak of timm ToTensor for clarity

This commit is contained in:
Ross Wightman 2024-02-10 14:57:40 -08:00
parent 5a58f4d3dc
commit 7d121ac2ef

View File

@ -32,16 +32,12 @@ class ToNumpy:
class ToTensor:
""" ToTensor with no rescaling of values"""
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in