mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
262 lines
8.6 KiB
Python
262 lines
8.6 KiB
Python
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
import logging
|
||
|
import os
|
||
|
import time
|
||
|
import xml.etree.ElementTree as ET
|
||
|
from multiprocessing import Pool, cpu_count
|
||
|
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
from mmcv.runner.dist_utils import get_dist_info
|
||
|
from PIL import Image
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from easycv.datasets.registry import DATASOURCES
|
||
|
from easycv.file import io
|
||
|
from easycv.utils.constant import MAX_READ_IMAGE_TRY_TIMES
|
||
|
|
||
|
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
|
||
|
|
||
|
|
||
|
def parse_xml(xml_path, classes):
|
||
|
with io.open(xml_path, 'r') as f:
|
||
|
tree = ET.parse(f)
|
||
|
root = tree.getroot()
|
||
|
size = root.find('size')
|
||
|
w = int(size.find('width').text)
|
||
|
h = int(size.find('height').text)
|
||
|
gt_bboxes = []
|
||
|
gt_labels = []
|
||
|
for obj in root.iter('object'):
|
||
|
difficult = obj.find('difficult').text
|
||
|
cls = obj.find('name').text
|
||
|
if int(difficult) == 1:
|
||
|
continue
|
||
|
if cls not in classes:
|
||
|
logging.warning(
|
||
|
'class: %s not in given class list, skip the object!' %
|
||
|
cls)
|
||
|
continue
|
||
|
cls_id = classes.index(cls)
|
||
|
xmlbox = obj.find('bndbox')
|
||
|
box = (float(xmlbox.find('xmin').text),
|
||
|
float(xmlbox.find('ymin').text),
|
||
|
float(xmlbox.find('xmax').text),
|
||
|
float(xmlbox.find('ymax').text))
|
||
|
gt_bboxes.append(box)
|
||
|
gt_labels.append(cls_id)
|
||
|
|
||
|
if len(gt_bboxes) == 0:
|
||
|
gt_bboxes = np.zeros((0, 5), dtype=np.float32)
|
||
|
|
||
|
img_info = {
|
||
|
'gt_bboxes': np.array(gt_bboxes, dtype=np.float32),
|
||
|
'gt_labels': np.array(gt_labels, dtype=np.int64)
|
||
|
}
|
||
|
|
||
|
return img_info
|
||
|
|
||
|
|
||
|
@DATASOURCES.register_module
|
||
|
class DetSourceVOC(object):
|
||
|
"""
|
||
|
data dir is as follows:
|
||
|
```
|
||
|
|- voc_data
|
||
|
|-ImageSets
|
||
|
|-Main
|
||
|
|-train.txt
|
||
|
|-...
|
||
|
|-JPEGImages
|
||
|
|-00001.jpg
|
||
|
|-...
|
||
|
|-Annotations
|
||
|
|-00001.xml
|
||
|
|-...
|
||
|
|
||
|
```
|
||
|
Example1:
|
||
|
data_source = DetSourceVOC(
|
||
|
path='/your/voc_data/ImageSets/Main/train.txt',
|
||
|
classes=${VOC_CLASSES},
|
||
|
)
|
||
|
Example1:
|
||
|
data_source = DetSourceVOC(
|
||
|
path='/your/voc_data/train.txt',
|
||
|
classes=${VOC_CLASSES},
|
||
|
img_root_path='/your/voc_data/images',
|
||
|
img_root_path='/your/voc_data/annotations'
|
||
|
)
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
path,
|
||
|
classes=[],
|
||
|
img_root_path=None,
|
||
|
label_root_path=None,
|
||
|
cache_at_init=False,
|
||
|
cache_on_the_fly=False,
|
||
|
img_suffix='.jpg',
|
||
|
label_suffix='.xml',
|
||
|
**kwargs):
|
||
|
"""
|
||
|
Args:
|
||
|
path: path of img id list file in ImageSets/Main/
|
||
|
classes: classes list
|
||
|
img_root_path: image dir path, if None, default to detect the image dir by the relative path of the `path`
|
||
|
according to the VOC data format.
|
||
|
label_root_path: label dir path, if None, default to detect the label dir by the relative path of the `path`
|
||
|
according to the VOC data format.
|
||
|
cache_at_init: if set True, will cache in memory in __init__ for faster training
|
||
|
cache_on_the_fly: if set True, will cache in memroy during training
|
||
|
img_suffix: suffix of image file
|
||
|
label_suffix: suffix of label file
|
||
|
"""
|
||
|
self.CLASSES = classes
|
||
|
self.rank, self.world_size = get_dist_info()
|
||
|
self.path = path
|
||
|
self.img_root_path = img_root_path
|
||
|
self.label_root_path = label_root_path
|
||
|
self.cache_at_init = cache_at_init
|
||
|
self.cache_on_the_fly = cache_on_the_fly
|
||
|
|
||
|
if not img_root_path:
|
||
|
self.img_root_path = os.path.join(
|
||
|
self.path.split('ImageSets/Main')[0], 'JPEGImages')
|
||
|
if not self.label_root_path:
|
||
|
self.label_root_path = os.path.join(
|
||
|
self.path.split('ImageSets/Main')[0], 'Annotations')
|
||
|
|
||
|
imgs_path_list = []
|
||
|
labels_path_list = []
|
||
|
with io.open(self.path, 'r') as t:
|
||
|
id_lines = t.read().splitlines()
|
||
|
for id_line in id_lines:
|
||
|
img_id = id_line.strip().split(' ')[0]
|
||
|
img_path = os.path.join(self.img_root_path,
|
||
|
img_id + img_suffix)
|
||
|
imgs_path_list.append(img_path)
|
||
|
|
||
|
label_path = os.path.join(self.label_root_path,
|
||
|
img_id + label_suffix)
|
||
|
labels_path_list.append(label_path)
|
||
|
|
||
|
# TODO: filter bad sample
|
||
|
self.samples_list = self.build_samples(
|
||
|
list(zip(imgs_path_list, labels_path_list)))
|
||
|
|
||
|
def get_source_info(self, img_and_label):
|
||
|
img_path = img_and_label[0]
|
||
|
label_path = img_and_label[1]
|
||
|
source_info = parse_xml(label_path, self.CLASSES)
|
||
|
source_info.update({'filename': img_path})
|
||
|
|
||
|
return source_info
|
||
|
|
||
|
def _build_sample_from_source_info(self, source_info):
|
||
|
if 'filename' not in source_info:
|
||
|
return {}
|
||
|
|
||
|
result_dict = source_info
|
||
|
|
||
|
img_info = self.load_image(source_info['filename'])
|
||
|
result_dict.update(img_info)
|
||
|
|
||
|
result_dict.update({
|
||
|
'img_fields': ['img'],
|
||
|
'bbox_fields': ['gt_bboxes']
|
||
|
})
|
||
|
|
||
|
return result_dict
|
||
|
|
||
|
def build_sample(self, data):
|
||
|
result_dict = self.get_source_info(data)
|
||
|
|
||
|
if self.cache_at_init:
|
||
|
result_dict = self._build_sample_from_source_info(result_dict)
|
||
|
|
||
|
return result_dict
|
||
|
|
||
|
def build_samples(self, iterable):
|
||
|
samples_list = []
|
||
|
proc_num = int(cpu_count() / 2)
|
||
|
with Pool(processes=proc_num) as p:
|
||
|
with tqdm(total=len(iterable), desc='Scanning images') as pbar:
|
||
|
for _, result_dict in enumerate(
|
||
|
p.imap_unordered(self.build_sample, iterable)):
|
||
|
if result_dict:
|
||
|
samples_list.append(result_dict)
|
||
|
pbar.update()
|
||
|
|
||
|
return samples_list
|
||
|
|
||
|
def load_image(self, img_path):
|
||
|
result = {}
|
||
|
try_cnt = 0
|
||
|
img = None
|
||
|
while try_cnt < MAX_READ_IMAGE_TRY_TIMES:
|
||
|
try:
|
||
|
with io.open(img_path, 'rb') as infile:
|
||
|
# cv2.imdecode may corrupt when the img is broken
|
||
|
image = Image.open(infile)
|
||
|
img = cv2.cvtColor(
|
||
|
np.asarray(image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
||
|
assert img is not None, 'Image load error, try %s : %s' % (
|
||
|
try_cnt, img_path)
|
||
|
break
|
||
|
except:
|
||
|
time.sleep(2)
|
||
|
try_cnt += 1
|
||
|
|
||
|
if img is None:
|
||
|
raise ValueError('Read Image Times Out: ' + img_path)
|
||
|
|
||
|
result['img'] = img.astype(np.float32)
|
||
|
result['img_shape'] = img.shape # h, w, c
|
||
|
result['ori_img_shape'] = img.shape
|
||
|
|
||
|
return result
|
||
|
|
||
|
def get_length(self):
|
||
|
return len(self.samples_list)
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.get_length()
|
||
|
|
||
|
def get_ann_info(self, idx):
|
||
|
"""
|
||
|
Get raw annotation info, include bounding boxes, labels and so on.
|
||
|
`bboxes` format is as [x1, y1, x2, y2] without normalization.
|
||
|
"""
|
||
|
sample_info = self.samples_list[idx]
|
||
|
if sample_info.get('gt_labels', None) is None:
|
||
|
sample_info = self._build_sample_from_source_info(sample_info)
|
||
|
if self.cache_at_init or self.cache_on_the_fly:
|
||
|
self.samples_list[idx] = sample_info
|
||
|
|
||
|
annotations = {
|
||
|
'bboxes': sample_info['gt_bboxes'],
|
||
|
'labels': sample_info['gt_labels'],
|
||
|
'groundtruth_is_crowd': np.zeros_like(sample_info['gt_labels'])
|
||
|
}
|
||
|
|
||
|
return annotations
|
||
|
|
||
|
def get_sample(self, idx):
|
||
|
result_dict = self.samples_list[idx]
|
||
|
try:
|
||
|
if result_dict.get('img', None) is None:
|
||
|
result_dict = self._build_sample_from_source_info(result_dict)
|
||
|
if self.cache_at_init or self.cache_on_the_fly:
|
||
|
self.samples_list[idx] = result_dict
|
||
|
except Exception as e:
|
||
|
logging.warning(e)
|
||
|
|
||
|
if not result_dict:
|
||
|
logging.warning(
|
||
|
'Something wrong with current sample %s,Try load next sample...'
|
||
|
% result_dict.get('filename', ''))
|
||
|
result_dict = self.get_sample(idx + 1)
|
||
|
|
||
|
return result_dict
|