Add repr to auto_augment and random_erasing impl
parent
135a48d024
commit
ed41d32637
|
@ -316,6 +316,7 @@ class AugmentOp:
|
|||
|
||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
self.name = name
|
||||
self.aug_fn = NAME_TO_OP[name]
|
||||
self.level_fn = LEVEL_TO_ARG[name]
|
||||
self.prob = prob
|
||||
|
@ -351,6 +352,14 @@ class AugmentOp:
|
|||
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
||||
return self.aug_fn(img, *level_args, **self.kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
|
||||
fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
|
||||
if self.magnitude_max is not None:
|
||||
fs += f', mmax={self.magnitude_max}'
|
||||
fs += ')'
|
||||
return fs
|
||||
|
||||
|
||||
def auto_augment_policy_v0(hparams):
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
|
||||
|
@ -510,6 +519,15 @@ class AutoAugment:
|
|||
img = op(img)
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
fs = self.__class__.__name__ + f'(policy='
|
||||
for p in self.policy:
|
||||
fs += '\n\t['
|
||||
fs += ', '.join([str(op) for op in p])
|
||||
fs += ']'
|
||||
fs += ')'
|
||||
return fs
|
||||
|
||||
|
||||
def auto_augment_transform(config_str, hparams):
|
||||
"""
|
||||
|
@ -634,6 +652,13 @@ class RandAugment:
|
|||
img = op(img)
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
|
||||
for op in self.ops:
|
||||
fs += f'\n\t{op}'
|
||||
fs += ')'
|
||||
return fs
|
||||
|
||||
|
||||
def rand_augment_transform(config_str, hparams):
|
||||
"""
|
||||
|
@ -782,6 +807,13 @@ class AugMixAugment:
|
|||
mixed = self._apply_basic(img, mixing_weights, m)
|
||||
return mixed
|
||||
|
||||
def __repr__(self):
|
||||
fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
|
||||
for op in self.ops:
|
||||
fs += f'\n\t{op}'
|
||||
fs += ')'
|
||||
return fs
|
||||
|
||||
|
||||
def augment_and_mix_transform(config_str, hparams):
|
||||
""" Create AugMix PyTorch transform
|
||||
|
|
|
@ -54,15 +54,15 @@ class RandomErasing:
|
|||
self.min_count = min_count
|
||||
self.max_count = max_count or min_count
|
||||
self.num_splits = num_splits
|
||||
mode = mode.lower()
|
||||
self.mode = mode.lower()
|
||||
self.rand_color = False
|
||||
self.per_pixel = False
|
||||
if mode == 'rand':
|
||||
if self.mode == 'rand':
|
||||
self.rand_color = True # per block random normal
|
||||
elif mode == 'pixel':
|
||||
elif self.mode == 'pixel':
|
||||
self.per_pixel = True # per pixel random normal
|
||||
else:
|
||||
assert not mode or mode == 'const'
|
||||
assert not self.mode or self.mode == 'const'
|
||||
self.device = device
|
||||
|
||||
def _erase(self, img, chan, img_h, img_w, dtype):
|
||||
|
@ -95,3 +95,9 @@ class RandomErasing:
|
|||
for i in range(batch_start, batch_size):
|
||||
self._erase(input[i], chan, img_h, img_w, input.dtype)
|
||||
return input
|
||||
|
||||
def __repr__(self):
|
||||
# NOTE simplified state for repr
|
||||
fs = self.__class__.__name__ + f'(p={self.probability}, mode={self.mode}'
|
||||
fs += f', count=({self.min_count}, {self.max_count}))'
|
||||
return fs
|
||||
|
|
Loading…
Reference in New Issue