mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add repr to auto_augment and random_erasing impl
This commit is contained in:
parent
135a48d024
commit
ed41d32637
@ -316,6 +316,7 @@ class AugmentOp:
|
|||||||
|
|
||||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||||
hparams = hparams or _HPARAMS_DEFAULT
|
hparams = hparams or _HPARAMS_DEFAULT
|
||||||
|
self.name = name
|
||||||
self.aug_fn = NAME_TO_OP[name]
|
self.aug_fn = NAME_TO_OP[name]
|
||||||
self.level_fn = LEVEL_TO_ARG[name]
|
self.level_fn = LEVEL_TO_ARG[name]
|
||||||
self.prob = prob
|
self.prob = prob
|
||||||
@ -351,6 +352,14 @@ class AugmentOp:
|
|||||||
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
||||||
return self.aug_fn(img, *level_args, **self.kwargs)
|
return self.aug_fn(img, *level_args, **self.kwargs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
|
||||||
|
fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
|
||||||
|
if self.magnitude_max is not None:
|
||||||
|
fs += f', mmax={self.magnitude_max}'
|
||||||
|
fs += ')'
|
||||||
|
return fs
|
||||||
|
|
||||||
|
|
||||||
def auto_augment_policy_v0(hparams):
|
def auto_augment_policy_v0(hparams):
|
||||||
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
|
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
|
||||||
@ -510,6 +519,15 @@ class AutoAugment:
|
|||||||
img = op(img)
|
img = op(img)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
fs = self.__class__.__name__ + f'(policy='
|
||||||
|
for p in self.policy:
|
||||||
|
fs += '\n\t['
|
||||||
|
fs += ', '.join([str(op) for op in p])
|
||||||
|
fs += ']'
|
||||||
|
fs += ')'
|
||||||
|
return fs
|
||||||
|
|
||||||
|
|
||||||
def auto_augment_transform(config_str, hparams):
|
def auto_augment_transform(config_str, hparams):
|
||||||
"""
|
"""
|
||||||
@ -634,6 +652,13 @@ class RandAugment:
|
|||||||
img = op(img)
|
img = op(img)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
|
||||||
|
for op in self.ops:
|
||||||
|
fs += f'\n\t{op}'
|
||||||
|
fs += ')'
|
||||||
|
return fs
|
||||||
|
|
||||||
|
|
||||||
def rand_augment_transform(config_str, hparams):
|
def rand_augment_transform(config_str, hparams):
|
||||||
"""
|
"""
|
||||||
@ -782,6 +807,13 @@ class AugMixAugment:
|
|||||||
mixed = self._apply_basic(img, mixing_weights, m)
|
mixed = self._apply_basic(img, mixing_weights, m)
|
||||||
return mixed
|
return mixed
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
|
||||||
|
for op in self.ops:
|
||||||
|
fs += f'\n\t{op}'
|
||||||
|
fs += ')'
|
||||||
|
return fs
|
||||||
|
|
||||||
|
|
||||||
def augment_and_mix_transform(config_str, hparams):
|
def augment_and_mix_transform(config_str, hparams):
|
||||||
""" Create AugMix PyTorch transform
|
""" Create AugMix PyTorch transform
|
||||||
|
@ -54,15 +54,15 @@ class RandomErasing:
|
|||||||
self.min_count = min_count
|
self.min_count = min_count
|
||||||
self.max_count = max_count or min_count
|
self.max_count = max_count or min_count
|
||||||
self.num_splits = num_splits
|
self.num_splits = num_splits
|
||||||
mode = mode.lower()
|
self.mode = mode.lower()
|
||||||
self.rand_color = False
|
self.rand_color = False
|
||||||
self.per_pixel = False
|
self.per_pixel = False
|
||||||
if mode == 'rand':
|
if self.mode == 'rand':
|
||||||
self.rand_color = True # per block random normal
|
self.rand_color = True # per block random normal
|
||||||
elif mode == 'pixel':
|
elif self.mode == 'pixel':
|
||||||
self.per_pixel = True # per pixel random normal
|
self.per_pixel = True # per pixel random normal
|
||||||
else:
|
else:
|
||||||
assert not mode or mode == 'const'
|
assert not self.mode or self.mode == 'const'
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def _erase(self, img, chan, img_h, img_w, dtype):
|
def _erase(self, img, chan, img_h, img_w, dtype):
|
||||||
@ -95,3 +95,9 @@ class RandomErasing:
|
|||||||
for i in range(batch_start, batch_size):
|
for i in range(batch_start, batch_size):
|
||||||
self._erase(input[i], chan, img_h, img_w, input.dtype)
|
self._erase(input[i], chan, img_h, img_w, input.dtype)
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
# NOTE simplified state for repr
|
||||||
|
fs = self.__class__.__name__ + f'(p={self.probability}, mode={self.mode}'
|
||||||
|
fs += f', count=({self.min_count}, {self.max_count}))'
|
||||||
|
return fs
|
||||||
|
Loading…
x
Reference in New Issue
Block a user