2022-04-02 20:01:06 +08:00

132 lines
3.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import numpy as np
from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from easycv.utils.bbox_util import batched_cxcywh2xyxy_with_shape
from .voc import DetSourceVOC
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
label_formats = ['.txt']
@DATASOURCES.register_module
class DetSourceRaw(DetSourceVOC):
"""
data dir is as follows:
```
|- data_dir
|-images
|-1.jpg
|-...
|-labels
|-1.txt
|-...
```
Label txt file is as follows:
The first column is the label id, and columns 2 to 5 are
coordinates relative to the image width and height [x_center, y_center, bbox_w, bbox_h].
```
15 0.519398 0.544087 0.476359 0.572061
2 0.501859 0.820726 0.996281 0.332178
...
```
Example:
data_source = DetSourceRaw(
img_root_path='/your/data_dir/images',
label_root_path='/your/data_dir/labels',
)
"""
def __init__(self,
img_root_path,
label_root_path,
cache_at_init=False,
cache_on_the_fly=False,
delimeter=' ',
**kwargs):
"""
Args:
img_root_path: images dir path
label_root_path: labels dir path
cache_at_init: if set True, will cache in memory in __init__ for faster training
cache_on_the_fly: if set True, will cache in memroy during training
"""
self.cache_on_the_fly = cache_on_the_fly
self.cache_at_init = cache_at_init
self.delimeter = delimeter
self.img_root_path = img_root_path
self.label_root_path = label_root_path
self.img_files = [
os.path.join(self.img_root_path, i)
for i in io.listdir(self.img_root_path, recursive=True)
]
self.img_files = sorted([
i for i in self.img_files
if os.path.splitext(i)[-1].lower() in img_formats
])
assert len(
self.img_files) > 0, 'No images found in %s' % self.img_root_path
self.label_files = [
os.path.join(self.label_root_path, i)
for i in io.listdir(self.label_root_path, recursive=True)
]
self.label_files = sorted([
i for i in self.label_files
if os.path.splitext(i)[-1].lower() in label_formats
])
assert len(self.label_files
) > 0, 'No labels found in %s.' % self.label_root_path
# TODO: filter bad sample
self.samples_list = self.build_samples(
list(zip(self.img_files, self.label_files)))
def get_source_info(self, img_and_label):
img_path = img_and_label[0]
label_path = img_and_label[1]
source_info = {'filename': img_path}
with io.open(label_path, 'r') as f:
labels_and_boxes = np.array(
[line.split(self.delimeter) for line in f.read().splitlines()])
if not len(labels_and_boxes):
return {}
labels = labels_and_boxes[:, 0]
bboxes = labels_and_boxes[:, 1:]
source_info.update({
'gt_bboxes': np.array(bboxes, dtype=np.float32),
'gt_labels': labels.astype(np.int64)
})
return source_info
def _build_sample_from_source_info(self, source_info):
if 'filename' not in source_info:
return {}
result_dict = source_info
img_info = self.load_image(source_info['filename'])
result_dict.update(img_info)
result_dict.update({
'img_fields': ['img'],
'bbox_fields': ['gt_bboxes']
})
# shape: h, w
result_dict['gt_bboxes'] = batched_cxcywh2xyxy_with_shape(
result_dict['gt_bboxes'], shape=img_info['img_shape'][:2])
return result_dict