"""
MambaOut models for image classification.
Some implementations are modified from:
timm (https://github.com/rwightman/pytorch-image-models),
MetaFormer (https://github.com/sail-sg/metaformer),
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
"""
from collections import OrderedDict
from typing import Optional

import torch
from torch import nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model


class Stem(nn.Module):
    r""" Code modified from InternImage:
        https://github.com/OpenGVLab/InternImage
    """

    def __init__(
            self,
            in_chs=3,
            out_chs=96,
            mid_norm: bool = True,
            act_layer=nn.GELU,
            norm_layer=LayerNorm,
    ):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_chs,
            out_chs // 2,
            kernel_size=3,
            stride=2,
            padding=1
        )
        self.norm1 = norm_layer(out_chs // 2) if mid_norm else None
        self.act = act_layer()
        self.conv2 = nn.Conv2d(
            out_chs // 2,
            out_chs,
            kernel_size=3,
            stride=2,
            padding=1
        )
        self.norm2 = norm_layer(out_chs)

    def forward(self, x):
        x = self.conv1(x)
        if self.norm1 is not None:
            x = x.permute(0, 2, 3, 1)
            x = self.norm1(x)
            x = x.permute(0, 3, 1, 2)
        x = self.act(x)
        x = self.conv2(x)
        x = x.permute(0, 2, 3, 1)
        x = self.norm2(x)
        return x


class DownsampleNormFirst(nn.Module):

    def __init__(
            self,
            in_chs=96,
            out_chs=198,
            norm_layer=LayerNorm,
    ):
        super().__init__()
        self.norm = norm_layer(in_chs)
        self.conv = nn.Conv2d(
            in_chs,
            out_chs,
            kernel_size=3,
            stride=2,
            padding=1
        )

    def forward(self, x):
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)
        x = self.conv(x)
        x = x.permute(0, 2, 3, 1)
        return x


class Downsample(nn.Module):

    def __init__(
            self,
            in_chs=96,
            out_chs=198,
            norm_layer=LayerNorm,
    ):
        super().__init__()
        self.conv = nn.Conv2d(
            in_chs,
            out_chs,
            kernel_size=3,
            stride=2,
            padding=1
        )
        self.norm = norm_layer(out_chs)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = self.conv(x)
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        return x


class MlpHead(nn.Module):
    """ MLP classification head
    """

    def __init__(
            self,
            in_features,
            num_classes=1000,
            pool_type='avg',
            act_layer=nn.GELU,
            mlp_ratio=4,
            norm_layer=LayerNorm,
            drop_rate=0.,
            bias=True,
    ):
        super().__init__()
        if mlp_ratio is not None:
            hidden_size = int(mlp_ratio * in_features)
        else:
            hidden_size = None
        self.pool_type = pool_type
        self.in_features = in_features
        self.hidden_size = hidden_size or in_features

        self.norm = norm_layer(in_features)
        if hidden_size:
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(in_features, hidden_size)),
                ('act', act_layer()),
                ('norm', norm_layer(hidden_size))
            ]))
            self.num_features = hidden_size
        else:
            self.num_features = in_features
            self.pre_logits = nn.Identity()

        self.fc = nn.Linear(hidden_size, num_classes, bias=bias)
        self.head_dropout = nn.Dropout(drop_rate)

    def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
        if pool_type is not None:
            self.pool_type = pool_type
        if reset_other:
            self.norm = nn.Identity()
            self.pre_logits = nn.Identity()
            self.num_features = self.in_features
        self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

    def forward(self, x, pre_logits: bool = False):
        if self.pool_type == 'avg':
            x = x.mean((1, 2))
        x = self.norm(x)
        x = self.pre_logits(x)
        x = self.head_dropout(x)
        if pre_logits:
            return x
        x = self.fc(x)
        return x


class GatedConvBlock(nn.Module):
    r""" Our implementation of Gated CNN Block: https://arxiv.org/pdf/1612.08083
    Args:
        conv_ratio: control the number of channels to conduct depthwise convolution.
            Conduct convolution on partial channels can improve paraitcal efficiency.
            The idea of partial channels is from ShuffleNet V2 (https://arxiv.org/abs/1807.11164) and
            also used by InceptionNeXt (https://arxiv.org/abs/2303.16900) and FasterNet (https://arxiv.org/abs/2303.03667)
    """

    def __init__(
            self,
            dim,
            expansion_ratio=8 / 3,
            kernel_size=7,
            conv_ratio=1.0,
            ls_init_value=None,
            norm_layer=LayerNorm,
            act_layer=nn.GELU,
            drop_path=0.,
            **kwargs
    ):
        super().__init__()
        self.norm = norm_layer(dim)
        hidden = int(expansion_ratio * dim)
        self.fc1 = nn.Linear(dim, hidden * 2)
        self.act = act_layer()
        conv_channels = int(conv_ratio * dim)
        self.split_indices = (hidden, hidden - conv_channels, conv_channels)
        self.conv = nn.Conv2d(
            conv_channels,
            conv_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=conv_channels
        )
        self.fc2 = nn.Linear(hidden, dim)
        self.ls = LayerScale(dim) if ls_init_value is not None else nn.Identity()
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        shortcut = x  # [B, H, W, C]
        x = self.norm(x)
        x = self.fc1(x)
        g, i, c = torch.split(x, self.split_indices, dim=-1)
        c = c.permute(0, 3, 1, 2)  # [B, H, W, C] -> [B, C, H, W]
        c = self.conv(c)
        c = c.permute(0, 2, 3, 1)  # [B, C, H, W] -> [B, H, W, C]
        x = self.fc2(self.act(g) * torch.cat((i, c), dim=-1))
        x = self.ls(x)
        x = self.drop_path(x)
        return x + shortcut


