[Bugs] Fix bugs in colo optimwrapper (#1426)
parent
26f22ed283
commit
5a90805b1e
|
@ -94,20 +94,25 @@ class ColossalAIOptimWrapper(OptimWrapper):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
booster: Booster,
|
booster: Optional[Booster] = None,
|
||||||
accumulative_counts: int = 1):
|
accumulative_counts: int = 1):
|
||||||
super().__init__(optimizer, accumulative_counts=accumulative_counts)
|
super().__init__(optimizer, accumulative_counts=accumulative_counts)
|
||||||
self.booster = booster
|
self.booster = booster
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def optim_context(self, model: nn.Module):
|
def optim_context(self, model: nn.Module):
|
||||||
|
assert isinstance(self.booster, Booster), \
|
||||||
|
'Please set the booster attribute before using ' \
|
||||||
|
'`ColossalAIOptimWrapper`.'
|
||||||
if self.booster.plugin.support_no_sync():
|
if self.booster.plugin.support_no_sync():
|
||||||
sync_context = self.booster.no_sync(model, self.optimizer)
|
no_sync_context = self.booster.no_sync(model, self.optimizer)
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
return
|
return
|
||||||
if not self.should_sync():
|
if self.should_sync():
|
||||||
with sync_context:
|
yield
|
||||||
|
else:
|
||||||
|
with no_sync_context:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor, **kwargs) -> None:
|
def backward(self, loss: torch.Tensor, **kwargs) -> None:
|
||||||
|
@ -305,7 +310,6 @@ class ColossalAIStrategy(BaseStrategy):
|
||||||
# optim_wrapper is required by booster
|
# optim_wrapper is required by booster
|
||||||
if optim_wrapper is not None and isinstance(optim_wrapper, dict):
|
if optim_wrapper is not None and isinstance(optim_wrapper, dict):
|
||||||
optim_wrapper.setdefault('type', 'ColossalAIOptimWrapper')
|
optim_wrapper.setdefault('type', 'ColossalAIOptimWrapper')
|
||||||
optim_wrapper.setdefault('booster', self.booster)
|
|
||||||
optim_wrapper_type = OPTIM_WRAPPERS.get(optim_wrapper['type'])
|
optim_wrapper_type = OPTIM_WRAPPERS.get(optim_wrapper['type'])
|
||||||
if optim_wrapper_type is None:
|
if optim_wrapper_type is None:
|
||||||
raise ValueError(f'Failed to find {optim_wrapper["type"]} in '
|
raise ValueError(f'Failed to find {optim_wrapper["type"]} in '
|
||||||
|
@ -318,6 +322,7 @@ class ColossalAIStrategy(BaseStrategy):
|
||||||
'`ColossalAIOptimWrapper` (or subclass), but got '
|
'`ColossalAIOptimWrapper` (or subclass), but got '
|
||||||
f'{optim_wrapper_type}')
|
f'{optim_wrapper_type}')
|
||||||
optim_wrapper = self.build_optim_wrapper(optim_wrapper, model)
|
optim_wrapper = self.build_optim_wrapper(optim_wrapper, model)
|
||||||
|
optim_wrapper.booster = self.booster # type: ignore
|
||||||
|
|
||||||
if optim_wrapper is not None:
|
if optim_wrapper is not None:
|
||||||
self.model, self.optim_wrapper = self._wrap(
|
self.model, self.optim_wrapper = self._wrap(
|
||||||
|
|
Loading…
Reference in New Issue