# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmcv.runner import DefaultOptimizerConstructor

from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer,
                                build_optimizer_constructor)


class ExampleModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.param1 = nn.Parameter(torch.ones(1))
        self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
        self.bn = nn.BatchNorm2d(2)

    def forward(self, x):
        return x


base_lr = 0.01
base_wd = 0.0001
momentum = 0.9


def test_build_optimizer_constructor():
    optimizer_cfg = dict(
        type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
    optim_constructor_cfg = dict(
        type='DefaultOptimizerConstructor', optimizer_cfg=optimizer_cfg)
    optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
    # Test whether optimizer constructor can be built from parent.
    assert type(optim_constructor) is DefaultOptimizerConstructor

    @OPTIMIZER_BUILDERS.register_module()
    class MyOptimizerConstructor(DefaultOptimizerConstructor):
        pass

    optim_constructor_cfg = dict(
        type='MyOptimizerConstructor', optimizer_cfg=optimizer_cfg)
    optim_constructor = build_optimizer_constructor(optim_constructor_cfg)
    # Test optimizer constructor can be built from child registry.
    assert type(optim_constructor) is MyOptimizerConstructor

    # Test unregistered constructor cannot be built
    with pytest.raises(KeyError):
        build_optimizer_constructor(dict(type='A'))


def test_build_optimizer():
    model = ExampleModel()
    optimizer_cfg = dict(
        type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum)
    optimizer = build_optimizer(model, optimizer_cfg)
    # test whether optimizer is successfully built from parent.
    assert isinstance(optimizer, torch.optim.SGD)