[Bugs] Fix bugs in colo optimwrapper (#1426)
parent
26f22ed283
commit
5a90805b1e
|
@ -94,20 +94,25 @@ class ColossalAIOptimWrapper(OptimWrapper):
|
|||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
booster: Booster,
|
||||
booster: Optional[Booster] = None,
|
||||
accumulative_counts: int = 1):
|
||||
super().__init__(optimizer, accumulative_counts=accumulative_counts)
|
||||
self.booster = booster
|
||||
|
||||
@contextmanager
|
||||
def optim_context(self, model: nn.Module):
|
||||
assert isinstance(self.booster, Booster), \
|
||||
'Please set the booster attribute before using ' \
|
||||
'`ColossalAIOptimWrapper`.'
|
||||
if self.booster.plugin.support_no_sync():
|
||||
sync_context = self.booster.no_sync(model, self.optimizer)
|
||||
no_sync_context = self.booster.no_sync(model, self.optimizer)
|
||||
else:
|
||||
yield
|
||||
return
|
||||
if not self.should_sync():
|
||||
with sync_context:
|
||||
if self.should_sync():
|
||||
yield
|
||||
else:
|
||||
with no_sync_context:
|
||||
yield
|
||||
|
||||
def backward(self, loss: torch.Tensor, **kwargs) -> None:
|
||||
|
@ -305,7 +310,6 @@ class ColossalAIStrategy(BaseStrategy):
|
|||
# optim_wrapper is required by booster
|
||||
if optim_wrapper is not None and isinstance(optim_wrapper, dict):
|
||||
optim_wrapper.setdefault('type', 'ColossalAIOptimWrapper')
|
||||
optim_wrapper.setdefault('booster', self.booster)
|
||||
optim_wrapper_type = OPTIM_WRAPPERS.get(optim_wrapper['type'])
|
||||
if optim_wrapper_type is None:
|
||||
raise ValueError(f'Failed to find {optim_wrapper["type"]} in '
|
||||
|
@ -318,6 +322,7 @@ class ColossalAIStrategy(BaseStrategy):
|
|||
'`ColossalAIOptimWrapper` (or subclass), but got '
|
||||
f'{optim_wrapper_type}')
|
||||
optim_wrapper = self.build_optim_wrapper(optim_wrapper, model)
|
||||
optim_wrapper.booster = self.booster # type: ignore
|
||||
|
||||
if optim_wrapper is not None:
|
||||
self.model, self.optim_wrapper = self._wrap(
|
||||
|
|
Loading…
Reference in New Issue