mirror of https://github.com/FoundationVision/GLEE
67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
import unittest
|
|
|
|
from detectron2.solver.build import _expand_param_groups, reduce_param_groups
|
|
|
|
|
|
class TestOptimizer(unittest.TestCase):
|
|
def testExpandParamsGroups(self):
|
|
params = [
|
|
{
|
|
"params": ["p1", "p2", "p3", "p4"],
|
|
"lr": 1.0,
|
|
"weight_decay": 3.0,
|
|
},
|
|
{
|
|
"params": ["p2", "p3", "p5"],
|
|
"lr": 2.0,
|
|
"momentum": 2.0,
|
|
},
|
|
{
|
|
"params": ["p1"],
|
|
"weight_decay": 4.0,
|
|
},
|
|
]
|
|
out = _expand_param_groups(params)
|
|
gt = [
|
|
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa
|
|
dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
|
|
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
|
|
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa
|
|
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa
|
|
]
|
|
self.assertEqual(out, gt)
|
|
|
|
def testReduceParamGroups(self):
|
|
params = [
|
|
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa
|
|
dict(params=["p2", "p6"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
|
|
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
|
|
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa
|
|
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa
|
|
]
|
|
gt_groups = [
|
|
{
|
|
"lr": 1.0,
|
|
"weight_decay": 4.0,
|
|
"params": ["p1"],
|
|
},
|
|
{
|
|
"lr": 2.0,
|
|
"weight_decay": 3.0,
|
|
"momentum": 2.0,
|
|
"params": ["p2", "p6", "p3"],
|
|
},
|
|
{
|
|
"lr": 1.0,
|
|
"weight_decay": 3.0,
|
|
"params": ["p4"],
|
|
},
|
|
{
|
|
"lr": 2.0,
|
|
"momentum": 2.0,
|
|
"params": ["p5"],
|
|
},
|
|
]
|
|
out = reduce_param_groups(params)
|
|
self.assertEqual(out, gt_groups)
|