[Refactor] Refactor data_batch type and remove cur_dataloader in runner. (#171)

* [Refactor] Refactor data_batch type.

* fix sampler

* [Refactor] Remove cur_dataloader in runner.

* fix set_epoch
pull/172/head
RangiLyu 2022-04-08 15:57:10 +08:00 committed by GitHub
parent ab8b51682f
commit 59cc08e3ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 186 additions and 196 deletions

View File

@ -2,18 +2,13 @@
import itertools
import math
from typing import Iterator, Optional, Sized
# from mmengine.dist import get_dist_info, sync_random_seed
from unittest.mock import MagicMock
import torch
from torch.utils.data import Sampler
from mmengine.dist import get_dist_info, sync_random_seed
from mmengine.registry import DATA_SAMPLERS
# TODO, need to remove those lines after implementing dist module
get_dist_info = MagicMock(return_value=(0, 1))
sync_random_seed = MagicMock(return_value=0)
@DATA_SAMPLERS.register_module()
class DefaultSampler(Sampler):

View File

@ -1,13 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import Any, Sequence, Tuple
from typing import Sequence
import numpy as np
import torch
from .base_data_element import BaseDataElement
DATA_BATCH = Sequence[Tuple[Any, BaseDataElement]]
DATA_BATCH = Sequence[dict]
def worker_init_fn(worker_id: int, num_workers: int, rank: int,
@ -36,10 +34,10 @@ def pseudo_collate(data_batch: DATA_BATCH) -> DATA_BATCH:
nothing just returns ``data_batch``.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data from
data_batch (Sequence[dict]): Batch of data from
dataloader.
Returns:
Sequence[Tuple[Any, BaseDataElement]]: Return input ``data_batch``.
Sequence[dict]: Return input ``data_batch``.
"""
return data_batch

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union
from typing import Iterator, List, Optional, Sequence, Union
from mmengine.data import BaseDataElement
from ..registry.root import METRICS
@ -37,23 +37,25 @@ class Evaluator:
for metric in self.metrics:
metric.dataset_meta = dataset_meta
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
def process(self, data_batch: Sequence[dict],
predictions: Sequence[BaseDataElement]):
"""Convert ``BaseDataSample`` to dict and invoke process method of each
metric.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
from the dataloader.
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from
the model.
"""
_data_batch = []
for input, data in data_batch:
if isinstance(data, BaseDataElement):
_data_batch.append((input, data.to_dict()))
for data in data_batch:
if isinstance(data['data_sample'], BaseDataElement):
_data_batch.append(
dict(
inputs=data['inputs'],
data_sample=data['data_sample'].to_dict()))
else:
_data_batch.append((input, data))
_data_batch.append(data)
_predictions = []
for pred in predictions:
if isinstance(pred, BaseDataElement):

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, List, Optional, Sequence, Union
from mmengine.dist import (broadcast_object_list, collect_results,
is_main_process)
@ -50,15 +50,14 @@ class BaseMetric(metaclass=ABCMeta):
self._dataset_meta = dataset_meta
@abstractmethod
def process(self, data_batch: Sequence[Tuple[Any, dict]],
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_batch (Sequence[Tuple[Any, dict]]): A batch of data
from the dataloader.
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from
the model.
"""

View File

@ -2,15 +2,14 @@
import os.path as osp
import warnings
from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.dist import master_only
from mmengine.fileio import FileClient
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
@ -185,8 +184,8 @@ class CheckpointHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
import torch
@ -7,7 +7,7 @@ from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
@ -46,8 +46,8 @@ class EmptyCacheHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict or sequence, optional): Outputs from model.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
from mmengine.data import BaseDataElement
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
class Hook:
@ -174,8 +174,8 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
"""
self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='train')
@ -190,8 +190,8 @@ class Hook:
Args:
runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
"""
self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='val')
@ -206,8 +206,8 @@ class Hook:
Args:
runner (Runner): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
"""
self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='test')
@ -223,8 +223,8 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
@ -247,8 +247,8 @@ class Hook:
Args:
runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict or sequence, optional): Outputs from
model. Defaults to None.
"""
@ -271,8 +271,8 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
@ -317,8 +317,8 @@ class Hook:
runner (Runner): The runner of the training, validation or testing
process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
@ -337,8 +337,8 @@ class Hook:
runner (Runner): The runner of the training, validation or testing
process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (Sequence[BaseDataElement], optional): Outputs from model.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
@ -387,19 +387,19 @@ class Hook:
"""
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner, batch_idx: int) -> bool:
def end_of_epoch(self, dataloader, batch_idx: int) -> bool:
"""Check whether the current iteration reaches the last iteration of
current dataloader.
the dataloader.
Args:
runner (Runner): The runner of the training, validation or testing
process.
dataloader (Dataloader): The dataloader of the training,
validation or testing process.
batch_idx (int): The index of the current batch in the loop.
Returns:
bool: Whether reaches the end of current epoch or not.
"""
return batch_idx + 1 == len(runner.cur_dataloader)
return batch_idx + 1 == len(dataloader)
def is_last_train_epoch(self, runner) -> bool:
"""Test whether current epoch is the last train epoch.

