236 lines
8.5 KiB
Python
236 lines
8.5 KiB
Python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import division
|
|
|
|
import nvidia.dali.fn as fn
|
|
import nvidia.dali.ops as ops
|
|
import nvidia.dali.types as types
|
|
|
|
|
|
class DecodeImage(ops.decoders.Image):
|
|
def __init__(self, *kargs, device="cpu", **kwargs):
|
|
super(DecodeImage, self).__init__(*kargs, device=device, **kwargs)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
return super(DecodeImage, self).__call__(data, **kwargs)
|
|
|
|
|
|
class ToCHWImage(ops.Transpose):
|
|
def __init__(self, *kargs, device="cpu", **kwargs):
|
|
super(ToCHWImage, self).__init__(*kargs, device=device, **kwargs)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
return super(ToCHWImage, self).__call__(data, **kwargs)
|
|
|
|
|
|
class ColorJitter(ops.ColorTwist):
|
|
def __init__(self,
|
|
*kargs,
|
|
device="cpu",
|
|
prob=1.0,
|
|
brightness_factor=0.0,
|
|
contrast_factor=0.0,
|
|
saturation_factor=0.0,
|
|
hue_factor=0.0,
|
|
**kwargs):
|
|
super(ColorJitter, self).__init__(*kargs, device=device, **kwargs)
|
|
self.brightness_factor = brightness_factor
|
|
self.contrast_factor = contrast_factor
|
|
self.saturation_factor = saturation_factor
|
|
self.hue_factor = hue_factor
|
|
self.rng = ops.random.CoinFlip(probability=prob)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
do_jitter = self.rng()
|
|
brightness = fn.random.uniform(
|
|
range=(max(0, 1 - self.brightness_factor),
|
|
1 + self.brightness_factor)) * do_jitter
|
|
contrast = fn.random.uniform(
|
|
range=(max(0, 1 - self.contrast_factor),
|
|
1 + self.contrast_factor)) * do_jitter
|
|
saturation = fn.random.uniform(
|
|
range=(max(0, 1 - self.saturation_factor),
|
|
1 + self.saturation_factor)) * do_jitter
|
|
hue = fn.random.uniform(range=(-self.hue_factor,
|
|
self.hue_factor)) * do_jitter
|
|
return super(ColorJitter, self).__call__(
|
|
data,
|
|
brightness=brightness,
|
|
contrast=contrast,
|
|
saturation=saturation,
|
|
hue=hue,
|
|
**kwargs)
|
|
|
|
|
|
class DecodeRandomResizedCrop(ops.decoders.ImageRandomCrop):
|
|
def __init__(self,
|
|
*kargs,
|
|
device="cpu",
|
|
resize_x=224,
|
|
resize_y=224,
|
|
resize_short=None,
|
|
interp_type=types.DALIInterpType.INTERP_LINEAR,
|
|
**kwargs):
|
|
super(DecodeRandomResizedCrop, self).__init__(
|
|
*kargs, device=device, **kwargs)
|
|
if resize_short is None:
|
|
self.resize = ops.Resize(
|
|
device="gpu" if device == "mixed" else "cpu",
|
|
resize_x=resize_x,
|
|
resize_y=resize_y,
|
|
interp_type=interp_type)
|
|
else:
|
|
self.resize = ops.Resize(
|
|
device="gpu" if device == "mixed" else "cpu",
|
|
resize_short=resize_short,
|
|
interp_type=interp_type)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
data = super(DecodeRandomResizedCrop, self).__call__(data, **kwargs)
|
|
data = self.resize(data)
|
|
return data
|
|
|
|
|
|
class CropMirrorNormalize(ops.CropMirrorNormalize):
|
|
def __init__(self, *kargs, device="cpu", prob=0.5, **kwargs):
|
|
super(CropMirrorNormalize, self).__init__(
|
|
*kargs, device=device, **kwargs)
|
|
self.rng = ops.random.CoinFlip(probability=prob)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
do_mirror = self.rng()
|
|
return super(CropMirrorNormalize, self).__call__(
|
|
data, mirror=do_mirror, **kwargs)
|
|
|
|
|
|
class RandCropImage(ops.RandomResizedCrop):
|
|
def __init__(self, *kargs, device="cpu", **kwargs):
|
|
super(RandCropImage, self).__init__(*kargs, device=device, **kwargs)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
return super(RandCropImage, self).__call__(data, **kwargs)
|
|
|
|
|
|
class CropImage(ops.Crop):
|
|
def __init__(self, *kargs, device="cpu", **kwargs):
|
|
super(CropImage, self).__init__(*kargs, device=device, **kwargs)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
return super(CropImage, self).__call__(data, **kwargs)
|
|
|
|
|
|
class ResizeImage(ops.Resize):
|
|
def __init__(self, *kargs, device="cpu", **kwargs):
|
|
super(ResizeImage, self).__init__(*kargs, device=device, **kwargs)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
return super(ResizeImage, self).__call__(data, **kwargs)
|
|
|
|
|
|
class RandFlipImage(ops.Flip):
|
|
def __init__(self, *kargs, device="cpu", prob=0.5, flip_code=1, **kwargs):
|
|
super(RandFlipImage, self).__init__(*kargs, device=device, **kwargs)
|
|
self.flip_code = flip_code
|
|
self.rng = ops.random.CoinFlip(probability=prob)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
do_flip = self.rng()
|
|
if self.flip_code == 1:
|
|
return super(RandFlipImage, self).__call__(
|
|
data, horizontal=do_flip, vertical=0, **kwargs)
|
|
elif self.flip_code == 0:
|
|
return super(RandFlipImage, self).__call__(
|
|
data, horizontal=0, vertical=do_flip, **kwargs)
|
|
else:
|
|
return super(RandFlipImage, self).__call__(
|
|
data, horizontal=do_flip, vertical=do_flip, **kwargs)
|
|
|
|
|
|
class Pad(ops.Crop):
|
|
"""
|
|
use ops.Crop to implement Pad operator, for ops.Pad alwayls only pad in right and bottom.
|
|
"""
|
|
|
|
def __init__(self, *kargs, device="cpu", **kwargs):
|
|
super(Pad, self).__init__(*kargs, device=device, **kwargs)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
return super(Pad, self).__call__(data, **kwargs)
|
|
|
|
|
|
class RandCropImageV2(ops.Crop):
|
|
def __init__(self, *kargs, device="cpu", **kwargs):
|
|
super(RandCropImageV2, self).__init__(*kargs, device=device, **kwargs)
|
|
self.rng_x = ops.random.Uniform(range=(0.0, 1.0))
|
|
self.rng_y = ops.random.Uniform(range=(0.0, 1.0))
|
|
|
|
def __call__(self, data, **kwargs):
|
|
pos_x = self.rng_x()
|
|
pos_y = self.rng_y()
|
|
return super(RandCropImageV2, self).__call__(
|
|
data, crop_pos_x=pos_x, crop_pos_y=pos_y, **kwargs)
|
|
|
|
|
|
class RandomCropImage(ops.Crop):
|
|
def __init__(self, *kargs, device="cpu", **kwargs):
|
|
super(RandomCropImage, self).__init__(*kargs, device=device, **kwargs)
|
|
self.rng_x = ops.random.Uniform(range=(0.0, 1.0))
|
|
self.rng_y = ops.random.Uniform(range=(0.0, 1.0))
|
|
|
|
def __call__(self, data, **kwargs):
|
|
pos_x = self.rng_x()
|
|
pos_y = self.rng_y()
|
|
return super(RandomCropImage, self).__call__(
|
|
data, crop_pos_x=pos_x, crop_pos_y=pos_y, **kwargs)
|
|
|
|
|
|
class RandomRotation(ops.Rotate):
|
|
def __init__(self, *kargs, device="cpu", prob=0.5, angle=0, **kwargs):
|
|
super(RandomRotation, self).__init__(*kargs, device=device, **kwargs)
|
|
self.rng = ops.random.CoinFlip(probability=prob)
|
|
discrete_angle = list(range(-angle, angle + 1))
|
|
self.rng_angle = ops.random.Uniform(values=discrete_angle)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
do_rotate = self.rng()
|
|
angle = self.rng_angle()
|
|
flip_data = super(RandomRotation, self).__call__(
|
|
data,
|
|
angle=do_rotate * angle,
|
|
keep_size=True,
|
|
fill_value=0,
|
|
**kwargs)
|
|
return flip_data
|
|
|
|
|
|
class RandomRot90(ops.Rotate):
|
|
def __init__(self, *kargs, device="cpu", **kwargs):
|
|
super(RandomRot90, self).__init__(*kargs, device=device, **kwargs)
|
|
self.rng_angle = ops.random.Uniform(values=[0.0, 1.0, 2.0, 3.0])
|
|
|
|
def __call__(self, data, **kwargs):
|
|
angle = self.rng_angle() * 90.0
|
|
flip_data = super(RandomRot90, self).__call__(
|
|
data, angle=angle, keep_size=True, fill_value=0, **kwargs)
|
|
return flip_data
|
|
|
|
|
|
class NormalizeImage(ops.Normalize):
|
|
def __init__(self, *kargs, device="cpu", **kwargs):
|
|
super(NormalizeImage, self).__init__(*kargs, device=device, **kwargs)
|
|
|
|
def __call__(self, data, **kwargs):
|
|
return super(NormalizeImage, self).__call__(data, **kwargs)
|