mirror of https://github.com/alibaba/EasyCV.git
fix multi-process reading of detection datasource and accelerate data preprocessing (#23)
* fix multi-process reading of detection datasource and accelerate data preprocessingpull/38/head
parent
81b91086eb
commit
c6ad4c7858
|
@ -0,0 +1,188 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from abc import abstractmethod
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from mmcv.runner.dist_utils import get_dist_info
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from easycv.file import io
|
||||||
|
from easycv.utils.constant import MAX_READ_IMAGE_TRY_TIMES
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(img_path):
|
||||||
|
result = {}
|
||||||
|
try_cnt = 0
|
||||||
|
img = None
|
||||||
|
while try_cnt < MAX_READ_IMAGE_TRY_TIMES:
|
||||||
|
try:
|
||||||
|
with io.open(img_path, 'rb') as infile:
|
||||||
|
# cv2.imdecode may corrupt when the img is broken
|
||||||
|
image = Image.open(infile)
|
||||||
|
img = cv2.cvtColor(
|
||||||
|
np.asarray(image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
||||||
|
assert img is not None, 'Image load error, try %s : %s' % (
|
||||||
|
try_cnt, img_path)
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
time.sleep(2)
|
||||||
|
try_cnt += 1
|
||||||
|
|
||||||
|
if img is None:
|
||||||
|
raise ValueError('Read Image Times Out: ' + img_path)
|
||||||
|
|
||||||
|
result['img'] = img.astype(np.float32)
|
||||||
|
result['img_shape'] = img.shape # h, w, c
|
||||||
|
result['ori_img_shape'] = img.shape
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def build_sample(source_item, classes, parse_fn, load_img):
|
||||||
|
"""Build sample info from source item.
|
||||||
|
Args:
|
||||||
|
source_item: item of source iterator
|
||||||
|
classes: classes list
|
||||||
|
parse_fn: parse function to parse source_item, only accepts two params: source_item and classes
|
||||||
|
load_img: load image or not, if true, cache all images in memory at init
|
||||||
|
"""
|
||||||
|
result_dict = parse_fn(source_item, classes)
|
||||||
|
|
||||||
|
if load_img:
|
||||||
|
result_dict.update(load_image(result_dict['filename']))
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DetSourceBase(object):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
classes=[],
|
||||||
|
cache_at_init=False,
|
||||||
|
cache_on_the_fly=False,
|
||||||
|
parse_fn=None,
|
||||||
|
num_processes=int(cpu_count() / 2),
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
classes: classes list
|
||||||
|
cache_at_init: if set True, will cache in memory in __init__ for faster training
|
||||||
|
cache_on_the_fly: if set True, will cache in memroy during training
|
||||||
|
parse_fn: parse function to parse source iterator, parse_fn should return dict containing:
|
||||||
|
gt_bboxes(np.ndarry): Float32 numpy array of shape [num_boxes, 4] and
|
||||||
|
format [ymin, xmin, ymax, xmax] in absolute image coordinates.
|
||||||
|
gt_labels(np.ndarry): Integer numpy array of shape [num_boxes]
|
||||||
|
containing 1-indexed detection classes for the boxes.
|
||||||
|
filename(str): absolute file path.
|
||||||
|
num_processes: number of processes to parse samples
|
||||||
|
"""
|
||||||
|
self.CLASSES = classes
|
||||||
|
self.rank, self.world_size = get_dist_info()
|
||||||
|
self.cache_at_init = cache_at_init
|
||||||
|
self.cache_on_the_fly = cache_on_the_fly
|
||||||
|
self.num_processes = num_processes
|
||||||
|
|
||||||
|
if self.cache_at_init and self.cache_on_the_fly:
|
||||||
|
raise ValueError(
|
||||||
|
'Only one of `cache_on_the_fly` and `cache_at_init` can be True!'
|
||||||
|
)
|
||||||
|
source_iter = self.get_source_iterator()
|
||||||
|
|
||||||
|
process_fn = functools.partial(
|
||||||
|
build_sample,
|
||||||
|
parse_fn=parse_fn,
|
||||||
|
classes=self.CLASSES,
|
||||||
|
load_img=cache_at_init == True,
|
||||||
|
)
|
||||||
|
self.samples_list = self.build_samples(
|
||||||
|
source_iter, process_fn=process_fn)
|
||||||
|
self.num_samples = self.get_length()
|
||||||
|
# An error will be raised if failed to load _max_retry_num times in a row
|
||||||
|
self._max_retry_num = self.num_samples
|
||||||
|
self._retry_count = 0
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_source_iterator():
|
||||||
|
"""Return data list iterator, source iterator will be passed to parse_fn,
|
||||||
|
and parse_fn will receive params of item of source iter and classes for parsing.
|
||||||
|
What does parse_fn need, what does source iterator returns.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def build_samples(self, iterable, process_fn):
|
||||||
|
samples_list = []
|
||||||
|
with Pool(processes=self.num_processes) as p:
|
||||||
|
with tqdm(total=len(iterable), desc='Scanning images') as pbar:
|
||||||
|
for _, result_dict in enumerate(
|
||||||
|
p.imap_unordered(process_fn, iterable)):
|
||||||
|
if result_dict:
|
||||||
|
samples_list.append(result_dict)
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
return samples_list
|
||||||
|
|
||||||
|
def get_length(self):
|
||||||
|
return len(self.samples_list)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.get_length()
|
||||||
|
|
||||||
|
def get_ann_info(self, idx):
|
||||||
|
"""
|
||||||
|
Get raw annotation info, include bounding boxes, labels and so on.
|
||||||
|
`bboxes` format is as [x1, y1, x2, y2] without normalization.
|
||||||
|
"""
|
||||||
|
sample_info = self.samples_list[idx]
|
||||||
|
|
||||||
|
groundtruth_is_crowd = sample_info.get('groundtruth_is_crowd', None)
|
||||||
|
if groundtruth_is_crowd is None:
|
||||||
|
groundtruth_is_crowd = np.zeros_like(sample_info['gt_labels'])
|
||||||
|
|
||||||
|
annotations = {
|
||||||
|
'bboxes': sample_info['gt_bboxes'],
|
||||||
|
'labels': sample_info['gt_labels'],
|
||||||
|
'groundtruth_is_crowd': groundtruth_is_crowd
|
||||||
|
}
|
||||||
|
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
def post_process_fn(self, result_dict):
|
||||||
|
if result_dict.get('img_fields', None) is None:
|
||||||
|
result_dict['img_fields'] = ['img']
|
||||||
|
if result_dict.get('bbox_fields', None) is None:
|
||||||
|
result_dict['bbox_fields'] = ['gt_bboxes']
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
def get_sample(self, idx):
|
||||||
|
result_dict = self.samples_list[idx]
|
||||||
|
load_success = True
|
||||||
|
try:
|
||||||
|
if not self.cache_at_init and result_dict.get('img', None) is None:
|
||||||
|
result_dict.update(load_image(result_dict['filename']))
|
||||||
|
if self.cache_on_the_fly:
|
||||||
|
self.samples_list[idx] = result_dict
|
||||||
|
|
||||||
|
result_dict = self.post_process_fn(result_dict)
|
||||||
|
# load success,reset to 0
|
||||||
|
self._retry_count = 0
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(e)
|
||||||
|
load_success = False
|
||||||
|
|
||||||
|
if not load_success:
|
||||||
|
logging.warning(
|
||||||
|
'Something wrong with current sample %s,Try load next sample...'
|
||||||
|
% result_dict.get('filename', ''))
|
||||||
|
self._retry_count += 1
|
||||||
|
if self._retry_count >= self._max_retry_num:
|
||||||
|
raise ValueError('All samples failed to load!')
|
||||||
|
|
||||||
|
result_dict = self.get_sample((idx + 1) % self.num_samples)
|
||||||
|
|
||||||
|
return result_dict
|
|
@ -1,13 +1,13 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.runner.dist_utils import get_dist_info
|
|
||||||
|
|
||||||
|
from easycv.datasets.detection.data_sources.base import DetSourceBase
|
||||||
from easycv.datasets.registry import DATASOURCES
|
from easycv.datasets.registry import DATASOURCES
|
||||||
from easycv.file import io
|
from easycv.file import io
|
||||||
from .voc import DetSourceVOC
|
|
||||||
|
|
||||||
|
|
||||||
def get_prior_task_id(keys):
|
def get_prior_task_id(keys):
|
||||||
|
@ -44,7 +44,7 @@ def is_itag_v2(row):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def parser_manifest_row_str(row_str):
|
def parser_manifest_row_str(row_str, classes):
|
||||||
row = json.loads(row_str.strip())
|
row = json.loads(row_str.strip())
|
||||||
_is_itag_v2 = is_itag_v2(row)
|
_is_itag_v2 = is_itag_v2(row)
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ def parser_manifest_row_str(row_str):
|
||||||
if not ann_json:
|
if not ann_json:
|
||||||
return parse_results
|
return parse_results
|
||||||
|
|
||||||
bboxes, class_names = [], []
|
bboxes, gt_labels = [], []
|
||||||
for result in ann_json['results']:
|
for result in ann_json['results']:
|
||||||
if result['type'] != 'image':
|
if result['type'] != 'image':
|
||||||
continue
|
continue
|
||||||
|
@ -100,7 +100,7 @@ def parser_manifest_row_str(row_str):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Not support multi label, get class name %s!' %
|
'Not support multi label, get class name %s!' %
|
||||||
class_name)
|
class_name)
|
||||||
class_names.append(class_name[0])
|
gt_labels.append(classes.index(class_name[0]))
|
||||||
else:
|
else:
|
||||||
if obj['type'] != 'image/rectangleLabel':
|
if obj['type'] != 'image/rectangleLabel':
|
||||||
logging.warning(
|
logging.warning(
|
||||||
|
@ -113,18 +113,18 @@ def parser_manifest_row_str(row_str):
|
||||||
bnd = [x, y, x + w, y + h]
|
bnd = [x, y, x + w, y + h]
|
||||||
class_name = obj['labels'][0]
|
class_name = obj['labels'][0]
|
||||||
bboxes.append(bnd)
|
bboxes.append(bnd)
|
||||||
class_names.append(class_name)
|
gt_labels.append(classes.index(class_name))
|
||||||
break
|
break
|
||||||
|
|
||||||
parse_results['gt_bboxes'] = bboxes
|
|
||||||
parse_results['class_names'] = class_names
|
|
||||||
parse_results['filename'] = img_url
|
parse_results['filename'] = img_url
|
||||||
|
parse_results['gt_bboxes'] = np.array(bboxes, dtype=np.float32)
|
||||||
|
parse_results['gt_labels'] = np.array(gt_labels, dtype=np.int64)
|
||||||
|
|
||||||
return parse_results
|
return parse_results
|
||||||
|
|
||||||
|
|
||||||
@DATASOURCES.register_module
|
@DATASOURCES.register_module
|
||||||
class DetSourcePAI(DetSourceVOC):
|
class DetSourcePAI(DetSourceBase):
|
||||||
"""
|
"""
|
||||||
data format please refer to: https://help.aliyun.com/document_detail/311173.html
|
data format please refer to: https://help.aliyun.com/document_detail/311173.html
|
||||||
"""
|
"""
|
||||||
|
@ -134,6 +134,8 @@ class DetSourcePAI(DetSourceVOC):
|
||||||
classes=[],
|
classes=[],
|
||||||
cache_at_init=False,
|
cache_at_init=False,
|
||||||
cache_on_the_fly=False,
|
cache_on_the_fly=False,
|
||||||
|
parse_fn=parser_manifest_row_str,
|
||||||
|
num_processes=int(cpu_count() / 2),
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -141,30 +143,19 @@ class DetSourcePAI(DetSourceVOC):
|
||||||
classes: classes list
|
classes: classes list
|
||||||
cache_at_init: if set True, will cache in memory in __init__ for faster training
|
cache_at_init: if set True, will cache in memory in __init__ for faster training
|
||||||
cache_on_the_fly: if set True, will cache in memroy during training
|
cache_on_the_fly: if set True, will cache in memroy during training
|
||||||
|
parse_fn: parse function to parse item of source iterator
|
||||||
|
num_processes: number of processes to parse samples
|
||||||
"""
|
"""
|
||||||
self.CLASSES = classes
|
|
||||||
self.rank, self.world_size = get_dist_info()
|
|
||||||
self.manifest_path = path
|
|
||||||
self.cache_at_init = cache_at_init
|
|
||||||
self.cache_on_the_fly = cache_on_the_fly
|
|
||||||
if self.cache_at_init and self.cache_on_the_fly:
|
|
||||||
raise ValueError(
|
|
||||||
'Only one of `cache_on_the_fly` and `cache_at_init` can be True!'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
self.manifest_path = path
|
||||||
|
super(DetSourcePAI, self).__init__(
|
||||||
|
classes=classes,
|
||||||
|
cache_at_init=cache_at_init,
|
||||||
|
cache_on_the_fly=cache_on_the_fly,
|
||||||
|
parse_fn=parse_fn,
|
||||||
|
num_processes=num_processes)
|
||||||
|
|
||||||
|
def get_source_iterator(self):
|
||||||
with io.open(self.manifest_path, 'r') as f:
|
with io.open(self.manifest_path, 'r') as f:
|
||||||
rows = f.read().splitlines()
|
rows = f.read().splitlines()
|
||||||
|
return rows
|
||||||
self.samples_list = self.build_samples(rows)
|
|
||||||
|
|
||||||
def get_source_info(self, row_str):
|
|
||||||
source_info = parser_manifest_row_str(row_str)
|
|
||||||
source_info['gt_bboxes'] = np.array(
|
|
||||||
source_info['gt_bboxes'], dtype=np.float32)
|
|
||||||
source_info['gt_labels'] = np.array([
|
|
||||||
self.CLASSES.index(class_name)
|
|
||||||
for class_name in source_info['class_names']
|
|
||||||
],
|
|
||||||
dtype=np.int64)
|
|
||||||
|
|
||||||
return source_info
|
|
||||||
|
|
|
@ -1,20 +1,45 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from easycv.datasets.registry import DATASOURCES
|
from easycv.datasets.registry import DATASOURCES
|
||||||
from easycv.file import io
|
from easycv.file import io
|
||||||
from easycv.utils.bbox_util import batched_cxcywh2xyxy_with_shape
|
from easycv.utils.bbox_util import batched_cxcywh2xyxy_with_shape
|
||||||
from .voc import DetSourceVOC
|
from .base import DetSourceBase
|
||||||
|
|
||||||
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
|
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
|
||||||
label_formats = ['.txt']
|
label_formats = ['.txt']
|
||||||
|
|
||||||
|
|
||||||
|
def parse_raw(source_iter, classes=None, delimeter=' '):
|
||||||
|
img_path, label_path = source_iter
|
||||||
|
|
||||||
|
source_info = {'filename': img_path}
|
||||||
|
|
||||||
|
with io.open(label_path, 'r') as f:
|
||||||
|
labels_and_boxes = np.array(
|
||||||
|
[line.split(delimeter) for line in f.read().splitlines()])
|
||||||
|
|
||||||
|
if not len(labels_and_boxes):
|
||||||
|
return source_info
|
||||||
|
|
||||||
|
labels = labels_and_boxes[:, 0]
|
||||||
|
bboxes = labels_and_boxes[:, 1:]
|
||||||
|
|
||||||
|
source_info.update({
|
||||||
|
'gt_bboxes': np.array(bboxes, dtype=np.float32),
|
||||||
|
'gt_labels': labels.astype(np.int64)
|
||||||
|
})
|
||||||
|
|
||||||
|
return source_info
|
||||||
|
|
||||||
|
|
||||||
@DATASOURCES.register_module
|
@DATASOURCES.register_module
|
||||||
class DetSourceRaw(DetSourceVOC):
|
class DetSourceRaw(DetSourceBase):
|
||||||
"""
|
"""
|
||||||
data dir is as follows:
|
data dir is as follows:
|
||||||
```
|
```
|
||||||
|
@ -45,24 +70,39 @@ class DetSourceRaw(DetSourceVOC):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
img_root_path,
|
img_root_path,
|
||||||
label_root_path,
|
label_root_path,
|
||||||
|
classes=[],
|
||||||
cache_at_init=False,
|
cache_at_init=False,
|
||||||
cache_on_the_fly=False,
|
cache_on_the_fly=False,
|
||||||
delimeter=' ',
|
delimeter=' ',
|
||||||
|
parse_fn=parse_raw,
|
||||||
|
num_processes=int(cpu_count() / 2),
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
img_root_path: images dir path
|
img_root_path: images dir path
|
||||||
label_root_path: labels dir path
|
label_root_path: labels dir path
|
||||||
|
classes(list, optional): classes list
|
||||||
cache_at_init: if set True, will cache in memory in __init__ for faster training
|
cache_at_init: if set True, will cache in memory in __init__ for faster training
|
||||||
cache_on_the_fly: if set True, will cache in memroy during training
|
cache_on_the_fly: if set True, will cache in memroy during training
|
||||||
|
delimeter: delimeter of txt file
|
||||||
|
parse_fn: parse function to parse item of source iterator
|
||||||
|
num_processes: number of processes to parse samples
|
||||||
"""
|
"""
|
||||||
self.cache_on_the_fly = cache_on_the_fly
|
|
||||||
self.cache_at_init = cache_at_init
|
|
||||||
self.delimeter = delimeter
|
|
||||||
|
|
||||||
|
self.delimeter = delimeter
|
||||||
self.img_root_path = img_root_path
|
self.img_root_path = img_root_path
|
||||||
self.label_root_path = label_root_path
|
self.label_root_path = label_root_path
|
||||||
|
|
||||||
|
parse_fn = functools.partial(parse_fn, delimeter=delimeter)
|
||||||
|
|
||||||
|
super(DetSourceRaw, self).__init__(
|
||||||
|
classes=classes,
|
||||||
|
cache_at_init=cache_at_init,
|
||||||
|
cache_on_the_fly=cache_on_the_fly,
|
||||||
|
parse_fn=parse_fn,
|
||||||
|
num_processes=num_processes)
|
||||||
|
|
||||||
|
def get_source_iterator(self):
|
||||||
self.img_files = [
|
self.img_files = [
|
||||||
os.path.join(self.img_root_path, i)
|
os.path.join(self.img_root_path, i)
|
||||||
for i in io.listdir(self.img_root_path, recursive=True)
|
for i in io.listdir(self.img_root_path, recursive=True)
|
||||||
|
@ -90,48 +130,11 @@ class DetSourceRaw(DetSourceVOC):
|
||||||
assert len(
|
assert len(
|
||||||
self.img_files) > 0, 'No samples found in %s' % self.img_root_path
|
self.img_files) > 0, 'No samples found in %s' % self.img_root_path
|
||||||
|
|
||||||
# TODO: filter bad sample
|
return list(zip(self.img_files, self.label_files))
|
||||||
self.samples_list = self.build_samples(
|
|
||||||
list(zip(self.img_files, self.label_files)))
|
|
||||||
|
|
||||||
def get_source_info(self, img_and_label):
|
def post_process_fn(self, result_dict):
|
||||||
img_path = img_and_label[0]
|
result_dict = super(DetSourceRaw, self).post_process_fn(result_dict)
|
||||||
label_path = img_and_label[1]
|
|
||||||
|
|
||||||
source_info = {'filename': img_path}
|
|
||||||
|
|
||||||
with io.open(label_path, 'r') as f:
|
|
||||||
labels_and_boxes = np.array(
|
|
||||||
[line.split(self.delimeter) for line in f.read().splitlines()])
|
|
||||||
|
|
||||||
if not len(labels_and_boxes):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
labels = labels_and_boxes[:, 0]
|
|
||||||
bboxes = labels_and_boxes[:, 1:]
|
|
||||||
|
|
||||||
source_info.update({
|
|
||||||
'gt_bboxes': np.array(bboxes, dtype=np.float32),
|
|
||||||
'gt_labels': labels.astype(np.int64)
|
|
||||||
})
|
|
||||||
|
|
||||||
return source_info
|
|
||||||
|
|
||||||
def _build_sample_from_source_info(self, source_info):
|
|
||||||
if 'filename' not in source_info:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
result_dict = source_info
|
|
||||||
|
|
||||||
img_info = self.load_image(source_info['filename'])
|
|
||||||
result_dict.update(img_info)
|
|
||||||
|
|
||||||
result_dict.update({
|
|
||||||
'img_fields': ['img'],
|
|
||||||
'bbox_fields': ['gt_bboxes']
|
|
||||||
})
|
|
||||||
# shape: h, w
|
|
||||||
result_dict['gt_bboxes'] = batched_cxcywh2xyxy_with_shape(
|
result_dict['gt_bboxes'] = batched_cxcywh2xyxy_with_shape(
|
||||||
result_dict['gt_bboxes'], shape=img_info['img_shape'][:2])
|
result_dict['gt_bboxes'], shape=result_dict['img_shape'][:2])
|
||||||
|
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
|
@ -1,24 +1,20 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from multiprocessing import Pool, cpu_count
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.runner.dist_utils import get_dist_info
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
from easycv.datasets.detection.data_sources.base import DetSourceBase
|
||||||
from easycv.datasets.registry import DATASOURCES
|
from easycv.datasets.registry import DATASOURCES
|
||||||
from easycv.file import io
|
from easycv.file import io
|
||||||
from easycv.utils.constant import MAX_READ_IMAGE_TRY_TIMES
|
|
||||||
|
|
||||||
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
|
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
|
||||||
|
|
||||||
|
|
||||||
def parse_xml(xml_path, classes):
|
def parse_xml(source_item, classes):
|
||||||
|
img_path, xml_path = source_item
|
||||||
with io.open(xml_path, 'r') as f:
|
with io.open(xml_path, 'r') as f:
|
||||||
tree = ET.parse(f)
|
tree = ET.parse(f)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
|
@ -51,14 +47,15 @@ def parse_xml(xml_path, classes):
|
||||||
|
|
||||||
img_info = {
|
img_info = {
|
||||||
'gt_bboxes': np.array(gt_bboxes, dtype=np.float32),
|
'gt_bboxes': np.array(gt_bboxes, dtype=np.float32),
|
||||||
'gt_labels': np.array(gt_labels, dtype=np.int64)
|
'gt_labels': np.array(gt_labels, dtype=np.int64),
|
||||||
|
'filename': img_path,
|
||||||
}
|
}
|
||||||
|
|
||||||
return img_info
|
return img_info
|
||||||
|
|
||||||
|
|
||||||
@DATASOURCES.register_module
|
@DATASOURCES.register_module
|
||||||
class DetSourceVOC(object):
|
class DetSourceVOC(DetSourceBase):
|
||||||
"""
|
"""
|
||||||
data dir is as follows:
|
data dir is as follows:
|
||||||
```
|
```
|
||||||
|
@ -98,6 +95,8 @@ class DetSourceVOC(object):
|
||||||
cache_on_the_fly=False,
|
cache_on_the_fly=False,
|
||||||
img_suffix='.jpg',
|
img_suffix='.jpg',
|
||||||
label_suffix='.xml',
|
label_suffix='.xml',
|
||||||
|
parse_fn=parse_xml,
|
||||||
|
num_processes=int(cpu_count() / 2),
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -111,16 +110,24 @@ class DetSourceVOC(object):
|
||||||
cache_on_the_fly: if set True, will cache in memroy during training
|
cache_on_the_fly: if set True, will cache in memroy during training
|
||||||
img_suffix: suffix of image file
|
img_suffix: suffix of image file
|
||||||
label_suffix: suffix of label file
|
label_suffix: suffix of label file
|
||||||
|
parse_fn: parse function to parse item of source iterator
|
||||||
|
num_processes: number of processes to parse samples
|
||||||
"""
|
"""
|
||||||
self.CLASSES = classes
|
|
||||||
self.rank, self.world_size = get_dist_info()
|
|
||||||
self.path = path
|
self.path = path
|
||||||
self.img_root_path = img_root_path
|
self.img_root_path = img_root_path
|
||||||
self.label_root_path = label_root_path
|
self.label_root_path = label_root_path
|
||||||
self.cache_at_init = cache_at_init
|
self.img_suffix = img_suffix
|
||||||
self.cache_on_the_fly = cache_on_the_fly
|
self.label_suffix = label_suffix
|
||||||
|
super(DetSourceVOC, self).__init__(
|
||||||
|
classes=classes,
|
||||||
|
cache_at_init=cache_at_init,
|
||||||
|
cache_on_the_fly=cache_on_the_fly,
|
||||||
|
parse_fn=parse_fn,
|
||||||
|
num_processes=num_processes)
|
||||||
|
|
||||||
if not img_root_path:
|
def get_source_iterator(self):
|
||||||
|
if not self.img_root_path:
|
||||||
self.img_root_path = os.path.join(
|
self.img_root_path = os.path.join(
|
||||||
self.path.split('ImageSets/Main')[0], 'JPEGImages')
|
self.path.split('ImageSets/Main')[0], 'JPEGImages')
|
||||||
if not self.label_root_path:
|
if not self.label_root_path:
|
||||||
|
@ -134,128 +141,10 @@ class DetSourceVOC(object):
|
||||||
for id_line in id_lines:
|
for id_line in id_lines:
|
||||||
img_id = id_line.strip().split(' ')[0]
|
img_id = id_line.strip().split(' ')[0]
|
||||||
img_path = os.path.join(self.img_root_path,
|
img_path = os.path.join(self.img_root_path,
|
||||||
img_id + img_suffix)
|
img_id + self.img_suffix)
|
||||||
imgs_path_list.append(img_path)
|
imgs_path_list.append(img_path)
|
||||||
|
|
||||||
label_path = os.path.join(self.label_root_path,
|
label_path = os.path.join(self.label_root_path,
|
||||||
img_id + label_suffix)
|
img_id + self.label_suffix)
|
||||||
labels_path_list.append(label_path)
|
labels_path_list.append(label_path)
|
||||||
|
|
||||||
# TODO: filter bad sample
|
return list(zip(imgs_path_list, labels_path_list))
|
||||||
self.samples_list = self.build_samples(
|
|
||||||
list(zip(imgs_path_list, labels_path_list)))
|
|
||||||
|
|
||||||
def get_source_info(self, img_and_label):
|
|
||||||
img_path = img_and_label[0]
|
|
||||||
label_path = img_and_label[1]
|
|
||||||
source_info = parse_xml(label_path, self.CLASSES)
|
|
||||||
source_info.update({'filename': img_path})
|
|
||||||
|
|
||||||
return source_info
|
|
||||||
|
|
||||||
def _build_sample_from_source_info(self, source_info):
|
|
||||||
if 'filename' not in source_info:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
result_dict = source_info
|
|
||||||
|
|
||||||
img_info = self.load_image(source_info['filename'])
|
|
||||||
result_dict.update(img_info)
|
|
||||||
|
|
||||||
result_dict.update({
|
|
||||||
'img_fields': ['img'],
|
|
||||||
'bbox_fields': ['gt_bboxes']
|
|
||||||
})
|
|
||||||
|
|
||||||
return result_dict
|
|
||||||
|
|
||||||
def build_sample(self, data):
|
|
||||||
result_dict = self.get_source_info(data)
|
|
||||||
|
|
||||||
if self.cache_at_init:
|
|
||||||
result_dict = self._build_sample_from_source_info(result_dict)
|
|
||||||
|
|
||||||
return result_dict
|
|
||||||
|
|
||||||
def build_samples(self, iterable):
|
|
||||||
samples_list = []
|
|
||||||
proc_num = int(cpu_count() / 2)
|
|
||||||
with Pool(processes=proc_num) as p:
|
|
||||||
with tqdm(total=len(iterable), desc='Scanning images') as pbar:
|
|
||||||
for _, result_dict in enumerate(
|
|
||||||
p.imap_unordered(self.build_sample, iterable)):
|
|
||||||
if result_dict:
|
|
||||||
samples_list.append(result_dict)
|
|
||||||
pbar.update()
|
|
||||||
|
|
||||||
return samples_list
|
|
||||||
|
|
||||||
def load_image(self, img_path):
|
|
||||||
result = {}
|
|
||||||
try_cnt = 0
|
|
||||||
img = None
|
|
||||||
while try_cnt < MAX_READ_IMAGE_TRY_TIMES:
|
|
||||||
try:
|
|
||||||
with io.open(img_path, 'rb') as infile:
|
|
||||||
# cv2.imdecode may corrupt when the img is broken
|
|
||||||
image = Image.open(infile)
|
|
||||||
img = cv2.cvtColor(
|
|
||||||
np.asarray(image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
|
||||||
assert img is not None, 'Image load error, try %s : %s' % (
|
|
||||||
try_cnt, img_path)
|
|
||||||
break
|
|
||||||
except:
|
|
||||||
time.sleep(2)
|
|
||||||
try_cnt += 1
|
|
||||||
|
|
||||||
if img is None:
|
|
||||||
raise ValueError('Read Image Times Out: ' + img_path)
|
|
||||||
|
|
||||||
result['img'] = img.astype(np.float32)
|
|
||||||
result['img_shape'] = img.shape # h, w, c
|
|
||||||
result['ori_img_shape'] = img.shape
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_length(self):
|
|
||||||
return len(self.samples_list)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.get_length()
|
|
||||||
|
|
||||||
def get_ann_info(self, idx):
|
|
||||||
"""
|
|
||||||
Get raw annotation info, include bounding boxes, labels and so on.
|
|
||||||
`bboxes` format is as [x1, y1, x2, y2] without normalization.
|
|
||||||
"""
|
|
||||||
sample_info = self.samples_list[idx]
|
|
||||||
if sample_info.get('gt_labels', None) is None:
|
|
||||||
sample_info = self._build_sample_from_source_info(sample_info)
|
|
||||||
if self.cache_at_init or self.cache_on_the_fly:
|
|
||||||
self.samples_list[idx] = sample_info
|
|
||||||
|
|
||||||
annotations = {
|
|
||||||
'bboxes': sample_info['gt_bboxes'],
|
|
||||||
'labels': sample_info['gt_labels'],
|
|
||||||
'groundtruth_is_crowd': np.zeros_like(sample_info['gt_labels'])
|
|
||||||
}
|
|
||||||
|
|
||||||
return annotations
|
|
||||||
|
|
||||||
def get_sample(self, idx):
|
|
||||||
result_dict = self.samples_list[idx]
|
|
||||||
try:
|
|
||||||
if result_dict.get('img', None) is None:
|
|
||||||
result_dict = self._build_sample_from_source_info(result_dict)
|
|
||||||
if self.cache_at_init or self.cache_on_the_fly:
|
|
||||||
self.samples_list[idx] = result_dict
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(e)
|
|
||||||
|
|
||||||
if not result_dict:
|
|
||||||
logging.warning(
|
|
||||||
'Something wrong with current sample %s,Try load next sample...'
|
|
||||||
% result_dict.get('filename', ''))
|
|
||||||
result_dict = self.get_sample(idx + 1)
|
|
||||||
|
|
||||||
return result_dict
|
|
||||||
|
|
|
@ -25,6 +25,9 @@ class SourceConcat(object):
|
||||||
def get_length(self):
|
def get_length(self):
|
||||||
return self.cumsum_length_list[-1]
|
return self.cumsum_length_list[-1]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.get_length()
|
||||||
|
|
||||||
def get_sample(self, idx):
|
def get_sample(self, idx):
|
||||||
dataset_idx = bisect.bisect_right(self.cumsum_length_list, idx)
|
dataset_idx = bisect.bisect_right(self.cumsum_length_list, idx)
|
||||||
if dataset_idx == 0:
|
if dataset_idx == 0:
|
||||||
|
|
|
@ -42,7 +42,7 @@ class DetSourceRawTest(unittest.TestCase):
|
||||||
data_source.samples_list[exclude_idx[i]])
|
data_source.samples_list[exclude_idx[i]])
|
||||||
|
|
||||||
length = data_source.get_length()
|
length = data_source.get_length()
|
||||||
self.assertEqual(length, 126)
|
self.assertEqual(length, 128)
|
||||||
|
|
||||||
exists = False
|
exists = False
|
||||||
for idx in range(length):
|
for idx in range(length):
|
||||||
|
|
|
@ -90,6 +90,25 @@ class DetSourceVOCTest(unittest.TestCase):
|
||||||
cache_on_the_fly=True)
|
cache_on_the_fly=True)
|
||||||
self._base_test(data_source)
|
self._base_test(data_source)
|
||||||
|
|
||||||
|
def test_max_retry_num(self):
|
||||||
|
data_root = DET_DATA_SMALL_VOC_LOCAL
|
||||||
|
data_source = DetSourceVOC(
|
||||||
|
path=os.path.join(data_root, 'ImageSets/Main/train_20.txt'),
|
||||||
|
classes=VOC_CLASSES,
|
||||||
|
img_root_path=os.path.join(data_root, 'fault_path'),
|
||||||
|
label_root_path=os.path.join(data_root, 'Annotations'))
|
||||||
|
data_source._max_retry_num = 2
|
||||||
|
num_samples = data_source.num_samples
|
||||||
|
with self.assertRaises(ValueError) as cm:
|
||||||
|
for idx in range(num_samples - 1, num_samples * 2):
|
||||||
|
_ = data_source.get_sample(idx)
|
||||||
|
|
||||||
|
exception = cm.exception
|
||||||
|
|
||||||
|
self.assertEqual(num_samples, 20)
|
||||||
|
self.assertEqual(data_source._retry_count, 2)
|
||||||
|
self.assertEqual(exception.args[0], 'All samples failed to load!')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -15,7 +15,7 @@ sys.path.append(
|
||||||
osp.join(os.path.dirname(os.path.dirname(__file__)), '../')))
|
osp.join(os.path.dirname(os.path.dirname(__file__)), '../')))
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
import cv2
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from mmcv.runner import init_dist
|
from mmcv.runner import init_dist
|
||||||
|
@ -33,6 +33,9 @@ from easycv.utils.config_tools import traverse_replace
|
||||||
from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO,
|
from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO,
|
||||||
mmcv_config_fromfile, rebuild_config)
|
mmcv_config_fromfile, rebuild_config)
|
||||||
|
|
||||||
|
# refer to: https://github.com/open-mmlab/mmdetection/pull/6867
|
||||||
|
cv2.setNumThreads(0)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Train a model')
|
parser = argparse.ArgumentParser(description='Train a model')
|
||||||
|
|
Loading…
Reference in New Issue