diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 822983fe..02a069bd 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -32,16 +32,12 @@ class ToNumpy: class ToTensor: - + """ ToTensor with no rescaling of values""" def __init__(self, dtype=torch.float32): self.dtype = dtype def __call__(self, pil_img): - np_img = np.array(pil_img, dtype=np.uint8) - if np_img.ndim < 3: - np_img = np.expand_dims(np_img, axis=-1) - np_img = np.rollaxis(np_img, 2) # HWC to CHW - return torch.from_numpy(np_img).to(dtype=self.dtype) + return F.pil_to_tensor(pil_img).to(dtype=self.dtype) # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in