[Refactor] Refactor data_batch type and remove cur_dataloader in runner. (#171)
* [Refactor] Refactor data_batch type. * fix sampler * [Refactor] Remove cur_dataloader in runner. * fix set_epochpull/172/head
parent
ab8b51682f
commit
59cc08e3ac
|
@ -2,18 +2,13 @@
|
|||
import itertools
|
||||
import math
|
||||
from typing import Iterator, Optional, Sized
|
||||
# from mmengine.dist import get_dist_info, sync_random_seed
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from mmengine.dist import get_dist_info, sync_random_seed
|
||||
from mmengine.registry import DATA_SAMPLERS
|
||||
|
||||
# TODO, need to remove those lines after implementing dist module
|
||||
get_dist_info = MagicMock(return_value=(0, 1))
|
||||
sync_random_seed = MagicMock(return_value=0)
|
||||
|
||||
|
||||
@DATA_SAMPLERS.register_module()
|
||||
class DefaultSampler(Sampler):
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
from typing import Any, Sequence, Tuple
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .base_data_element import BaseDataElement
|
||||
|
||||
DATA_BATCH = Sequence[Tuple[Any, BaseDataElement]]
|
||||
DATA_BATCH = Sequence[dict]
|
||||
|
||||
|
||||
def worker_init_fn(worker_id: int, num_workers: int, rank: int,
|
||||
|
@ -36,10 +34,10 @@ def pseudo_collate(data_batch: DATA_BATCH) -> DATA_BATCH:
|
|||
nothing just returns ``data_batch``.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data from
|
||||
data_batch (Sequence[dict]): Batch of data from
|
||||
dataloader.
|
||||
|
||||
Returns:
|
||||
Sequence[Tuple[Any, BaseDataElement]]: Return input ``data_batch``.
|
||||
Sequence[dict]: Return input ``data_batch``.
|
||||
"""
|
||||
return data_batch
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Iterator, List, Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from ..registry.root import METRICS
|
||||
|
@ -37,23 +37,25 @@ class Evaluator:
|
|||
for metric in self.metrics:
|
||||
metric.dataset_meta = dataset_meta
|
||||
|
||||
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[BaseDataElement]):
|
||||
"""Convert ``BaseDataSample`` to dict and invoke process method of each
|
||||
metric.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
|
||||
from the dataloader.
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[BaseDataElement]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
_data_batch = []
|
||||
for input, data in data_batch:
|
||||
if isinstance(data, BaseDataElement):
|
||||
_data_batch.append((input, data.to_dict()))
|
||||
for data in data_batch:
|
||||
if isinstance(data['data_sample'], BaseDataElement):
|
||||
_data_batch.append(
|
||||
dict(
|
||||
inputs=data['inputs'],
|
||||
data_sample=data['data_sample'].to_dict()))
|
||||
else:
|
||||
_data_batch.append((input, data))
|
||||
_data_batch.append(data)
|
||||
_predictions = []
|
||||
for pred in predictions:
|
||||
if isinstance(pred, BaseDataElement):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
|
||||
from mmengine.dist import (broadcast_object_list, collect_results,
|
||||
is_main_process)
|
||||
|
@ -50,15 +50,14 @@ class BaseMetric(metaclass=ABCMeta):
|
|||
self._dataset_meta = dataset_meta
|
||||
|
||||
@abstractmethod
|
||||
def process(self, data_batch: Sequence[Tuple[Any, dict]],
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
"""Process one batch of data samples and predictions. The processed
|
||||
results should be stored in ``self.results``, which will be used to
|
||||
compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, dict]]): A batch of data
|
||||
from the dataloader.
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
|
|
|
@ -2,15 +2,14 @@
|
|||
import os.path as osp
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -185,8 +184,8 @@ class CheckpointHook(Hook):
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -7,7 +7,7 @@ from mmengine.data import BaseDataElement
|
|||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -46,8 +46,8 @@ class EmptyCacheHook(Hook):
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
||||
|
||||
class Hook:
|
||||
|
@ -174,8 +174,8 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._before_iter(
|
||||
runner, batch_idx=batch_idx, data_batch=data_batch, mode='train')
|
||||
|
@ -190,8 +190,8 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._before_iter(
|
||||
runner, batch_idx=batch_idx, data_batch=data_batch, mode='val')
|
||||
|
@ -206,8 +206,8 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._before_iter(
|
||||
runner, batch_idx=batch_idx, data_batch=data_batch, mode='test')
|
||||
|
@ -223,8 +223,8 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
@ -247,8 +247,8 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from
|
||||
model. Defaults to None.
|
||||
"""
|
||||
|
@ -271,8 +271,8 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
@ -317,8 +317,8 @@ class Hook:
|
|||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
|
@ -337,8 +337,8 @@ class Hook:
|
|||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (Sequence[BaseDataElement], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
|
@ -387,19 +387,19 @@ class Hook:
|
|||
"""
|
||||
return (runner.iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def end_of_epoch(self, runner, batch_idx: int) -> bool:
|
||||
def end_of_epoch(self, dataloader, batch_idx: int) -> bool:
|
||||
"""Check whether the current iteration reaches the last iteration of
|
||||
current dataloader.
|
||||
the dataloader.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
dataloader (Dataloader): The dataloader of the training,
|
||||
validation or testing process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
|
||||
Returns:
|
||||
bool: Whether reaches the end of current epoch or not.
|
||||
"""
|
||||
return batch_idx + 1 == len(runner.cur_dataloader)
|
||||
return batch_idx + 1 == len(dataloader)
|
||||
|
||||
def is_last_train_epoch(self, runner) -> bool:
|
||||
"""Test whether current epoch is the last train epoch.
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import time
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -37,8 +37,8 @@ class IterTimerHook(Hook):
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
|
@ -57,8 +57,8 @@ class IterTimerHook(Hook):
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from model. Defaults
|
||||
to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
|
|
|
@ -5,18 +5,17 @@ import os
|
|||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils import is_tuple_of, scandir
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -183,15 +182,16 @@ class LoggerHook(Hook):
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[BaseDataElement], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._inner_iter = batch_idx
|
||||
if runner.meta is not None and 'exp_name' in runner.meta:
|
||||
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
||||
self.by_epoch and self.end_of_epoch(runner, batch_idx)):
|
||||
self.by_epoch and self.end_of_epoch(
|
||||
runner.train_loop.dataloader, batch_idx)):
|
||||
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
||||
runner.logger.info(exp_info)
|
||||
if self.by_epoch and self.every_n_inner_iters(batch_idx,
|
||||
|
@ -199,7 +199,8 @@ class LoggerHook(Hook):
|
|||
self._log_train(runner)
|
||||
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
|
||||
self._log_train(runner)
|
||||
elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last:
|
||||
elif self.end_of_epoch(runner.train_loop.dataloader,
|
||||
batch_idx) and not self.ignore_last:
|
||||
# `runner.max_iters` may not be divisible by `self.interval`. if
|
||||
# `self.ignore_last==True`, the log of remaining iterations will
|
||||
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
|
||||
|
@ -271,7 +272,7 @@ class LoggerHook(Hook):
|
|||
# by iter: Iter [100/100000]
|
||||
if self.by_epoch:
|
||||
log_str = f'Epoch [{cur_epoch}]' \
|
||||
f'[{cur_iter}/{len(runner.cur_dataloader)}]\t'
|
||||
f'[{cur_iter}/{len(runner.train_loop.dataloader)}]\t'
|
||||
else:
|
||||
log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t'
|
||||
log_str += f'{lr_momentum_str}, '
|
||||
|
@ -311,7 +312,7 @@ class LoggerHook(Hook):
|
|||
"""
|
||||
tag = self._collect_info(runner, 'val')
|
||||
# Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
|
||||
eval_iter = len(runner.cur_dataloader)
|
||||
eval_iter = len(runner.val_loop.dataloader)
|
||||
cur_iter = self._get_iter(runner)
|
||||
cur_epoch = self._get_epoch(runner, 'val')
|
||||
# val/test time
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -41,26 +41,25 @@ class NaiveVisualizationHook(Hook):
|
|||
self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataElement]]] = None,
|
||||
data_batch: Optional[Sequence[dict]] = None,
|
||||
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
|
||||
"""Show or Write the predicted results.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
|
||||
data_batch (Sequence[dict], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataElement], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
if self.every_n_iters(runner, self._interval):
|
||||
inputs, data_samples = data_batch # type: ignore
|
||||
inputs = tensor2imgs(inputs,
|
||||
**data_samples[0].get('img_norm_cfg', dict()))
|
||||
for input, data_sample, output in zip(
|
||||
inputs,
|
||||
data_samples, # type: ignore
|
||||
outputs): # type: ignore
|
||||
for data, output in zip(data_batch, outputs): # type: ignore
|
||||
input = data['inputs']
|
||||
data_sample = data['data_sample']
|
||||
input = tensor2imgs(input,
|
||||
**data_sample.get('img_norm_cfg',
|
||||
dict()))[0]
|
||||
# TODO We will implement a function to revert the augmentation
|
||||
# in the future.
|
||||
ori_shape = (data_sample.ori_width, data_sample.ori_height)
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import Any, List, Optional, Sequence, Tuple
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.utils import clip_grad
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -77,10 +76,9 @@ class OptimizerHook(Hook):
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
|
||||
from dataloader. In order to keep this interface consistent
|
||||
with other hooks, we keep ``data_batch`` here.
|
||||
Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
In order to keep this interface consistent with other hooks,
|
||||
we keep ``data_batch`` here. Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
In order to keep this interface consistent with other hooks,
|
||||
we keep ``outputs`` here. Defaults to None.
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -25,10 +24,9 @@ class ParamSchedulerHook(Hook):
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
|
||||
from dataloader. In order to keep this interface consistent
|
||||
with other hooks, we keep ``data_batch`` here.
|
||||
Defaults to None.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
In order to keep this interface consistent with other hooks,
|
||||
we keep ``data_batch`` here. Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
In order to keep this interface consistent with other hooks, we
|
||||
keep ``data_batch`` here. Defaults to None.
|
||||
|
|
|
@ -20,9 +20,11 @@ class DistSamplerSeedHook(Hook):
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if hasattr(runner.cur_dataloader.sampler, 'set_epoch'):
|
||||
if hasattr(runner.train_loop.dataloader.sampler, 'set_epoch'):
|
||||
# in case the data loader uses `SequentialSampler` in Pytorch
|
||||
runner.cur_dataloader.sampler.set_epoch(runner.epoch)
|
||||
elif hasattr(runner.cur_dataloader.batch_sampler.sampler, 'set_epoch'):
|
||||
runner.train_loop.dataloader.sampler.set_epoch(runner.epoch)
|
||||
elif hasattr(runner.train_loop.dataloader.batch_sampler.sampler,
|
||||
'set_epoch'):
|
||||
# batch sampler in pytorch warps the sampler as its attributes.
|
||||
runner.cur_dataloader.batch_sampler.sampler.set_epoch(runner.epoch)
|
||||
runner.train_loop.dataloader.batch_sampler.sampler.set_epoch(
|
||||
runner.epoch)
|
||||
|
|
|
@ -25,9 +25,6 @@ class BaseLoop(metaclass=ABCMeta):
|
|||
else:
|
||||
self.dataloader = dataloader
|
||||
|
||||
# TODO, used by `end_of_epoch` of `Hook`
|
||||
self._runner.data_loader = self.dataloader
|
||||
|
||||
@property
|
||||
def runner(self):
|
||||
return self._runner
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
import warnings
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.registry import LOOPS
|
||||
from mmengine.utils import is_list_of
|
||||
|
@ -40,7 +40,6 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
|
||||
def run(self) -> None:
|
||||
"""Launch training."""
|
||||
self.runner.cur_dataloader = self.dataloader
|
||||
self.runner.call_hook('before_train')
|
||||
|
||||
while self.runner._epoch < self._max_epochs:
|
||||
|
@ -62,13 +61,11 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
self.runner.call_hook('after_train_epoch')
|
||||
self.runner._epoch += 1
|
||||
|
||||
def run_iter(self, idx,
|
||||
data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None:
|
||||
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
||||
"""Iterate one min-batch.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data
|
||||
from dataloader.
|
||||
data_batch (Sequence[dict]): Batch of data from dataloader.
|
||||
"""
|
||||
self.runner.call_hook(
|
||||
'before_train_iter', batch_idx=idx, data_batch=data_batch)
|
||||
|
@ -112,7 +109,6 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
|
||||
def run(self) -> None:
|
||||
"""Launch training."""
|
||||
self.runner.cur_dataloader = self.dataloader
|
||||
self.runner.call_hook('before_train')
|
||||
# In iteration-based training loop, we treat the whole training process
|
||||
# as a big epoch and execute the corresponding hook.
|
||||
|
@ -130,13 +126,11 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
self.runner.call_hook('after_train_epoch')
|
||||
self.runner.call_hook('after_train')
|
||||
|
||||
def run_iter(self, data_batch: Sequence[Tuple[Any,
|
||||
BaseDataElement]]) -> None:
|
||||
def run_iter(self, data_batch: Sequence[dict]) -> None:
|
||||
"""Iterate one mini-batch.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data
|
||||
from dataloader.
|
||||
data_batch (Sequence[dict]): Batch of data from dataloader.
|
||||
"""
|
||||
self.runner.call_hook(
|
||||
'before_train_iter',
|
||||
|
@ -180,12 +174,17 @@ class ValLoop(BaseLoop):
|
|||
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
||||
else:
|
||||
self.evaluator = evaluator # type: ignore
|
||||
|
||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
|
||||
else:
|
||||
warnings.warn(
|
||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||
'metainfo. ``dataset_meta`` in evaluator and metric will be '
|
||||
'None.')
|
||||
self.interval = interval
|
||||
|
||||
def run(self):
|
||||
"""Launch validation."""
|
||||
self.runner.cur_dataloader = self.dataloader
|
||||
self.runner.call_hook('before_val')
|
||||
self.runner.call_hook('before_val_epoch')
|
||||
self.runner.model.eval()
|
||||
|
@ -201,11 +200,11 @@ class ValLoop(BaseLoop):
|
|||
self.runner.call_hook('after_val')
|
||||
|
||||
@torch.no_grad()
|
||||
def run_iter(self, idx, data_batch: Sequence[Tuple[Any, BaseDataElement]]):
|
||||
def run_iter(self, idx, data_batch: Sequence[dict]):
|
||||
"""Iterate one mini-batch.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data
|
||||
data_batch (Sequence[dict]): Batch of data
|
||||
from dataloader.
|
||||
"""
|
||||
self.runner.call_hook(
|
||||
|
@ -239,10 +238,16 @@ class TestLoop(BaseLoop):
|
|||
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
||||
else:
|
||||
self.evaluator = evaluator # type: ignore
|
||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
|
||||
else:
|
||||
warnings.warn(
|
||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||
'metainfo. ``dataset_meta`` in evaluator and metric will be '
|
||||
'None.')
|
||||
|
||||
def run(self) -> None:
|
||||
"""Launch test."""
|
||||
self.runner.cur_dataloader = self.dataloader
|
||||
self.runner.call_hook('before_test')
|
||||
self.runner.call_hook('before_test_epoch')
|
||||
self.runner.model.eval()
|
||||
|
@ -258,13 +263,11 @@ class TestLoop(BaseLoop):
|
|||
self.runner.call_hook('after_test')
|
||||
|
||||
@torch.no_grad()
|
||||
def run_iter(self, idx,
|
||||
data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None:
|
||||
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
||||
"""Iterate one mini-batch.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data
|
||||
from dataloader.
|
||||
data_batch (Sequence[dict]): Batch of data from dataloader.
|
||||
"""
|
||||
self.runner.call_hook(
|
||||
'before_test_iter', batch_idx=idx, data_batch=data_batch)
|
||||
|
|
|
@ -1215,6 +1215,8 @@ class Runner:
|
|||
+----------------------+-------------------------+
|
||||
| IterTimerHook | NORMAL (40) |
|
||||
+----------------------+-------------------------+
|
||||
| DistSamplerSeedHook | NORMAL (40) |
|
||||
+----------------------+-------------------------+
|
||||
| LoggerHook | BELOW_NORMAL (60) |
|
||||
+----------------------+-------------------------+
|
||||
| ParamSchedulerHook | LOW (70) |
|
||||
|
@ -1228,6 +1230,7 @@ class Runner:
|
|||
default_hooks = dict(
|
||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
logger=dict(type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1),
|
||||
|
@ -1252,6 +1255,7 @@ class Runner:
|
|||
logger=dict(type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
)
|
||||
if hooks is not None:
|
||||
for name, hook in hooks.items():
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
|
@ -40,7 +40,7 @@ class ToyMetric(BaseMetric):
|
|||
def process(self, data_batch, predictions):
|
||||
results = [{
|
||||
'pred': pred.get('pred'),
|
||||
'label': data[1].get('label')
|
||||
'label': data['data_sample'].get('label')
|
||||
} for pred, data in zip(predictions, data_batch)]
|
||||
self.results.extend(results)
|
||||
|
||||
|
@ -66,7 +66,7 @@ class NonPrefixedMetric(BaseMetric):
|
|||
"""Evaluator with unassigned `default_prefix` to test the warning
|
||||
information."""
|
||||
|
||||
def process(self, data_batch: Sequence[Tuple[Any, dict]],
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
pass
|
||||
|
||||
|
@ -79,8 +79,11 @@ def generate_test_results(size, batch_size, pred, label):
|
|||
bs_residual = size % batch_size
|
||||
for i in range(num_batch):
|
||||
bs = bs_residual if i == num_batch - 1 else batch_size
|
||||
data_batch = [(np.zeros((3, 10, 10)), BaseDataElement(label=label))
|
||||
for _ in range(bs)]
|
||||
data_batch = [
|
||||
dict(
|
||||
inputs=np.zeros((3, 10, 10)),
|
||||
data_sample=BaseDataElement(label=label)) for _ in range(bs)
|
||||
]
|
||||
predictions = [BaseDataElement(pred=pred) for _ in range(bs)]
|
||||
yield (data_batch, predictions)
|
||||
|
||||
|
@ -228,7 +231,10 @@ class TestEvaluator(TestCase):
|
|||
|
||||
size = 10
|
||||
|
||||
all_data = [(np.zeros((3, 10, 10)), BaseDataElement(label=1))
|
||||
for _ in range(size)]
|
||||
all_data = [
|
||||
dict(
|
||||
inputs=np.zeros((3, 10, 10)),
|
||||
data_sample=BaseDataElement(label=1)) for _ in range(size)
|
||||
]
|
||||
all_predictions = [BaseDataElement(pred=0) for _ in range(size)]
|
||||
evaluator.offline_evaluate(all_data, all_predictions)
|
||||
|
|
|
@ -157,18 +157,17 @@ class TestHook:
|
|||
|
||||
def test_end_of_epoch(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
|
||||
# last inner iter
|
||||
batch_idx = 1
|
||||
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
||||
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
||||
return_val = hook.end_of_epoch(runner, batch_idx)
|
||||
dataloader = Mock()
|
||||
dataloader.__len__ = Mock(return_value=2)
|
||||
return_val = hook.end_of_epoch(dataloader, batch_idx)
|
||||
assert return_val
|
||||
|
||||
# not the last inner iter
|
||||
batch_idx = 0
|
||||
return_val = hook.end_of_epoch(runner, batch_idx)
|
||||
return_val = hook.end_of_epoch(dataloader, batch_idx)
|
||||
assert not return_val
|
||||
|
||||
def test_is_last_train_epoch(self):
|
||||
|
|
|
@ -111,7 +111,7 @@ class TestLoggerHook:
|
|||
# Test end of the epoch.
|
||||
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
|
||||
logger_hook._log_train = MagicMock()
|
||||
runner.cur_dataloader = [0] * 5
|
||||
runner.train_loop.dataloader = [0] * 5
|
||||
batch_idx = 4
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
logger_hook._log_train.assert_called()
|
||||
|
@ -341,7 +341,9 @@ class TestLoggerHook:
|
|||
def _setup_runner(self):
|
||||
runner = MagicMock()
|
||||
runner.epoch = 1
|
||||
runner.cur_dataloader = [0] * 5
|
||||
runner.train_loop.dataloader = [0] * 5
|
||||
runner.val_loop.dataloader = [0] * 5
|
||||
runner.test_loop.dataloader = [0] * 5
|
||||
runner.iter = 10
|
||||
runner.train_loop.max_iters = 50
|
||||
logger = logging.getLogger()
|
||||
|
|
|
@ -16,70 +16,56 @@ class TestNaiveVisualizationHook:
|
|||
inputs = torch.randn(1, 3, 15, 15)
|
||||
batch_idx = 10
|
||||
# test with normalize, resize, pad
|
||||
gt_datasamples = [
|
||||
BaseDataElement(
|
||||
metainfo=dict(
|
||||
img_norm_cfg=dict(
|
||||
mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
|
||||
scale=(10, 10),
|
||||
pad_shape=(15, 15, 3),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg'))
|
||||
]
|
||||
gt_datasamples = BaseDataElement(
|
||||
metainfo=dict(
|
||||
img_norm_cfg=dict(
|
||||
mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
|
||||
scale=(10, 10),
|
||||
pad_shape=(15, 15, 3),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg'))
|
||||
pred_datasamples = [BaseDataElement()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
|
||||
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||
pred_datasamples)
|
||||
# test with resize, pad
|
||||
gt_datasamples = [
|
||||
BaseDataElement(
|
||||
metainfo=dict(
|
||||
scale=(10, 10),
|
||||
pad_shape=(15, 15, 3),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg')),
|
||||
]
|
||||
gt_datasamples = BaseDataElement(
|
||||
metainfo=dict(
|
||||
scale=(10, 10),
|
||||
pad_shape=(15, 15, 3),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg'))
|
||||
pred_datasamples = [BaseDataElement()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
|
||||
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||
pred_datasamples)
|
||||
# test with only resize
|
||||
gt_datasamples = [
|
||||
BaseDataElement(
|
||||
metainfo=dict(
|
||||
scale=(15, 15),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg')),
|
||||
]
|
||||
gt_datasamples = BaseDataElement(
|
||||
metainfo=dict(
|
||||
scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg'))
|
||||
pred_datasamples = [BaseDataElement()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
|
||||
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||
pred_datasamples)
|
||||
|
||||
# test with only pad
|
||||
gt_datasamples = [
|
||||
BaseDataElement(
|
||||
metainfo=dict(
|
||||
pad_shape=(15, 15, 3),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg')),
|
||||
]
|
||||
gt_datasamples = BaseDataElement(
|
||||
metainfo=dict(
|
||||
pad_shape=(15, 15, 3),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg'))
|
||||
pred_datasamples = [BaseDataElement()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
|
||||
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||
pred_datasamples)
|
||||
|
||||
# test no transform
|
||||
gt_datasamples = [
|
||||
BaseDataElement(
|
||||
metainfo=dict(ori_height=15, ori_width=15,
|
||||
img_path='tmp.jpg')),
|
||||
]
|
||||
gt_datasamples = BaseDataElement(
|
||||
metainfo=dict(ori_height=15, ori_width=15, img_path='tmp.jpg'))
|
||||
pred_datasamples = [BaseDataElement()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
|
||||
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
||||
pred_datasamples)
|
||||
|
|
|
@ -12,17 +12,18 @@ class TestDistSamplerSeedHook:
|
|||
# Test dataset sampler
|
||||
runner = Mock()
|
||||
runner.epoch = 1
|
||||
runner.cur_dataloader = Mock()
|
||||
runner.cur_dataloader.sampler = Mock()
|
||||
runner.cur_dataloader.sampler.set_epoch = Mock()
|
||||
runner.train_loop.dataloader = Mock()
|
||||
runner.train_loop.dataloader.sampler = Mock()
|
||||
runner.train_loop.dataloader.sampler.set_epoch = Mock()
|
||||
hook.before_train_epoch(runner)
|
||||
runner.cur_dataloader.sampler.set_epoch.assert_called()
|
||||
runner.train_loop.dataloader.sampler.set_epoch.assert_called()
|
||||
# Test batch sampler
|
||||
runner = Mock()
|
||||
runner.cur_dataloader = Mock()
|
||||
runner.cur_dataloader.sampler = Mock(spec_set=True)
|
||||
runner.cur_dataloader.batch_sampler = Mock()
|
||||
runner.cur_dataloader.batch_sampler.sampler = Mock()
|
||||
runner.cur_dataloader.batch_sampler.sampler.set_epoch = Mock()
|
||||
runner.train_loop.dataloader = Mock()
|
||||
runner.train_loop.dataloader.sampler = Mock(spec_set=True)
|
||||
runner.train_loop.dataloader.batch_sampler = Mock()
|
||||
runner.train_loop.dataloader.batch_sampler.sampler = Mock()
|
||||
runner.train_loop.dataloader.batch_sampler.sampler.set_epoch = Mock()
|
||||
hook.before_train_epoch(runner)
|
||||
runner.cur_dataloader.batch_sampler.sampler.set_epoch.assert_called()
|
||||
runner.train_loop.dataloader.\
|
||||
batch_sampler.sampler.set_epoch.assert_called()
|
||||
|
|
|
@ -36,7 +36,8 @@ class ToyModel(nn.Module):
|
|||
self.linear = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, data_batch, return_loss=False):
|
||||
inputs, labels = zip(*data_batch)
|
||||
inputs, labels = zip(
|
||||
*map(lambda x: (x['inputs'], x['data_sample']), data_batch))
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
inputs = torch.stack(inputs).to(device)
|
||||
labels = torch.stack(labels).to(device)
|
||||
|
@ -67,7 +68,7 @@ class CustomModelWrapper(nn.Module):
|
|||
|
||||
@DATASETS.register_module()
|
||||
class ToyDataset(Dataset):
|
||||
META = dict() # type: ignore
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
|
@ -75,7 +76,7 @@ class ToyDataset(Dataset):
|
|||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index], self.label[index]
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
|
|
Loading…
Reference in New Issue