mmsegmentation/tests/test_engine/test_optimizer.py

34 lines
922 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.optim import build_optim_wrapper
class ExampleModel(nn.Module):
def __init__(self):
super().__init__()
self.param1 = nn.Parameter(torch.ones(1))
self.conv1 = nn.Conv2d(3, 4, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(4, 2, kernel_size=1)
self.bn = nn.BatchNorm2d(2)
def forward(self, x):
return x
base_lr = 0.01
base_wd = 0.0001
momentum = 0.9
def test_build_optimizer():
model = ExampleModel()
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD', lr=base_lr, weight_decay=base_wd, momentum=momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
# test whether optimizer is successfully built from parent.
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)