From 3659b515eb96e7283d7ea5ef2637ce2d2875ab4d Mon Sep 17 00:00:00 2001 From: Semyon Bevzyuk Date: Thu, 9 Dec 2021 11:21:50 +0300 Subject: [PATCH] [Enhancement]: The ability to sort the dataset in tools/test.py. (#244) * Add sort-data arg to test.py * Set is_sort_dataset to True. * Add a check for the possibility of sorting. * lint * Added mmdeploy.utils.dataset. * Add unit test Co-authored-by: SingleZombie --- mmdeploy/codebase/base/task.py | 17 +++++- mmdeploy/utils/dataset.py | 36 +++++++++++ .../test_mmocr/test_mmocr_models.py | 4 -- tests/test_utils/test_dataset.py | 60 +++++++++++++++++++ 4 files changed, 111 insertions(+), 6 deletions(-) create mode 100644 mmdeploy/utils/dataset.py create mode 100644 tests/test_utils/test_dataset.py diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index e557f8e48..19409b0bd 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging from abc import ABCMeta, abstractmethod from typing import Any, Dict, Optional, Sequence, Tuple, Union @@ -8,6 +9,7 @@ import torch from torch.utils.data import DataLoader, Dataset from mmdeploy.utils import get_codebase +from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset class BaseTask(metaclass=ABCMeta): @@ -66,6 +68,7 @@ class BaseTask(metaclass=ABCMeta): def build_dataset(self, dataset_cfg: Union[str, mmcv.Config], dataset_type: str = 'val', + is_sort_dataset: bool = True, **kwargs) -> Dataset: """Build dataset for different codebase. @@ -74,12 +77,22 @@ class BaseTask(metaclass=ABCMeta): object. dataset_type (str): Specifying dataset type, e.g.: 'train', 'test', 'val', defaults to 'val'. + is_sort_dataset (bool): When 'True', the dataset will be sorted + by image shape in ascending order if 'dataset_cfg' + contains information about height and width. Returns: Dataset: The built dataset. """ - return self.codebase_class.build_dataset(dataset_cfg, dataset_type, - **kwargs) + dataset = self.codebase_class.build_dataset(dataset_cfg, dataset_type, + **kwargs) + if is_sort_dataset: + if is_can_sort_dataset(dataset): + sort_dataset(dataset) + else: + logging.info('Sorting the dataset by \'height\' and \'width\' ' + 'is not possible.') + return dataset def build_dataloader(self, dataset: Dataset, samples_per_gpu: int, workers_per_gpu: int, **kwargs) -> DataLoader: diff --git a/mmdeploy/utils/dataset.py b/mmdeploy/utils/dataset.py new file mode 100644 index 000000000..c4bd2f734 --- /dev/null +++ b/mmdeploy/utils/dataset.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch.utils.data import Dataset + + +def is_can_sort_dataset(dataset: Dataset) -> bool: + """Checking for the possibility of sorting the dataset by fields 'height' + and 'width'. + + Args: + dataset (Dataset): The dataset. + + Returns: + bool: Is it possible or not to sort the dataset. + """ + is_sort_possible = \ + hasattr(dataset, 'data_infos') and \ + dataset.data_infos and \ + all(key in dataset.data_infos[0] for key in ('height', 'width')) + return is_sort_possible + + +def sort_dataset(dataset: Dataset) -> Dataset: + """Sorts the dataset by image height and width. + + Args: + dataset (Dataset): The dataset. + + Returns: + Dataset: Sorted dataset. + """ + sort_data_infos = sorted( + dataset.data_infos, key=lambda e: (e['height'], e['width'])) + sort_img_ids = [e['id'] for e in sort_data_infos] + dataset.data_infos = sort_data_infos + dataset.img_ids = sort_img_ids + return dataset diff --git a/tests/test_codebase/test_mmocr/test_mmocr_models.py b/tests/test_codebase/test_mmocr/test_mmocr_models.py index 93ad7a34e..15e812ba2 100644 --- a/tests/test_codebase/test_mmocr/test_mmocr_models.py +++ b/tests/test_codebase/test_mmocr/test_mmocr_models.py @@ -7,15 +7,11 @@ import pytest import torch from mmocr.models.textdet.necks import FPNC -from mmdeploy.apis.onnxruntime import is_available as ort_available from mmdeploy.core import RewriterContext, patch_model from mmdeploy.utils import Backend from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs, get_rewrite_outputs) -onnxruntime_skip = not ort_available() -cuda_skip = not torch.cuda.is_available() - class FPNCNeckModel(FPNC): diff --git a/tests/test_utils/test_dataset.py b/tests/test_utils/test_dataset.py new file mode 100644 index 000000000..877b8b6f4 --- /dev/null +++ b/tests/test_utils/test_dataset.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset + + +class DummyDataset(): + + def __init__(self, data_infos=None): + if data_infos: + self.data_infos = data_infos + + +emtpy_dataset = DummyDataset() +dataset = DummyDataset([{ + 'id': 0, + 'height': 0, + 'width': 0 +}, { + 'id': 1, + 'height': 1, + 'width': 1 +}, { + 'id': 2, + 'height': 1, + 'width': 0 +}, { + 'id': 3, + 'height': 0, + 'width': 1 +}]) + + +class TestIsCanSortDataset: + + def test_is_can_sort_dataset_false(self): + assert not is_can_sort_dataset(emtpy_dataset) + + def test_is_can_sort_dataset_True(self): + assert is_can_sort_dataset(dataset) + + +def test_sort_dataset(): + result_dataset = sort_dataset(dataset) + assert result_dataset.data_infos == [{ + 'id': 0, + 'height': 0, + 'width': 0 + }, { + 'id': 3, + 'height': 0, + 'width': 1 + }, { + 'id': 2, + 'height': 1, + 'width': 0 + }, { + 'id': 1, + 'height': 1, + 'width': 1 + }] + assert result_dataset.img_ids == [0, 3, 2, 1]