mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #148 from rwightman/drop_block_improve
Improve dropblock impl, add fast variant, better AMP speed, inplace…
This commit is contained in:
commit
dab9935b36
@ -22,44 +22,89 @@ import math
|
||||
|
||||
|
||||
def drop_block_2d(
|
||||
x, drop_prob: float = 0.1, training: bool = False, block_size: int = 7,
|
||||
gamma_scale: float = 1.0, drop_with_noise: bool = False):
|
||||
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
|
||||
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
|
||||
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
||||
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
||||
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
_, _, height, width = x.shape
|
||||
total_size = width * height
|
||||
clipped_block_size = min(block_size, min(width, height))
|
||||
B, C, H, W = x.shape
|
||||
total_size = W * H
|
||||
clipped_block_size = min(block_size, min(W, H))
|
||||
# seed_drop_rate, the gamma parameter
|
||||
seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(width - block_size + 1) *
|
||||
(height - block_size + 1))
|
||||
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(W - block_size + 1) * (H - block_size + 1))
|
||||
|
||||
# Forces the block to be inside the feature map.
|
||||
w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device))
|
||||
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \
|
||||
((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2))
|
||||
valid_block = torch.reshape(valid_block, (1, 1, height, width)).float()
|
||||
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
|
||||
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
|
||||
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
|
||||
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
|
||||
|
||||
uniform_noise = torch.rand_like(x, dtype=torch.float32)
|
||||
block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float()
|
||||
if batchwise:
|
||||
# one mask for whole batch, quite a bit faster
|
||||
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
uniform_noise = torch.rand_like(x)
|
||||
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
|
||||
block_mask = -F.max_pool2d(
|
||||
-block_mask,
|
||||
kernel_size=clipped_block_size, # block_size, ???
|
||||
kernel_size=clipped_block_size, # block_size,
|
||||
stride=1,
|
||||
padding=clipped_block_size // 2)
|
||||
|
||||
if drop_with_noise:
|
||||
normal_noise = torch.randn_like(x)
|
||||
x = x * block_mask + normal_noise * (1 - block_mask)
|
||||
if with_noise:
|
||||
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||||
if inplace:
|
||||
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
|
||||
else:
|
||||
x = x * block_mask + normal_noise * (1 - block_mask)
|
||||
else:
|
||||
normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7)
|
||||
x = x * block_mask * normalize_scale
|
||||
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
|
||||
if inplace:
|
||||
x.mul_(block_mask * normalize_scale)
|
||||
else:
|
||||
x = x * block_mask * normalize_scale
|
||||
return x
|
||||
|
||||
|
||||
def drop_block_fast_2d(
|
||||
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
|
||||
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
|
||||
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
||||
block mask at edges.
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
total_size = W * H
|
||||
clipped_block_size = min(block_size, min(W, H))
|
||||
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(W - block_size + 1) * (H - block_size + 1))
|
||||
|
||||
if batchwise:
|
||||
# one mask for whole batch, quite a bit faster
|
||||
block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
|
||||
else:
|
||||
# mask per batch element
|
||||
block_mask = torch.rand_like(x) < gamma
|
||||
block_mask = F.max_pool2d(
|
||||
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
|
||||
|
||||
if with_noise:
|
||||
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||||
if inplace:
|
||||
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
|
||||
else:
|
||||
x = x * (1. - block_mask) + normal_noise * block_mask
|
||||
else:
|
||||
block_mask = 1 - block_mask
|
||||
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
|
||||
if inplace:
|
||||
x.mul_(block_mask * normalize_scale)
|
||||
else:
|
||||
x = x * block_mask * normalize_scale
|
||||
return x
|
||||
|
||||
|
||||
@ -70,15 +115,28 @@ class DropBlock2d(nn.Module):
|
||||
drop_prob=0.1,
|
||||
block_size=7,
|
||||
gamma_scale=1.0,
|
||||
with_noise=False):
|
||||
with_noise=False,
|
||||
inplace=False,
|
||||
batchwise=False,
|
||||
fast=True):
|
||||
super(DropBlock2d, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.gamma_scale = gamma_scale
|
||||
self.block_size = block_size
|
||||
self.with_noise = with_noise
|
||||
self.inplace = inplace
|
||||
self.batchwise = batchwise
|
||||
self.fast = fast # FIXME finish comparisons of fast vs not
|
||||
|
||||
def forward(self, x):
|
||||
return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise)
|
||||
if not self.training or not self.drop_prob:
|
||||
return x
|
||||
if self.fast:
|
||||
return drop_block_fast_2d(
|
||||
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||||
else:
|
||||
return drop_block_2d(
|
||||
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
|
@ -31,25 +31,24 @@ class RadixSoftmax(nn.Module):
|
||||
class SplitAttnConv2d(nn.Module):
|
||||
"""Split-Attention Conv2d
|
||||
"""
|
||||
def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0,
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||||
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
|
||||
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
|
||||
super(SplitAttnConv2d, self).__init__()
|
||||
self.radix = radix
|
||||
self.cardinality = groups
|
||||
self.channels = channels
|
||||
mid_chs = channels * radix
|
||||
self.drop_block = drop_block
|
||||
mid_chs = out_channels * radix
|
||||
attn_chs = max(in_channels * radix // reduction_factor, 32)
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, mid_chs, kernel_size, stride, padding, dilation,
|
||||
groups=groups * radix, bias=bias, **kwargs)
|
||||
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
|
||||
self.act0 = act_layer(inplace=True)
|
||||
self.fc1 = nn.Conv2d(channels, attn_chs, 1, groups=self.cardinality)
|
||||
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
|
||||
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=self.cardinality)
|
||||
self.drop_block = drop_block
|
||||
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
|
||||
self.rsoftmax = RadixSoftmax(radix, groups)
|
||||
|
||||
def forward(self, x):
|
||||
@ -63,7 +62,7 @@ class SplitAttnConv2d(nn.Module):
|
||||
B, RC, H, W = x.shape
|
||||
if self.radix > 1:
|
||||
x = x.reshape((B, self.radix, RC // self.radix, H, W))
|
||||
x_gap = torch.sum(x, dim=1)
|
||||
x_gap = x.sum(dim=1)
|
||||
else:
|
||||
x_gap = x
|
||||
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
|
||||
|
@ -76,10 +76,10 @@ class ResNestBottleneck(nn.Module):
|
||||
else:
|
||||
avd_stride = 0
|
||||
self.radix = radix
|
||||
self.drop_block = drop_block
|
||||
|
||||
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
|
||||
self.bn1 = norm_layer(group_width)
|
||||
self.drop_block1 = drop_block if drop_block is not None else None
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
|
||||
|
||||
@ -88,20 +88,17 @@ class ResNestBottleneck(nn.Module):
|
||||
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
|
||||
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness
|
||||
self.drop_block2 = None
|
||||
self.act2 = None
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(
|
||||
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
dilation=first_dilation, groups=cardinality, bias=False)
|
||||
self.bn2 = norm_layer(group_width)
|
||||
self.drop_block2 = drop_block if drop_block is not None else None
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None
|
||||
|
||||
self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = norm_layer(planes*4)
|
||||
self.drop_block3 = drop_block if drop_block is not None else None
|
||||
self.act3 = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
@ -113,8 +110,8 @@ class ResNestBottleneck(nn.Module):
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
if self.drop_block1 is not None:
|
||||
out = self.drop_block1(out)
|
||||
if self.drop_block is not None:
|
||||
out = self.drop_block(out)
|
||||
out = self.act1(out)
|
||||
|
||||
if self.avd_first is not None:
|
||||
@ -123,8 +120,8 @@ class ResNestBottleneck(nn.Module):
|
||||
out = self.conv2(out)
|
||||
if self.bn2 is not None:
|
||||
out = self.bn2(out)
|
||||
if self.drop_block2 is not None:
|
||||
out = self.drop_block2(out)
|
||||
if self.drop_block is not None:
|
||||
out = self.drop_block(out)
|
||||
out = self.act2(out)
|
||||
|
||||
if self.avd_last is not None:
|
||||
@ -132,8 +129,8 @@ class ResNestBottleneck(nn.Module):
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
if self.drop_block3 is not None:
|
||||
out = self.drop_block3(out)
|
||||
if self.drop_block is not None:
|
||||
out = self.drop_block(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user