New `LetterBox(size)` `CenterCrop(size)`, `ToTensor()` transforms (#9213)
* New LetterBox transform YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([T.ToTensor(), LetterBox(size)]) Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update augmentations.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update augmentations.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update augmentations.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup * cleanup * cleanup * cleanup * cleanup Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/9226/head
parent
5fb267f3e5
commit
6e7a7ae7ed
|
@ -8,6 +8,7 @@ import random
|
|||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
|
@ -345,4 +346,51 @@ def classify_albumentations(augment=True,
|
|||
def classify_transforms(size=224):
|
||||
# Transforms to apply if albumentations not installed
|
||||
assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)'
|
||||
return T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
|
||||
|
||||
class LetterBox:
|
||||
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||
def __init__(self, size=(640, 640), auto=False, stride=32):
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
self.auto = auto # pass max size integer, automatically solve for short side using stride
|
||||
self.stride = stride # used with auto
|
||||
|
||||
def __call__(self, im): # im = np.array HWC
|
||||
imh, imw = im.shape[:2]
|
||||
r = min(self.h / imh, self.w / imw) # ratio of new/old
|
||||
h, w = round(imh * r), round(imw * r) # resized image
|
||||
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
|
||||
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
|
||||
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
|
||||
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
return im_out
|
||||
|
||||
|
||||
class CenterCrop:
|
||||
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
|
||||
def __init__(self, size=640):
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
|
||||
def __call__(self, im): # im = np.array HWC
|
||||
imh, imw = im.shape[:2]
|
||||
m = min(imh, imw) # min dimension
|
||||
top, left = (imh - m) // 2, (imw - m) // 2
|
||||
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
|
||||
class ToTensor:
|
||||
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||
def __init__(self, half=False):
|
||||
super().__init__()
|
||||
self.half = half
|
||||
|
||||
def __call__(self, im): # im = np.array HWC in BGR order
|
||||
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
|
||||
im = torch.from_numpy(im) # to torch
|
||||
im = im.half() if self.half else im.float() # uint8 to fp16/32
|
||||
im /= 255.0 # 0-255 to 0.0-1.0
|
||||
return im
|
||||
|
|
|
@ -251,7 +251,7 @@ class LoadImages:
|
|||
s = f'image {self.count}/{self.nf} {path}: '
|
||||
|
||||
if self.transforms:
|
||||
im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # transforms
|
||||
im = self.transforms(im0) # transforms
|
||||
else:
|
||||
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
|
||||
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||
|
@ -386,7 +386,7 @@ class LoadStreams:
|
|||
|
||||
im0 = self.imgs.copy()
|
||||
if self.transforms:
|
||||
im = np.stack([self.transforms(cv2.cvtColor(x, cv2.COLOR_BGR2RGB)) for x in im0]) # transforms
|
||||
im = np.stack([self.transforms(x) for x in im0]) # transforms
|
||||
else:
|
||||
im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
|
||||
|
@ -1113,18 +1113,18 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
|
||||
def __getitem__(self, i):
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
if self.cache_ram and im is None:
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
elif self.cache_disk:
|
||||
if not fn.exists(): # load npy
|
||||
np.save(fn.as_posix(), cv2.imread(f))
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if self.album_transforms:
|
||||
if self.cache_ram and im is None:
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
elif self.cache_disk:
|
||||
if not fn.exists(): # load npy
|
||||
np.save(fn.as_posix(), cv2.imread(f))
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
|
||||
else:
|
||||
sample = self.torch_transforms(self.loader(f))
|
||||
sample = self.torch_transforms(im)
|
||||
return sample, j
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue