diff --git a/docs/en/tutorials/optim_wrapper.md b/docs/en/tutorials/optim_wrapper.md index 09d548eb..b07a54b0 100644 --- a/docs/en/tutorials/optim_wrapper.md +++ b/docs/en/tutorials/optim_wrapper.md @@ -243,7 +243,7 @@ As shown in the above example, `OptimWrapperDict` exports learning rates and mom ### Configure the OptimWapper in [Runner](runner.md) -We first need to configure the `optimizer` for the OptimWrapper. MMEngine automatically adds all optimizers in PyTorch to the `OPTIMIZERS` registry, and users can specify the optimizers they need in the form of a `dict`. All supported optimizers in PyTorch are listed [here](https://pytorch.org/docs/stable/optim.html#algorithms). In addition, 'DAdaptAdaGrad', 'DAdaptAdam', and 'DAdaptSGD' can be used by installing [dadaptation](https://github.com/facebookresearch/dadaptation). +We first need to configure the `optimizer` for the OptimWrapper. MMEngine automatically adds all optimizers in PyTorch to the `OPTIMIZERS` registry, and users can specify the optimizers they need in the form of a `dict`. All supported optimizers in PyTorch are listed [here](https://pytorch.org/docs/stable/optim.html#algorithms). In addition, `DAdaptAdaGrad`, `DAdaptAdam`, and `DAdaptSGD` can be used by installing [dadaptation](https://github.com/facebookresearch/dadaptation). `Lion` optimizer can used by install [lion-pytorch](https://github.com/lucidrains/lion-pytorch)。 Now we take setting up a SGD OptimWrapper as an example. diff --git a/docs/zh_cn/tutorials/optim_wrapper.md b/docs/zh_cn/tutorials/optim_wrapper.md index 2c9c7d6b..05e70365 100644 --- a/docs/zh_cn/tutorials/optim_wrapper.md +++ b/docs/zh_cn/tutorials/optim_wrapper.md @@ -243,7 +243,7 @@ print(optim_dict.get_momentum()) # {'gen.momentum': [0], 'disc.momentum': [0]} ### 在[执行器](./runner.md)中配置优化器封装 -优化器封装需要接受 `optimizer` 参数,因此我们首先需要为优化器封装配置 `optimizer`。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以用字典的形式来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)。 +优化器封装需要接受 `optimizer` 参数,因此我们首先需要为优化器封装配置 `optimizer`。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以用字典的形式来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)。另外,可以通过安装 [dadaptation](https://github.com/facebookresearch/dadaptation) 使用 `DAdaptAdaGrad`、`DAdaptAdam` 和 `DAdaptSGD` 3 个优化器。也可以通过安装 [lion-pytorch](https://github.com/lucidrains/lion-pytorch) 使用 `Lion` 优化器。 以配置一个 SGD 优化器封装为例: diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 5f5f099c..03199ecd 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -57,6 +57,26 @@ def register_dadaptation_optimizers() -> List[str]: DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers() +def register_lion_optimizers() -> List[str]: + """Register Lion optimizer to the ``OPTIMIZERS`` registry. + + Returns: + List[str]: A list of registered optimizers' name. + """ + optimizers = [] + try: + from lion_pytorch import Lion + except ImportError: + pass + else: + OPTIMIZERS.register_module(module=Lion) + optimizers.append('Lion') + return optimizers + + +LION_OPTIMIZERS = register_lion_optimizers() + + def build_optim_wrapper(model: nn.Module, cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper: """Build function of OptimWrapper. diff --git a/requirements/tests.txt b/requirements/tests.txt index debf7eb1..5228ea39 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,6 @@ coverage +dadaptation +lion-pytorch lmdb parameterized pytest diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index 0033de3f..fa62db27 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -14,6 +14,7 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS, DefaultOptimWrapperConstructor, OptimWrapper, build_optim_wrapper) from mmengine.optim.optimizer.builder import (DADAPTATION_OPTIMIZERS, + LION_OPTIMIZERS, TORCH_OPTIMIZERS) from mmengine.registry import build_from_cfg from mmengine.testing._internal import MultiProcessTestCase @@ -34,6 +35,14 @@ def has_dadaptation() -> bool: return False +def has_lion() -> bool: + try: + import lion_pytorch # noqa: F401 + return True + except ImportError: + return False + + class ExampleModel(nn.Module): def __init__(self): @@ -221,6 +230,10 @@ class TestBuilder(TestCase): assert set(dadaptation_optimizers).issubset( set(DADAPTATION_OPTIMIZERS)) + @unittest.skipIf(not has_lion(), 'lion-pytorch is not installed') + def test_lion_optimizers(self): + assert 'Lion' in LION_OPTIMIZERS + def test_build_optimizer(self): # test build function without ``constructor`` and ``paramwise_cfg`` optim_wrapper_cfg = dict(