PaddleOCR/ppocr/modeling/backbones/rec_resnetv2.py

1227 lines
38 KiB
Python

# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnetv2.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import collections.abc
from itertools import repeat
from collections import OrderedDict # pylint: disable=g-importing-member
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingUniform
from functools import partial
from typing import Union, Callable, Type, List, Tuple
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
normal_ = Normal(mean=0.0, std=0.01)
zeros_ = Constant(value=0.0)
ones_ = Constant(value=1.0)
kaiming_normal_ = KaimingUniform(nonlinearity="relu")
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
class StdConv2dSame(nn.Conv2D):
def __init__(
self,
in_channel,
out_channels,
kernel_size,
stride=1,
padding="SAME",
dilation=1,
groups=1,
bias_attr=False,
eps=1e-6,
is_export=False,
):
padding, is_dynamic = get_padding_value(
padding, kernel_size, stride=stride, dilation=dilation
)
super().__init__(
in_channel,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=bias_attr,
)
self.same_pad = is_dynamic
self.export = is_export
self.eps = eps
def forward(self, x):
if not self.training:
self.export = True
if self.same_pad:
if self.export:
x = pad_same_export(x, self._kernel_size, self._stride, self._dilation)
else:
x = pad_same(x, self._kernel_size, self._stride, self._dilation)
running_mean = paddle.to_tensor([0] * self._out_channels, dtype="float32")
running_variance = paddle.to_tensor([1] * self._out_channels, dtype="float32")
if self.export:
weight = paddle.reshape(
F.batch_norm(
self.weight.reshape([1, self._out_channels, -1]),
running_mean,
running_variance,
momentum=0.0,
epsilon=self.eps,
use_global_stats=False,
),
self.weight.shape,
)
else:
weight = paddle.reshape(
F.batch_norm(
self.weight.reshape([1, self._out_channels, -1]),
running_mean,
running_variance,
training=True,
momentum=0.0,
epsilon=self.eps,
),
self.weight.shape,
)
x = F.conv2d(
x,
weight,
self.bias,
self._stride,
self._padding,
self._dilation,
self._groups,
)
return x
class StdConv2d(nn.Conv2D):
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self,
in_channel,
out_channels,
kernel_size,
stride=1,
padding=None,
dilation=1,
groups=1,
bias=False,
eps=1e-6,
):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channel,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=bias,
)
self.eps = eps
def forward(self, x):
weight = F.batch_norm(
self.weight.reshape(1, self.out_channels, -1),
None,
None,
training=True,
momentum=0.0,
epsilon=self.eps,
).reshape_as(self.weight)
x = F.conv2d(
x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
return x
class MaxPool2dSame(nn.MaxPool2D):
"""Tensorflow like 'SAME' wrapper for 2D max pooling"""
def __init__(
self,
kernel_size: int,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
is_export=False,
):
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
self.export = is_export
super(MaxPool2dSame, self).__init__(
kernel_size, stride, (0, 0), dilation, ceil_mode
)
def forward(self, x):
if not self.training:
self.export = True
if self.export:
x = pad_same_export(x, self.ksize, self.stride, value=-float("inf"))
else:
x = pad_same(x, self.ksize, self.stride, value=-float("inf"))
return F.max_pool2d(x, self.ksize, self.stride, (0, 0), self.ceil_mode)
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == "same":
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == "valid":
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = get_padding(kernel_size, **kwargs)
return padding, dynamic
def create_pool2d(pool_type, kernel_size, stride=None, is_export=False, **kwargs):
stride = stride or kernel_size
padding = kwargs.pop("padding", "")
padding, is_dynamic = get_padding_value(
padding, kernel_size, stride=stride, **kwargs
)
if is_dynamic:
if pool_type == "avg":
return AvgPool2dSame(
kernel_size, stride=stride, is_export=is_export, **kwargs
)
elif pool_type == "max":
return MaxPool2dSame(
kernel_size, stride=stride, is_export=is_export, **kwargs
)
else:
assert False, f"Unsupported pool type {pool_type}"
def get_same_padding(x, k, s, d):
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
def get_same_padding_export(x, k, s, d):
x = paddle.to_tensor(x)
k = paddle.to_tensor(k)
s = paddle.to_tensor(s)
d = paddle.to_tensor(d)
return paddle.max((paddle.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
def pad_same_export(x, k, s, d=(1, 1), value=0):
ih, iw = x.shape[-2:]
pad_h, pad_w = get_same_padding_export(
ih, k[0], s[0], d[0]
), get_same_padding_export(iw, k[1], s[1], d[1])
pad_h = pad_h.cast(paddle.int32)
pad_w = pad_w.cast(paddle.int32)
pad_list = paddle.to_tensor(
[
(pad_w // 2),
(pad_w - pad_w // 2).cast(paddle.int32),
(pad_h // 2).cast(paddle.int32),
(pad_h - pad_h // 2).cast(paddle.int32),
]
)
if pad_h > 0 or pad_w > 0:
if len(pad_list.shape) == 2:
pad_list = pad_list.squeeze(1)
x = F.pad(x, pad_list.cast(paddle.int32), value=value)
return x
def pad_same(x, k, s, d=(1, 1), value=0, pad_h=None, pad_w=None):
ih, iw = x.shape[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(
iw, k[1], s[1], d[1]
)
if pad_h > 0 or pad_w > 0:
x = F.pad(
x,
[pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
value=value,
)
return x
class AvgPool2dSame(nn.AvgPool2D):
"""Tensorflow like 'SAME' wrapper for 2D average pooling"""
def __init__(
self,
kernel_size: int,
stride=None,
padding=0,
ceil_mode=False,
count_include_pad=True,
):
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
super(AvgPool2dSame, self).__init__(
kernel_size, stride, (0, 0), ceil_mode, count_include_pad
)
def forward(self, x):
x = pad_same(x, self.kernel_size, self.stride)
return F.avg_pool2d(
x,
self.kernel_size,
self.stride,
self.padding,
self.ceil_mode,
self.count_include_pad,
)
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None, scale_by_keep=True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def adaptive_pool_feat_mult(pool_type="avg"):
if pool_type == "catavgmax":
return 2
else:
return 1
class SelectAdaptivePool2d(nn.Layer):
"""Selectable global pooling layer with dynamic input kernel size"""
def __init__(self, output_size=1, pool_type="fast", flatten=False):
super(SelectAdaptivePool2d, self).__init__()
self.pool_type = (
pool_type or ""
) # convert other falsy values to empty string for consistent TS typing
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
if pool_type == "":
self.pool = nn.Identity() # pass through
def is_identity(self):
return not self.pool_type
def forward(self, x):
x = self.pool(x)
x = self.flatten(x)
return x
def feat_mult(self):
return adaptive_pool_feat_mult(self.pool_type)
def __repr__(self):
return (
self.__class__.__name__
+ " ("
+ "pool_type="
+ self.pool_type
+ ", flatten="
+ str(self.flatten)
+ ")"
)
def _create_pool(num_features, num_classes, pool_type="avg", use_conv=False):
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
if not pool_type:
assert (
num_classes == 0 or use_conv
), "Pooling can only be disabled if classifier is also removed or conv classifier is used"
flatten_in_pool = (
False # disable flattening if pooling is pass-through (no pooling)
)
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
num_pooled_features = num_features * global_pool.feat_mult()
return global_pool, num_pooled_features
def _create_fc(num_features, num_classes, use_conv=False):
if num_classes <= 0:
fc = nn.Identity() # pass-through (no classifier)
elif use_conv:
fc = nn.Conv2D(num_features, num_classes, 1, bias_attr=True)
else:
fc = nn.Linear(num_features, num_classes, bias_attr=True)
return fc
class ClassifierHead(nn.Layer):
"""Classifier head w/ configurable global pooling and dropout."""
def __init__(
self, in_chs, num_classes, pool_type="avg", drop_rate=0.0, use_conv=False
):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool, num_pooled_features = _create_pool(
in_chs, num_classes, pool_type, use_conv=use_conv
)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def forward(self, x):
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
x = self.flatten(x)
return x
class EvoNormBatch2d(nn.Layer):
def __init__(
self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None
):
super(EvoNormBatch2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum
self.eps = eps
self.weight = paddle.create_parameter(
paddle.ones(num_features), dtype="float32"
)
self.bias = paddle.create_parameter(paddle.zeros(num_features), dtype="float32")
self.v = (
paddle.create_parameter(paddle.ones(num_features), dtype="float32")
if apply_act
else None
)
self.register_buffer("running_var", paddle.ones([num_features]))
self.reset_parameters()
def reset_parameters(self):
ones_(self.weight)
zeros_(self.bias)
if self.apply_act:
ones_(self.v)
def forward(self, x):
x_type = x.dtype
if self.v is not None:
running_var = self.running_var.view(1, -1, 1, 1)
if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1]
running_var = var.detach() * self.momentum * (
n / (n - 1)
) + running_var * (1 - self.momentum)
self.running_var.copy_(running_var.view(self.running_var.shape))
else:
var = running_var
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
d = x * v + (
x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps
).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
x = x / d
return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
class EvoNormSample2d(nn.Layer):
def __init__(
self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None
):
super(EvoNormSample2d, self).__init__()
self.apply_act = apply_act
self.groups = groups
self.eps = eps
self.weight = paddle.create_parameter(
paddle.ones(num_features), dtype="float32"
)
self.bias = paddle.create_parameter(paddle.zeros(num_features), dtype="float32")
self.v = (
paddle.create_parameter(paddle.ones(num_features), dtype="float32")
if apply_act
else None
)
self.reset_parameters()
def reset_parameters(self):
ones_(self.weight)
zeros_(self.bias)
if self.apply_act:
ones_(self.v)
def forward(self, x):
B, C, H, W = x.shape
if self.v is not None:
n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid()
x = x.reshape(B, self.groups, -1)
x = (
n.reshape(B, self.groups, -1)
/ (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
)
x = x.reshape(B, C, H, W)
return x * self.weight.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])
class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(
self,
num_channels,
num_groups=32,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
drop_block=None,
):
super(GroupNormAct, self).__init__(num_groups, num_channels, epsilon=eps)
if affine:
self.weight = paddle.create_parameter([num_channels], dtype="float32")
self.bias = paddle.create_parameter([num_channels], dtype="float32")
ones_(self.weight)
zeros_(self.bias)
if act_layer is not None and apply_act:
act_args = {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def forward(self, x):
x = F.group_norm(
x,
num_groups=self._num_groups,
epsilon=self._epsilon,
weight=self.weight,
bias=self.bias,
)
x = self.act(x)
return x
class BatchNormAct2d(nn.BatchNorm2D):
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
apply_act=True,
act_layer=nn.ReLU,
drop_block=None,
):
super(BatchNormAct2d, self).__init__(
num_features,
epsilon=eps,
momentum=momentum,
use_global_stats=track_running_stats,
)
if act_layer is not None and apply_act:
act_args = dict()
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def _forward_python(self, x):
return super(BatchNormAct2d, self).forward(x)
def forward(self, x):
x = self._forward_python(x)
x = self.act(x)
return x
def adapt_input_conv(in_chans, conv_weight):
conv_type = conv_weight.dtype
conv_weight = (
conv_weight.float()
) # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape
if in_chans == 1:
if I > 3:
assert conv_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
conv_weight = conv_weight.sum(dim=2, keepdim=False)
else:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
if I != 3:
raise NotImplementedError("Weight format not supported by conversion.")
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
repeat = int(math.ceil(in_chans / 3))
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv_weight *= 3 / float(in_chans)
conv_weight = conv_weight.to(conv_type)
return conv_weight
def named_apply(
fn: Callable, module: nn.Layer, name="", depth_first=True, include_root=False
) -> nn.Layer:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(
fn=fn,
module=child_module,
name=child_name,
depth_first=depth_first,
include_root=True,
)
if depth_first and include_root:
fn(module=module, name=name)
return module
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 224, 224),
"pool_size": (7, 7),
"crop_pct": 0.875,
"interpolation": "bilinear",
"mean": IMAGENET_INCEPTION_MEAN,
"std": IMAGENET_INCEPTION_STD,
"first_conv": "stem.conv",
"classifier": "head.fc",
**kwargs,
}
def make_div(v, divisor=8):
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class PreActBottleneck(nn.Layer):
"""Pre-activation (v2) bottleneck block.
Follows the implementation of "Identity Mappings in Deep Residual Networks":
https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
Except it puts the stride on 3x3 conv when available.
"""
def __init__(
self,
in_chs,
out_chs=None,
bottle_ratio=0.25,
stride=1,
dilation=1,
first_dilation=None,
groups=1,
act_layer=None,
conv_layer=None,
norm_layer=None,
proj_layer=None,
drop_path_rate=0.0,
is_export=False,
):
super().__init__()
first_dilation = first_dilation or dilation
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
in_chs,
out_chs,
stride=stride,
dilation=dilation,
first_dilation=first_dilation,
preact=True,
conv_layer=conv_layer,
norm_layer=norm_layer,
)
else:
self.downsample = None
self.norm1 = norm_layer(in_chs)
self.conv1 = conv_layer(in_chs, mid_chs, 1, is_export=is_export)
self.norm2 = norm_layer(mid_chs)
self.conv2 = conv_layer(
mid_chs,
mid_chs,
3,
stride=stride,
dilation=first_dilation,
groups=groups,
is_export=is_export,
)
self.norm3 = norm_layer(mid_chs)
self.conv3 = conv_layer(mid_chs, out_chs, 1, is_export=is_export)
self.drop_path = (
DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
)
def zero_init_last(self):
zeros_(self.conv3.weight)
def forward(self, x):
x_preact = self.norm1(x)
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x_preact)
# residual branch
x = self.conv1(x_preact)
x = self.conv2(self.norm2(x))
x = self.conv3(self.norm3(x))
x = self.drop_path(x)
return x + shortcut
class Bottleneck(nn.Layer):
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT."""
def __init__(
self,
in_chs,
out_chs=None,
bottle_ratio=0.25,
stride=1,
dilation=1,
first_dilation=None,
groups=1,
act_layer=None,
conv_layer=None,
norm_layer=None,
proj_layer=None,
drop_path_rate=0.0,
is_export=False,
):
super().__init__()
first_dilation = first_dilation or dilation
act_layer = act_layer or nn.ReLU
conv_layer = conv_layer or StdConv2d
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
out_chs = out_chs or in_chs
mid_chs = make_div(out_chs * bottle_ratio)
if proj_layer is not None:
self.downsample = proj_layer(
in_chs,
out_chs,
stride=stride,
dilation=dilation,
preact=False,
conv_layer=conv_layer,
norm_layer=norm_layer,
is_export=is_export,
)
else:
self.downsample = None
self.conv1 = conv_layer(in_chs, mid_chs, 1, is_export=is_export)
self.norm1 = norm_layer(mid_chs)
self.conv2 = conv_layer(
mid_chs,
mid_chs,
3,
stride=stride,
dilation=first_dilation,
groups=groups,
is_export=is_export,
)
self.norm2 = norm_layer(mid_chs)
self.conv3 = conv_layer(mid_chs, out_chs, 1, is_export=is_export)
self.norm3 = norm_layer(out_chs, apply_act=False)
self.drop_path = (
DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
)
self.act3 = act_layer()
def zero_init_last(self):
zeros_(self.norm3.weight)
def forward(self, x):
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
# residual
x = self.conv1(x)
x = self.norm1(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.drop_path(x)
x = self.act3(x + shortcut)
return x
class DownsampleConv(nn.Layer):
def __init__(
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
preact=True,
conv_layer=None,
norm_layer=None,
is_export=False,
):
super(DownsampleConv, self).__init__()
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride, is_export=is_export)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
def forward(self, x):
return self.norm(self.conv(x))
class DownsampleAvg(nn.Layer):
def __init__(
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
preact=True,
conv_layer=None,
norm_layer=None,
is_export=False,
):
"""AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = (
AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2D
)
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, exclusive=False)
else:
self.pool = nn.Identity()
self.conv = conv_layer(in_chs, out_chs, 1, stride=1, is_export=is_export)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
def forward(self, x):
return self.norm(self.conv(self.pool(x)))
class ResNetStage(nn.Layer):
"""ResNet Stage."""
def __init__(
self,
in_chs,
out_chs,
stride,
dilation,
depth,
bottle_ratio=0.25,
groups=1,
avg_down=False,
block_dpr=None,
block_fn=PreActBottleneck,
is_export=False,
act_layer=None,
conv_layer=None,
norm_layer=None,
**block_kwargs,
):
super(ResNetStage, self).__init__()
first_dilation = 1 if dilation in (1, 2) else 2
layer_kwargs = dict(
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer
)
proj_layer = DownsampleAvg if avg_down else DownsampleConv
prev_chs = in_chs
self.blocks = nn.Sequential()
for block_idx in range(depth):
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.0
stride = stride if block_idx == 0 else 1
self.blocks.add_sublayer(
str(block_idx),
block_fn(
prev_chs,
out_chs,
stride=stride,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
first_dilation=first_dilation,
proj_layer=proj_layer,
drop_path_rate=drop_path_rate,
is_export=is_export,
**layer_kwargs,
**block_kwargs,
),
)
prev_chs = out_chs
first_dilation = dilation
proj_layer = None
def forward(self, x):
x = self.blocks(x)
return x
def is_stem_deep(stem_type):
return any([s in stem_type for s in ("deep", "tiered")])
def create_resnetv2_stem(
in_chs,
out_chs=64,
stem_type="",
preact=True,
conv_layer=StdConv2d,
norm_layer=partial(GroupNormAct, num_groups=32),
is_export=False,
):
stem = OrderedDict()
assert stem_type in (
"",
"fixed",
"same",
"deep",
"deep_fixed",
"deep_same",
"tiered",
)
# NOTE conv padding mode can be changed by overriding the conv_layer def
if is_stem_deep(stem_type):
# A 3 deep 3x3 conv stack as in ResNet V1D models
if "tiered" in stem_type:
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
else:
stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets
stem["conv1"] = conv_layer(
in_chs, stem_chs[0], kernel_size=3, stride=2, is_export=is_export
)
stem["norm1"] = norm_layer(stem_chs[0])
stem["conv2"] = conv_layer(
stem_chs[0], stem_chs[1], kernel_size=3, stride=1, is_export=is_export
)
stem["norm2"] = norm_layer(stem_chs[1])
stem["conv3"] = conv_layer(
stem_chs[1], out_chs, kernel_size=3, stride=1, is_export=is_export
)
if not preact:
stem["norm3"] = norm_layer(out_chs)
else:
# The usual 7x7 stem conv
stem["conv"] = conv_layer(
in_chs, out_chs, kernel_size=7, stride=2, is_export=is_export
)
if not preact:
stem["norm"] = norm_layer(out_chs)
if "fixed" in stem_type:
# 'fixed' SAME padding approximation that is used in BiT models
stem["pad"] = paddle.nn.Pad2D(
1, mode="constant", value=0.0, data_format="NCHW", name=None
)
stem["pool"] = nn.MaxPool2D(kernel_size=3, stride=2, padding=0)
elif "same" in stem_type:
# full, input size based 'SAME' padding, used in ViT Hybrid model
stem["pool"] = create_pool2d(
"max", kernel_size=3, stride=2, padding="same", is_export=is_export
)
else:
# the usual Pypaddle symmetric padding
stem["pool"] = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
stem_seq = nn.Sequential()
for key, value in stem.items():
stem_seq.add_sublayer(key, value)
return stem_seq
class ResNetV2(nn.Layer):
"""Implementation of Pre-activation (v2) ResNet mode.
Args:
x: input images with shape [N, 1, H, W]
Returns:
The extracted features [N, 1, H//16, W//16]
"""
def __init__(
self,
layers,
channels=(256, 512, 1024, 2048),
num_classes=1000,
in_chans=3,
global_pool="avg",
output_stride=32,
width_factor=1,
stem_chs=64,
stem_type="",
avg_down=False,
preact=True,
act_layer=nn.ReLU,
conv_layer=StdConv2d,
norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0.0,
drop_path_rate=0.0,
zero_init_last=False,
is_export=False,
):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.is_export = is_export
wf = width_factor
self.feature_info = []
stem_chs = make_div(stem_chs * wf)
self.stem = create_resnetv2_stem(
in_chans,
stem_chs,
stem_type,
preact,
conv_layer=conv_layer,
norm_layer=norm_layer,
is_export=is_export,
)
stem_feat = (
("stem.conv3" if is_stem_deep(stem_type) else "stem.conv")
if preact
else "stem.norm"
)
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
prev_chs = stem_chs
curr_stride = 4
dilation = 1
block_dprs = [
x.tolist()
for x in paddle.linspace(0, drop_path_rate, sum(layers)).split(layers)
]
block_fn = PreActBottleneck if preact else Bottleneck
self.stages = nn.Sequential()
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
out_chs = make_div(c * wf)
stride = 1 if stage_idx == 0 else 2
if curr_stride >= output_stride:
dilation *= stride
stride = 1
stage = ResNetStage(
prev_chs,
out_chs,
stride=stride,
dilation=dilation,
depth=d,
avg_down=avg_down,
act_layer=act_layer,
conv_layer=conv_layer,
norm_layer=norm_layer,
block_dpr=bdpr,
block_fn=block_fn,
is_export=is_export,
)
prev_chs = out_chs
curr_stride *= stride
self.feature_info += [
dict(
num_chs=prev_chs,
reduction=curr_stride,
module=f"stages.{stage_idx}",
)
]
self.stages.add_sublayer(str(stage_idx), stage)
self.num_features = prev_chs
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
use_conv=True,
)
self.init_weights(zero_init_last=zero_init_last)
def init_weights(self, zero_init_last=True):
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
def load_pretrained(self, checkpoint_path, prefix="resnet/"):
_load_weights(self, checkpoint_path, prefix)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool="avg"):
self.num_classes = num_classes
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
use_conv=True,
)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _init_weights(module: nn.Layer, name: str = "", zero_init_last=True):
if isinstance(module, nn.Linear) or (
"head.fc" in name and isinstance(module, nn.Conv2D)
):
normal_(module.weight)
zeros_(module.bias)
elif isinstance(module, nn.Conv2D):
kaiming_normal_(module.weight)
if module.bias is not None:
zeros_(module.bias)
elif isinstance(module, (nn.BatchNorm2D, nn.LayerNorm, nn.GroupNorm)):
ones_(module.weight)
zeros_(module.bias)
elif zero_init_last and hasattr(module, "zero_init_last"):
module.zero_init_last()
@paddle.no_grad()
def _load_weights(model: nn.Layer, checkpoint_path: str, prefix: str = "resnet/"):
import numpy as np
def t2p(conv_weights):
"""Possibly convert HWIO to OIHW."""
if conv_weights.ndim == 4:
conv_weights = conv_weights.transpose([3, 2, 0, 1])
return paddle.to_tensor(conv_weights)
weights = np.load(checkpoint_path)
stem_conv_w = adapt_input_conv(
model.stem.conv.weight.shape[1],
t2p(weights[f"{prefix}root_block/standardized_conv2d/kernel"]),
)
model.stem.conv.weight.copy_(stem_conv_w)
model.norm.weight.copy_(t2p(weights[f"{prefix}group_norm/gamma"]))
model.norm.bias.copy_(t2p(weights[f"{prefix}group_norm/beta"]))
if (
isinstance(getattr(model.head, "fc", None), nn.Conv2D)
and model.head.fc.weight.shape[0]
== weights[f"{prefix}head/conv2d/kernel"].shape[-1]
):
model.head.fc.weight.copy_(t2p(weights[f"{prefix}head/conv2d/kernel"]))
model.head.fc.bias.copy_(t2p(weights[f"{prefix}head/conv2d/bias"]))
for i, (sname, stage) in enumerate(model.stages.named_children()):
for j, (bname, block) in enumerate(stage.blocks.named_children()):
cname = "standardized_conv2d"
block_prefix = f"{prefix}block{i + 1}/unit{j + 1:02d}/"
block.conv1.weight.copy_(t2p(weights[f"{block_prefix}a/{cname}/kernel"]))
block.conv2.weight.copy_(t2p(weights[f"{block_prefix}b/{cname}/kernel"]))
block.conv3.weight.copy_(t2p(weights[f"{block_prefix}c/{cname}/kernel"]))
block.norm1.weight.copy_(t2p(weights[f"{block_prefix}a/group_norm/gamma"]))
block.norm2.weight.copy_(t2p(weights[f"{block_prefix}b/group_norm/gamma"]))
block.norm3.weight.copy_(t2p(weights[f"{block_prefix}c/group_norm/gamma"]))
block.norm1.bias.copy_(t2p(weights[f"{block_prefix}a/group_norm/beta"]))
block.norm2.bias.copy_(t2p(weights[f"{block_prefix}b/group_norm/beta"]))
block.norm3.bias.copy_(t2p(weights[f"{block_prefix}c/group_norm/beta"]))
if block.downsample is not None:
w = weights[f"{block_prefix}a/proj/{cname}/kernel"]
block.downsample.conv.weight.copy_(t2p(w))