221 lines
7.2 KiB
Python
221 lines
7.2 KiB
Python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# This code is based on https://github.com/ecs-vlc/FMix
|
|
# reference: https://arxiv.org/abs/2002.12047
|
|
|
|
import math
|
|
import random
|
|
|
|
import numpy as np
|
|
from scipy.stats import beta
|
|
|
|
|
|
def fftfreqnd(h, w=None, z=None):
|
|
""" Get bin values for discrete fourier transform of size (h, w, z)
|
|
|
|
:param h: Required, first dimension size
|
|
:param w: Optional, second dimension size
|
|
:param z: Optional, third dimension size
|
|
"""
|
|
fz = fx = 0
|
|
fy = np.fft.fftfreq(h)
|
|
|
|
if w is not None:
|
|
fy = np.expand_dims(fy, -1)
|
|
|
|
if w % 2 == 1:
|
|
fx = np.fft.fftfreq(w)[:w // 2 + 2]
|
|
else:
|
|
fx = np.fft.fftfreq(w)[:w // 2 + 1]
|
|
|
|
if z is not None:
|
|
fy = np.expand_dims(fy, -1)
|
|
if z % 2 == 1:
|
|
fz = np.fft.fftfreq(z)[:, None]
|
|
else:
|
|
fz = np.fft.fftfreq(z)[:, None]
|
|
|
|
return np.sqrt(fx * fx + fy * fy + fz * fz)
|
|
|
|
|
|
def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
|
|
""" Samples a fourier image with given size and frequencies decayed by decay power
|
|
|
|
:param freqs: Bin values for the discrete fourier transform
|
|
:param decay_power: Decay power for frequency decay prop 1/f**d
|
|
:param ch: Number of channels for the resulting mask
|
|
:param h: Required, first dimension size
|
|
:param w: Optional, second dimension size
|
|
:param z: Optional, third dimension size
|
|
"""
|
|
scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)]))
|
|
**decay_power)
|
|
|
|
param_size = [ch] + list(freqs.shape) + [2]
|
|
param = np.random.randn(*param_size)
|
|
|
|
scale = np.expand_dims(scale, -1)[None, :]
|
|
|
|
return scale * param
|
|
|
|
|
|
def make_low_freq_image(decay, shape, ch=1):
|
|
""" Sample a low frequency image from fourier space
|
|
|
|
:param decay_power: Decay power for frequency decay prop 1/f**d
|
|
:param shape: Shape of desired mask, list up to 3 dims
|
|
:param ch: Number of channels for desired mask
|
|
"""
|
|
freqs = fftfreqnd(*shape)
|
|
spectrum = get_spectrum(freqs, decay, ch,
|
|
*shape) #.reshape((1, *shape[:-1], -1))
|
|
spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
|
|
mask = np.real(np.fft.irfftn(spectrum, shape))
|
|
|
|
if len(shape) == 1:
|
|
mask = mask[:1, :shape[0]]
|
|
if len(shape) == 2:
|
|
mask = mask[:1, :shape[0], :shape[1]]
|
|
if len(shape) == 3:
|
|
mask = mask[:1, :shape[0], :shape[1], :shape[2]]
|
|
|
|
mask = mask
|
|
mask = (mask - mask.min())
|
|
mask = mask / mask.max()
|
|
return mask
|
|
|
|
|
|
def sample_lam(alpha, reformulate=False):
|
|
""" Sample a lambda from symmetric beta distribution with given alpha
|
|
|
|
:param alpha: Alpha value for beta distribution
|
|
:param reformulate: If True, uses the reformulation of [1].
|
|
"""
|
|
if reformulate:
|
|
lam = beta.rvs(alpha + 1, alpha)
|
|
else:
|
|
lam = beta.rvs(alpha, alpha)
|
|
|
|
return lam
|
|
|
|
|
|
def binarise_mask(mask, lam, in_shape, max_soft=0.0):
|
|
""" Binarises a given low frequency image such that it has mean lambda.
|
|
|
|
:param mask: Low frequency image, usually the result of `make_low_freq_image`
|
|
:param lam: Mean value of final mask
|
|
:param in_shape: Shape of inputs
|
|
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
|
|
:return:
|
|
"""
|
|
idx = mask.reshape(-1).argsort()[::-1]
|
|
mask = mask.reshape(-1)
|
|
num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(
|
|
lam * mask.size)
|
|
|
|
eff_soft = max_soft
|
|
if max_soft > lam or max_soft > (1 - lam):
|
|
eff_soft = min(lam, 1 - lam)
|
|
|
|
soft = int(mask.size * eff_soft)
|
|
num_low = int(num - soft)
|
|
num_high = int(num + soft)
|
|
|
|
mask[idx[:num_high]] = 1
|
|
mask[idx[num_low:]] = 0
|
|
mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))
|
|
|
|
mask = mask.reshape((1, 1, in_shape[0], in_shape[1]))
|
|
return mask
|
|
|
|
|
|
def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
|
|
""" Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
|
|
it based on this lambda
|
|
|
|
:param alpha: Alpha value for beta distribution from which to sample mean of mask
|
|
:param decay_power: Decay power for frequency decay prop 1/f**d
|
|
:param shape: Shape of desired mask, list up to 3 dims
|
|
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
|
|
:param reformulate: If True, uses the reformulation of [1].
|
|
"""
|
|
if isinstance(shape, int):
|
|
shape = (shape, )
|
|
|
|
# Choose lambda
|
|
lam = sample_lam(alpha, reformulate)
|
|
|
|
# Make mask, get mean / std
|
|
mask = make_low_freq_image(decay_power, shape)
|
|
mask = binarise_mask(mask, lam, shape, max_soft)
|
|
|
|
return float(lam), mask
|
|
|
|
|
|
def sample_and_apply(x,
|
|
alpha,
|
|
decay_power,
|
|
shape,
|
|
max_soft=0.0,
|
|
reformulate=False):
|
|
"""
|
|
|
|
:param x: Image batch on which to apply fmix of shape [b, c, shape*]
|
|
:param alpha: Alpha value for beta distribution from which to sample mean of mask
|
|
:param decay_power: Decay power for frequency decay prop 1/f**d
|
|
:param shape: Shape of desired mask, list up to 3 dims
|
|
:param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
|
|
:param reformulate: If True, uses the reformulation of [1].
|
|
:return: mixed input, permutation indices, lambda value of mix,
|
|
"""
|
|
lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
|
|
index = np.random.permutation(x.shape[0])
|
|
|
|
x1, x2 = x * mask, x[index] * (1 - mask)
|
|
return x1 + x2, index, lam
|
|
|
|
|
|
class FMixBase:
|
|
""" FMix augmentation
|
|
|
|
Args:
|
|
decay_power (float): Decay power for frequency decay prop 1/f**d
|
|
alpha (float): Alpha value for beta distribution from which to sample mean of mask
|
|
size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
|
|
max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
|
|
reformulate (bool): If True, uses the reformulation of [1].
|
|
"""
|
|
|
|
def __init__(self,
|
|
decay_power=3,
|
|
alpha=1,
|
|
size=(32, 32),
|
|
max_soft=0.0,
|
|
reformulate=False):
|
|
super().__init__()
|
|
self.decay_power = decay_power
|
|
self.reformulate = reformulate
|
|
self.size = size
|
|
self.alpha = alpha
|
|
self.max_soft = max_soft
|
|
self.index = None
|
|
self.lam = None
|
|
|
|
def __call__(self, x):
|
|
raise NotImplementedError
|
|
|
|
def loss(self, *args, **kwargs):
|
|
raise NotImplementedError
|