Add torch.no_grad() decorator to the whole val workflow(#777)

pull/795/head
Zhiyuan Chen 2021-01-08 13:16:49 +08:00 committed by GitHub
parent 477f0c0a39
commit daab369e99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -54,6 +54,7 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
@ -63,8 +64,7 @@ class EpochBasedRunner(BaseRunner):
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
self.run_iter(data_batch, train_mode=False)
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')