[Enhancement]: The ability to sort the dataset in tools/test.py. (#244)

* Add sort-data arg to test.py

* Set is_sort_dataset to True.

* Add a check for the possibility of sorting.

* lint

* Added mmdeploy.utils.dataset.

* Add unit test

Co-authored-by: SingleZombie <singlezombie@163.com>
This commit is contained in:
Semyon Bevzyuk 2021-12-09 11:21:50 +03:00 committed by GitHub
parent ad72c19482
commit 3659b515eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 111 additions and 6 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional, Sequence, Tuple, Union
@ -8,6 +9,7 @@ import torch
from torch.utils.data import DataLoader, Dataset
from mmdeploy.utils import get_codebase
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
class BaseTask(metaclass=ABCMeta):
@ -66,6 +68,7 @@ class BaseTask(metaclass=ABCMeta):
def build_dataset(self,
dataset_cfg: Union[str, mmcv.Config],
dataset_type: str = 'val',
is_sort_dataset: bool = True,
**kwargs) -> Dataset:
"""Build dataset for different codebase.
@ -74,12 +77,22 @@ class BaseTask(metaclass=ABCMeta):
object.
dataset_type (str): Specifying dataset type, e.g.: 'train', 'test',
'val', defaults to 'val'.
is_sort_dataset (bool): When 'True', the dataset will be sorted
by image shape in ascending order if 'dataset_cfg'
contains information about height and width.
Returns:
Dataset: The built dataset.
"""
return self.codebase_class.build_dataset(dataset_cfg, dataset_type,
**kwargs)
dataset = self.codebase_class.build_dataset(dataset_cfg, dataset_type,
**kwargs)
if is_sort_dataset:
if is_can_sort_dataset(dataset):
sort_dataset(dataset)
else:
logging.info('Sorting the dataset by \'height\' and \'width\' '
'is not possible.')
return dataset
def build_dataloader(self, dataset: Dataset, samples_per_gpu: int,
workers_per_gpu: int, **kwargs) -> DataLoader:

36
mmdeploy/utils/dataset.py Normal file
View File

@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch.utils.data import Dataset
def is_can_sort_dataset(dataset: Dataset) -> bool:
"""Checking for the possibility of sorting the dataset by fields 'height'
and 'width'.
Args:
dataset (Dataset): The dataset.
Returns:
bool: Is it possible or not to sort the dataset.
"""
is_sort_possible = \
hasattr(dataset, 'data_infos') and \
dataset.data_infos and \
all(key in dataset.data_infos[0] for key in ('height', 'width'))
return is_sort_possible
def sort_dataset(dataset: Dataset) -> Dataset:
"""Sorts the dataset by image height and width.
Args:
dataset (Dataset): The dataset.
Returns:
Dataset: Sorted dataset.
"""
sort_data_infos = sorted(
dataset.data_infos, key=lambda e: (e['height'], e['width']))
sort_img_ids = [e['id'] for e in sort_data_infos]
dataset.data_infos = sort_data_infos
dataset.img_ids = sort_img_ids
return dataset

View File

@ -7,15 +7,11 @@ import pytest
import torch
from mmocr.models.textdet.necks import FPNC
from mmdeploy.apis.onnxruntime import is_available as ort_available
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import Backend
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
get_rewrite_outputs)
onnxruntime_skip = not ort_available()
cuda_skip = not torch.cuda.is_available()
class FPNCNeckModel(FPNC):

View File

@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
class DummyDataset():
def __init__(self, data_infos=None):
if data_infos:
self.data_infos = data_infos
emtpy_dataset = DummyDataset()
dataset = DummyDataset([{
'id': 0,
'height': 0,
'width': 0
}, {
'id': 1,
'height': 1,
'width': 1
}, {
'id': 2,
'height': 1,
'width': 0
}, {
'id': 3,
'height': 0,
'width': 1
}])
class TestIsCanSortDataset:
def test_is_can_sort_dataset_false(self):
assert not is_can_sort_dataset(emtpy_dataset)
def test_is_can_sort_dataset_True(self):
assert is_can_sort_dataset(dataset)
def test_sort_dataset():
result_dataset = sort_dataset(dataset)
assert result_dataset.data_infos == [{
'id': 0,
'height': 0,
'width': 0
}, {
'id': 3,
'height': 0,
'width': 1
}, {
'id': 2,
'height': 1,
'width': 0
}, {
'id': 1,
'height': 1,
'width': 1
}]
assert result_dataset.img_ids == [0, 3, 2, 1]