mirror of https://github.com/alibaba/EasyCV.git
149 lines
4.3 KiB
Python
149 lines
4.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.runner import auto_fp16
|
|
from PIL import Image
|
|
|
|
|
|
class Grid(object):
|
|
|
|
def __init__(self,
|
|
use_h,
|
|
use_w,
|
|
rotate=1,
|
|
offset=False,
|
|
ratio=0.5,
|
|
mode=0,
|
|
prob=1.):
|
|
self.use_h = use_h
|
|
self.use_w = use_w
|
|
self.rotate = rotate
|
|
self.offset = offset
|
|
self.ratio = ratio
|
|
self.mode = mode
|
|
self.st_prob = prob
|
|
self.prob = prob
|
|
|
|
def set_prob(self, epoch, max_epoch):
|
|
self.prob = self.st_prob * epoch / max_epoch
|
|
|
|
def __call__(self, img, label):
|
|
if np.random.rand() > self.prob:
|
|
return img, label
|
|
h = img.size(1)
|
|
w = img.size(2)
|
|
self.d1 = 2
|
|
self.d2 = min(h, w)
|
|
hh = int(1.5 * h)
|
|
ww = int(1.5 * w)
|
|
d = np.random.randint(self.d1, self.d2)
|
|
if self.ratio == 1:
|
|
self.l = np.random.randint(1, d)
|
|
else:
|
|
self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
|
|
mask = np.ones((hh, ww), np.float32)
|
|
st_h = np.random.randint(d)
|
|
st_w = np.random.randint(d)
|
|
if self.use_h:
|
|
for i in range(hh // d):
|
|
s = d * i + st_h
|
|
t = min(s + self.l, hh)
|
|
mask[s:t, :] *= 0
|
|
if self.use_w:
|
|
for i in range(ww // d):
|
|
s = d * i + st_w
|
|
t = min(s + self.l, ww)
|
|
mask[:, s:t] *= 0
|
|
|
|
r = np.random.randint(self.rotate)
|
|
mask = Image.fromarray(np.uint8(mask))
|
|
mask = mask.rotate(r)
|
|
mask = np.asarray(mask)
|
|
mask = mask[(hh - h) // 2:(hh - h) // 2 + h,
|
|
(ww - w) // 2:(ww - w) // 2 + w]
|
|
|
|
mask = torch.from_numpy(mask).float()
|
|
if self.mode == 1:
|
|
mask = 1 - mask
|
|
|
|
mask = mask.expand_as(img)
|
|
if self.offset:
|
|
offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float()
|
|
offset = (1 - mask) * offset
|
|
img = img * mask + offset
|
|
else:
|
|
img = img * mask
|
|
|
|
return img, label
|
|
|
|
|
|
class GridMask(nn.Module):
|
|
|
|
def __init__(self,
|
|
use_h,
|
|
use_w,
|
|
rotate=1,
|
|
offset=False,
|
|
ratio=0.5,
|
|
mode=0,
|
|
prob=1.):
|
|
super(GridMask, self).__init__()
|
|
self.use_h = use_h
|
|
self.use_w = use_w
|
|
self.rotate = rotate
|
|
self.offset = offset
|
|
self.ratio = ratio
|
|
self.mode = mode
|
|
self.st_prob = prob
|
|
self.prob = prob
|
|
self.fp16_enable = False
|
|
|
|
def set_prob(self, epoch, max_epoch):
|
|
self.prob = self.st_prob * epoch / max_epoch
|
|
|
|
@auto_fp16()
|
|
def forward(self, x):
|
|
if np.random.rand() > self.prob or not self.training:
|
|
return x
|
|
n, c, h, w = x.size()
|
|
x = x.view(-1, h, w)
|
|
hh = int(1.5 * h)
|
|
ww = int(1.5 * w)
|
|
d = np.random.randint(2, h)
|
|
self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
|
|
mask = np.ones((hh, ww), np.float32)
|
|
st_h = np.random.randint(d)
|
|
st_w = np.random.randint(d)
|
|
if self.use_h:
|
|
for i in range(hh // d):
|
|
s = d * i + st_h
|
|
t = min(s + self.l, hh)
|
|
mask[s:t, :] *= 0
|
|
if self.use_w:
|
|
for i in range(ww // d):
|
|
s = d * i + st_w
|
|
t = min(s + self.l, ww)
|
|
mask[:, s:t] *= 0
|
|
|
|
r = np.random.randint(self.rotate)
|
|
mask = Image.fromarray(np.uint8(mask))
|
|
mask = mask.rotate(r)
|
|
mask = np.asarray(mask)
|
|
mask = mask[(hh - h) // 2:(hh - h) // 2 + h,
|
|
(ww - w) // 2:(ww - w) // 2 + w]
|
|
|
|
mask = torch.from_numpy(mask).to(x.dtype).cuda()
|
|
if self.mode == 1:
|
|
mask = 1 - mask
|
|
mask = mask.expand_as(x)
|
|
if self.offset:
|
|
offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).to(
|
|
x.dtype).cuda()
|
|
x = x * mask + offset * (1 - mask)
|
|
else:
|
|
x = x * mask
|
|
|
|
return x.view(n, c, h, w)
|