Merge pull request #49 from xieenze/master

Replace 'GaussianBlur' from Opencv to PIL to speed up.
This commit is contained in:
Xiaohang Zhan 2020-09-27 16:07:57 +08:00 committed by GitHub
commit a343437542
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 12 additions and 20 deletions

View File

@ -52,8 +52,7 @@ train_pipeline = [
dict( dict(
type='GaussianBlur', type='GaussianBlur',
sigma_min=0.1, sigma_min=0.1,
sigma_max=2.0, sigma_max=2.0)
kernel_size=23)
], ],
p=1.), p=1.),
dict(type='RandomAppliedTrans', dict(type='RandomAppliedTrans',

View File

@ -52,8 +52,7 @@ train_pipeline = [
dict( dict(
type='GaussianBlur', type='GaussianBlur',
sigma_min=0.1, sigma_min=0.1,
sigma_max=2.0, sigma_max=2.0)
kernel_size=23)
], ],
p=1.), p=1.),
dict(type='RandomAppliedTrans', dict(type='RandomAppliedTrans',

View File

@ -48,8 +48,7 @@ train_pipeline = [
dict( dict(
type='GaussianBlur', type='GaussianBlur',
sigma_min=0.1, sigma_min=0.1,
sigma_max=2.0, sigma_max=2.0)
kernel_size=23)
], ],
p=0.5), p=0.5),
dict(type='RandomHorizontalFlip'), dict(type='RandomHorizontalFlip'),

View File

@ -49,8 +49,7 @@ train_pipeline = [
dict( dict(
type='GaussianBlur', type='GaussianBlur',
sigma_min=0.1, sigma_min=0.1,
sigma_max=2.0, sigma_max=2.0)
kernel_size=23)
], ],
p=0.5), p=0.5),
dict(type='RandomHorizontalFlip'), dict(type='RandomHorizontalFlip'),

View File

@ -47,8 +47,7 @@ train_pipeline = [
dict( dict(
type='GaussianBlur', type='GaussianBlur',
sigma_min=0.1, sigma_min=0.1,
sigma_max=2.0, sigma_max=2.0)
kernel_size=23)
], ],
p=0.5), p=0.5),
dict(type='ToTensor'), dict(type='ToTensor'),

View File

@ -46,8 +46,7 @@ train_pipeline = [
dict( dict(
type='GaussianBlur', type='GaussianBlur',
sigma_min=0.1, sigma_min=0.1,
sigma_max=2.0, sigma_max=2.0)
kernel_size=23)
], ],
p=0.5), p=0.5),
dict(type='ToTensor'), dict(type='ToTensor'),

View File

@ -47,8 +47,7 @@ train_pipeline = [
dict( dict(
type='GaussianBlur', type='GaussianBlur',
sigma_min=0.1, sigma_min=0.1,
sigma_max=2.0, sigma_max=2.0)
kernel_size=23)
], ],
p=0.5), p=0.5),
dict(type='ToTensor'), dict(type='ToTensor'),

View File

@ -1,7 +1,8 @@
import cv2 import cv2
import inspect import inspect
import numpy as np import numpy as np
from PIL import Image from PIL import Image, ImageFilter
import torch import torch
from torchvision import transforms as _transforms from torchvision import transforms as _transforms
@ -80,16 +81,14 @@ class Lighting(object):
class GaussianBlur(object): class GaussianBlur(object):
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709.""" """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709."""
def __init__(self, sigma_min, sigma_max, kernel_size): def __init__(self, sigma_min, sigma_max):
self.sigma_min = sigma_min self.sigma_min = sigma_min
self.sigma_max = sigma_max self.sigma_max = sigma_max
self.kernel_size = kernel_size
def __call__(self, img): def __call__(self, img):
sigma = np.random.uniform(self.sigma_min, self.sigma_max) sigma = np.random.uniform(self.sigma_min, self.sigma_max)
img = cv2.GaussianBlur( img = img.filter(ImageFilter.GaussianBlur(radius=sigma))
np.array(img), (self.kernel_size, self.kernel_size), sigma) return img
return Image.fromarray(img.astype(np.uint8))
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__