""" PP-HGNet (V1 & V2)

Reference:
https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/models/ImageNet1k/PP-HGNetV2.md
The Paddle Implement of PP-HGNet (https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/docs/en/models/PP-HGNet_en.md)
PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py
PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, DropPath, create_conv2d
from ._builder import build_model_with_cfg
from ._registry import register_model, generate_default_cfgs

__all__ = ['HighPerfGpuNet']


class LearnableAffineBlock(nn.Module):
    def __init__(
            self,
            scale_value=1.0,
            bias_value=0.0
    ):
        super().__init__()
        self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True)
        self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True)

    def forward(self, x):
        return self.scale * x + self.bias


class ConvBNAct(nn.Module):
    def __init__(
            self,
            in_chs,
            out_chs,
            kernel_size,
            stride=1,
            groups=1,
            padding='',
            use_act=True,
            use_lab=False
    ):
        super().__init__()
        self.use_act = use_act
        self.use_lab = use_lab
        self.conv = create_conv2d(
            in_chs,
            out_chs,
            kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
        )
        self.bn = nn.BatchNorm2d(out_chs)
        if self.use_act:
            self.act = nn.ReLU()
        else:
            self.act = nn.Identity()
        if self.use_act and self.use_lab:
            self.lab = LearnableAffineBlock()
        else:
            self.lab = nn.Identity()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        x = self.lab(x)
        return x


class LightConvBNAct(nn.Module):
    def __init__(
            self,
            in_chs,
            out_chs,
            kernel_size,
            groups=1,
            use_lab=False
    ):
        super().__init__()
        self.conv1 = ConvBNAct(
            in_chs,
            out_chs,
            kernel_size=1,
            use_act=False,
            use_lab=use_lab,
        )
        self.conv2 = ConvBNAct(
            out_chs,
            out_chs,
            kernel_size=kernel_size,
            groups=out_chs,
            use_act=True,
            use_lab=use_lab,
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class EseModule(nn.Module):
    def __init__(self, chs):
        super().__init__()
        self.conv = nn.Conv2d(
            chs,
            chs,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        identity = x
        x = x.mean((2, 3), keepdim=True)
        x = self.conv(x)
        x = self.sigmoid(x)
        return torch.mul(identity, x)


class StemV1(nn.Module):
    # for PP-HGNet
    def __init__(self, stem_chs):
        super().__init__()
        self.stem = nn.Sequential(*[
            ConvBNAct(
                stem_chs[i],
                stem_chs[i + 1],
                kernel_size=3,
                stride=2 if i == 0 else 1) for i in range(
                len(stem_chs) - 1)
        ])
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.stem(x)
        x = self.pool(x)
        return x


class StemV2(nn.Module):
    # for PP-HGNetv2
    def __init__(self, in_chs, mid_chs, out_chs, use_lab=False):
        super().__init__()
        self.stem1 = ConvBNAct(
            in_chs,
            mid_chs,
            kernel_size=3,
            stride=2,
            use_lab=use_lab,
        )
        self.stem2a = ConvBNAct(
            mid_chs,
            mid_chs // 2,
            kernel_size=2,
            stride=1,
            use_lab=use_lab,
        )
        self.stem2b = ConvBNAct(
            mid_chs // 2,
            mid_chs,
            kernel_size=2,
            stride=1,
            use_lab=use_lab,
        )
        self.stem3 = ConvBNAct(
            mid_chs * 2,
            mid_chs,
            kernel_size=3,
            stride=2,
            use_lab=use_lab,
        )
        self.stem4 = ConvBNAct(
            mid_chs,
            out_chs,
            kernel_size=1,
            stride=1,
            use_lab=use_lab,
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True)

    def forward(self, x):
        x = self.stem1(x)
        x = F.pad(x, (0, 1, 0, 1))
        x2 = self.stem2a(x)
        x2 = F.pad(x2, (0, 1, 0, 1))
        x2 = self.stem2b(x2)
        x1 = self.pool(x)
        x = torch.cat([x1, x2], dim=1)
        x = self.stem3(x)
        x = self.stem4(x)
        return x


class HighPerfGpuBlock(nn.Module):
    def __init__(
            self,
            in_chs,
            mid_chs,
            out_chs,
            layer_num,
            kernel_size=3,
            residual=False,
            light_block=False,
            use_lab=False,
            agg='ese',
            drop_path=0.,
    ):
        super().__init__()
        self.residual = residual

        self.layers = nn.ModuleList()
        for i in range(layer_num):
            if light_block:
                self.layers.append(
                    LightConvBNAct(
                        in_chs if i == 0 else mid_chs,
                        mid_chs,
                        kernel_size=kernel_size,
                        use_lab=use_lab,
                    )
                )
            else:
                self.layers.append(
                    ConvBNAct(
                        in_chs if i == 0 else mid_chs,
                        mid_chs,
                        kernel_size=kernel_size,
                        stride=1,
                        use_lab=use_lab,
                    )
                )

        # feature aggregation
        total_chs = in_chs + layer_num * mid_chs
        if agg == 'se':
            aggregation_squeeze_conv = ConvBNAct(
                total_chs,
                out_chs // 2,
                kernel_size=1,
                stride=1,
                use_lab=use_lab,
            )
            aggregation_excitation_conv = ConvBNAct(
                out_chs // 2,
                out_chs,
                kernel_size=1,
                stride=1,
                use_lab=use_lab,
            )
            self.aggregation = nn.Sequential(
                aggregation_squeeze_conv,
                aggregation_excitation_conv,
            )
        else:
            aggregation_conv = ConvBNAct(
                total_chs,
                out_chs,
                kernel_size=1,
                stride=1,
                use_lab=use_lab,
            )
            att = EseModule(out_chs)
            self.aggregation = nn.Sequential(
                aggregation_conv,
                att,
            )

        self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()

    def forward(self, x):
        identity = x
        output = [x]
        for layer in self.layers:
            x = layer(x)
            output.append(x)
        x = torch.cat(output, dim=1)
        x = self.aggregation(x)
        if self.residual:
            x = self.drop_path(x) + identity
        return x


class HighPerfGpuStage(nn.Module):
    def __init__(
            self,
            in_chs,
            mid_chs,
            out_chs,
            block_num,
            layer_num,
            downsample=True,
            stride=2,
            light_block=False,
            kernel_size=3,
            use_lab=False,
            agg='ese',
            drop_path=0.,
    ):
        super().__init__()
        self.downsample = downsample
        if downsample:
            self.downsample = ConvBNAct(
                in_chs,
                in_chs,
                kernel_size=3,
                stride=stride,
                groups=in_chs,
                use_act=False,
                use_lab=use_lab,
            )
        else:
            self.downsample = nn.Identity()

        blocks_list = []
        for i in range(block_num):
            blocks_list.append(
                HighPerfGpuBlock(
                    in_chs if i == 0 else out_chs,
                    mid_chs,
                    out_chs,
                    layer_num,
                    residual=False if i == 0 else True,
                    kernel_size=kernel_size,
                    light_block=light_block,
                    use_lab=use_lab,
                    agg=agg,
                    drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
                )
            )
        self.blocks = nn.Sequential(*blocks_list)

    def forward(self, x):
        x = self.downsample(x)
        x = self.blocks(x)
        return x


class ClassifierHead(nn.Module):
    def __init__(
            self,
            num_features,
            num_classes,
            pool_type='avg',
            drop_rate=0.,
            use_last_conv=True,
            class_expand=2048,
            use_lab=False
    ):
        super(ClassifierHead, self).__init__()
        self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=False, input_fmt='NCHW')
        if use_last_conv:
            last_conv = nn.Conv2d(
                num_features,
                class_expand,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            )
            act = nn.ReLU()
            if use_lab:
                lab = LearnableAffineBlock()
                self.last_conv = nn.Sequential(last_conv, act, lab)
            else:
                self.last_conv = nn.Sequential(last_conv, act)
        else:
            self.last_conv = nn.Indentity()

        if drop_rate > 0:
            self.dropout = nn.Dropout(drop_rate)
        else:
            self.dropout = nn.Identity()

        self.flatten = nn.Flatten()
        self.fc = nn.Linear(class_expand if use_last_conv else num_features, num_classes)

    def forward(self, x, pre_logits: bool = False):
        x = self.global_pool(x)
        x = self.last_conv(x)
        x = self.dropout(x)
        x = self.flatten(x)
        if pre_logits:
            return x
        x = self.fc(x)
        return x


class HighPerfGpuNet(nn.Module):

    def __init__(
            self,
            cfg,
            in_chans=3,
            num_classes=1000,
            global_pool='avg',
            use_last_conv=True,
            class_expand=2048,
            drop_rate=0.,
            drop_path_rate=0.,
            use_lab=False,
            **kwargs,
    ):
        super(HighPerfGpuNet, self).__init__()
        stem_type = cfg["stem_type"]
        stem_chs = cfg["stem_chs"]
        stages_cfg = [cfg["stage1"], cfg["stage2"], cfg["stage3"], cfg["stage4"]]
        self.num_classes = num_classes
        self.drop_rate = drop_rate
        self.use_last_conv = use_last_conv
        self.class_expand = class_expand
        self.use_lab = use_lab

        assert stem_type in ['v1', 'v2']
        if stem_type == 'v2':
            self.stem = StemV2(
                in_chs=in_chans,
                mid_chs=stem_chs[0],
                out_chs=stem_chs[1],
                use_lab=use_lab)
        else:
            self.stem = StemV1([in_chans] + stem_chs)

        current_stride = 4

        stages = []
        self.feature_info = []
        block_depths = [c[3] for c in stages_cfg]
        dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(block_depths)).split(block_depths)]
        for i, stage_config in enumerate(stages_cfg):
            in_chs, mid_chs, out_chs, block_num, downsample, light_block, kernel_size, layer_num = stage_config
            stages += [HighPerfGpuStage(
                in_chs=in_chs,
                mid_chs=mid_chs,
                out_chs=out_chs,
                block_num=block_num,
                layer_num=layer_num,
                downsample=downsample,
                light_block=light_block,
                kernel_size=kernel_size,
                use_lab=use_lab,
                agg='ese' if stem_type == 'v1' else 'se',
                drop_path=dpr[i],
            )]
            self.num_features = out_chs
            if downsample:
                current_stride *= 2
            self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
        self.stages = nn.Sequential(*stages)

        if num_classes > 0:
            self.head = ClassifierHead(
                self.num_features,
                num_classes=num_classes,
                pool_type=global_pool,
                drop_rate=drop_rate,
                use_last_conv=use_last_conv,
                class_expand=class_expand,
                use_lab=use_lab
            )
        else:
            if global_pool == 'avg':
                self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
            else:
                self.head = nn.Identity()

        for n, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.zeros_(m.bias)

    @torch.jit.ignore
    def group_matcher(self, coarse=False):
        return dict(
            stem=r'^stem',
            blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)',
        )

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        for s in self.stages:
            s.grad_checkpointing = enable

    @torch.jit.ignore
    def get_classifier(self):
        return self.head.fc

    def reset_classifier(self, num_classes, global_pool='avg'):
        self.num_classes = num_classes
        if num_classes > 0:
            self.head = ClassifierHead(
                self.num_features,
                num_classes=num_classes,
                pool_type=global_pool,
                drop_rate=self.drop_rate,
                use_last_conv=self.use_last_conv,
                class_expand=self.class_expand,
                use_lab=self.use_lab)
        else:
            if global_pool:
                self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
            else:
                self.head = nn.Identity()

    def forward_features(self, x):
        x = self.stem(x)
        return self.stages(x)

    def forward_head(self, x, pre_logits: bool = False):
        return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


model_cfgs = dict(
    # PP-HGNet
    hgnet_tiny={
        "stem_type": 'v1',
        "stem_chs": [48, 48, 96],
        # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
        "stage1": [96, 96, 224, 1, False, False, 3, 5],
        "stage2": [224, 128, 448, 1, True, False, 3, 5],
        "stage3": [448, 160, 512, 2, True, False, 3, 5],
        "stage4": [512, 192, 768, 1, True, False, 3, 5],
    },
    hgnet_small={
        "stem_type": 'v1',
        "stem_chs": [64, 64, 128],
        # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
        "stage1": [128, 128, 256, 1, False, False, 3, 6],
        "stage2": [256, 160, 512, 1, True, False, 3, 6],
        "stage3": [512, 192, 768, 2, True, False, 3, 6],
        "stage4": [768, 224, 1024, 1, True, False, 3, 6],
    },
    hgnet_base={
        "stem_type": 'v1',
        "stem_chs": [96, 96, 160],
        # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
        "stage1": [160, 192, 320, 1, False, False, 3, 7],
        "stage2": [320, 224, 640, 2, True, False, 3, 7],
        "stage3": [640, 256, 960, 3, True, False, 3, 7],
        "stage4": [960, 288, 1280, 2, True, False, 3, 7],
    },
    # PP-HGNetv2
    hgnetv2_b0={
        "stem_type": 'v2',
        "stem_chs": [16, 16],
        # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
        "stage1": [16, 16, 64, 1, False, False, 3, 3],
        "stage2": [64, 32, 256, 1, True, False, 3, 3],
        "stage3": [256, 64, 512, 2, True, True, 5, 3],
        "stage4": [512, 128, 1024, 1, True, True, 5, 3],
    },
    hgnetv2_b1={
        "stem_type": 'v2',
        "stem_chs": [24, 32],
        # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
        "stage1": [32, 32, 64, 1, False, False, 3, 3],
        "stage2": [64, 48, 256, 1, True, False, 3, 3],
        "stage3": [256, 96, 512, 2, True, True, 5, 3],
        "stage4": [512, 192, 1024, 1, True, True, 5, 3],
    },
    hgnetv2_b2={
        "stem_type": 'v2',
        "stem_chs": [24, 32],
        # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
        "stage1": [32, 32, 96, 1, False, False, 3, 4],
        "stage2": [96, 64, 384, 1, True, False, 3, 4],
        "stage3": [384, 128, 768, 3, True, True, 5, 4],
        "stage4": [768, 256, 1536, 1, True, True, 5, 4],
    },
    hgnetv2_b3={
        "stem_type": 'v2',
        "stem_chs": [24, 32],
        # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
        "stage1": [32, 32, 128, 1, False, False, 3, 5],
        "stage2": [128, 64, 512, 1, True, False, 3, 5],
        "stage3": [512, 128, 1024, 3, True, True, 5, 5],
        "stage4": [1024, 256, 2048, 1, True, True, 5, 5],
    },
    hgnetv2_b4={
        "stem_type": 'v2',
        "stem_chs": [32, 48],
        # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
        "stage1": [48, 48, 128, 1, False, False, 3, 6],
        "stage2": [128, 96, 512, 1, True, False, 3, 6],
        "stage3": [512, 192, 1024, 3, True, True, 5, 6],
        "stage4": [1024, 384, 2048, 1, True, True, 5, 6],
    },
    hgnetv2_b5={
        "stem_type": 'v2',
        "stem_chs": [32, 64],
        # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
        "stage1": [64, 64, 128, 1, False, False, 3, 6],
        "stage2": [128, 128, 512, 2, True, False, 3, 6],
        "stage3": [512, 256, 1024, 5, True, True, 5, 6],
        "stage4": [1024, 512, 2048, 2, True, True, 5, 6],
    },
    hgnetv2_b6={
        "stem_type": 'v2',
        "stem_chs": [48, 96],
        # in_chs, mid_chs, out_chs, blocks, downsample, light_block, kernel_size, layer_num
        "stage1": [96, 96, 192, 2, False, False, 3, 6],
        "stage2": [192, 192, 512, 3, True, False, 3, 6],
        "stage3": [512, 384, 1024, 6, True, True, 5, 6],
        "stage4": [1024, 768, 2048, 3, True, True, 5, 6],
    },
)


def _create_hgnet(variant, pretrained=False, **kwargs):
    out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
    return build_model_with_cfg(
        HighPerfGpuNet,
        variant,
        pretrained,
        model_cfg=model_cfgs[variant],
        feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
        **kwargs,
    )


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
        'crop_pct': 0.965, 'interpolation': 'bicubic',
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
        'classifier': 'head.fc', 'first_conv': 'stem.stem1.conv',
        'test_crop_pct': 1.0, 'test_input_size': (3, 288, 288),
        **kwargs,
    }


default_cfgs = generate_default_cfgs({
    'hgnet_tiny.paddle_in1k': _cfg(
        first_conv='stem.stem.0.conv',
        hf_hub_id='timm/'),
    'hgnet_tiny.ssld_in1k': _cfg(
        first_conv='stem.stem.0.conv',
        hf_hub_id='timm/'),
    'hgnet_small.paddle_in1k': _cfg(
        first_conv='stem.stem.0.conv',
        hf_hub_id='timm/'),
    'hgnet_small.ssld_in1k': _cfg(
        first_conv='stem.stem.0.conv',
        hf_hub_id='timm/'),
    'hgnet_base.ssld_in1k': _cfg(
        first_conv='stem.stem.0.conv',
        hf_hub_id='timm/'),
    'hgnetv2_b0.ssld_stage2_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b0.ssld_stage1_in22k_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b1.ssld_stage2_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b1.ssld_stage1_in22k_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b2.ssld_stage2_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b2.ssld_stage1_in22k_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b3.ssld_stage2_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b3.ssld_stage1_in22k_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b4.ssld_stage2_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b4.ssld_stage1_in22k_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b5.ssld_stage2_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b5.ssld_stage1_in22k_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b6.ssld_stage2_ft_in1k': _cfg(
        hf_hub_id='timm/'),
    'hgnetv2_b6.ssld_stage1_in22k_in1k': _cfg(
        hf_hub_id='timm/'),
})


@register_model
def hgnet_tiny(pretrained=False, **kwargs) -> HighPerfGpuNet:
    return _create_hgnet('hgnet_tiny', pretrained=pretrained, **kwargs)


@register_model
def hgnet_small(pretrained=False, **kwargs) -> HighPerfGpuNet:
    return _create_hgnet('hgnet_small', pretrained=pretrained, **kwargs)


@register_model
def hgnet_base(pretrained=False, **kwargs) -> HighPerfGpuNet:
    return _create_hgnet('hgnet_base', pretrained=pretrained, **kwargs)


@register_model
def hgnetv2_b0(pretrained=False, **kwargs) -> HighPerfGpuNet:
    return _create_hgnet('hgnetv2_b0', pretrained=pretrained, use_lab=True, **kwargs)


@register_model
def hgnetv2_b1(pretrained=False, **kwargs) -> HighPerfGpuNet:
    return _create_hgnet('hgnetv2_b1', pretrained=pretrained, use_lab=True, **kwargs)


@register_model
def hgnetv2_b2(pretrained=False, **kwargs) -> HighPerfGpuNet:
    return _create_hgnet('hgnetv2_b2', pretrained=pretrained, use_lab=True, **kwargs)


@register_model
def hgnetv2_b3(pretrained=False, **kwargs) -> HighPerfGpuNet:
    return _create_hgnet('hgnetv2_b3', pretrained=pretrained, use_lab=True, **kwargs)


@register_model
def hgnetv2_b4(pretrained=False, **kwargs) -> HighPerfGpuNet:
    return _create_hgnet('hgnetv2_b4', pretrained=pretrained, **kwargs)


@register_model
def hgnetv2_b5(pretrained=False, **kwargs) -> HighPerfGpuNet:
    return _create_hgnet('hgnetv2_b5', pretrained=pretrained, **kwargs)


@register_model
def hgnetv2_b6(pretrained=False, **kwargs) -> HighPerfGpuNet:
    return _create_hgnet('hgnetv2_b6', pretrained=pretrained, **kwargs)