mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Lazy loader for TF, more LAB fiddling
This commit is contained in:
parent
3fbbd511e6
commit
d285526dc9
@ -15,25 +15,36 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
|
||||
import importlib
|
||||
|
||||
class LazyTfLoader:
|
||||
def __init__(self):
|
||||
self._tf = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._tf is None:
|
||||
self._tf = importlib.import_module('tensorflow')
|
||||
self._tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
|
||||
return getattr(self._tf, name)
|
||||
|
||||
class LazyTfdsLoader:
|
||||
def __init__(self):
|
||||
self._tfds = None
|
||||
self.has_buggy_even_splits = False
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._tfds is None:
|
||||
self._tfds = importlib.import_module('tensorflow_datasets')
|
||||
try:
|
||||
import tensorflow as tf
|
||||
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
|
||||
import tensorflow_datasets as tfds
|
||||
try:
|
||||
tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
|
||||
has_buggy_even_splits = False
|
||||
self._tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
|
||||
except TypeError:
|
||||
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
|
||||
"Please update or use tfds-nightly for better fine-grained split behaviour.")
|
||||
has_buggy_even_splits = True
|
||||
# NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults)
|
||||
# import resource
|
||||
# low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
# resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||
raise e
|
||||
self.has_buggy_even_splits = True
|
||||
return getattr(self._tfds, name)
|
||||
|
||||
tf = LazyTfLoader()
|
||||
tfds = LazyTfdsLoader()
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .reader import Reader
|
||||
@ -45,7 +56,6 @@ SHUFFLE_SIZE = int(os.environ.get('TFDS_SHUFFLE_SIZE', 8192)) # samples to shuf
|
||||
PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch
|
||||
|
||||
|
||||
@tfds.decode.make_decoder()
|
||||
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3):
|
||||
return tf.image.decode_jpeg(
|
||||
serialized_image,
|
||||
@ -231,7 +241,7 @@ class ReaderTfds(Reader):
|
||||
if should_subsplit:
|
||||
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal
|
||||
# read patterns for distributed training (overlap across shards) so better to use InputContext there
|
||||
if has_buggy_even_splits:
|
||||
if tfds.has_buggy_even_splits:
|
||||
# my even_split workaround doesn't work on subsplits, upgrade tfds!
|
||||
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
|
||||
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
|
||||
@ -253,10 +263,11 @@ class ReaderTfds(Reader):
|
||||
shuffle_reshuffle_each_iteration=True,
|
||||
input_context=input_context,
|
||||
)
|
||||
decode_fn = tfds.decode.make_decoder()(decode_example)
|
||||
ds = self.builder.as_dataset(
|
||||
split=self.subsplit or self.split,
|
||||
shuffle_files=self.is_training,
|
||||
decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)),
|
||||
decoders=dict(image=decode_fn(channels=1 if self.input_img_mode == 'L' else 3)),
|
||||
read_config=read_config,
|
||||
)
|
||||
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
||||
|
@ -127,14 +127,16 @@ def rgb_to_lab_tensor(
|
||||
rgb_img: torch.Tensor,
|
||||
normalized: bool = True,
|
||||
srgb_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
split_channels: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Convert RGB image to LAB color space using tensor operations.
|
||||
|
||||
Args:
|
||||
rgb_img: Tensor of shape (..., 3) with values in range [0, 255]
|
||||
normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges
|
||||
|
||||
srgb_input: Input is gamma corrected sRGB, otherwise linear RGB is assumed (rare unless part of a pipeline)
|
||||
split_channels: If True, outputs a tuple of flattened colour channels instead of stacked image
|
||||
Returns:
|
||||
lab_img: Tensor of same shape with either:
|
||||
- normalized=False: L in [0, 100] and a,b in [-128, 127]
|
||||
@ -152,13 +154,14 @@ def rgb_to_lab_tensor(
|
||||
rgb_img = srgb_to_linear(rgb_img)
|
||||
|
||||
# FIXME transforms before this are causing -ve values, can have a large impact on this conversion
|
||||
rgb_img.clamp_(0, 1.0)
|
||||
rgb_img = rgb_img.clamp(0, 1.0)
|
||||
|
||||
# Convert to XYZ using matrix multiplication
|
||||
rgb_to_xyz = torch.tensor([
|
||||
[0.412453, 0.357580, 0.180423],
|
||||
[0.212671, 0.715160, 0.072169],
|
||||
[0.019334, 0.119193, 0.950227]
|
||||
# X Y Z
|
||||
[0.412453, 0.212671, 0.019334], # R
|
||||
[0.357580, 0.715160, 0.119193], # G
|
||||
[0.180423, 0.072169, 0.950227], # B
|
||||
], device=rgb_img.device)
|
||||
|
||||
# Reshape input for matrix multiplication if needed
|
||||
@ -167,38 +170,30 @@ def rgb_to_lab_tensor(
|
||||
rgb_img = rgb_img.reshape(-1, 3)
|
||||
|
||||
# Perform matrix multiplication
|
||||
xyz = torch.matmul(rgb_img, rgb_to_xyz.T)
|
||||
xyz = rgb_img @ rgb_to_xyz
|
||||
|
||||
# Adjust XYZ values
|
||||
xyz[..., 0].div_(xn)
|
||||
xyz[..., 1].div_(yn)
|
||||
xyz[..., 2].div_(zn)
|
||||
xyz.div_(torch.tensor([xn, yn, zn], device=xyz.device))
|
||||
|
||||
# Step 4: XYZ to LAB
|
||||
lab = torch.where(
|
||||
fxfyfz = torch.where(
|
||||
xyz > epsilon,
|
||||
torch.pow(xyz, 1 / 3),
|
||||
(kappa * xyz + 16) / 116
|
||||
)
|
||||
|
||||
L = 116 * fxfyfz[..., 1] - 16
|
||||
a = 500 * (fxfyfz[..., 0] - fxfyfz[..., 1])
|
||||
b = 200 * (fxfyfz[..., 1] - fxfyfz[..., 2])
|
||||
if normalized:
|
||||
# Calculate normalized [0,1] L,a,b values directly
|
||||
# L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16
|
||||
# a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502
|
||||
# b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502
|
||||
shift_128 = 128 / 255
|
||||
a_scale = 500 / 255
|
||||
b_scale = 200 / 255
|
||||
L = 1.16 * lab[..., 1] - 0.16
|
||||
a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128
|
||||
b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128
|
||||
else:
|
||||
# Calculate native range L,a,b values
|
||||
L = 116 * lab[..., 1] - 16
|
||||
a = 500 * (lab[..., 0] - lab[..., 1])
|
||||
b = 200 * (lab[..., 1] - lab[..., 2])
|
||||
# output in rage [0, 1] for each channel
|
||||
L.div_(100)
|
||||
a.add_(128).div_(255)
|
||||
b.add_(128).div_(255)
|
||||
|
||||
if split_channels:
|
||||
return L, a, b
|
||||
|
||||
# Stack the results
|
||||
lab = torch.stack([L, a, b], dim=-1)
|
||||
|
||||
# Restore original shape if needed
|
||||
|
@ -86,7 +86,7 @@ def transforms_imagenet_train(
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
separate: bool = False,
|
||||
use_tensor: Optional[bool] = True, # FIXME forced True for testing
|
||||
use_tensor: Optional[bool] = False,
|
||||
):
|
||||
""" ImageNet-oriented image transforms for training.
|
||||
|
||||
@ -273,7 +273,7 @@ def transforms_imagenet_eval(
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
use_prefetcher: bool = False,
|
||||
normalize: bool = True,
|
||||
use_tensor: bool = True,
|
||||
use_tensor: bool = False,
|
||||
):
|
||||
""" ImageNet-oriented image transform for evaluation and inference.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user