[Bugs] Fix bugs in colo optimwrapper (#1426)

pull/1428/head
whcao 2023-11-14 17:09:26 +08:00 committed by GitHub
parent 26f22ed283
commit 5a90805b1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 5 deletions

View File

@ -94,20 +94,25 @@ class ColossalAIOptimWrapper(OptimWrapper):
def __init__(self,
optimizer: torch.optim.Optimizer,
booster: Booster,
booster: Optional[Booster] = None,
accumulative_counts: int = 1):
super().__init__(optimizer, accumulative_counts=accumulative_counts)
self.booster = booster
@contextmanager
def optim_context(self, model: nn.Module):
assert isinstance(self.booster, Booster), \
'Please set the booster attribute before using ' \
'`ColossalAIOptimWrapper`.'
if self.booster.plugin.support_no_sync():
sync_context = self.booster.no_sync(model, self.optimizer)
no_sync_context = self.booster.no_sync(model, self.optimizer)
else:
yield
return
if not self.should_sync():
with sync_context:
if self.should_sync():
yield
else:
with no_sync_context:
yield
def backward(self, loss: torch.Tensor, **kwargs) -> None:
@ -305,7 +310,6 @@ class ColossalAIStrategy(BaseStrategy):
# optim_wrapper is required by booster
if optim_wrapper is not None and isinstance(optim_wrapper, dict):
optim_wrapper.setdefault('type', 'ColossalAIOptimWrapper')
optim_wrapper.setdefault('booster', self.booster)
optim_wrapper_type = OPTIM_WRAPPERS.get(optim_wrapper['type'])
if optim_wrapper_type is None:
raise ValueError(f'Failed to find {optim_wrapper["type"]} in '
@ -318,6 +322,7 @@ class ColossalAIStrategy(BaseStrategy):
'`ColossalAIOptimWrapper` (or subclass), but got '
f'{optim_wrapper_type}')
optim_wrapper = self.build_optim_wrapper(optim_wrapper, model)
optim_wrapper.booster = self.booster # type: ignore
if optim_wrapper is not None:
self.model, self.optim_wrapper = self._wrap(