mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhancement]: The ability to sort the dataset in tools/test.py. (#244)
* Add sort-data arg to test.py * Set is_sort_dataset to True. * Add a check for the possibility of sorting. * lint * Added mmdeploy.utils.dataset. * Add unit test Co-authored-by: SingleZombie <singlezombie@163.com>
This commit is contained in:
parent
ad72c19482
commit
3659b515eb
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
@ -8,6 +9,7 @@ import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmdeploy.utils import get_codebase
|
||||
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
|
||||
|
||||
|
||||
class BaseTask(metaclass=ABCMeta):
|
||||
@ -66,6 +68,7 @@ class BaseTask(metaclass=ABCMeta):
|
||||
def build_dataset(self,
|
||||
dataset_cfg: Union[str, mmcv.Config],
|
||||
dataset_type: str = 'val',
|
||||
is_sort_dataset: bool = True,
|
||||
**kwargs) -> Dataset:
|
||||
"""Build dataset for different codebase.
|
||||
|
||||
@ -74,12 +77,22 @@ class BaseTask(metaclass=ABCMeta):
|
||||
object.
|
||||
dataset_type (str): Specifying dataset type, e.g.: 'train', 'test',
|
||||
'val', defaults to 'val'.
|
||||
is_sort_dataset (bool): When 'True', the dataset will be sorted
|
||||
by image shape in ascending order if 'dataset_cfg'
|
||||
contains information about height and width.
|
||||
|
||||
Returns:
|
||||
Dataset: The built dataset.
|
||||
"""
|
||||
return self.codebase_class.build_dataset(dataset_cfg, dataset_type,
|
||||
**kwargs)
|
||||
dataset = self.codebase_class.build_dataset(dataset_cfg, dataset_type,
|
||||
**kwargs)
|
||||
if is_sort_dataset:
|
||||
if is_can_sort_dataset(dataset):
|
||||
sort_dataset(dataset)
|
||||
else:
|
||||
logging.info('Sorting the dataset by \'height\' and \'width\' '
|
||||
'is not possible.')
|
||||
return dataset
|
||||
|
||||
def build_dataloader(self, dataset: Dataset, samples_per_gpu: int,
|
||||
workers_per_gpu: int, **kwargs) -> DataLoader:
|
||||
|
36
mmdeploy/utils/dataset.py
Normal file
36
mmdeploy/utils/dataset.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def is_can_sort_dataset(dataset: Dataset) -> bool:
|
||||
"""Checking for the possibility of sorting the dataset by fields 'height'
|
||||
and 'width'.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The dataset.
|
||||
|
||||
Returns:
|
||||
bool: Is it possible or not to sort the dataset.
|
||||
"""
|
||||
is_sort_possible = \
|
||||
hasattr(dataset, 'data_infos') and \
|
||||
dataset.data_infos and \
|
||||
all(key in dataset.data_infos[0] for key in ('height', 'width'))
|
||||
return is_sort_possible
|
||||
|
||||
|
||||
def sort_dataset(dataset: Dataset) -> Dataset:
|
||||
"""Sorts the dataset by image height and width.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The dataset.
|
||||
|
||||
Returns:
|
||||
Dataset: Sorted dataset.
|
||||
"""
|
||||
sort_data_infos = sorted(
|
||||
dataset.data_infos, key=lambda e: (e['height'], e['width']))
|
||||
sort_img_ids = [e['id'] for e in sort_data_infos]
|
||||
dataset.data_infos = sort_data_infos
|
||||
dataset.img_ids = sort_img_ids
|
||||
return dataset
|
@ -7,15 +7,11 @@ import pytest
|
||||
import torch
|
||||
from mmocr.models.textdet.necks import FPNC
|
||||
|
||||
from mmdeploy.apis.onnxruntime import is_available as ort_available
|
||||
from mmdeploy.core import RewriterContext, patch_model
|
||||
from mmdeploy.utils import Backend
|
||||
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
|
||||
get_rewrite_outputs)
|
||||
|
||||
onnxruntime_skip = not ort_available()
|
||||
cuda_skip = not torch.cuda.is_available()
|
||||
|
||||
|
||||
class FPNCNeckModel(FPNC):
|
||||
|
||||
|
60
tests/test_utils/test_dataset.py
Normal file
60
tests/test_utils/test_dataset.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
|
||||
|
||||
|
||||
class DummyDataset():
|
||||
|
||||
def __init__(self, data_infos=None):
|
||||
if data_infos:
|
||||
self.data_infos = data_infos
|
||||
|
||||
|
||||
emtpy_dataset = DummyDataset()
|
||||
dataset = DummyDataset([{
|
||||
'id': 0,
|
||||
'height': 0,
|
||||
'width': 0
|
||||
}, {
|
||||
'id': 1,
|
||||
'height': 1,
|
||||
'width': 1
|
||||
}, {
|
||||
'id': 2,
|
||||
'height': 1,
|
||||
'width': 0
|
||||
}, {
|
||||
'id': 3,
|
||||
'height': 0,
|
||||
'width': 1
|
||||
}])
|
||||
|
||||
|
||||
class TestIsCanSortDataset:
|
||||
|
||||
def test_is_can_sort_dataset_false(self):
|
||||
assert not is_can_sort_dataset(emtpy_dataset)
|
||||
|
||||
def test_is_can_sort_dataset_True(self):
|
||||
assert is_can_sort_dataset(dataset)
|
||||
|
||||
|
||||
def test_sort_dataset():
|
||||
result_dataset = sort_dataset(dataset)
|
||||
assert result_dataset.data_infos == [{
|
||||
'id': 0,
|
||||
'height': 0,
|
||||
'width': 0
|
||||
}, {
|
||||
'id': 3,
|
||||
'height': 0,
|
||||
'width': 1
|
||||
}, {
|
||||
'id': 2,
|
||||
'height': 1,
|
||||
'width': 0
|
||||
}, {
|
||||
'id': 1,
|
||||
'height': 1,
|
||||
'width': 1
|
||||
}]
|
||||
assert result_dataset.img_ids == [0, 3, 2, 1]
|
Loading…
x
Reference in New Issue
Block a user