refactor: deprecate MixCELoss
parent
69d9a477e0
commit
ba2dd01a13
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
Eval:
|
||||
- CELoss:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
Eval:
|
||||
- CELoss:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -24,7 +24,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -26,7 +26,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -26,7 +26,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -26,7 +26,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -26,7 +26,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -26,7 +26,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -26,7 +26,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -22,7 +22,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -30,7 +30,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -29,7 +29,7 @@ Arch:
|
|||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import copy
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
@ -36,7 +37,7 @@ from ppcls.data import preprocess
|
|||
from ppcls.data.preprocess import transform
|
||||
|
||||
|
||||
def create_operators(params):
|
||||
def create_operators(params, class_num=None):
|
||||
"""
|
||||
create operators based on the config
|
||||
|
||||
|
@ -50,7 +51,10 @@ def create_operators(params):
|
|||
dict) and len(operator) == 1, "yaml format error"
|
||||
op_name = list(operator)[0]
|
||||
param = {} if operator[op_name] is None else operator[op_name]
|
||||
op = getattr(preprocess, op_name)(**param)
|
||||
op_func = getattr(preprocess, op_name)
|
||||
if "class_num" in inspect.getfullargspec(op_func).args:
|
||||
param.update({"class_num": class_num})
|
||||
op = op_func(**param)
|
||||
ops.append(op)
|
||||
|
||||
return ops
|
||||
|
@ -65,6 +69,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
|
|||
from ppcls.data.dataloader.dali import dali_dataloader
|
||||
return dali_dataloader(config, mode, paddle.device.get_device(), seed)
|
||||
|
||||
class_num = config.get("class_num", None)
|
||||
config_dataset = config[mode]['dataset']
|
||||
config_dataset = copy.deepcopy(config_dataset)
|
||||
dataset_name = config_dataset.pop('name')
|
||||
|
@ -104,7 +109,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
|
|||
return [np.stack(slot, axis=0) for slot in slots]
|
||||
|
||||
if isinstance(batch_transform, list):
|
||||
batch_ops = create_operators(batch_transform)
|
||||
batch_ops = create_operators(batch_transform, class_num)
|
||||
batch_collate_fn = mix_collate_fn
|
||||
else:
|
||||
batch_collate_fn = None
|
||||
|
|
|
@ -44,6 +44,14 @@ class BatchOperator(object):
|
|||
labels.append(item[1])
|
||||
return np.array(imgs), np.array(labels), bs
|
||||
|
||||
def _one_hot(self, targets):
|
||||
return np.eye(self.class_num, dtype="float32")[targets]
|
||||
|
||||
def _mix_target(self, targets0, targets1, lam):
|
||||
one_hots0 = self._one_hot(targets0)
|
||||
one_hots1 = self._one_hot(targets1)
|
||||
return one_hots0 * lam + one_hots1 * (1 - lam)
|
||||
|
||||
def __call__(self, batch):
|
||||
return batch
|
||||
|
||||
|
@ -51,7 +59,7 @@ class BatchOperator(object):
|
|||
class MixupOperator(BatchOperator):
|
||||
""" Mixup operator """
|
||||
|
||||
def __init__(self, alpha: float=1.):
|
||||
def __init__(self, class_num, alpha: float=1.):
|
||||
"""Build Mixup operator
|
||||
|
||||
Args:
|
||||
|
@ -64,21 +72,27 @@ class MixupOperator(BatchOperator):
|
|||
raise Exception(
|
||||
f"Parameter \"alpha\" of Mixup should be greater than 0. \"alpha\": {alpha}."
|
||||
)
|
||||
if not class_num:
|
||||
msg = "Please set \"Arch.class_num\" in config if use \"MixupOperator\"."
|
||||
logger.error(Exception(msg))
|
||||
raise Exception(msg)
|
||||
|
||||
self._alpha = alpha
|
||||
self.class_num = class_num
|
||||
|
||||
def __call__(self, batch):
|
||||
imgs, labels, bs = self._unpack(batch)
|
||||
idx = np.random.permutation(bs)
|
||||
lam = np.random.beta(self._alpha, self._alpha)
|
||||
lams = np.array([lam] * bs, dtype=np.float32)
|
||||
imgs = lam * imgs + (1 - lam) * imgs[idx]
|
||||
return list(zip(imgs, labels, labels[idx], lams))
|
||||
targets = self._mix_target(labels, labels[idx], lam)
|
||||
return list(zip(imgs, targets))
|
||||
|
||||
|
||||
class CutmixOperator(BatchOperator):
|
||||
""" Cutmix operator """
|
||||
|
||||
def __init__(self, alpha=0.2):
|
||||
def __init__(self, class_num, alpha=0.2):
|
||||
"""Build Cutmix operator
|
||||
|
||||
Args:
|
||||
|
@ -91,7 +105,13 @@ class CutmixOperator(BatchOperator):
|
|||
raise Exception(
|
||||
f"Parameter \"alpha\" of Cutmix should be greater than 0. \"alpha\": {alpha}."
|
||||
)
|
||||
if not class_num:
|
||||
msg = "Please set \"Arch.class_num\" in config if use \"CutmixOperator\"."
|
||||
logger.error(Exception(msg))
|
||||
raise Exception(msg)
|
||||
|
||||
self._alpha = alpha
|
||||
self.class_num = class_num
|
||||
|
||||
def _rand_bbox(self, size, lam):
|
||||
""" _rand_bbox """
|
||||
|
@ -121,18 +141,29 @@ class CutmixOperator(BatchOperator):
|
|||
imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[idx, :, bbx1:bbx2, bby1:bby2]
|
||||
lam = 1 - (float(bbx2 - bbx1) * (bby2 - bby1) /
|
||||
(imgs.shape[-2] * imgs.shape[-1]))
|
||||
lams = np.array([lam] * bs, dtype=np.float32)
|
||||
return list(zip(imgs, labels, labels[idx], lams))
|
||||
targets = self._mix_target(labels, labels[idx], lam)
|
||||
return list(zip(imgs, targets))
|
||||
|
||||
|
||||
class FmixOperator(BatchOperator):
|
||||
""" Fmix operator """
|
||||
|
||||
def __init__(self, alpha=1, decay_power=3, max_soft=0., reformulate=False):
|
||||
def __init__(self,
|
||||
class_num,
|
||||
alpha=1,
|
||||
decay_power=3,
|
||||
max_soft=0.,
|
||||
reformulate=False):
|
||||
if not class_num:
|
||||
msg = "Please set \"Arch.class_num\" in config if use \"FmixOperator\"."
|
||||
logger.error(Exception(msg))
|
||||
raise Exception(msg)
|
||||
|
||||
self._alpha = alpha
|
||||
self._decay_power = decay_power
|
||||
self._max_soft = max_soft
|
||||
self._reformulate = reformulate
|
||||
self.class_num = class_num
|
||||
|
||||
def __call__(self, batch):
|
||||
imgs, labels, bs = self._unpack(batch)
|
||||
|
@ -141,20 +172,27 @@ class FmixOperator(BatchOperator):
|
|||
lam, mask = sample_mask(self._alpha, self._decay_power, \
|
||||
size, self._max_soft, self._reformulate)
|
||||
imgs = mask * imgs + (1 - mask) * imgs[idx]
|
||||
return list(zip(imgs, labels, labels[idx], [lam] * bs))
|
||||
targets = self._mix_target(labels, labels[idx], lam)
|
||||
return list(zip(imgs, targets))
|
||||
|
||||
|
||||
class OpSampler(object):
|
||||
""" Sample a operator from """
|
||||
|
||||
def __init__(self, **op_dict):
|
||||
def __init__(self, class_num, **op_dict):
|
||||
"""Build OpSampler
|
||||
|
||||
Raises:
|
||||
Exception: The parameter \"prob\" of operator(s) are be set error.
|
||||
"""
|
||||
if not class_num:
|
||||
msg = "Please set \"Arch.class_num\" in config if use \"OpSampler\"."
|
||||
logger.error(Exception(msg))
|
||||
raise Exception(msg)
|
||||
|
||||
if len(op_dict) < 1:
|
||||
msg = f"ConfigWarning: No operator in \"OpSampler\". \"OpSampler\" has been skipped."
|
||||
logger.warning(msg)
|
||||
|
||||
self.ops = {}
|
||||
total_prob = 0
|
||||
|
@ -165,12 +203,13 @@ class OpSampler(object):
|
|||
logger.warning(msg)
|
||||
prob = param.pop("prob", 0)
|
||||
total_prob += prob
|
||||
param.update({"class_num": class_num})
|
||||
op = eval(op_name)(**param)
|
||||
self.ops.update({op: prob})
|
||||
|
||||
if total_prob > 1:
|
||||
msg = f"ConfigError: The total prob of operators in \"OpSampler\" should be less 1."
|
||||
logger.error(msg)
|
||||
logger.error(Exception(msg))
|
||||
raise Exception(msg)
|
||||
|
||||
# add "None Op" when total_prob < 1, "None Op" do nothing
|
||||
|
|
|
@ -112,6 +112,8 @@ class Engine(object):
|
|||
}
|
||||
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
|
||||
class_num = config["Arch"].get("class_num", None)
|
||||
self.config["DataLoader"].update({"class_num": class_num})
|
||||
# build dataloader
|
||||
if self.mode == 'train':
|
||||
self.train_dataloader = build_dataloader(
|
||||
|
|
|
@ -36,25 +36,19 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
]
|
||||
batch_size = batch[0].shape[0]
|
||||
if not engine.config["Global"].get("use_multilabel", False):
|
||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||
batch[1] = batch[1].reshape([batch_size, -1])
|
||||
engine.global_step += 1
|
||||
|
||||
if engine.config["DataLoader"]["Train"]["dataset"].get(
|
||||
"batch_transform_ops", None):
|
||||
gt_input = batch[1:]
|
||||
else:
|
||||
gt_input = batch[1]
|
||||
|
||||
# image input
|
||||
if engine.amp:
|
||||
with paddle.amp.auto_cast(custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
}):
|
||||
out = forward(engine, batch)
|
||||
loss_dict = engine.train_loss_func(out, gt_input)
|
||||
else:
|
||||
out = forward(engine, batch)
|
||||
loss_dict = engine.train_loss_func(out, gt_input)
|
||||
|
||||
loss_dict = engine.train_loss_func(out, batch[1])
|
||||
|
||||
# step opt and lr
|
||||
if engine.amp:
|
||||
|
|
|
@ -12,10 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from ppcls.utils import logger
|
||||
|
||||
|
||||
class CELoss(nn.Layer):
|
||||
"""
|
||||
|
@ -56,19 +60,8 @@ class CELoss(nn.Layer):
|
|||
return {"CELoss": loss}
|
||||
|
||||
|
||||
class MixCELoss(CELoss):
|
||||
"""
|
||||
Cross entropy loss with mix(mixup, cutmix, fixmix)
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon=None):
|
||||
super().__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def __call__(self, input, batch):
|
||||
target0, target1, lam = batch
|
||||
loss0 = super().forward(input, target0)["CELoss"]
|
||||
loss1 = super().forward(input, target1)["CELoss"]
|
||||
loss = lam * loss0 + (1.0 - lam) * loss1
|
||||
loss = paddle.mean(loss)
|
||||
return {"MixCELoss": loss}
|
||||
class MixCELoss(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
msg = "\"MixCELos\" is deprecated, please use \"CELoss\" instead."
|
||||
logger.error(DeprecationWarning(msg))
|
||||
raise DeprecationWarning(msg)
|
||||
|
|
Loading…
Reference in New Issue