View File

@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import time
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
@ -37,8 +37,8 @@ class IterTimerHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
# TODO: update for new logging system
@ -57,8 +57,8 @@ class IterTimerHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict or sequence, optional): Outputs from model. Defaults
to None.
mode (str): Current mode of runner. Defaults to 'train'.

View File

@ -5,18 +5,17 @@ import os
import os.path as osp
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
import torch
from mmengine.data import BaseDataElement
from mmengine.dist import master_only
from mmengine.fileio import FileClient
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.utils import is_tuple_of, scandir
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
@ -183,15 +182,16 @@ class LoggerHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[BaseDataElement], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self._inner_iter = batch_idx
if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(runner, batch_idx)):
self.by_epoch and self.end_of_epoch(
runner.train_loop.dataloader, batch_idx)):
exp_info = f'Exp name: {runner.meta["exp_name"]}'
runner.logger.info(exp_info)
if self.by_epoch and self.every_n_inner_iters(batch_idx,
@ -199,7 +199,8 @@ class LoggerHook(Hook):
self._log_train(runner)
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
self._log_train(runner)
elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last:
elif self.end_of_epoch(runner.train_loop.dataloader,
batch_idx) and not self.ignore_last:
# `runner.max_iters` may not be divisible by `self.interval`. if
# `self.ignore_last==True`, the log of remaining iterations will
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
@ -271,7 +272,7 @@ class LoggerHook(Hook):
# by iter: Iter [100/100000]
if self.by_epoch:
log_str = f'Epoch [{cur_epoch}]' \
f'[{cur_iter}/{len(runner.cur_dataloader)}]\t'
f'[{cur_iter}/{len(runner.train_loop.dataloader)}]\t'
else:
log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t'
log_str += f'{lr_momentum_str}, '
@ -311,7 +312,7 @@ class LoggerHook(Hook):
"""
tag = self._collect_info(runner, 'val')
# Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
eval_iter = len(runner.cur_dataloader)
eval_iter = len(runner.val_loop.dataloader)
cur_iter = self._get_iter(runner)
cur_epoch = self._get_epoch(runner, 'val')
# val/test time

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Any, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple
import cv2
import numpy as np
@ -41,26 +41,25 @@ class NaiveVisualizationHook(Hook):
self,
runner,
batch_idx: int,
data_batch: Optional[Sequence[Tuple[Any, BaseDataElement]]] = None,
data_batch: Optional[Sequence[dict]] = None,
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
"""Show or Write the predicted results.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
data_batch (Sequence[dict], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataElement], optional): Outputs from model.
Defaults to None.
"""
if self.every_n_iters(runner, self._interval):
inputs, data_samples = data_batch # type: ignore
inputs = tensor2imgs(inputs,
**data_samples[0].get('img_norm_cfg', dict()))
for input, data_sample, output in zip(
inputs,
data_samples, # type: ignore
outputs): # type: ignore
for data, output in zip(data_batch, outputs): # type: ignore
input = data['inputs']
data_sample = data['data_sample']
input = tensor2imgs(input,
**data_sample.get('img_norm_cfg',
dict()))[0]
# TODO We will implement a function to revert the augmentation
# in the future.
ori_shape = (data_sample.ori_width, data_sample.ori_height)

View File

@ -1,16 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Any, List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence
import torch
from torch.nn.parameter import Parameter
from torch.nn.utils import clip_grad
from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
@ -77,10 +76,9 @@ class OptimizerHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.
Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
In order to keep this interface consistent with other hooks,
we keep ``data_batch`` here. Defaults to None.
outputs (dict, optional): Outputs from model.
In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None.

View File

@ -1,11 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple
from typing import Optional, Sequence
from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
@ -25,10 +24,9 @@ class ParamSchedulerHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.
Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
In order to keep this interface consistent with other hooks,
we keep ``data_batch`` here. Defaults to None.
outputs (dict, optional): Outputs from model.
In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.

View File

@ -20,9 +20,11 @@ class DistSamplerSeedHook(Hook):
Args:
runner (Runner): The runner of the training process.
"""
if hasattr(runner.cur_dataloader.sampler, 'set_epoch'):
if hasattr(runner.train_loop.dataloader.sampler, 'set_epoch'):
# in case the data loader uses `SequentialSampler` in Pytorch
runner.cur_dataloader.sampler.set_epoch(runner.epoch)
elif hasattr(runner.cur_dataloader.batch_sampler.sampler, 'set_epoch'):
runner.train_loop.dataloader.sampler.set_epoch(runner.epoch)
elif hasattr(runner.train_loop.dataloader.batch_sampler.sampler,
'set_epoch'):
# batch sampler in pytorch warps the sampler as its attributes.
runner.cur_dataloader.batch_sampler.sampler.set_epoch(runner.epoch)
runner.train_loop.dataloader.batch_sampler.sampler.set_epoch(
runner.epoch)

