[Enhance] Unify the parameter style of DeepSpeedStrategy (#1320)
parent
a53c2802a6
commit
714c8eedc3
|
@ -219,10 +219,10 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
config for deepspeed. Defaults to None.
|
||||
zero_optimization (dict, optional): Enabling and configuring ZeRO
|
||||
memory optimizations. Defaults to None.
|
||||
gradient_clipping (float): Enable gradient clipping with value.
|
||||
Defaults to 1.0.
|
||||
gradient_clipping (float, optional): Enable gradient clipping with
|
||||
value. Defaults to None.
|
||||
fp16 (dict, optional): Configuration for using mixed precision/FP16
|
||||
training that leverages NVIDIA's Apex package.
|
||||
training that leverages NVIDIA's Apex package. Defaults to None.
|
||||
inputs_to_half (list[int or str], optional): Which inputs are to
|
||||
converted to half precision. Defaults to None.
|
||||
If ``fp16`` is enabled, it also should be set.
|
||||
|
@ -239,6 +239,12 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
offloading parameter and optimizer states to persistent (NVMe)
|
||||
storage. This module uses Linux native asynchronous I/O (libaio).
|
||||
Defaults to None.
|
||||
train_micro_batch_size_per_gpu (int, optional): Batch size to be
|
||||
processed by one GPU in one step (without gradient accumulation).
|
||||
Defaults to None.
|
||||
gradient_accumulation_steps (int, optional): Number of training steps
|
||||
to accumulate gradients before averaging and applying them.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -247,7 +253,7 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
# the following args are for deepspeed
|
||||
config: Union[str, dict, None] = None,
|
||||
zero_optimization: Optional[dict] = None,
|
||||
gradient_clipping: float = 1.0,
|
||||
gradient_clipping: Optional[float] = None,
|
||||
fp16: Optional[dict] = None,
|
||||
inputs_to_half: Optional[List[Union[int, str]]] = None,
|
||||
bf16: Optional[dict] = None,
|
||||
|
@ -255,7 +261,7 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
activation_checkpointing: Optional[dict] = None,
|
||||
aio: Optional[dict] = None,
|
||||
train_micro_batch_size_per_gpu: Optional[int] = None,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
gradient_accumulation_steps: Optional[int] = None,
|
||||
# disable the log printed by deepseed
|
||||
steps_per_print: int = 10000000000000,
|
||||
# the following args are for BaseStrategy
|
||||
|
@ -270,7 +276,8 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
self.config = self._parse_config(config)
|
||||
if zero_optimization is not None:
|
||||
self.config['zero_optimization'] = zero_optimization
|
||||
self.config['gradient_clipping'] = gradient_clipping
|
||||
if gradient_clipping is not None:
|
||||
self.config['gradient_clipping'] = gradient_clipping
|
||||
if fp16 is not None:
|
||||
self.config['fp16'] = fp16
|
||||
if bf16 is not None:
|
||||
|
@ -281,21 +288,14 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
self.config['activation_checkpointing'] = activation_checkpointing
|
||||
if aio is not None:
|
||||
self.config['aio'] = aio
|
||||
|
||||
if ('train_micro_batch_size_per_gpu' not in self.config
|
||||
and 'train_batch_size' not in self.config):
|
||||
assert train_micro_batch_size_per_gpu is not None, (
|
||||
'`train_micro_batch_size_per_gpu` or `train_batch_size` '
|
||||
'should be set!')
|
||||
self.config['train_micro_batch_size_per_gpu'] = \
|
||||
train_micro_batch_size_per_gpu
|
||||
|
||||
if train_micro_batch_size_per_gpu is not None:
|
||||
self.config['train_micro_batch_size_per_gpu'] = \
|
||||
train_micro_batch_size_per_gpu
|
||||
|
||||
self.config['gradient_accumulation_steps'] = \
|
||||
gradient_accumulation_steps
|
||||
if gradient_accumulation_steps is not None:
|
||||
self.config['gradient_accumulation_steps'] = \
|
||||
gradient_accumulation_steps
|
||||
else:
|
||||
self.config.setdefault('gradient_accumulation_steps', 1)
|
||||
self.config['steps_per_print'] = steps_per_print
|
||||
self._inputs_to_half = inputs_to_half
|
||||
|
||||
|
|
Loading…
Reference in New Issue