mirror of https://github.com/alibaba/EasyCV.git
57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
from timm.data import Mixup
|
|
from timm.data.mixup import mixup_target
|
|
|
|
from .registry import HOOKS
|
|
|
|
|
|
class BaseCollateHook(object):
|
|
"""Collate fn hook when build dataloader.
|
|
Used when you need to process before or after merges a list of samples to form a mini-batch of Tensor(s).
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def before_collate(self, batch):
|
|
return batch
|
|
|
|
def after_collate(self, batch):
|
|
return batch
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class MixupCollateHook(BaseCollateHook):
|
|
"""Mixedup data batch, should be used after merges a list of samples to form a mini-batch of Tensor(s).
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.mixup = Mixup(**kwargs)
|
|
|
|
def after_collate(self, results):
|
|
batch_size = results['img'].size()[0]
|
|
assert batch_size % 2 == 0, 'Batch size should be even when using this, but get {}'.format(
|
|
batch_size)
|
|
samples = results['img']
|
|
targets = results['gt_labels']
|
|
|
|
if self.mixup.mode == 'elem':
|
|
lam = self.mixup._mix_elem(samples)
|
|
elif self.mixup.mode == 'pair':
|
|
lam = self.mixup._mix_pair(samples)
|
|
else:
|
|
lam = self.mixup._mix_batch(samples)
|
|
|
|
device = samples.device
|
|
targets = mixup_target(
|
|
target=targets,
|
|
num_classes=self.mixup.num_classes,
|
|
lam=lam,
|
|
smoothing=self.mixup.label_smoothing,
|
|
device=device)
|
|
|
|
results['img'] = samples
|
|
results['gt_labels'] = targets
|
|
|
|
return results
|