class MambaOutStage(nn.Module):

    def __init__(
            self,
            dim,
            dim_out: Optional[int] = None,
            depth: int = 4,
            expansion_ratio=8 / 3,
            kernel_size=7,
            conv_ratio=1.0,
            downsample: str = '',
            ls_init_value: Optional[float] = None,
            norm_layer=LayerNorm,
            act_layer=nn.GELU,
            drop_path=0.,
    ):
        super().__init__()
        dim_out = dim_out or dim
        self.grad_checkpointing = False

        if downsample == 'conv':
            self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer)
        elif downsample == 'conv_nf':
            self.downsample = DownsampleNormFirst(dim, dim_out, norm_layer=norm_layer)
        else:
            assert dim == dim_out
            self.downsample = nn.Identity()

        self.blocks = nn.Sequential(*[
            GatedConvBlock(
                dim=dim_out,
                expansion_ratio=expansion_ratio,
                kernel_size=kernel_size,
                conv_ratio=conv_ratio,
                ls_init_value=ls_init_value,
                norm_layer=norm_layer,
                act_layer=act_layer,
                drop_path=drop_path[j] if isinstance(drop_path, (list, tuple)) else drop_path,
            )
            for j in range(depth)
        ])

    def forward(self, x):
        x = self.downsample(x)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)
        return x


class MambaOut(nn.Module):
    r""" MetaFormer
        A PyTorch impl of : `MetaFormer Baselines for Vision`  -
          https://arxiv.org/abs/2210.13452

    Args:
        in_chans (int): Number of input image channels. Default: 3.
        num_classes (int): Number of classes for classification head. Default: 1000.
        depths (list or tuple): Number of blocks at each stage. Default: [3, 3, 9, 3].
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 576].
        downsample_layers: (list or tuple): Downsampling layers before each stage.
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6).
        head_fn: classification head. Default: nn.Linear.
        head_dropout (float): dropout for MLP classifier. Default: 0.
    """

    def __init__(
            self,
            in_chans=3,
            num_classes=1000,
            depths=(3, 3, 9, 3),
            dims=(96, 192, 384, 576),
            norm_layer=LayerNorm,
            act_layer=nn.GELU,
            conv_ratio=1.0,
            expansion_ratio=8/3,
            kernel_size=7,
            stem_mid_norm=True,
            ls_init_value=None,
            downsample='conv',
            drop_path_rate=0.,
            drop_rate=0.,
            head_fn='default',
            **kwargs,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.drop_rate = drop_rate
        self.output_fmt = 'NHWC'
        if not isinstance(depths, (list, tuple)):
            depths = [depths]  # it means the model has only one stage
        if not isinstance(dims, (list, tuple)):
            dims = [dims]
        act_layer = get_act_layer(act_layer)

        num_stage = len(depths)
        self.num_stage = num_stage
        self.feature_info = []

        self.stem = Stem(
            in_chans,
            dims[0],
            mid_norm=stem_mid_norm,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )
        prev_dim = dims[0]
        dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
        cur = 0
        curr_stride = 4
        self.stages = nn.Sequential()
        for i in range(num_stage):
            dim = dims[i]
            stride = 2 if curr_stride == 2 or i > 0 else 1
            curr_stride *= stride
            stage = MambaOutStage(
                dim=prev_dim,
                dim_out=dim,
                depth=depths[i],
                kernel_size=kernel_size,
                conv_ratio=conv_ratio,
                expansion_ratio=expansion_ratio,
                downsample=downsample if i > 0 else '',
                ls_init_value=ls_init_value,
                norm_layer=norm_layer,
                act_layer=act_layer,
                drop_path=dp_rates[i],
            )
            self.stages.append(stage)
            prev_dim = dim
            # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
            self.feature_info += [dict(num_chs=prev_dim, reduction=curr_stride, module=f'stages.{i}')]
            cur += depths[i]

        if head_fn == 'default':
            # specific to this model, unusual norm -> pool -> fc -> act -> norm -> fc combo
            self.head = MlpHead(
                prev_dim,
                num_classes,
                pool_type='avg',
                drop_rate=drop_rate,
                norm_layer=norm_layer,
            )
        else:
            # more typical norm -> pool -> fc -> act -> fc
            self.head = ClNormMlpClassifierHead(
                prev_dim,
                num_classes,
                hidden_size=int(prev_dim * 4),
                pool_type='avg',
                norm_layer=norm_layer,
                drop_rate=drop_rate,
            )
        self.num_features = prev_dim
        self.head_hidden_size = self.head.num_features

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    @torch.jit.ignore
    def group_matcher(self, coarse=False):
        return dict(
            stem=r'^stem',
            blocks=r'^stages\.(\d+)' if coarse else [
                (r'^stages\.(\d+)\.downsample', (0,)),  # blocks
                (r'^stages\.(\d+)\.blocks\.(\d+)', None),
            ]
        )

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        for s in self.stages:
            s.grad_checkpointing = enable

    @torch.jit.ignore
    def get_classifier(self) -> nn.Module:
        return self.head.fc

    def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
        self.num_classes = num_classes
        self.head.reset(num_classes, global_pool)

    def forward_features(self, x):
        x = self.stem(x)
        x = self.stages(x)
        return x

    def forward_head(self, x, pre_logits: bool = False):
        x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


def checkpoint_filter_fn(state_dict, model):
    if 'model' in state_dict:
        state_dict = state_dict['model']

    import re
    out_dict = {}
    for k, v in state_dict.items():
        k = k.replace('downsample_layers.0.', 'stem.')
        k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
        k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
        # remap head names
        if k.startswith('norm.'):
            # this is moving to head since it's after the pooling
            k = k.replace('norm.', 'head.norm.')
        elif k.startswith('head.'):
            k = k.replace('head.fc1.', 'head.pre_logits.fc.')
            k = k.replace('head.norm.', 'head.pre_logits.norm.')
            k = k.replace('head.fc2.', 'head.fc.')
        out_dict[k] = v

    return out_dict


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
        'crop_pct': 1.0, 'interpolation': 'bicubic',
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head.fc',
        **kwargs
    }