View File

@ -25,9 +25,6 @@ class BaseLoop(metaclass=ABCMeta):
else:
self.dataloader = dataloader
# TODO, used by `end_of_epoch` of `Hook`
self._runner.data_loader = self.dataloader
@property
def runner(self):
return self._runner

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Sequence, Tuple, Union
import warnings
from typing import Dict, List, Sequence, Union
import torch
from torch.utils.data import DataLoader
from mmengine.data import BaseDataElement
from mmengine.evaluator import Evaluator
from mmengine.registry import LOOPS
from mmengine.utils import is_list_of
@ -40,7 +40,6 @@ class EpochBasedTrainLoop(BaseLoop):
def run(self) -> None:
"""Launch training."""
self.runner.cur_dataloader = self.dataloader
self.runner.call_hook('before_train')
while self.runner._epoch < self._max_epochs:
@ -62,13 +61,11 @@ class EpochBasedTrainLoop(BaseLoop):
self.runner.call_hook('after_train_epoch')
self.runner._epoch += 1
def run_iter(self, idx,
data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None:
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
"""Iterate one min-batch.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data
from dataloader.
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_train_iter', batch_idx=idx, data_batch=data_batch)
@ -112,7 +109,6 @@ class IterBasedTrainLoop(BaseLoop):
def run(self) -> None:
"""Launch training."""
self.runner.cur_dataloader = self.dataloader
self.runner.call_hook('before_train')
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
@ -130,13 +126,11 @@ class IterBasedTrainLoop(BaseLoop):
self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train')
def run_iter(self, data_batch: Sequence[Tuple[Any,
BaseDataElement]]) -> None:
def run_iter(self, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data
from dataloader.
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_train_iter',
@ -180,12 +174,17 @@ class ValLoop(BaseLoop):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
self.evaluator = evaluator # type: ignore
if hasattr(self.dataloader.dataset, 'metainfo'):
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
else:
warnings.warn(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in evaluator and metric will be '
'None.')
self.interval = interval
def run(self):
"""Launch validation."""
self.runner.cur_dataloader = self.dataloader
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
@ -201,11 +200,11 @@ class ValLoop(BaseLoop):
self.runner.call_hook('after_val')
@torch.no_grad()
def run_iter(self, idx, data_batch: Sequence[Tuple[Any, BaseDataElement]]):
def run_iter(self, idx, data_batch: Sequence[dict]):
"""Iterate one mini-batch.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data
data_batch (Sequence[dict]): Batch of data
from dataloader.
"""
self.runner.call_hook(
@ -239,10 +238,16 @@ class TestLoop(BaseLoop):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
self.evaluator = evaluator # type: ignore
if hasattr(self.dataloader.dataset, 'metainfo'):
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
else:
warnings.warn(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in evaluator and metric will be '
'None.')
def run(self) -> None:
"""Launch test."""
self.runner.cur_dataloader = self.dataloader
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()
@ -258,13 +263,11 @@ class TestLoop(BaseLoop):
self.runner.call_hook('after_test')
@torch.no_grad()
def run_iter(self, idx,
data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None:
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data
from dataloader.
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)

View File

@ -1215,6 +1215,8 @@ class Runner:
+----------------------+-------------------------+
| IterTimerHook | NORMAL (40) |
+----------------------+-------------------------+
| DistSamplerSeedHook | NORMAL (40) |
+----------------------+-------------------------+
| LoggerHook | BELOW_NORMAL (60) |
+----------------------+-------------------------+
| ParamSchedulerHook | LOW (70) |
@ -1228,6 +1230,7 @@ class Runner:
default_hooks = dict(
optimizer=dict(type='OptimizerHook', grad_clip=None),
timer=dict(type='IterTimerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
@ -1252,6 +1255,7 @@ class Runner:
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
sampler_seed=dict(type='DistSamplerSeedHook'),
)
if hooks is not None:
for name, hook in hooks.items():

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Optional, Sequence
from unittest import TestCase
import numpy as np
@ -40,7 +40,7 @@ class ToyMetric(BaseMetric):
def process(self, data_batch, predictions):
results = [{
'pred': pred.get('pred'),
'label': data[1].get('label')
'label': data['data_sample'].get('label')
} for pred, data in zip(predictions, data_batch)]
self.results.extend(results)
@ -66,7 +66,7 @@ class NonPrefixedMetric(BaseMetric):
"""Evaluator with unassigned `default_prefix` to test the warning
information."""
def process(self, data_batch: Sequence[Tuple[Any, dict]],
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
pass
@ -79,8 +79,11 @@ def generate_test_results(size, batch_size, pred, label):
bs_residual = size % batch_size
for i in range(num_batch):
bs = bs_residual if i == num_batch - 1 else batch_size
data_batch = [(np.zeros((3, 10, 10)), BaseDataElement(label=label))
for _ in range(bs)]
data_batch = [
dict(
inputs=np.zeros((3, 10, 10)),
data_sample=BaseDataElement(label=label)) for _ in range(bs)
]
predictions = [BaseDataElement(pred=pred) for _ in range(bs)]
yield (data_batch, predictions)
@ -228,7 +231,10 @@ class TestEvaluator(TestCase):
size = 10
all_data = [(np.zeros((3, 10, 10)), BaseDataElement(label=1))
for _ in range(size)]
all_data = [
dict(
inputs=np.zeros((3, 10, 10)),
data_sample=BaseDataElement(label=1)) for _ in range(size)
]
all_predictions = [BaseDataElement(pred=0) for _ in range(size)]
evaluator.offline_evaluate(all_data, all_predictions)

View File

@ -157,18 +157,17 @@ class TestHook:
def test_end_of_epoch(self):
hook = Hook()
runner = Mock()
# last inner iter
batch_idx = 1
runner.cur_dataloader.__len__ = Mock(return_value=2)
runner.cur_dataloader.__len__ = Mock(return_value=2)
return_val = hook.end_of_epoch(runner, batch_idx)
dataloader = Mock()
dataloader.__len__ = Mock(return_value=2)
return_val = hook.end_of_epoch(dataloader, batch_idx)
assert return_val
# not the last inner iter
batch_idx = 0
return_val = hook.end_of_epoch(runner, batch_idx)
return_val = hook.end_of_epoch(dataloader, batch_idx)
assert not return_val
def test_is_last_train_epoch(self):

View File

@ -111,7 +111,7 @@ class TestLoggerHook:
# Test end of the epoch.
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
logger_hook._log_train = MagicMock()
runner.cur_dataloader = [0] * 5
runner.train_loop.dataloader = [0] * 5
batch_idx = 4
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called()
@ -341,7 +341,9 @@ class TestLoggerHook:
def _setup_runner(self):
runner = MagicMock()
runner.epoch = 1
runner.cur_dataloader = [0] * 5
runner.train_loop.dataloader = [0] * 5
runner.val_loop.dataloader = [0] * 5
runner.test_loop.dataloader = [0] * 5
runner.iter = 10
runner.train_loop.max_iters = 50
logger = logging.getLogger()

View File

@ -16,70 +16,56 @@ class TestNaiveVisualizationHook:
inputs = torch.randn(1, 3, 15, 15)
batch_idx = 10
# test with normalize, resize, pad
gt_datasamples = [
BaseDataElement(
metainfo=dict(
img_norm_cfg=dict(
mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
scale=(10, 10),
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg'))
]
gt_datasamples = BaseDataElement(
metainfo=dict(
img_norm_cfg=dict(
mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
scale=(10, 10),
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg'))
pred_datasamples = [BaseDataElement()]
data_batch = (inputs, gt_datasamples)
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test with resize, pad
gt_datasamples = [
BaseDataElement(
metainfo=dict(
scale=(10, 10),
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg')),
]
gt_datasamples = BaseDataElement(
metainfo=dict(
scale=(10, 10),
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg'))
pred_datasamples = [BaseDataElement()]
data_batch = (inputs, gt_datasamples)
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test with only resize
gt_datasamples = [
BaseDataElement(
metainfo=dict(
scale=(15, 15),
ori_height=5,
ori_width=5,
img_path='tmp.jpg')),
]
gt_datasamples = BaseDataElement(
metainfo=dict(
scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg'))
pred_datasamples = [BaseDataElement()]
data_batch = (inputs, gt_datasamples)
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test with only pad
gt_datasamples = [
BaseDataElement(
metainfo=dict(
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg')),
]
gt_datasamples = BaseDataElement(
metainfo=dict(
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg'))
pred_datasamples = [BaseDataElement()]
data_batch = (inputs, gt_datasamples)
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test no transform
gt_datasamples = [
BaseDataElement(
metainfo=dict(ori_height=15, ori_width=15,
img_path='tmp.jpg')),
]
gt_datasamples = BaseDataElement(
metainfo=dict(ori_height=15, ori_width=15, img_path='tmp.jpg'))
pred_datasamples = [BaseDataElement()]
data_batch = (inputs, gt_datasamples)
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)

View File

@ -12,17 +12,18 @@ class TestDistSamplerSeedHook:
# Test dataset sampler
runner = Mock()
runner.epoch = 1
runner.cur_dataloader = Mock()
runner.cur_dataloader.sampler = Mock()
runner.cur_dataloader.sampler.set_epoch = Mock()
runner.train_loop.dataloader = Mock()
runner.train_loop.dataloader.sampler = Mock()
runner.train_loop.dataloader.sampler.set_epoch = Mock()
hook.before_train_epoch(runner)
runner.cur_dataloader.sampler.set_epoch.assert_called()
runner.train_loop.dataloader.sampler.set_epoch.assert_called()
# Test batch sampler
runner = Mock()
runner.cur_dataloader = Mock()
runner.cur_dataloader.sampler = Mock(spec_set=True)
runner.cur_dataloader.batch_sampler = Mock()
runner.cur_dataloader.batch_sampler.sampler = Mock()
runner.cur_dataloader.batch_sampler.sampler.set_epoch = Mock()
runner.train_loop.dataloader = Mock()
runner.train_loop.dataloader.sampler = Mock(spec_set=True)
runner.train_loop.dataloader.batch_sampler = Mock()
runner.train_loop.dataloader.batch_sampler.sampler = Mock()
runner.train_loop.dataloader.batch_sampler.sampler.set_epoch = Mock()
hook.before_train_epoch(runner)
runner.cur_dataloader.batch_sampler.sampler.set_epoch.assert_called()
runner.train_loop.dataloader.\
batch_sampler.sampler.set_epoch.assert_called()

View File

@ -36,7 +36,8 @@ class ToyModel(nn.Module):
self.linear = nn.Linear(2, 1)
def forward(self, data_batch, return_loss=False):
inputs, labels = zip(*data_batch)
inputs, labels = zip(
*map(lambda x: (x['inputs'], x['data_sample']), data_batch))
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
inputs = torch.stack(inputs).to(device)
labels = torch.stack(labels).to(device)
@ -67,7 +68,7 @@ class CustomModelWrapper(nn.Module):
@DATASETS.register_module()
class ToyDataset(Dataset):
META = dict() # type: ignore
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@ -75,7 +76,7 @@ class ToyDataset(Dataset):
return self.data.size(0)
def __getitem__(self, index):
return self.data[index], self.label[index]
return dict(inputs=self.data[index], data_sample=self.label[index])
@METRICS.register_module()