[Fix] Fix detect_anomalous_params (#588)

This commit is contained in:
Mashiro 2022-10-08 19:48:35 +08:00 committed by GitHub
parent 1f63d2436c
commit 89146a5b90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 2 deletions

View File

@ -119,10 +119,10 @@ class MMDistributedDataParallel(DistributedDataParallel):
with optim_wrapper.optim_context(self):
data = self.module.data_preprocessor(data, training=True)
losses = self._run_forward(data, mode='loss')
if self.detect_anomalous_params:
detect_anomalous_params(losses, model=self)
parsed_loss, log_vars = self.module.parse_losses(losses)
optim_wrapper.update_params(parsed_loss)
if self.detect_anomalous_params:
detect_anomalous_params(parsed_loss, model=self)
return log_vars
def val_step(self, data: Union[dict, tuple, list]) -> list:

View File

@ -109,6 +109,16 @@ class TestDistributedDataParallel(MultiProcessTestCase):
assert_allclose(all_grads[0], torch.zeros_like(all_grads[0]))
assert_allclose(all_grads[1], torch.zeros_like(all_grads[0]))
# Test enable detect_anomalous_params.
ddp_model = MMDistributedDataParallel(
module=model, detect_anomalous_params=True)
optimizer = SGD(ddp_model.parameters(), lr=0)
optim_wrapper = AmpOptimWrapper(
optimizer=optimizer, accumulative_counts=3)
inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255
data = dict(inputs=inputs, data_sample=None)
res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss']
def test_val_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel()