Lazy loader for TF, more LAB fiddling

This commit is contained in:
Ross Wightman 2024-12-23 13:24:11 -08:00
parent 3fbbd511e6
commit d285526dc9
3 changed files with 57 additions and 51 deletions

View File

@ -15,25 +15,36 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from PIL import Image from PIL import Image
try: import importlib
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) class LazyTfLoader:
import tensorflow_datasets as tfds def __init__(self):
try: self._tf = None
tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
has_buggy_even_splits = False def __getattr__(self, name):
except TypeError: if self._tf is None:
print("Warning: This version of tfds doesn't have the latest even_splits impl. " self._tf = importlib.import_module('tensorflow')
"Please update or use tfds-nightly for better fine-grained split behaviour.") self._tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
has_buggy_even_splits = True return getattr(self._tf, name)
# NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults)
# import resource class LazyTfdsLoader:
# low, high = resource.getrlimit(resource.RLIMIT_NOFILE) def __init__(self):
# resource.setrlimit(resource.RLIMIT_NOFILE, (high, high)) self._tfds = None
except ImportError as e: self.has_buggy_even_splits = False
print(e)
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") def __getattr__(self, name):
raise e if self._tfds is None:
self._tfds = importlib.import_module('tensorflow_datasets')
try:
self._tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
except TypeError:
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
"Please update or use tfds-nightly for better fine-grained split behaviour.")
self.has_buggy_even_splits = True
return getattr(self._tfds, name)
tf = LazyTfLoader()
tfds = LazyTfdsLoader()
from .class_map import load_class_map from .class_map import load_class_map
from .reader import Reader from .reader import Reader
@ -45,7 +56,6 @@ SHUFFLE_SIZE = int(os.environ.get('TFDS_SHUFFLE_SIZE', 8192)) # samples to shuf
PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch
@tfds.decode.make_decoder()
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3): def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3):
return tf.image.decode_jpeg( return tf.image.decode_jpeg(
serialized_image, serialized_image,
@ -231,7 +241,7 @@ class ReaderTfds(Reader):
if should_subsplit: if should_subsplit:
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal # split the dataset w/o using sharding for more even samples / worker, can result in less optimal
# read patterns for distributed training (overlap across shards) so better to use InputContext there # read patterns for distributed training (overlap across shards) so better to use InputContext there
if has_buggy_even_splits: if tfds.has_buggy_even_splits:
# my even_split workaround doesn't work on subsplits, upgrade tfds! # my even_split workaround doesn't work on subsplits, upgrade tfds!
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples) subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
@ -253,10 +263,11 @@ class ReaderTfds(Reader):
shuffle_reshuffle_each_iteration=True, shuffle_reshuffle_each_iteration=True,
input_context=input_context, input_context=input_context,
) )
decode_fn = tfds.decode.make_decoder()(decode_example)
ds = self.builder.as_dataset( ds = self.builder.as_dataset(
split=self.subsplit or self.split, split=self.subsplit or self.split,
shuffle_files=self.is_training, shuffle_files=self.is_training,
decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)), decoders=dict(image=decode_fn(channels=1 if self.input_img_mode == 'L' else 3)),
read_config=read_config, read_config=read_config,
) )
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers # avoid overloading threading w/ combo of TF ds threads + PyTorch workers

View File