default_cfgs = {
    'mambaout_femto': _cfg(
        url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_femto.pth'),
    'mambaout_kobe': _cfg(
        url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_kobe.pth'),
    'mambaout_tiny': _cfg(
        url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_tiny.pth'),
    'mambaout_small': _cfg(
        url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_small.pth'),
    'mambaout_base': _cfg(
        url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth'),
    'mambaout_small_rw': _cfg(),
    'mambaout_base_slim_rw': _cfg(),
    'mambaout_base_plus_rw': _cfg(),
    'test_mambaout': _cfg(input_size=(3, 160, 160), pool_size=(5, 5)),
}


def _create_mambaout(variant, pretrained=False, **kwargs):
    model = build_model_with_cfg(
        MambaOut, variant, pretrained,
        pretrained_filter_fn=checkpoint_filter_fn,
        feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
        **kwargs,
    )
    return model


# a series of MambaOut models
@register_model
def mambaout_femto(pretrained=False, **kwargs):
    model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 192, 288))
    return _create_mambaout('mambaout_femto', pretrained=pretrained, **dict(model_args, **kwargs))

# Kobe Memorial Version with 24 Gated CNN blocks
@register_model
def mambaout_kobe(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 15, 3], dims=[48, 96, 192, 288])
    return _create_mambaout('mambaout_kobe', pretrained=pretrained, **dict(model_args, **kwargs))

@register_model
def mambaout_tiny(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 576])
    return _create_mambaout('mambaout_tiny', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def mambaout_small(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 4, 27, 3], dims=[96, 192, 384, 576])
    return _create_mambaout('mambaout_small', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def mambaout_base(pretrained=False, **kwargs):
    model_args = dict(depths=[3, 4, 27, 3], dims=[128, 256, 512, 768])
    return _create_mambaout('mambaout_base', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def mambaout_small_rw(pretrained=False, **kwargs):
    model_args = dict(
        depths=[3, 4, 27, 3],
        dims=[96, 192, 384, 576],
        stem_mid_norm=False,
        downsample='conv_nf',
        ls_init_value=1e-6,
        head_fn='norm_mlp',
    )
    return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def mambaout_base_slim_rw(pretrained=False, **kwargs):
    model_args = dict(
        depths=(3, 4, 27, 3),
        dims=(128, 256, 512, 768),
        expansion_ratio=2.5,
        conv_ratio=1.25,
        stem_mid_norm=False,
        downsample='conv_nf',
        ls_init_value=1e-6,
        head_fn='norm_mlp',
    )
    return _create_mambaout('mambaout_base_slim_rw', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def mambaout_base_plus_rw(pretrained=False, **kwargs):
    model_args = dict(
        depths=(3, 4, 27, 3),
        dims=(128, 256, 512, 768),
        expansion_ratio=3.0,
        conv_ratio=1.5,
        stem_mid_norm=False,
        downsample='conv_nf',
        ls_init_value=1e-6,
        act_layer='silu',
        head_fn='norm_mlp',
    )
    return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def test_mambaout(pretrained=False, **kwargs):
    model_args = dict(
        depths=(1, 1, 3, 1),
        dims=(16, 32, 48, 64),
        expansion_ratio=3,
        stem_mid_norm=False,
        downsample='conv_nf',
        ls_init_value=1e-4,
        act_layer='silu',
        head_fn='norm_mlp',
    )
    return _create_mambaout('test_mambaout', pretrained=pretrained, **dict(model_args, **kwargs))