116 lines
3.1 KiB
Python
116 lines
3.1 KiB
Python
import mmcv
|
|
import numpy as np
|
|
from torchvision import transforms
|
|
|
|
from ..builder import PIPELINES
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class RandomCrop(transforms.RandomCrop):
|
|
"""
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(RandomCrop, self).__init__(*args, **kwargs)
|
|
|
|
def __call__(self, results):
|
|
results['img'] = super(RandomCrop, self).__call__(results['img'])
|
|
return results
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class RandomResizedCrop(transforms.RandomResizedCrop):
|
|
"""
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(RandomResizedCrop, self).__init__(*args, **kwargs)
|
|
|
|
def __call__(self, results):
|
|
results['img'] = super(RandomResizedCrop,
|
|
self).__call__(results['img'])
|
|
return results
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
|
|
"""
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(RandomHorizontalFlip, self).__init__(*args, **kwargs)
|
|
|
|
def __call__(self, results):
|
|
results['img'] = super(RandomHorizontalFlip,
|
|
self).__call__(results['img'])
|
|
return results
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class Resize(transforms.Resize):
|
|
"""
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(Resize, self).__init__(*args, **kwargs)
|
|
|
|
def __call__(self, results):
|
|
results['img'] = super(Resize, self).__call__(results['img'])
|
|
return results
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class CenterCrop(transforms.CenterCrop):
|
|
"""
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(CenterCrop, self).__init__(*args, **kwargs)
|
|
|
|
def __call__(self, results):
|
|
results['img'] = super(CenterCrop, self).__call__(results['img'])
|
|
return results
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class ColorJitter(transforms.ColorJitter):
|
|
"""
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(ColorJitter, self).__init__(*args, **kwargs)
|
|
|
|
def __call__(self, results):
|
|
results['img'] = super(ColorJitter, self).__call__(results['img'])
|
|
return results
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class Normalize(object):
|
|
"""Normalize the image.
|
|
|
|
Args:
|
|
mean (sequence): Mean values of 3 channels.
|
|
std (sequence): Std values of 3 channels.
|
|
to_rgb (bool): Whether to convert the image from BGR to RGB,
|
|
default is true.
|
|
"""
|
|
|
|
def __init__(self, mean, std, to_rgb=True):
|
|
self.mean = np.array(mean, dtype=np.float32)
|
|
self.std = np.array(std, dtype=np.float32)
|
|
self.to_rgb = to_rgb
|
|
|
|
def __call__(self, results):
|
|
for key in results.get('img_fields', ['img']):
|
|
results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
|
|
self.to_rgb)
|
|
results['img_norm_cfg'] = dict(
|
|
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
|
|
return repr_str
|