[Feature] Support Adafactor Optimizer (#1361)
parent
53474ef1ba
commit
d617bcafdd
|
@ -121,3 +121,35 @@ runner = Runner(
|
|||
)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
## transformers
|
||||
|
||||
[transformers](https://github.com/huggingface/transformers) provides `Adafactor` optimzier。
|
||||
|
||||
```{note}
|
||||
If you use the optimizer provided by transformers, you need to upgrade mmengine to `0.8.5`.
|
||||
```
|
||||
|
||||
- Installation
|
||||
|
||||
```bash
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
- Usage
|
||||
|
||||
Take the `Adafactor` as an example.
|
||||
|
||||
```python
|
||||
runner = Runner(
|
||||
model=ResNet18(),
|
||||
work_dir='./work_dir',
|
||||
train_dataloader=train_dataloader_cfg,
|
||||
# To view the input parameters for Adafactor, you can refer to
|
||||
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/optimization.py#L492
|
||||
optim_wrapper=dict(optimizer=dict(type='Adafactor', lr=1e-5,
|
||||
weight_decay=1e-2, scale_parameter=False, relative_step=False)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=3),
|
||||
)
|
||||
runner.train()
|
||||
```
|
||||
|
|
|
@ -157,6 +157,21 @@ def register_bitsandbytes_optimizers() -> List[str]:
|
|||
BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers()
|
||||
|
||||
|
||||
def register_transformers_optimizers():
|
||||
transformer_optimizers = []
|
||||
try:
|
||||
from transformers import Adafactor
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
|
||||
transformer_optimizers.append('Adafactor')
|
||||
return transformer_optimizers
|
||||
|
||||
|
||||
TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers()
|
||||
|
||||
|
||||
def build_optim_wrapper(model: nn.Module,
|
||||
cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper:
|
||||
"""Build function of OptimWrapper.
|
||||
|
|
|
@ -11,3 +11,4 @@ neptune
|
|||
parameterized
|
||||
pydantic==1.10.9
|
||||
pytest
|
||||
transformers
|
||||
|
|
|
@ -17,7 +17,8 @@ from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
|||
from mmengine.optim.optimizer.builder import (BITSANDBYTES_OPTIMIZERS,
|
||||
DADAPTATION_OPTIMIZERS,
|
||||
LION_OPTIMIZERS,
|
||||
TORCH_OPTIMIZERS)
|
||||
TORCH_OPTIMIZERS,
|
||||
TRANSFORMERS_OPTIMIZERS)
|
||||
from mmengine.registry import DefaultScope, Registry, build_from_cfg
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION, mmcv_full_available
|
||||
|
@ -53,6 +54,14 @@ def has_bitsandbytes() -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def has_transformers() -> bool:
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -244,7 +253,7 @@ class TestBuilder(TestCase):
|
|||
def test_lion_optimizers(self):
|
||||
assert 'Lion' in LION_OPTIMIZERS
|
||||
|
||||
@unittest.skipIf(not has_bitsandbytes(), 'dadaptation is not installed')
|
||||
@unittest.skipIf(not has_bitsandbytes(), 'bitsandbytes is not installed')
|
||||
def test_bitsandbytes_optimizers(self):
|
||||
bitsandbytes_optimizers = [
|
||||
'AdamW8bit', 'Adam8bit', 'Adagrad8bit', 'PagedAdam8bit',
|
||||
|
@ -254,6 +263,12 @@ class TestBuilder(TestCase):
|
|||
assert set(bitsandbytes_optimizers).issubset(
|
||||
set(BITSANDBYTES_OPTIMIZERS))
|
||||
|
||||
@unittest.skipIf(not has_transformers(), 'transformers is not installed')
|
||||
def test_transformers_optimizers(self):
|
||||
transformers_optimizers = ['Adafactor']
|
||||
assert set(transformers_optimizers).issubset(
|
||||
set(TRANSFORMERS_OPTIMIZERS))
|
||||
|
||||
def test_build_optimizer(self):
|
||||
# test build function without ``constructor`` and ``paramwise_cfg``
|
||||
optim_wrapper_cfg = dict(
|
||||
|
|
Loading…
Reference in New Issue