Merge pull request #571 from normster/augmix-fix
Enable uniform augmentation magnitude sampling and set AugMix defaultpull/581/head
commit
9a1bd358c7
|
@ -332,14 +332,18 @@ class AugmentOp:
|
|||
# in the usually fixed policy and sample magnitude from a normal distribution
|
||||
# with mean `magnitude` and std-dev of `magnitude_std`.
|
||||
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
||||
# If magnitude_std is inf, we sample magnitude from a uniform distribution
|
||||
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
||||
|
||||
def __call__(self, img):
|
||||
if self.prob < 1.0 and random.random() > self.prob:
|
||||
return img
|
||||
magnitude = self.magnitude
|
||||
if self.magnitude_std and self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
if self.magnitude_std:
|
||||
if self.magnitude_std == float('inf'):
|
||||
magnitude = random.uniform(0, magnitude)
|
||||
elif self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
||||
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
||||
return self.aug_fn(img, *level_args, **self.kwargs)
|
||||
|
@ -790,6 +794,7 @@ def augment_and_mix_transform(config_str, hparams):
|
|||
depth = -1
|
||||
alpha = 1.
|
||||
blended = False
|
||||
hparams['magnitude_std'] = float('inf')
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'augmix'
|
||||
config = config[1:]
|
||||
|
|
Loading…
Reference in New Issue