# encoding: utf-8
"""
@author:  liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy

import torch
import torch.nn.functional as F
from torch import nn

from .build import META_ARCH_REGISTRY
from ..backbones import build_backbone
from ..backbones.resnet import Bottleneck
from ..heads import build_reid_heads
from ..model_utils import weights_init_kaiming
from ...layers import GeneralizedMeanPoolingP, Flatten


@META_ARCH_REGISTRY.register()
class MGN(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg
        # backbone
        backbone = build_backbone(cfg)
        self.backbone = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool,
            backbone.layer1,
            backbone.layer2,
            backbone.layer3[0]
        )
        res_conv4 = nn.Sequential(*backbone.layer3[1:])
        res_g_conv5 = backbone.layer4

        res_p_conv5 = nn.Sequential(
            Bottleneck(1024, 512, downsample=nn.Sequential(
                nn.Conv2d(1024, 2048, 1, bias=False), nn.BatchNorm2d(2048))),
            Bottleneck(2048, 512),
            Bottleneck(2048, 512))
        res_p_conv5.load_state_dict(backbone.layer4.state_dict())

        if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
            pool_layer = nn.AdaptiveAvgPool2d(1)
        elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
            pool_layer = nn.AdaptiveMaxPool2d(1)
        elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
            pool_layer = GeneralizedMeanPoolingP()
        else:
            pool_layer = nn.Identity()

        # branch1
        self.b1 = nn.Sequential(
            copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5)
        )
        self.b1_pool = self._build_pool_reduce(pool_layer)
        self.b1_head = build_reid_heads(cfg, 256, nn.Identity())

        # branch2
        self.b2 = nn.Sequential(
            copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)
        )
        self.b2_pool = self._build_pool_reduce(pool_layer)
        self.b2_head = build_reid_heads(cfg, 256, nn.Identity())

        self.b21_pool = self._build_pool_reduce(pool_layer)
        self.b21_head = build_reid_heads(cfg, 256, nn.Identity())

        self.b22_pool = self._build_pool_reduce(pool_layer)
        self.b22_head = build_reid_heads(cfg, 256, nn.Identity())

        # branch3
        self.b3 = nn.Sequential(
            copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)
        )
        self.b3_pool = self._build_pool_reduce(pool_layer)
        self.b3_head = build_reid_heads(cfg, 256, nn.Identity())

        self.b31_pool = self._build_pool_reduce(pool_layer)
        self.b31_head = build_reid_heads(cfg, 256, nn.Identity())

        self.b32_pool = self._build_pool_reduce(pool_layer)
        self.b32_head = build_reid_heads(cfg, 256, nn.Identity())

        self.b33_pool = self._build_pool_reduce(pool_layer)
        self.b33_head = build_reid_heads(cfg, 256, nn.Identity())

    def _build_pool_reduce(self, pool_layer, input_dim=2048, reduce_dim=256):
        pool_reduce = nn.Sequential(
            pool_layer,
            nn.Conv2d(input_dim, reduce_dim, 1, bias=False),
            nn.BatchNorm2d(reduce_dim),
            nn.ReLU(True),
            Flatten()
        )
        pool_reduce.apply(weights_init_kaiming)
        return pool_reduce

    def forward(self, inputs):
        images = inputs["images"]
        targets = inputs["targets"]

        if not self.training:
            pred_feat = self.inference(images)
            return pred_feat, targets, inputs["camid"]

        features = self.backbone(images)  # (bs, 2048, 16, 8)

        # branch1
        b1_feat = self.b1(features)
        b1_pool_feat = self.b1_pool(b1_feat)
        b1_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets)

        # branch2
        b2_feat = self.b2(features)
        # global
        b2_pool_feat = self.b2_pool(b2_feat)
        b2_logits, b2_pool_feat = self.b2_head(b2_pool_feat, targets)

        b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
        # part1
        b21_pool_feat = self.b21_pool(b21_feat)
        b21_logits, b21_pool_feat = self.b21_head(b21_pool_feat, targets)
        # part2
        b22_pool_feat = self.b22_pool(b22_feat)
        b22_logits, b22_pool_feat = self.b22_head(b22_pool_feat, targets)

        # branch3
        b3_feat = self.b3(features)
        # global
        b3_pool_feat = self.b3_pool(b3_feat)
        b3_logits, b3_pool_feat = self.b3_head(b3_pool_feat, targets)

        b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
        # part1
        b31_pool_feat = self.b31_pool(b31_feat)
        b31_logits, b31_pool_feat = self.b31_head(b31_pool_feat, targets)
        # part2
        b32_pool_feat = self.b32_pool(b32_feat)
        b32_logits, b32_pool_feat = self.b32_head(b32_pool_feat, targets)
        # part3
        b33_pool_feat = self.b33_pool(b33_feat)
        b33_logits, b33_pool_feat = self.b33_head(b33_pool_feat, targets)

        return (b1_logits, b2_logits, b3_logits, b21_logits, b22_logits, b31_logits, b32_logits, b33_logits), \
               (b1_pool_feat, b2_pool_feat, b3_pool_feat), \
               targets

    def inference(self, images):
        assert not self.training
        features = self.backbone(images)  # (bs, 2048, 16, 8)

        # branch1
        b1_feat = self.b1(features)
        b1_pool_feat = self.b1_pool(b1_feat)
        b1_pool_feat = self.b1_head(b1_pool_feat)

        # branch2
        b2_feat = self.b2(features)
        # global
        b2_pool_feat = self.b2_pool(b2_feat)
        b2_pool_feat = self.b2_head(b2_pool_feat)

        b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
        # part1
        b21_pool_feat = self.b21_pool(b21_feat)
        b21_pool_feat = self.b21_head(b21_pool_feat)
        # part2
        b22_pool_feat = self.b22_pool(b22_feat)
        b22_pool_feat = self.b22_head(b22_pool_feat)

        # branch3
        b3_feat = self.b3(features)
        # global
        b3_pool_feat = self.b3_pool(b3_feat)
        b3_pool_feat = self.b3_head(b3_pool_feat)

        b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
        # part1
        b31_pool_feat = self.b31_pool(b31_feat)
        b31_pool_feat = self.b31_head(b31_pool_feat)
        # part2
        b32_pool_feat = self.b32_pool(b32_feat)
        b32_pool_feat = self.b32_head(b32_pool_feat)
        # part3
        b33_pool_feat = self.b33_pool(b33_feat)
        b33_pool_feat = self.b33_head(b33_pool_feat)

        pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
                               b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)

        return F.normalize(pred_feat)

    def losses(self, outputs):
        loss_dict = {}
        loss_dict.update(self.b1_head.losses(self._cfg, outputs[0][0], outputs[1][0], outputs[2], 'b1_'))
        loss_dict.update(self.b2_head.losses(self._cfg, outputs[0][1], outputs[1][1], outputs[2], 'b2_'))
        loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][2], outputs[1][2], outputs[2], 'b3_'))
        loss_dict.update(self.b2_head.losses(self._cfg, outputs[0][3], None, outputs[2], 'b21_'))
        loss_dict.update(self.b2_head.losses(self._cfg, outputs[0][4], None, outputs[2], 'b22_'))
        loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][5], None, outputs[2], 'b31_'))
        loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][6], None, outputs[2], 'b32_'))
        loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][7], None, outputs[2], 'b33_'))
        return loss_dict