Imporve ability to view usage data during training (#1940)

* Imporve ability to view usage data during training

The wrong labeling of the dataset may cause problems such as gradient explosion during the training process. The wrong labeling can be found by setting up a hook function.
**Problem**: However, when a for loop is used to traverse the iterator of the pytorch `DataLoader` class, the hook cannot obtain the information of the currently read image, and thus cannot determine the source of the error.
**Solution**: Load the data_batch information into the `runner` during the train process, and then pass it to the hook function to solve the problem.

* Update epoch_based_runner.py

* Update iter_based_runner.py

* strict the scope of runner.data_batch

* strict the scope of runner.data_batch

* strict the scope of runner.data_batch
pull/1984/head
wlf-darkmatter 2022-05-25 19:33:27 +08:00 committed by GitHub
parent 21bada3256
commit 544a6b1432
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 1 deletions

View File

@ -45,10 +45,12 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
del self.data_batch
self._iter += 1
self.call_hook('after_train_epoch')
@ -62,11 +64,12 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
del self.data_batch
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):

View File

@ -57,6 +57,7 @@ class IterBasedRunner(BaseRunner):
self.data_loader = data_loader
self._epoch = data_loader.epoch
data_batch = next(data_loader)
self.data_batch = data_batch
self.call_hook('before_train_iter')
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
@ -65,6 +66,7 @@ class IterBasedRunner(BaseRunner):
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_train_iter')
del self.data_batch
self._inner_iter += 1
self._iter += 1
@ -74,6 +76,7 @@ class IterBasedRunner(BaseRunner):
self.mode = 'val'
self.data_loader = data_loader
data_batch = next(data_loader)
self.data_batch = data_batch
self.call_hook('before_val_iter')
outputs = self.model.val_step(data_batch, **kwargs)
if not isinstance(outputs, dict):
@ -82,6 +85,7 @@ class IterBasedRunner(BaseRunner):
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
self.call_hook('after_val_iter')
del self.data_batch
self._inner_iter += 1
def run(self, data_loaders, workflow, max_iters=None, **kwargs):