mirror of https://github.com/open-mmlab/mmcv.git
Imporve ability to view usage data during training (#1940)
* Imporve ability to view usage data during training The wrong labeling of the dataset may cause problems such as gradient explosion during the training process. The wrong labeling can be found by setting up a hook function. **Problem**: However, when a for loop is used to traverse the iterator of the pytorch `DataLoader` class, the hook cannot obtain the information of the currently read image, and thus cannot determine the source of the error. **Solution**: Load the data_batch information into the `runner` during the train process, and then pass it to the hook function to solve the problem. * Update epoch_based_runner.py * Update iter_based_runner.py * strict the scope of runner.data_batch * strict the scope of runner.data_batch * strict the scope of runner.data_batchpull/1984/head
parent
21bada3256
commit
544a6b1432
|
@ -45,10 +45,12 @@ class EpochBasedRunner(BaseRunner):
|
|||
self.call_hook('before_train_epoch')
|
||||
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||
for i, data_batch in enumerate(self.data_loader):
|
||||
self.data_batch = data_batch
|
||||
self._inner_iter = i
|
||||
self.call_hook('before_train_iter')
|
||||
self.run_iter(data_batch, train_mode=True, **kwargs)
|
||||
self.call_hook('after_train_iter')
|
||||
del self.data_batch
|
||||
self._iter += 1
|
||||
|
||||
self.call_hook('after_train_epoch')
|
||||
|
@ -62,11 +64,12 @@ class EpochBasedRunner(BaseRunner):
|
|||
self.call_hook('before_val_epoch')
|
||||
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||
for i, data_batch in enumerate(self.data_loader):
|
||||
self.data_batch = data_batch
|
||||
self._inner_iter = i
|
||||
self.call_hook('before_val_iter')
|
||||
self.run_iter(data_batch, train_mode=False)
|
||||
self.call_hook('after_val_iter')
|
||||
|
||||
del self.data_batch
|
||||
self.call_hook('after_val_epoch')
|
||||
|
||||
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
|
||||
|
|
|
@ -57,6 +57,7 @@ class IterBasedRunner(BaseRunner):
|
|||
self.data_loader = data_loader
|
||||
self._epoch = data_loader.epoch
|
||||
data_batch = next(data_loader)
|
||||
self.data_batch = data_batch
|
||||
self.call_hook('before_train_iter')
|
||||
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
|
||||
if not isinstance(outputs, dict):
|
||||
|
@ -65,6 +66,7 @@ class IterBasedRunner(BaseRunner):
|
|||
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
|
||||
self.outputs = outputs
|
||||
self.call_hook('after_train_iter')
|
||||
del self.data_batch
|
||||
self._inner_iter += 1
|
||||
self._iter += 1
|
||||
|
||||
|
@ -74,6 +76,7 @@ class IterBasedRunner(BaseRunner):
|
|||
self.mode = 'val'
|
||||
self.data_loader = data_loader
|
||||
data_batch = next(data_loader)
|
||||
self.data_batch = data_batch
|
||||
self.call_hook('before_val_iter')
|
||||
outputs = self.model.val_step(data_batch, **kwargs)
|
||||
if not isinstance(outputs, dict):
|
||||
|
@ -82,6 +85,7 @@ class IterBasedRunner(BaseRunner):
|
|||
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
|
||||
self.outputs = outputs
|
||||
self.call_hook('after_val_iter')
|
||||
del self.data_batch
|
||||
self._inner_iter += 1
|
||||
|
||||
def run(self, data_loaders, workflow, max_iters=None, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue