diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 36310900..741aafe2 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -119,10 +119,10 @@ class MMDistributedDataParallel(DistributedDataParallel): with optim_wrapper.optim_context(self): data = self.module.data_preprocessor(data, training=True) losses = self._run_forward(data, mode='loss') - if self.detect_anomalous_params: - detect_anomalous_params(losses, model=self) parsed_loss, log_vars = self.module.parse_losses(losses) optim_wrapper.update_params(parsed_loss) + if self.detect_anomalous_params: + detect_anomalous_params(parsed_loss, model=self) return log_vars def val_step(self, data: Union[dict, tuple, list]) -> list: diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index 4884b826..999f1fed 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -109,6 +109,16 @@ class TestDistributedDataParallel(MultiProcessTestCase): assert_allclose(all_grads[0], torch.zeros_like(all_grads[0])) assert_allclose(all_grads[1], torch.zeros_like(all_grads[0])) + # Test enable detect_anomalous_params. + ddp_model = MMDistributedDataParallel( + module=model, detect_anomalous_params=True) + optimizer = SGD(ddp_model.parameters(), lr=0) + optim_wrapper = AmpOptimWrapper( + optimizer=optimizer, accumulative_counts=3) + inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 + data = dict(inputs=inputs, data_sample=None) + res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss'] + def test_val_step(self): self._init_dist_env(self.rank, self.world_size) model = ToyModel()