mirror of https://github.com/open-mmlab/mmcv.git
Change the epoch runner to use the data_loader from attributes rather than args. (#483)
parent
e92f826abc
commit
17e4732c49
|
@ -21,10 +21,10 @@ class EpochBasedRunner(BaseRunner):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
self.mode = 'train'
|
self.mode = 'train'
|
||||||
self.data_loader = data_loader
|
self.data_loader = data_loader
|
||||||
self._max_iters = self._max_epochs * len(data_loader)
|
self._max_iters = self._max_epochs * len(self.data_loader)
|
||||||
self.call_hook('before_train_epoch')
|
self.call_hook('before_train_epoch')
|
||||||
time.sleep(2) # Prevent possible deadlock during epoch transition
|
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||||
for i, data_batch in enumerate(data_loader):
|
for i, data_batch in enumerate(self.data_loader):
|
||||||
self._inner_iter = i
|
self._inner_iter = i
|
||||||
self.call_hook('before_train_iter')
|
self.call_hook('before_train_iter')
|
||||||
if self.batch_processor is None:
|
if self.batch_processor is None:
|
||||||
|
@ -52,7 +52,7 @@ class EpochBasedRunner(BaseRunner):
|
||||||
self.data_loader = data_loader
|
self.data_loader = data_loader
|
||||||
self.call_hook('before_val_epoch')
|
self.call_hook('before_val_epoch')
|
||||||
time.sleep(2) # Prevent possible deadlock during epoch transition
|
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||||
for i, data_batch in enumerate(data_loader):
|
for i, data_batch in enumerate(self.data_loader):
|
||||||
self._inner_iter = i
|
self._inner_iter = i
|
||||||
self.call_hook('before_val_iter')
|
self.call_hook('before_val_iter')
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
Loading…
Reference in New Issue