[Enhancement]: The ability to sort the dataset in tools/test.py. (#244)

* Add sort-data arg to test.py

* Set is_sort_dataset to True.

* Add a check for the possibility of sorting.

* lint

* Added mmdeploy.utils.dataset.

* Add unit test

Co-authored-by: SingleZombie <singlezombie@163.com>
This commit is contained in:
Semyon Bevzyuk 2021-12-09 11:21:50 +03:00 committed by GitHub
parent ad72c19482
commit 3659b515eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 111 additions and 6 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional, Sequence, Tuple, Union from typing import Any, Dict, Optional, Sequence, Tuple, Union
@ -8,6 +9,7 @@ import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from mmdeploy.utils import get_codebase from mmdeploy.utils import get_codebase
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
class BaseTask(metaclass=ABCMeta): class BaseTask(metaclass=ABCMeta):
@ -66,6 +68,7 @@ class BaseTask(metaclass=ABCMeta):
def build_dataset(self, def build_dataset(self,
dataset_cfg: Union[str, mmcv.Config], dataset_cfg: Union[str, mmcv.Config],
dataset_type: str = 'val', dataset_type: str = 'val',
is_sort_dataset: bool = True,
**kwargs) -> Dataset: **kwargs) -> Dataset:
"""Build dataset for different codebase. """Build dataset for different codebase.
@ -74,12 +77,22 @@ class BaseTask(metaclass=ABCMeta):
object. object.
dataset_type (str): Specifying dataset type, e.g.: 'train', 'test', dataset_type (str): Specifying dataset type, e.g.: 'train', 'test',
'val', defaults to 'val'. 'val', defaults to 'val'.
is_sort_dataset (bool): When 'True', the dataset will be sorted
by image shape in ascending order if 'dataset_cfg'
contains information about height and width.
Returns: Returns:
Dataset: The built dataset. Dataset: The built dataset.
""" """
return self.codebase_class.build_dataset(dataset_cfg, dataset_type, dataset = self.codebase_class.build_dataset(dataset_cfg, dataset_type,
**kwargs) **kwargs)
if is_sort_dataset:
if is_can_sort_dataset(dataset):
sort_dataset(dataset)
else:
logging.info('Sorting the dataset by \'height\' and \'width\' '
'is not possible.')
return dataset
def build_dataloader(self, dataset: Dataset, samples_per_gpu: int, def build_dataloader(self, dataset: Dataset, samples_per_gpu: int,
workers_per_gpu: int, **kwargs) -> DataLoader: workers_per_gpu: int, **kwargs) -> DataLoader:

36
mmdeploy/utils/dataset.py Normal file
View File

@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch.utils.data import Dataset
def is_can_sort_dataset(dataset: Dataset) -> bool:
"""Checking for the possibility of sorting the dataset by fields 'height'
and 'width'.
Args:
dataset (Dataset): The dataset.
Returns:
bool: Is it possible or not to sort the dataset.
"""
is_sort_possible = \
hasattr(dataset, 'data_infos') and \
dataset.data_infos and \
all(key in dataset.data_infos[0] for key in ('height', 'width'))
return is_sort_possible
def sort_dataset(dataset: Dataset) -> Dataset:
"""Sorts the dataset by image height and width.
Args:
dataset (Dataset): The dataset.
Returns:
Dataset: Sorted dataset.
"""
sort_data_infos = sorted(
dataset.data_infos, key=lambda e: (e['height'], e['width']))
sort_img_ids = [e['id'] for e in sort_data_infos]
dataset.data_infos = sort_data_infos
dataset.img_ids = sort_img_ids
return dataset

View File

@ -7,15 +7,11 @@ import pytest
import torch import torch
from mmocr.models.textdet.necks import FPNC from mmocr.models.textdet.necks import FPNC
from mmdeploy.apis.onnxruntime import is_available as ort_available
from mmdeploy.core import RewriterContext, patch_model from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import Backend from mmdeploy.utils import Backend
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs, from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
get_rewrite_outputs) get_rewrite_outputs)
onnxruntime_skip = not ort_available()
cuda_skip = not torch.cuda.is_available()
class FPNCNeckModel(FPNC): class FPNCNeckModel(FPNC):

View File

@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
class DummyDataset():
def __init__(self, data_infos=None):
if data_infos:
self.data_infos = data_infos
emtpy_dataset = DummyDataset()
dataset = DummyDataset([{
'id': 0,
'height': 0,
'width': 0
}, {
'id': 1,
'height': 1,
'width': 1
}, {
'id': 2,
'height': 1,
'width': 0
}, {
'id': 3,
'height': 0,
'width': 1
}])
class TestIsCanSortDataset:
def test_is_can_sort_dataset_false(self):
assert not is_can_sort_dataset(emtpy_dataset)
def test_is_can_sort_dataset_True(self):
assert is_can_sort_dataset(dataset)
def test_sort_dataset():
result_dataset = sort_dataset(dataset)
assert result_dataset.data_infos == [{
'id': 0,
'height': 0,
'width': 0
}, {
'id': 3,
'height': 0,
'width': 1
}, {
'id': 2,
'height': 1,
'width': 0
}, {
'id': 1,
'height': 1,
'width': 1
}]
assert result_dataset.img_ids == [0, 3, 2, 1]