mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add Lion optimizer (#952)
This commit is contained in:
parent
25dfe41c19
commit
fc9518e2c1
@ -243,7 +243,7 @@ As shown in the above example, `OptimWrapperDict` exports learning rates and mom
|
||||
|
||||
### Configure the OptimWapper in [Runner](runner.md)
|
||||
|
||||
We first need to configure the `optimizer` for the OptimWrapper. MMEngine automatically adds all optimizers in PyTorch to the `OPTIMIZERS` registry, and users can specify the optimizers they need in the form of a `dict`. All supported optimizers in PyTorch are listed [here](https://pytorch.org/docs/stable/optim.html#algorithms). In addition, 'DAdaptAdaGrad', 'DAdaptAdam', and 'DAdaptSGD' can be used by installing [dadaptation](https://github.com/facebookresearch/dadaptation).
|
||||
We first need to configure the `optimizer` for the OptimWrapper. MMEngine automatically adds all optimizers in PyTorch to the `OPTIMIZERS` registry, and users can specify the optimizers they need in the form of a `dict`. All supported optimizers in PyTorch are listed [here](https://pytorch.org/docs/stable/optim.html#algorithms). In addition, `DAdaptAdaGrad`, `DAdaptAdam`, and `DAdaptSGD` can be used by installing [dadaptation](https://github.com/facebookresearch/dadaptation). `Lion` optimizer can used by install [lion-pytorch](https://github.com/lucidrains/lion-pytorch)。
|
||||
|
||||
Now we take setting up a SGD OptimWrapper as an example.
|
||||
|
||||
|
@ -243,7 +243,7 @@ print(optim_dict.get_momentum()) # {'gen.momentum': [0], 'disc.momentum': [0]}
|
||||
|
||||
### 在[执行器](./runner.md)中配置优化器封装
|
||||
|
||||
优化器封装需要接受 `optimizer` 参数,因此我们首先需要为优化器封装配置 `optimizer`。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以用字典的形式来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)。
|
||||
优化器封装需要接受 `optimizer` 参数,因此我们首先需要为优化器封装配置 `optimizer`。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以用字典的形式来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)。另外,可以通过安装 [dadaptation](https://github.com/facebookresearch/dadaptation) 使用 `DAdaptAdaGrad`、`DAdaptAdam` 和 `DAdaptSGD` 3 个优化器。也可以通过安装 [lion-pytorch](https://github.com/lucidrains/lion-pytorch) 使用 `Lion` 优化器。
|
||||
|
||||
以配置一个 SGD 优化器封装为例:
|
||||
|
||||
|
@ -57,6 +57,26 @@ def register_dadaptation_optimizers() -> List[str]:
|
||||
DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers()
|
||||
|
||||
|
||||
def register_lion_optimizers() -> List[str]:
|
||||
"""Register Lion optimizer to the ``OPTIMIZERS`` registry.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of registered optimizers' name.
|
||||
"""
|
||||
optimizers = []
|
||||
try:
|
||||
from lion_pytorch import Lion
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
OPTIMIZERS.register_module(module=Lion)
|
||||
optimizers.append('Lion')
|
||||
return optimizers
|
||||
|
||||
|
||||
LION_OPTIMIZERS = register_lion_optimizers()
|
||||
|
||||
|
||||
def build_optim_wrapper(model: nn.Module,
|
||||
cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper:
|
||||
"""Build function of OptimWrapper.
|
||||
|
@ -1,4 +1,6 @@
|
||||
coverage
|
||||
dadaptation
|
||||
lion-pytorch
|
||||
lmdb
|
||||
parameterized
|
||||
pytest
|
||||
|
@ -14,6 +14,7 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||
DefaultOptimWrapperConstructor, OptimWrapper,
|
||||
build_optim_wrapper)
|
||||
from mmengine.optim.optimizer.builder import (DADAPTATION_OPTIMIZERS,
|
||||
LION_OPTIMIZERS,
|
||||
TORCH_OPTIMIZERS)
|
||||
from mmengine.registry import build_from_cfg
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
@ -34,6 +35,14 @@ def has_dadaptation() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def has_lion() -> bool:
|
||||
try:
|
||||
import lion_pytorch # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
@ -221,6 +230,10 @@ class TestBuilder(TestCase):
|
||||
assert set(dadaptation_optimizers).issubset(
|
||||
set(DADAPTATION_OPTIMIZERS))
|
||||
|
||||
@unittest.skipIf(not has_lion(), 'lion-pytorch is not installed')
|
||||
def test_lion_optimizers(self):
|
||||
assert 'Lion' in LION_OPTIMIZERS
|
||||
|
||||
def test_build_optimizer(self):
|
||||
# test build function without ``constructor`` and ``paramwise_cfg``
|
||||
optim_wrapper_cfg = dict(
|
||||
|
Loading…
x
Reference in New Issue
Block a user