EasyCV/easycv/hooks/collate_hook.py

57 lines
1.6 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from timm.data import Mixup
from timm.data.mixup import mixup_target
from .registry import HOOKS
class BaseCollateHook(object):
"""Collate fn hook when build dataloader.
Used when you need to process before or after merges a list of samples to form a mini-batch of Tensor(s).
"""
def __init__(self) -> None:
pass
def before_collate(self, batch):
return batch
def after_collate(self, batch):
return batch
@HOOKS.register_module()
class MixupCollateHook(BaseCollateHook):
"""Mixedup data batch, should be used after merges a list of samples to form a mini-batch of Tensor(s).
"""
def __init__(self, **kwargs):
self.mixup = Mixup(**kwargs)
def after_collate(self, results):
batch_size = results['img'].size()[0]
assert batch_size % 2 == 0, 'Batch size should be even when using this, but get {}'.format(
batch_size)
samples = results['img']
targets = results['gt_labels']
if self.mixup.mode == 'elem':
lam = self.mixup._mix_elem(samples)
elif self.mixup.mode == 'pair':
lam = self.mixup._mix_pair(samples)
else:
lam = self.mixup._mix_batch(samples)
device = samples.device
targets = mixup_target(
target=targets,
num_classes=self.mixup.num_classes,
lam=lam,
smoothing=self.mixup.label_smoothing,
device=device)
results['img'] = samples
results['gt_labels'] = targets
return results