mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix detect_anomalous_params (#588)
This commit is contained in:
parent
1f63d2436c
commit
89146a5b90
@ -119,10 +119,10 @@ class MMDistributedDataParallel(DistributedDataParallel):
|
||||
with optim_wrapper.optim_context(self):
|
||||
data = self.module.data_preprocessor(data, training=True)
|
||||
losses = self._run_forward(data, mode='loss')
|
||||
if self.detect_anomalous_params:
|
||||
detect_anomalous_params(losses, model=self)
|
||||
parsed_loss, log_vars = self.module.parse_losses(losses)
|
||||
optim_wrapper.update_params(parsed_loss)
|
||||
if self.detect_anomalous_params:
|
||||
detect_anomalous_params(parsed_loss, model=self)
|
||||
return log_vars
|
||||
|
||||
def val_step(self, data: Union[dict, tuple, list]) -> list:
|
||||
|
@ -109,6 +109,16 @@ class TestDistributedDataParallel(MultiProcessTestCase):
|
||||
assert_allclose(all_grads[0], torch.zeros_like(all_grads[0]))
|
||||
assert_allclose(all_grads[1], torch.zeros_like(all_grads[0]))
|
||||
|
||||
# Test enable detect_anomalous_params.
|
||||
ddp_model = MMDistributedDataParallel(
|
||||
module=model, detect_anomalous_params=True)
|
||||
optimizer = SGD(ddp_model.parameters(), lr=0)
|
||||
optim_wrapper = AmpOptimWrapper(
|
||||
optimizer=optimizer, accumulative_counts=3)
|
||||
inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=None)
|
||||
res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss']
|
||||
|
||||
def test_val_step(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
model = ToyModel()
|
||||
|
Loading…
x
Reference in New Issue
Block a user