mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhancement]: The ability to sort the dataset in tools/test.py. (#244)
* Add sort-data arg to test.py * Set is_sort_dataset to True. * Add a check for the possibility of sorting. * lint * Added mmdeploy.utils.dataset. * Add unit test Co-authored-by: SingleZombie <singlezombie@163.com>
This commit is contained in:
parent
ad72c19482
commit
3659b515eb
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import logging
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
@ -8,6 +9,7 @@ import torch
|
|||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from mmdeploy.utils import get_codebase
|
from mmdeploy.utils import get_codebase
|
||||||
|
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(metaclass=ABCMeta):
|
class BaseTask(metaclass=ABCMeta):
|
||||||
@ -66,6 +68,7 @@ class BaseTask(metaclass=ABCMeta):
|
|||||||
def build_dataset(self,
|
def build_dataset(self,
|
||||||
dataset_cfg: Union[str, mmcv.Config],
|
dataset_cfg: Union[str, mmcv.Config],
|
||||||
dataset_type: str = 'val',
|
dataset_type: str = 'val',
|
||||||
|
is_sort_dataset: bool = True,
|
||||||
**kwargs) -> Dataset:
|
**kwargs) -> Dataset:
|
||||||
"""Build dataset for different codebase.
|
"""Build dataset for different codebase.
|
||||||
|
|
||||||
@ -74,12 +77,22 @@ class BaseTask(metaclass=ABCMeta):
|
|||||||
object.
|
object.
|
||||||
dataset_type (str): Specifying dataset type, e.g.: 'train', 'test',
|
dataset_type (str): Specifying dataset type, e.g.: 'train', 'test',
|
||||||
'val', defaults to 'val'.
|
'val', defaults to 'val'.
|
||||||
|
is_sort_dataset (bool): When 'True', the dataset will be sorted
|
||||||
|
by image shape in ascending order if 'dataset_cfg'
|
||||||
|
contains information about height and width.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset: The built dataset.
|
Dataset: The built dataset.
|
||||||
"""
|
"""
|
||||||
return self.codebase_class.build_dataset(dataset_cfg, dataset_type,
|
dataset = self.codebase_class.build_dataset(dataset_cfg, dataset_type,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
if is_sort_dataset:
|
||||||
|
if is_can_sort_dataset(dataset):
|
||||||
|
sort_dataset(dataset)
|
||||||
|
else:
|
||||||
|
logging.info('Sorting the dataset by \'height\' and \'width\' '
|
||||||
|
'is not possible.')
|
||||||
|
return dataset
|
||||||
|
|
||||||
def build_dataloader(self, dataset: Dataset, samples_per_gpu: int,
|
def build_dataloader(self, dataset: Dataset, samples_per_gpu: int,
|
||||||
workers_per_gpu: int, **kwargs) -> DataLoader:
|
workers_per_gpu: int, **kwargs) -> DataLoader:
|
||||||
|
36
mmdeploy/utils/dataset.py
Normal file
36
mmdeploy/utils/dataset.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def is_can_sort_dataset(dataset: Dataset) -> bool:
|
||||||
|
"""Checking for the possibility of sorting the dataset by fields 'height'
|
||||||
|
and 'width'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (Dataset): The dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Is it possible or not to sort the dataset.
|
||||||
|
"""
|
||||||
|
is_sort_possible = \
|
||||||
|
hasattr(dataset, 'data_infos') and \
|
||||||
|
dataset.data_infos and \
|
||||||
|
all(key in dataset.data_infos[0] for key in ('height', 'width'))
|
||||||
|
return is_sort_possible
|
||||||
|
|
||||||
|
|
||||||
|
def sort_dataset(dataset: Dataset) -> Dataset:
|
||||||
|
"""Sorts the dataset by image height and width.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (Dataset): The dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset: Sorted dataset.
|
||||||
|
"""
|
||||||
|
sort_data_infos = sorted(
|
||||||
|
dataset.data_infos, key=lambda e: (e['height'], e['width']))
|
||||||
|
sort_img_ids = [e['id'] for e in sort_data_infos]
|
||||||
|
dataset.data_infos = sort_data_infos
|
||||||
|
dataset.img_ids = sort_img_ids
|
||||||
|
return dataset
|
@ -7,15 +7,11 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from mmocr.models.textdet.necks import FPNC
|
from mmocr.models.textdet.necks import FPNC
|
||||||
|
|
||||||
from mmdeploy.apis.onnxruntime import is_available as ort_available
|
|
||||||
from mmdeploy.core import RewriterContext, patch_model
|
from mmdeploy.core import RewriterContext, patch_model
|
||||||
from mmdeploy.utils import Backend
|
from mmdeploy.utils import Backend
|
||||||
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
|
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
|
||||||
get_rewrite_outputs)
|
get_rewrite_outputs)
|
||||||
|
|
||||||
onnxruntime_skip = not ort_available()
|
|
||||||
cuda_skip = not torch.cuda.is_available()
|
|
||||||
|
|
||||||
|
|
||||||
class FPNCNeckModel(FPNC):
|
class FPNCNeckModel(FPNC):
|
||||||
|
|
||||||
|
60
tests/test_utils/test_dataset.py
Normal file
60
tests/test_utils/test_dataset.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataset():
|
||||||
|
|
||||||
|
def __init__(self, data_infos=None):
|
||||||
|
if data_infos:
|
||||||
|
self.data_infos = data_infos
|
||||||
|
|
||||||
|
|
||||||
|
emtpy_dataset = DummyDataset()
|
||||||
|
dataset = DummyDataset([{
|
||||||
|
'id': 0,
|
||||||
|
'height': 0,
|
||||||
|
'width': 0
|
||||||
|
}, {
|
||||||
|
'id': 1,
|
||||||
|
'height': 1,
|
||||||
|
'width': 1
|
||||||
|
}, {
|
||||||
|
'id': 2,
|
||||||
|
'height': 1,
|
||||||
|
'width': 0
|
||||||
|
}, {
|
||||||
|
'id': 3,
|
||||||
|
'height': 0,
|
||||||
|
'width': 1
|
||||||
|
}])
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsCanSortDataset:
|
||||||
|
|
||||||
|
def test_is_can_sort_dataset_false(self):
|
||||||
|
assert not is_can_sort_dataset(emtpy_dataset)
|
||||||
|
|
||||||
|
def test_is_can_sort_dataset_True(self):
|
||||||
|
assert is_can_sort_dataset(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sort_dataset():
|
||||||
|
result_dataset = sort_dataset(dataset)
|
||||||
|
assert result_dataset.data_infos == [{
|
||||||
|
'id': 0,
|
||||||
|
'height': 0,
|
||||||
|
'width': 0
|
||||||
|
}, {
|
||||||
|
'id': 3,
|
||||||
|
'height': 0,
|
||||||
|
'width': 1
|
||||||
|
}, {
|
||||||
|
'id': 2,
|
||||||
|
'height': 1,
|
||||||
|
'width': 0
|
||||||
|
}, {
|
||||||
|
'id': 1,
|
||||||
|
'height': 1,
|
||||||
|
'width': 1
|
||||||
|
}]
|
||||||
|
assert result_dataset.img_ids == [0, 3, 2, 1]
|
Loading…
x
Reference in New Issue
Block a user