[Refactor] refactor file io and add ut (#582)

* remove fileclient

* add ut
pull/603/head
Yixiao Fang 2022-11-18 16:44:54 +08:00 committed by GitHub
parent 1835af3a91
commit a4e9f8f0a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 25 deletions

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Union
import numpy as np
from mmcls.datasets import CustomDataset
from mmengine import FileClient
from mmengine.fileio import join_path
from mmselfsup.registry import DATASETS
@ -48,7 +48,7 @@ class ImageList(CustomDataset):
...
Args:
ann_file (str): Annotation file path. Defaults to None.
ann_file (str): Annotation file path.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str): The root directory for ``data_prefix`` and
@ -62,7 +62,7 @@ class ImageList(CustomDataset):
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
def __init__(self,
ann_file: str = '',
ann_file: str,
metainfo: Optional[dict] = None,
data_root: str = '',
data_prefix: Union[str, dict] = '',
@ -76,32 +76,24 @@ class ImageList(CustomDataset):
**kwargs)
def load_data_list(self) -> List[dict]:
"""Rewrite load_data_list() function for supporting a list of
annotation files and unlabeled data.
"""Rewrite load_data_list() function for supporting annotation files
with unlabeled data.
Returns:
List[dict]: A list of data information.
"""
if self.img_prefix is not None:
file_client = FileClient.infer_client(uri=self.img_prefix)
assert self.ann_file is not None
if not isinstance(self.ann_file, list):
self.ann_file = [self.ann_file]
assert self.ann_file != ''
with open(self.ann_file, 'r') as f:
self.samples = f.readlines()
self.has_labels = len(self.samples[0].split()) == 2
data_list = []
for ann_file in self.ann_file:
with open(ann_file, 'r') as f:
self.samples = f.readlines()
self.has_labels = len(self.samples[0].split()) == 2
for sample in self.samples:
info = {'img_prefix': self.img_prefix}
sample = sample.split()
info['img_path'] = file_client.join_path(
self.img_prefix, sample[0])
info['img_info'] = {'filename': sample[0]}
labels = sample[1] if self.has_labels else -1
info['gt_label'] = np.array(labels, dtype=np.int64)
data_list.append(info)
for sample in self.samples:
info = {'img_prefix': self.img_prefix}
sample = sample.split()
info['img_path'] = join_path(self.img_prefix, sample[0])
info['img_info'] = {'filename': sample[0]}
labels = sample[1] if self.has_labels else -1
info['gt_label'] = np.array(labels, dtype=np.int64)
data_list.append(info)
return data_list

View File

@ -0,0 +1,2 @@
color.jpg
gray.jpg

View File

@ -0,0 +1,37 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import numpy as np
import pytest
from mmselfsup.datasets import ImageList
from mmselfsup.utils import register_all_modules
# dataset settings
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(type='RandomResizedCrop', size=4)
]
def test_image_list_dataset():
register_all_modules()
data = dict(
ann_file='',
metainfo=None,
data_root=osp.join(osp.dirname(__file__), '..', 'data'),
pipeline=train_pipeline)
with pytest.raises(AssertionError):
dataset = ImageList(**data)
ann_file = osp.join(
osp.dirname(__file__), '..', 'data', 'data_list_no_label.txt')
data = dict(
ann_file=ann_file,
metainfo=None,
data_root=osp.join(osp.dirname(__file__), '..', 'data'),
pipeline=train_pipeline)
dataset = ImageList(**data)
assert len(dataset) == 2
assert dataset[0]['gt_label'] == np.array(-1)