mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Testing some LAB stuff
This commit is contained in:
parent
3b181b78d1
commit
3fbbd511e6
@ -90,7 +90,7 @@ class MaybePILToTensor:
|
|||||||
return f"{self.__class__.__name__}()"
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
class ToLab(transforms.ToTensor):
|
class ToLabPIL:
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -115,6 +115,121 @@ class ToLab(transforms.ToTensor):
|
|||||||
return f"{self.__class__.__name__}()"
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
|
def srgb_to_linear(srgb_image: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.where(
|
||||||
|
srgb_image <= 0.04045,
|
||||||
|
srgb_image / 12.92,
|
||||||
|
((srgb_image + 0.055) / 1.055) ** 2.4
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rgb_to_lab_tensor(
|
||||||
|
rgb_img: torch.Tensor,
|
||||||
|
normalized: bool = True,
|
||||||
|
srgb_input: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert RGB image to LAB color space using tensor operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rgb_img: Tensor of shape (..., 3) with values in range [0, 255]
|
||||||
|
normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
lab_img: Tensor of same shape with either:
|
||||||
|
- normalized=False: L in [0, 100] and a,b in [-128, 127]
|
||||||
|
- normalized=True: L,a,b in [0, 1]
|
||||||
|
"""
|
||||||
|
# Constants
|
||||||
|
epsilon = 216 / 24389
|
||||||
|
kappa = 24389 / 27
|
||||||
|
xn = 0.95047
|
||||||
|
yn = 1.0
|
||||||
|
zn = 1.08883
|
||||||
|
|
||||||
|
# Convert sRGB to linear RGB
|
||||||
|
if srgb_input:
|
||||||
|
rgb_img = srgb_to_linear(rgb_img)
|
||||||
|
|
||||||
|
# FIXME transforms before this are causing -ve values, can have a large impact on this conversion
|
||||||
|
rgb_img.clamp_(0, 1.0)
|
||||||
|
|
||||||
|
# Convert to XYZ using matrix multiplication
|
||||||
|
rgb_to_xyz = torch.tensor([
|
||||||
|
[0.412453, 0.357580, 0.180423],
|
||||||
|
[0.212671, 0.715160, 0.072169],
|
||||||
|
[0.019334, 0.119193, 0.950227]
|
||||||
|
], device=rgb_img.device)
|
||||||
|
|
||||||
|
# Reshape input for matrix multiplication if needed
|
||||||
|
original_shape = rgb_img.shape
|
||||||
|
if len(original_shape) > 2:
|
||||||
|
rgb_img = rgb_img.reshape(-1, 3)
|
||||||
|
|
||||||
|
# Perform matrix multiplication
|
||||||
|
xyz = torch.matmul(rgb_img, rgb_to_xyz.T)
|
||||||
|
|
||||||
|
# Adjust XYZ values
|
||||||
|
xyz[..., 0].div_(xn)
|
||||||
|
xyz[..., 1].div_(yn)
|
||||||
|
xyz[..., 2].div_(zn)
|
||||||
|
|
||||||
|
# Step 4: XYZ to LAB
|
||||||
|
lab = torch.where(
|
||||||
|
xyz > epsilon,
|
||||||
|
torch.pow(xyz, 1 / 3),
|
||||||
|
(kappa * xyz + 16) / 116
|
||||||
|
)
|
||||||
|
|
||||||
|
if normalized:
|
||||||
|
# Calculate normalized [0,1] L,a,b values directly
|
||||||
|
# L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16
|
||||||
|
# a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502
|
||||||
|
# b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502
|
||||||
|
shift_128 = 128 / 255
|
||||||
|
a_scale = 500 / 255
|
||||||
|
b_scale = 200 / 255
|
||||||
|
L = 1.16 * lab[..., 1] - 0.16
|
||||||
|
a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128
|
||||||
|
b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128
|
||||||
|
else:
|
||||||
|
# Calculate native range L,a,b values
|
||||||
|
L = 116 * lab[..., 1] - 16
|
||||||
|
a = 500 * (lab[..., 0] - lab[..., 1])
|
||||||
|
b = 200 * (lab[..., 1] - lab[..., 2])
|
||||||
|
|
||||||
|
# Stack the results
|
||||||
|
lab = torch.stack([L, a, b], dim=-1)
|
||||||
|
|
||||||
|
# Restore original shape if needed
|
||||||
|
if len(original_shape) > 2:
|
||||||
|
lab = lab.reshape(original_shape)
|
||||||
|
|
||||||
|
return lab
|
||||||
|
|
||||||
|
|
||||||
|
class ToLabTensor:
|
||||||
|
def __init__(self, srgb_input=False, normalized=True) -> None:
|
||||||
|
self.srgb_input = srgb_input
|
||||||
|
self.normalized = normalized
|
||||||
|
|
||||||
|
def __call__(self, pic) -> torch.Tensor:
|
||||||
|
return rgb_to_lab_tensor(
|
||||||
|
pic,
|
||||||
|
normalized=self.normalized,
|
||||||
|
srgb_input=self.srgb_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToLinearRgb:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, pic) -> torch.Tensor:
|
||||||
|
assert isinstance(pic, torch.Tensor)
|
||||||
|
return srgb_to_linear(pic)
|
||||||
|
|
||||||
|
|
||||||
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
|
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
|
||||||
# favor of the Image.Resampling enum. The top-level resampling attributes will be
|
# favor of the Image.Resampling enum. The top-level resampling attributes will be
|
||||||
# removed in Pillow 10.
|
# removed in Pillow 10.
|
||||||
|
@ -14,6 +14,7 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEF
|
|||||||
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||||
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
|
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
|
||||||
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
|
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
|
||||||
|
from timm.data.transforms import ToLabTensor, ToLinearRgb
|
||||||
from timm.data.random_erasing import RandomErasing
|
from timm.data.random_erasing import RandomErasing
|
||||||
|
|
||||||
|
|
||||||
@ -123,7 +124,10 @@ def transforms_imagenet_train(
|
|||||||
* normalizes and converts the branches above with the third, final transform
|
* normalizes and converts the branches above with the third, final transform
|
||||||
"""
|
"""
|
||||||
if use_tensor:
|
if use_tensor:
|
||||||
primary_tfl = [MaybeToTensor()]
|
primary_tfl = [
|
||||||
|
MaybeToTensor(),
|
||||||
|
ToLinearRgb(), # FIXME
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
primary_tfl = []
|
primary_tfl = []
|
||||||
|
|
||||||
@ -236,6 +240,7 @@ def transforms_imagenet_train(
|
|||||||
if not use_tensor:
|
if not use_tensor:
|
||||||
final_tfl += [MaybeToTensor()]
|
final_tfl += [MaybeToTensor()]
|
||||||
final_tfl += [
|
final_tfl += [
|
||||||
|
ToLabTensor(), # FIXME
|
||||||
transforms.Normalize(
|
transforms.Normalize(
|
||||||
mean=torch.tensor(mean),
|
mean=torch.tensor(mean),
|
||||||
std=torch.tensor(std),
|
std=torch.tensor(std),
|
||||||
@ -268,6 +273,7 @@ def transforms_imagenet_eval(
|
|||||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||||
use_prefetcher: bool = False,
|
use_prefetcher: bool = False,
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
|
use_tensor: bool = True,
|
||||||
):
|
):
|
||||||
""" ImageNet-oriented image transform for evaluation and inference.
|
""" ImageNet-oriented image transform for evaluation and inference.
|
||||||
|
|
||||||
@ -294,7 +300,13 @@ def transforms_imagenet_eval(
|
|||||||
scale_size = math.floor(img_size / crop_pct)
|
scale_size = math.floor(img_size / crop_pct)
|
||||||
scale_size = (scale_size, scale_size)
|
scale_size = (scale_size, scale_size)
|
||||||
|
|
||||||
tfl = []
|
if use_tensor:
|
||||||
|
tfl = [
|
||||||
|
MaybeToTensor(),
|
||||||
|
ToLinearRgb(), # FIXME
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
tfl = []
|
||||||
|
|
||||||
if crop_border_pixels:
|
if crop_border_pixels:
|
||||||
tfl += [TrimBorder(crop_border_pixels)]
|
tfl += [TrimBorder(crop_border_pixels)]
|
||||||
@ -332,10 +344,13 @@ def transforms_imagenet_eval(
|
|||||||
tfl += [ToNumpy()]
|
tfl += [ToNumpy()]
|
||||||
elif not normalize:
|
elif not normalize:
|
||||||
# when normalize disabled, converted to tensor without scaling, keeps original dtype
|
# when normalize disabled, converted to tensor without scaling, keeps original dtype
|
||||||
tfl += [MaybePILToTensor()]
|
if not use_tensor:
|
||||||
|
tfl += [MaybePILToTensor()]
|
||||||
else:
|
else:
|
||||||
|
if not use_tensor:
|
||||||
|
tfl += [MaybeToTensor()]
|
||||||
tfl += [
|
tfl += [
|
||||||
MaybeToTensor(),
|
ToLabTensor(), # FIXME
|
||||||
transforms.Normalize(
|
transforms.Normalize(
|
||||||
mean=torch.tensor(mean),
|
mean=torch.tensor(mean),
|
||||||
std=torch.tensor(std),
|
std=torch.tensor(std),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user