Use build_runner (#54)
* Use build_runner in train api * Support iter in eval_hook * Add runner section * Add test_eval_hook * Pin mmcv version in install docs * Replace max_iters with max_epochs * Set by_epoch=True as default * Remove trailing space * Replace DeprecationWarning with UserWarning * pre-commit * Fix testspull/69/head^2
parent
f7a916f309
commit
4f4f2957ef
|
@ -3,4 +3,4 @@ optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
|
||||||
optimizer_config = dict(grad_clip=None)
|
optimizer_config = dict(grad_clip=None)
|
||||||
# learning policy
|
# learning policy
|
||||||
lr_config = dict(policy='step', step=[100, 150])
|
lr_config = dict(policy='step', step=[100, 150])
|
||||||
total_epochs = 200
|
runner = dict(type='EpochBasedRunner', max_epochs=200)
|
||||||
|
|
|
@ -14,4 +14,4 @@ lr_config = dict(
|
||||||
warmup='constant',
|
warmup='constant',
|
||||||
warmup_iters=5000,
|
warmup_iters=5000,
|
||||||
)
|
)
|
||||||
total_epochs = 300
|
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
||||||
|
|
|
@ -9,4 +9,4 @@ lr_config = dict(
|
||||||
warmup_iters=2500,
|
warmup_iters=2500,
|
||||||
warmup_ratio=0.25,
|
warmup_ratio=0.25,
|
||||||
step=[30, 60, 90])
|
step=[30, 60, 90])
|
||||||
total_epochs = 100
|
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
||||||
|
|
|
@ -9,4 +9,4 @@ lr_config = dict(
|
||||||
warmup='linear',
|
warmup='linear',
|
||||||
warmup_iters=2500,
|
warmup_iters=2500,
|
||||||
warmup_ratio=0.25)
|
warmup_ratio=0.25)
|
||||||
total_epochs = 100
|
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
||||||
|
|
|
@ -3,4 +3,4 @@ optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
|
||||||
optimizer_config = dict(grad_clip=None)
|
optimizer_config = dict(grad_clip=None)
|
||||||
# learning policy
|
# learning policy
|
||||||
lr_config = dict(policy='step', step=[30, 60, 90])
|
lr_config = dict(policy='step', step=[30, 60, 90])
|
||||||
total_epochs = 100
|
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
||||||
|
|
|
@ -3,4 +3,4 @@ optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
|
||||||
optimizer_config = dict(grad_clip=None)
|
optimizer_config = dict(grad_clip=None)
|
||||||
# learning policy
|
# learning policy
|
||||||
lr_config = dict(policy='step', step=[40, 80, 120])
|
lr_config = dict(policy='step', step=[40, 80, 120])
|
||||||
total_epochs = 140
|
runner = dict(type='EpochBasedRunner', max_epochs=140)
|
||||||
|
|
|
@ -3,4 +3,4 @@ optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
|
||||||
optimizer_config = dict(grad_clip=None)
|
optimizer_config = dict(grad_clip=None)
|
||||||
# learning policy
|
# learning policy
|
||||||
lr_config = dict(policy='CosineAnnealing', min_lr=0)
|
lr_config = dict(policy='CosineAnnealing', min_lr=0)
|
||||||
total_epochs = 100
|
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
||||||
|
|
|
@ -3,4 +3,4 @@ optimizer = dict(type='SGD', lr=0.045, momentum=0.9, weight_decay=0.00004)
|
||||||
optimizer_config = dict(grad_clip=None)
|
optimizer_config = dict(grad_clip=None)
|
||||||
# learning policy
|
# learning policy
|
||||||
lr_config = dict(policy='step', gamma=0.98, step=1)
|
lr_config = dict(policy='step', gamma=0.98, step=1)
|
||||||
total_epochs = 300
|
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
||||||
|
|
|
@ -47,7 +47,7 @@ log_config = dict(
|
||||||
])
|
])
|
||||||
# yapf:enable
|
# yapf:enable
|
||||||
# runtime settings
|
# runtime settings
|
||||||
total_epochs = 20
|
runner = dict(type='EpochBasedRunner', max_epochs=20)
|
||||||
dist_params = dict(backend='nccl')
|
dist_params = dict(backend='nccl')
|
||||||
log_level = 'INFO'
|
log_level = 'INFO'
|
||||||
work_dir = './work_dirs/mnist/'
|
work_dir = './work_dirs/mnist/'
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
- Python 3.6+
|
- Python 3.6+
|
||||||
- PyTorch 1.3+
|
- PyTorch 1.3+
|
||||||
- [mmcv](https://github.com/open-mmlab/mmcv)
|
- [mmcv](https://github.com/open-mmlab/mmcv) 1.1.4+
|
||||||
|
|
||||||
|
|
||||||
### Install mmclassification
|
### Install mmclassification
|
||||||
|
|
|
@ -75,7 +75,7 @@ optimizer_config = dict(grad_clip=None)
|
||||||
lr_config = dict(
|
lr_config = dict(
|
||||||
policy='step',
|
policy='step',
|
||||||
step=[15])
|
step=[15])
|
||||||
total_epochs = 20
|
runner = dict(type='EpochBasedRunner', max_epochs=200)
|
||||||
log_config = dict(interval=100)
|
log_config = dict(interval=100)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import random
|
import random
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||||
from mmcv.runner import DistSamplerSeedHook, EpochBasedRunner, build_optimizer
|
from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner
|
||||||
|
|
||||||
from mmcls.core import (DistEvalHook, DistOptimizerHook, EvalHook,
|
from mmcls.core import (DistEvalHook, DistOptimizerHook, EvalHook,
|
||||||
Fp16OptimizerHook)
|
Fp16OptimizerHook)
|
||||||
|
@ -70,12 +71,26 @@ def train_model(model,
|
||||||
|
|
||||||
# build runner
|
# build runner
|
||||||
optimizer = build_optimizer(model, cfg.optimizer)
|
optimizer = build_optimizer(model, cfg.optimizer)
|
||||||
runner = EpochBasedRunner(
|
|
||||||
model,
|
if cfg.get('runner') is None:
|
||||||
optimizer=optimizer,
|
cfg.runner = {
|
||||||
work_dir=cfg.work_dir,
|
'type': 'EpochBasedRunner',
|
||||||
logger=logger,
|
'max_epochs': cfg.total_epochs
|
||||||
meta=meta)
|
}
|
||||||
|
warnings.warn(
|
||||||
|
'config is now expected to have a `runner` section, '
|
||||||
|
'please set `runner` in your config.', UserWarning)
|
||||||
|
|
||||||
|
runner = build_runner(
|
||||||
|
cfg.runner,
|
||||||
|
default_args=dict(
|
||||||
|
model=model,
|
||||||
|
batch_processor=None,
|
||||||
|
optimizer=optimizer,
|
||||||
|
work_dir=cfg.work_dir,
|
||||||
|
logger=logger,
|
||||||
|
meta=meta))
|
||||||
|
|
||||||
# an ugly walkaround to make the .log and .log.json filenames the same
|
# an ugly walkaround to make the .log and .log.json filenames the same
|
||||||
runner.timestamp = timestamp
|
runner.timestamp = timestamp
|
||||||
|
|
||||||
|
@ -107,6 +122,7 @@ def train_model(model,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
round_up=False)
|
round_up=False)
|
||||||
eval_cfg = cfg.get('evaluation', {})
|
eval_cfg = cfg.get('evaluation', {})
|
||||||
|
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
||||||
eval_hook = DistEvalHook if distributed else EvalHook
|
eval_hook = DistEvalHook if distributed else EvalHook
|
||||||
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
||||||
|
|
||||||
|
@ -114,4 +130,4 @@ def train_model(model,
|
||||||
runner.resume(cfg.resume_from)
|
runner.resume(cfg.resume_from)
|
||||||
elif cfg.load_from:
|
elif cfg.load_from:
|
||||||
runner.load_checkpoint(cfg.load_from)
|
runner.load_checkpoint(cfg.load_from)
|
||||||
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
|
runner.run(data_loaders, cfg.workflow)
|
||||||
|
|
|
@ -12,21 +12,30 @@ class EvalHook(Hook):
|
||||||
interval (int): Evaluation interval (by epochs). Default: 1.
|
interval (int): Evaluation interval (by epochs). Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataloader, interval=1, **eval_kwargs):
|
def __init__(self, dataloader, interval=1, by_epoch=True, **eval_kwargs):
|
||||||
if not isinstance(dataloader, DataLoader):
|
if not isinstance(dataloader, DataLoader):
|
||||||
raise TypeError('dataloader must be a pytorch DataLoader, but got'
|
raise TypeError('dataloader must be a pytorch DataLoader, but got'
|
||||||
f' {type(dataloader)}')
|
f' {type(dataloader)}')
|
||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.eval_kwargs = eval_kwargs
|
self.eval_kwargs = eval_kwargs
|
||||||
|
self.by_epoch = by_epoch
|
||||||
|
|
||||||
def after_train_epoch(self, runner):
|
def after_train_epoch(self, runner):
|
||||||
if not self.every_n_epochs(runner, self.interval):
|
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
||||||
return
|
return
|
||||||
from mmcls.apis import single_gpu_test
|
from mmcls.apis import single_gpu_test
|
||||||
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
||||||
self.evaluate(runner, results)
|
self.evaluate(runner, results)
|
||||||
|
|
||||||
|
def after_train_iter(self, runner):
|
||||||
|
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
||||||
|
return
|
||||||
|
from mmcls.apis import single_gpu_test
|
||||||
|
runner.log_buffer.clear()
|
||||||
|
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
||||||
|
self.evaluate(runner, results)
|
||||||
|
|
||||||
def evaluate(self, runner, results):
|
def evaluate(self, runner, results):
|
||||||
eval_res = self.dataloader.dataset.evaluate(
|
eval_res = self.dataloader.dataset.evaluate(
|
||||||
results, logger=runner.logger, **self.eval_kwargs)
|
results, logger=runner.logger, **self.eval_kwargs)
|
||||||
|
@ -51,6 +60,7 @@ class DistEvalHook(EvalHook):
|
||||||
dataloader,
|
dataloader,
|
||||||
interval=1,
|
interval=1,
|
||||||
gpu_collect=False,
|
gpu_collect=False,
|
||||||
|
by_epoch=True,
|
||||||
**eval_kwargs):
|
**eval_kwargs):
|
||||||
if not isinstance(dataloader, DataLoader):
|
if not isinstance(dataloader, DataLoader):
|
||||||
raise TypeError('dataloader must be a pytorch DataLoader, but got '
|
raise TypeError('dataloader must be a pytorch DataLoader, but got '
|
||||||
|
@ -58,10 +68,11 @@ class DistEvalHook(EvalHook):
|
||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.gpu_collect = gpu_collect
|
self.gpu_collect = gpu_collect
|
||||||
|
self.by_epoch = by_epoch
|
||||||
self.eval_kwargs = eval_kwargs
|
self.eval_kwargs = eval_kwargs
|
||||||
|
|
||||||
def after_train_epoch(self, runner):
|
def after_train_epoch(self, runner):
|
||||||
if not self.every_n_epochs(runner, self.interval):
|
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
||||||
return
|
return
|
||||||
from mmcls.apis import multi_gpu_test
|
from mmcls.apis import multi_gpu_test
|
||||||
results = multi_gpu_test(
|
results = multi_gpu_test(
|
||||||
|
@ -72,3 +83,17 @@ class DistEvalHook(EvalHook):
|
||||||
if runner.rank == 0:
|
if runner.rank == 0:
|
||||||
print('\n')
|
print('\n')
|
||||||
self.evaluate(runner, results)
|
self.evaluate(runner, results)
|
||||||
|
|
||||||
|
def after_train_iter(self, runner):
|
||||||
|
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
||||||
|
return
|
||||||
|
from mmcls.apis import multi_gpu_test
|
||||||
|
runner.log_buffer.clear()
|
||||||
|
results = multi_gpu_test(
|
||||||
|
runner.model,
|
||||||
|
self.dataloader,
|
||||||
|
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
||||||
|
gpu_collect=self.gpu_collect)
|
||||||
|
if runner.rank == 0:
|
||||||
|
print('\n')
|
||||||
|
self.evaluate(runner, results)
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
mmcv
|
mmcv>=1.1.4
|
||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
matplotlib
|
matplotlib
|
||||||
mmcv
|
mmcv>=1.1.4
|
||||||
numpy
|
numpy
|
||||||
|
|
|
@ -0,0 +1,197 @@
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import mmcv.runner
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.runner import obj_from_dict
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from mmcls.apis import single_gpu_test
|
||||||
|
from mmcls.core import DistEvalHook, EvalHook
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleDataset(Dataset):
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ExampleModel, self).__init__()
|
||||||
|
self.test_cfg = None
|
||||||
|
self.conv = nn.Conv2d(3, 3, 3)
|
||||||
|
|
||||||
|
def forward(self, img, img_metas, test_mode=False, **kwargs):
|
||||||
|
return img
|
||||||
|
|
||||||
|
def train_step(self, data_batch, optimizer):
|
||||||
|
loss = self.forward(**data_batch)
|
||||||
|
return dict(loss=loss)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_eval_hook():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
test_dataset = ExampleModel()
|
||||||
|
data_loader = [
|
||||||
|
DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
sampler=None,
|
||||||
|
num_worker=0,
|
||||||
|
shuffle=False)
|
||||||
|
]
|
||||||
|
EvalHook(data_loader, by_epoch=False)
|
||||||
|
|
||||||
|
test_dataset = ExampleDataset()
|
||||||
|
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||||
|
loader = DataLoader(test_dataset, batch_size=1)
|
||||||
|
model = ExampleModel()
|
||||||
|
data_loader = DataLoader(
|
||||||
|
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||||
|
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||||
|
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||||
|
dict(params=model.parameters()))
|
||||||
|
|
||||||
|
# test EvalHook
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
eval_hook = EvalHook(data_loader, by_epoch=False)
|
||||||
|
runner = mmcv.runner.IterBasedRunner(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
work_dir=tmpdir,
|
||||||
|
logger=logging.getLogger(),
|
||||||
|
max_iters=1)
|
||||||
|
runner.register_hook(eval_hook)
|
||||||
|
runner.run([loader], [('train', 1)], 1)
|
||||||
|
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
|
||||||
|
logger=runner.logger)
|
||||||
|
|
||||||
|
|
||||||
|
def test_epoch_eval_hook():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
test_dataset = ExampleModel()
|
||||||
|
data_loader = [
|
||||||
|
DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
sampler=None,
|
||||||
|
num_worker=0,
|
||||||
|
shuffle=False)
|
||||||
|
]
|
||||||
|
EvalHook(data_loader, by_epoch=True)
|
||||||
|
|
||||||
|
test_dataset = ExampleDataset()
|
||||||
|
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||||
|
loader = DataLoader(test_dataset, batch_size=1)
|
||||||
|
model = ExampleModel()
|
||||||
|
data_loader = DataLoader(
|
||||||
|
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||||
|
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||||
|
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||||
|
dict(params=model.parameters()))
|
||||||
|
|
||||||
|
# test EvalHook with interval
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
eval_hook = EvalHook(data_loader, by_epoch=True, interval=2)
|
||||||
|
runner = mmcv.runner.EpochBasedRunner(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
work_dir=tmpdir,
|
||||||
|
logger=logging.getLogger(),
|
||||||
|
max_epochs=2)
|
||||||
|
runner.register_hook(eval_hook)
|
||||||
|
runner.run([loader], [('train', 1)])
|
||||||
|
test_dataset.evaluate.assert_called_once_with([torch.tensor([1])],
|
||||||
|
logger=runner.logger)
|
||||||
|
|
||||||
|
|
||||||
|
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
|
||||||
|
results = single_gpu_test(model, data_loader)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@patch('mmcls.apis.multi_gpu_test', multi_gpu_test)
|
||||||
|
def test_dist_eval_hook():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
test_dataset = ExampleModel()
|
||||||
|
data_loader = [
|
||||||
|
DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
sampler=None,
|
||||||
|
num_worker=0,
|
||||||
|
shuffle=False)
|
||||||
|
]
|
||||||
|
DistEvalHook(data_loader, by_epoch=False)
|
||||||
|
|
||||||
|
test_dataset = ExampleDataset()
|
||||||
|
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||||
|
loader = DataLoader(test_dataset, batch_size=1)
|
||||||
|
model = ExampleModel()
|
||||||
|
data_loader = DataLoader(
|
||||||
|
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||||
|
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||||
|
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||||
|
dict(params=model.parameters()))
|
||||||
|
|
||||||
|
# test DistEvalHook
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
eval_hook = DistEvalHook(data_loader, by_epoch=False)
|
||||||
|
runner = mmcv.runner.IterBasedRunner(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
work_dir=tmpdir,
|
||||||
|
logger=logging.getLogger(),
|
||||||
|
max_iters=1)
|
||||||
|
runner.register_hook(eval_hook)
|
||||||
|
runner.run([loader], [('train', 1)])
|
||||||
|
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
|
||||||
|
logger=runner.logger)
|
||||||
|
|
||||||
|
|
||||||
|
@patch('mmcls.apis.multi_gpu_test', multi_gpu_test)
|
||||||
|
def test_dist_eval_hook_epoch():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
test_dataset = ExampleModel()
|
||||||
|
data_loader = [
|
||||||
|
DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
sampler=None,
|
||||||
|
num_worker=0,
|
||||||
|
shuffle=False)
|
||||||
|
]
|
||||||
|
DistEvalHook(data_loader)
|
||||||
|
|
||||||
|
test_dataset = ExampleDataset()
|
||||||
|
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||||
|
loader = DataLoader(test_dataset, batch_size=1)
|
||||||
|
model = ExampleModel()
|
||||||
|
data_loader = DataLoader(
|
||||||
|
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||||
|
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||||
|
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||||
|
dict(params=model.parameters()))
|
||||||
|
|
||||||
|
# test DistEvalHook
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
eval_hook = DistEvalHook(data_loader, by_epoch=True, interval=2)
|
||||||
|
runner = mmcv.runner.EpochBasedRunner(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
work_dir=tmpdir,
|
||||||
|
logger=logging.getLogger(),
|
||||||
|
max_epochs=2)
|
||||||
|
runner.register_hook(eval_hook)
|
||||||
|
runner.run([loader], [('train', 1)])
|
||||||
|
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
|
||||||
|
logger=runner.logger)
|
Loading…
Reference in New Issue