[Feature] Support Adafactor Optimizer (#1361)

pull/1363/head
takuoko 2023-09-21 17:30:24 +09:00 committed by GitHub
parent 53474ef1ba
commit d617bcafdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 65 additions and 2 deletions

View File

@ -121,3 +121,35 @@ runner = Runner(
)
runner.train()
```
## transformers
[transformers](https://github.com/huggingface/transformers) provides `Adafactor` optimzier。
```{note}
If you use the optimizer provided by transformers, you need to upgrade mmengine to `0.8.5`.
```
- Installation
```bash
pip install transformers
```
- Usage
Take the `Adafactor` as an example.
```python
runner = Runner(
model=ResNet18(),
work_dir='./work_dir',
train_dataloader=train_dataloader_cfg,
# To view the input parameters for Adafactor, you can refer to
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/optimization.py#L492
optim_wrapper=dict(optimizer=dict(type='Adafactor', lr=1e-5,
weight_decay=1e-2, scale_parameter=False, relative_step=False)),
train_cfg=dict(by_epoch=True, max_epochs=3),
)
runner.train()
```

View File

@ -157,6 +157,21 @@ def register_bitsandbytes_optimizers() -> List[str]:
BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers()
def register_transformers_optimizers():
transformer_optimizers = []
try:
from transformers import Adafactor
except ImportError:
pass
else:
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
transformer_optimizers.append('Adafactor')
return transformer_optimizers
TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers()
def build_optim_wrapper(model: nn.Module,
cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper:
"""Build function of OptimWrapper.

View File

@ -11,3 +11,4 @@ neptune
parameterized
pydantic==1.10.9
pytest
transformers

View File

@ -17,7 +17,8 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
from mmengine.optim.optimizer.builder import (BITSANDBYTES_OPTIMIZERS,
DADAPTATION_OPTIMIZERS,
LION_OPTIMIZERS,
TORCH_OPTIMIZERS)
TORCH_OPTIMIZERS,
TRANSFORMERS_OPTIMIZERS)
from mmengine.registry import DefaultScope, Registry, build_from_cfg
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
@ -53,6 +54,14 @@ def has_bitsandbytes() -> bool:
return False
def has_transformers() -> bool:
try:
import transformers # noqa: F401
return True
except ImportError:
return False
class ExampleModel(nn.Module):
def __init__(self):
@ -244,7 +253,7 @@ class TestBuilder(TestCase):
def test_lion_optimizers(self):
assert 'Lion' in LION_OPTIMIZERS
@unittest.skipIf(not has_bitsandbytes(), 'dadaptation is not installed')
@unittest.skipIf(not has_bitsandbytes(), 'bitsandbytes is not installed')
def test_bitsandbytes_optimizers(self):
bitsandbytes_optimizers = [
'AdamW8bit', 'Adam8bit', 'Adagrad8bit', 'PagedAdam8bit',
@ -254,6 +263,12 @@ class TestBuilder(TestCase):
assert set(bitsandbytes_optimizers).issubset(
set(BITSANDBYTES_OPTIMIZERS))
@unittest.skipIf(not has_transformers(), 'transformers is not installed')
def test_transformers_optimizers(self):
transformers_optimizers = ['Adafactor']
assert set(transformers_optimizers).issubset(
set(TRANSFORMERS_OPTIMIZERS))
def test_build_optimizer(self):
# test build function without ``constructor`` and ``paramwise_cfg``
optim_wrapper_cfg = dict(