[Feature] Add Lion optimizer (#952)

This commit is contained in:
Zaida Zhou 2023-02-23 11:24:50 +08:00 committed by GitHub
parent 25dfe41c19
commit fc9518e2c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 37 additions and 2 deletions

View File

@ -243,7 +243,7 @@ As shown in the above example, `OptimWrapperDict` exports learning rates and mom
### Configure the OptimWapper in [Runner](runner.md)
We first need to configure the `optimizer` for the OptimWrapper. MMEngine automatically adds all optimizers in PyTorch to the `OPTIMIZERS` registry, and users can specify the optimizers they need in the form of a `dict`. All supported optimizers in PyTorch are listed [here](https://pytorch.org/docs/stable/optim.html#algorithms). In addition, 'DAdaptAdaGrad', 'DAdaptAdam', and 'DAdaptSGD' can be used by installing [dadaptation](https://github.com/facebookresearch/dadaptation).
We first need to configure the `optimizer` for the OptimWrapper. MMEngine automatically adds all optimizers in PyTorch to the `OPTIMIZERS` registry, and users can specify the optimizers they need in the form of a `dict`. All supported optimizers in PyTorch are listed [here](https://pytorch.org/docs/stable/optim.html#algorithms). In addition, `DAdaptAdaGrad`, `DAdaptAdam`, and `DAdaptSGD` can be used by installing [dadaptation](https://github.com/facebookresearch/dadaptation). `Lion` optimizer can used by install [lion-pytorch](https://github.com/lucidrains/lion-pytorch)。
Now we take setting up a SGD OptimWrapper as an example.

View File

@ -243,7 +243,7 @@ print(optim_dict.get_momentum()) # {'gen.momentum': [0], 'disc.momentum': [0]}
### 在[执行器](./runner.md)中配置优化器封装
优化器封装需要接受 `optimizer` 参数,因此我们首先需要为优化器封装配置 `optimizer`。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以用字典的形式来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)。
优化器封装需要接受 `optimizer` 参数,因此我们首先需要为优化器封装配置 `optimizer`。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以用字典的形式来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)。另外,可以通过安装 [dadaptation](https://github.com/facebookresearch/dadaptation) 使用 `DAdaptAdaGrad``DAdaptAdam``DAdaptSGD` 3 个优化器。也可以通过安装 [lion-pytorch](https://github.com/lucidrains/lion-pytorch) 使用 `Lion` 优化器。
以配置一个 SGD 优化器封装为例:

View File

@ -57,6 +57,26 @@ def register_dadaptation_optimizers() -> List[str]:
DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers()
def register_lion_optimizers() -> List[str]:
"""Register Lion optimizer to the ``OPTIMIZERS`` registry.
Returns:
List[str]: A list of registered optimizers' name.
"""
optimizers = []
try:
from lion_pytorch import Lion
except ImportError:
pass
else:
OPTIMIZERS.register_module(module=Lion)
optimizers.append('Lion')
return optimizers
LION_OPTIMIZERS = register_lion_optimizers()
def build_optim_wrapper(model: nn.Module,
cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper:
"""Build function of OptimWrapper.

View File

@ -1,4 +1,6 @@
coverage
dadaptation
lion-pytorch
lmdb
parameterized
pytest

View File

@ -14,6 +14,7 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimWrapperConstructor, OptimWrapper,
build_optim_wrapper)
from mmengine.optim.optimizer.builder import (DADAPTATION_OPTIMIZERS,
LION_OPTIMIZERS,
TORCH_OPTIMIZERS)
from mmengine.registry import build_from_cfg
from mmengine.testing._internal import MultiProcessTestCase
@ -34,6 +35,14 @@ def has_dadaptation() -> bool:
return False
def has_lion() -> bool:
try:
import lion_pytorch # noqa: F401
return True
except ImportError:
return False
class ExampleModel(nn.Module):
def __init__(self):
@ -221,6 +230,10 @@ class TestBuilder(TestCase):
assert set(dadaptation_optimizers).issubset(
set(DADAPTATION_OPTIMIZERS))
@unittest.skipIf(not has_lion(), 'lion-pytorch is not installed')
def test_lion_optimizers(self):
assert 'Lion' in LION_OPTIMIZERS
def test_build_optimizer(self):
# test build function without ``constructor`` and ``paramwise_cfg``
optim_wrapper_cfg = dict(