@ -127,14 +127,16 @@ def rgb_to_lab_tensor(
rgb_img: torch.Tensor, rgb_img: torch.Tensor,
normalized: bool = True, normalized: bool = True,
srgb_input: bool = True, srgb_input: bool = True,
) -> torch.Tensor: split_channels: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
""" """
Convert RGB image to LAB color space using tensor operations. Convert RGB image to LAB color space using tensor operations.
Args: Args:
rgb_img: Tensor of shape (..., 3) with values in range [0, 255] rgb_img: Tensor of shape (..., 3) with values in range [0, 255]
normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges
srgb_input: Input is gamma corrected sRGB, otherwise linear RGB is assumed (rare unless part of a pipeline)
split_channels: If True, outputs a tuple of flattened colour channels instead of stacked image
Returns: Returns:
lab_img: Tensor of same shape with either: lab_img: Tensor of same shape with either:
- normalized=False: L in [0, 100] and a,b in [-128, 127] - normalized=False: L in [0, 100] and a,b in [-128, 127]
@ -152,13 +154,14 @@ def rgb_to_lab_tensor(
rgb_img = srgb_to_linear(rgb_img) rgb_img = srgb_to_linear(rgb_img)
# FIXME transforms before this are causing -ve values, can have a large impact on this conversion # FIXME transforms before this are causing -ve values, can have a large impact on this conversion
rgb_img.clamp_(0, 1.0) rgb_img = rgb_img.clamp(0, 1.0)
# Convert to XYZ using matrix multiplication # Convert to XYZ using matrix multiplication
rgb_to_xyz = torch.tensor([ rgb_to_xyz = torch.tensor([
[0.412453, 0.357580, 0.180423], # X Y Z
[0.212671, 0.715160, 0.072169], [0.412453, 0.212671, 0.019334], # R
[0.019334, 0.119193, 0.950227] [0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
], device=rgb_img.device) ], device=rgb_img.device)
# Reshape input for matrix multiplication if needed # Reshape input for matrix multiplication if needed
@ -167,38 +170,30 @@ def rgb_to_lab_tensor(
rgb_img = rgb_img.reshape(-1, 3) rgb_img = rgb_img.reshape(-1, 3)
# Perform matrix multiplication # Perform matrix multiplication
xyz = torch.matmul(rgb_img, rgb_to_xyz.T) xyz = rgb_img @ rgb_to_xyz
# Adjust XYZ values # Adjust XYZ values
xyz[..., 0].div_(xn) xyz.div_(torch.tensor([xn, yn, zn], device=xyz.device))
xyz[..., 1].div_(yn)
xyz[..., 2].div_(zn)
# Step 4: XYZ to LAB # Step 4: XYZ to LAB
lab = torch.where( fxfyfz = torch.where(
xyz > epsilon, xyz > epsilon,
torch.pow(xyz, 1 / 3), torch.pow(xyz, 1 / 3),
(kappa * xyz + 16) / 116 (kappa * xyz + 16) / 116
) )
L = 116 * fxfyfz[..., 1] - 16
a = 500 * (fxfyfz[..., 0] - fxfyfz[..., 1])
b = 200 * (fxfyfz[..., 1] - fxfyfz[..., 2])
if normalized: if normalized:
# Calculate normalized [0,1] L,a,b values directly # output in rage [0, 1] for each channel
# L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16 L.div_(100)
# a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502 a.add_(128).div_(255)
# b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502 b.add_(128).div_(255)
shift_128 = 128 / 255
a_scale = 500 / 255 if split_channels:
b_scale = 200 / 255 return L, a, b
L = 1.16 * lab[..., 1] - 0.16
a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128
b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128
else:
# Calculate native range L,a,b values
L = 116 * lab[..., 1] - 16
a = 500 * (lab[..., 0] - lab[..., 1])
b = 200 * (lab[..., 1] - lab[..., 2])
# Stack the results
lab = torch.stack([L, a, b], dim=-1) lab = torch.stack([L, a, b], dim=-1)
# Restore original shape if needed # Restore original shape if needed

View File

@ -86,7 +86,7 @@ def transforms_imagenet_train(
use_prefetcher: bool = False, use_prefetcher: bool = False,
normalize: bool = True, normalize: bool = True,
separate: bool = False, separate: bool = False,
use_tensor: Optional[bool] = True, # FIXME forced True for testing use_tensor: Optional[bool] = False,
): ):
""" ImageNet-oriented image transforms for training. """ ImageNet-oriented image transforms for training.
@ -273,7 +273,7 @@ def transforms_imagenet_eval(
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False, use_prefetcher: bool = False,
normalize: bool = True, normalize: bool = True,
use_tensor: bool = True, use_tensor: bool = False,
): ):
""" ImageNet-oriented image transform for evaluation and inference. """ ImageNet-oriented image transform for evaluation and inference.