diff --git a/easycv/datasets/detection/data_sources/base.py b/easycv/datasets/detection/data_sources/base.py new file mode 100644 index 00000000..82e77a30 --- /dev/null +++ b/easycv/datasets/detection/data_sources/base.py @@ -0,0 +1,188 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import functools +import logging +import time +from abc import abstractmethod +from multiprocessing import Pool, cpu_count + +import cv2 +import numpy as np +from mmcv.runner.dist_utils import get_dist_info +from PIL import Image +from tqdm import tqdm + +from easycv.file import io +from easycv.utils.constant import MAX_READ_IMAGE_TRY_TIMES + + +def load_image(img_path): + result = {} + try_cnt = 0 + img = None + while try_cnt < MAX_READ_IMAGE_TRY_TIMES: + try: + with io.open(img_path, 'rb') as infile: + # cv2.imdecode may corrupt when the img is broken + image = Image.open(infile) + img = cv2.cvtColor( + np.asarray(image, dtype=np.uint8), cv2.COLOR_RGB2BGR) + assert img is not None, 'Image load error, try %s : %s' % ( + try_cnt, img_path) + break + except: + time.sleep(2) + try_cnt += 1 + + if img is None: + raise ValueError('Read Image Times Out: ' + img_path) + + result['img'] = img.astype(np.float32) + result['img_shape'] = img.shape # h, w, c + result['ori_img_shape'] = img.shape + + return result + + +def build_sample(source_item, classes, parse_fn, load_img): + """Build sample info from source item. + Args: + source_item: item of source iterator + classes: classes list + parse_fn: parse function to parse source_item, only accepts two params: source_item and classes + load_img: load image or not, if true, cache all images in memory at init + """ + result_dict = parse_fn(source_item, classes) + + if load_img: + result_dict.update(load_image(result_dict['filename'])) + + return result_dict + + +class DetSourceBase(object): + + def __init__(self, + classes=[], + cache_at_init=False, + cache_on_the_fly=False, + parse_fn=None, + num_processes=int(cpu_count() / 2), + **kwargs): + """ + Args: + classes: classes list + cache_at_init: if set True, will cache in memory in __init__ for faster training + cache_on_the_fly: if set True, will cache in memroy during training + parse_fn: parse function to parse source iterator, parse_fn should return dict containing: + gt_bboxes(np.ndarry): Float32 numpy array of shape [num_boxes, 4] and + format [ymin, xmin, ymax, xmax] in absolute image coordinates. + gt_labels(np.ndarry): Integer numpy array of shape [num_boxes] + containing 1-indexed detection classes for the boxes. + filename(str): absolute file path. + num_processes: number of processes to parse samples + """ + self.CLASSES = classes + self.rank, self.world_size = get_dist_info() + self.cache_at_init = cache_at_init + self.cache_on_the_fly = cache_on_the_fly + self.num_processes = num_processes + + if self.cache_at_init and self.cache_on_the_fly: + raise ValueError( + 'Only one of `cache_on_the_fly` and `cache_at_init` can be True!' + ) + source_iter = self.get_source_iterator() + + process_fn = functools.partial( + build_sample, + parse_fn=parse_fn, + classes=self.CLASSES, + load_img=cache_at_init == True, + ) + self.samples_list = self.build_samples( + source_iter, process_fn=process_fn) + self.num_samples = self.get_length() + # An error will be raised if failed to load _max_retry_num times in a row + self._max_retry_num = self.num_samples + self._retry_count = 0 + + @abstractmethod + def get_source_iterator(): + """Return data list iterator, source iterator will be passed to parse_fn, + and parse_fn will receive params of item of source iter and classes for parsing. + What does parse_fn need, what does source iterator returns. + """ + raise NotImplementedError + + def build_samples(self, iterable, process_fn): + samples_list = [] + with Pool(processes=self.num_processes) as p: + with tqdm(total=len(iterable), desc='Scanning images') as pbar: + for _, result_dict in enumerate( + p.imap_unordered(process_fn, iterable)): + if result_dict: + samples_list.append(result_dict) + pbar.update() + + return samples_list + + def get_length(self): + return len(self.samples_list) + + def __len__(self): + return self.get_length() + + def get_ann_info(self, idx): + """ + Get raw annotation info, include bounding boxes, labels and so on. + `bboxes` format is as [x1, y1, x2, y2] without normalization. + """ + sample_info = self.samples_list[idx] + + groundtruth_is_crowd = sample_info.get('groundtruth_is_crowd', None) + if groundtruth_is_crowd is None: + groundtruth_is_crowd = np.zeros_like(sample_info['gt_labels']) + + annotations = { + 'bboxes': sample_info['gt_bboxes'], + 'labels': sample_info['gt_labels'], + 'groundtruth_is_crowd': groundtruth_is_crowd + } + + return annotations + + def post_process_fn(self, result_dict): + if result_dict.get('img_fields', None) is None: + result_dict['img_fields'] = ['img'] + if result_dict.get('bbox_fields', None) is None: + result_dict['bbox_fields'] = ['gt_bboxes'] + + return result_dict + + def get_sample(self, idx): + result_dict = self.samples_list[idx] + load_success = True + try: + if not self.cache_at_init and result_dict.get('img', None) is None: + result_dict.update(load_image(result_dict['filename'])) + if self.cache_on_the_fly: + self.samples_list[idx] = result_dict + + result_dict = self.post_process_fn(result_dict) + # load success,reset to 0 + self._retry_count = 0 + except Exception as e: + logging.error(e) + load_success = False + + if not load_success: + logging.warning( + 'Something wrong with current sample %s,Try load next sample...' + % result_dict.get('filename', '')) + self._retry_count += 1 + if self._retry_count >= self._max_retry_num: + raise ValueError('All samples failed to load!') + + result_dict = self.get_sample((idx + 1) % self.num_samples) + + return result_dict diff --git a/easycv/datasets/detection/data_sources/pai_format.py b/easycv/datasets/detection/data_sources/pai_format.py index d66bcd35..5af874c5 100644 --- a/easycv/datasets/detection/data_sources/pai_format.py +++ b/easycv/datasets/detection/data_sources/pai_format.py @@ -1,13 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import json import logging +from multiprocessing import cpu_count import numpy as np -from mmcv.runner.dist_utils import get_dist_info +from easycv.datasets.detection.data_sources.base import DetSourceBase from easycv.datasets.registry import DATASOURCES from easycv.file import io -from .voc import DetSourceVOC def get_prior_task_id(keys): @@ -44,7 +44,7 @@ def is_itag_v2(row): return False -def parser_manifest_row_str(row_str): +def parser_manifest_row_str(row_str, classes): row = json.loads(row_str.strip()) _is_itag_v2 = is_itag_v2(row) @@ -77,7 +77,7 @@ def parser_manifest_row_str(row_str): if not ann_json: return parse_results - bboxes, class_names = [], [] + bboxes, gt_labels = [], [] for result in ann_json['results']: if result['type'] != 'image': continue @@ -100,7 +100,7 @@ def parser_manifest_row_str(row_str): raise ValueError( 'Not support multi label, get class name %s!' % class_name) - class_names.append(class_name[0]) + gt_labels.append(classes.index(class_name[0])) else: if obj['type'] != 'image/rectangleLabel': logging.warning( @@ -113,18 +113,18 @@ def parser_manifest_row_str(row_str): bnd = [x, y, x + w, y + h] class_name = obj['labels'][0] bboxes.append(bnd) - class_names.append(class_name) + gt_labels.append(classes.index(class_name)) break - parse_results['gt_bboxes'] = bboxes - parse_results['class_names'] = class_names parse_results['filename'] = img_url + parse_results['gt_bboxes'] = np.array(bboxes, dtype=np.float32) + parse_results['gt_labels'] = np.array(gt_labels, dtype=np.int64) return parse_results @DATASOURCES.register_module -class DetSourcePAI(DetSourceVOC): +class DetSourcePAI(DetSourceBase): """ data format please refer to: https://help.aliyun.com/document_detail/311173.html """ @@ -134,6 +134,8 @@ class DetSourcePAI(DetSourceVOC): classes=[], cache_at_init=False, cache_on_the_fly=False, + parse_fn=parser_manifest_row_str, + num_processes=int(cpu_count() / 2), **kwargs): """ Args: @@ -141,30 +143,19 @@ class DetSourcePAI(DetSourceVOC): classes: classes list cache_at_init: if set True, will cache in memory in __init__ for faster training cache_on_the_fly: if set True, will cache in memroy during training + parse_fn: parse function to parse item of source iterator + num_processes: number of processes to parse samples """ - self.CLASSES = classes - self.rank, self.world_size = get_dist_info() - self.manifest_path = path - self.cache_at_init = cache_at_init - self.cache_on_the_fly = cache_on_the_fly - if self.cache_at_init and self.cache_on_the_fly: - raise ValueError( - 'Only one of `cache_on_the_fly` and `cache_at_init` can be True!' - ) + self.manifest_path = path + super(DetSourcePAI, self).__init__( + classes=classes, + cache_at_init=cache_at_init, + cache_on_the_fly=cache_on_the_fly, + parse_fn=parse_fn, + num_processes=num_processes) + + def get_source_iterator(self): with io.open(self.manifest_path, 'r') as f: rows = f.read().splitlines() - - self.samples_list = self.build_samples(rows) - - def get_source_info(self, row_str): - source_info = parser_manifest_row_str(row_str) - source_info['gt_bboxes'] = np.array( - source_info['gt_bboxes'], dtype=np.float32) - source_info['gt_labels'] = np.array([ - self.CLASSES.index(class_name) - for class_name in source_info['class_names'] - ], - dtype=np.int64) - - return source_info + return rows diff --git a/easycv/datasets/detection/data_sources/raw.py b/easycv/datasets/detection/data_sources/raw.py index d05e9bd9..5d24730c 100644 --- a/easycv/datasets/detection/data_sources/raw.py +++ b/easycv/datasets/detection/data_sources/raw.py @@ -1,20 +1,45 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import functools import logging import os +from multiprocessing import cpu_count import numpy as np from easycv.datasets.registry import DATASOURCES from easycv.file import io from easycv.utils.bbox_util import batched_cxcywh2xyxy_with_shape -from .voc import DetSourceVOC +from .base import DetSourceBase img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng'] label_formats = ['.txt'] +def parse_raw(source_iter, classes=None, delimeter=' '): + img_path, label_path = source_iter + + source_info = {'filename': img_path} + + with io.open(label_path, 'r') as f: + labels_and_boxes = np.array( + [line.split(delimeter) for line in f.read().splitlines()]) + + if not len(labels_and_boxes): + return source_info + + labels = labels_and_boxes[:, 0] + bboxes = labels_and_boxes[:, 1:] + + source_info.update({ + 'gt_bboxes': np.array(bboxes, dtype=np.float32), + 'gt_labels': labels.astype(np.int64) + }) + + return source_info + + @DATASOURCES.register_module -class DetSourceRaw(DetSourceVOC): +class DetSourceRaw(DetSourceBase): """ data dir is as follows: ``` @@ -45,24 +70,39 @@ class DetSourceRaw(DetSourceVOC): def __init__(self, img_root_path, label_root_path, + classes=[], cache_at_init=False, cache_on_the_fly=False, delimeter=' ', + parse_fn=parse_raw, + num_processes=int(cpu_count() / 2), **kwargs): """ Args: img_root_path: images dir path label_root_path: labels dir path + classes(list, optional): classes list cache_at_init: if set True, will cache in memory in __init__ for faster training cache_on_the_fly: if set True, will cache in memroy during training + delimeter: delimeter of txt file + parse_fn: parse function to parse item of source iterator + num_processes: number of processes to parse samples """ - self.cache_on_the_fly = cache_on_the_fly - self.cache_at_init = cache_at_init - self.delimeter = delimeter + self.delimeter = delimeter self.img_root_path = img_root_path self.label_root_path = label_root_path + parse_fn = functools.partial(parse_fn, delimeter=delimeter) + + super(DetSourceRaw, self).__init__( + classes=classes, + cache_at_init=cache_at_init, + cache_on_the_fly=cache_on_the_fly, + parse_fn=parse_fn, + num_processes=num_processes) + + def get_source_iterator(self): self.img_files = [ os.path.join(self.img_root_path, i) for i in io.listdir(self.img_root_path, recursive=True) @@ -90,48 +130,11 @@ class DetSourceRaw(DetSourceVOC): assert len( self.img_files) > 0, 'No samples found in %s' % self.img_root_path - # TODO: filter bad sample - self.samples_list = self.build_samples( - list(zip(self.img_files, self.label_files))) + return list(zip(self.img_files, self.label_files)) - def get_source_info(self, img_and_label): - img_path = img_and_label[0] - label_path = img_and_label[1] + def post_process_fn(self, result_dict): + result_dict = super(DetSourceRaw, self).post_process_fn(result_dict) - source_info = {'filename': img_path} - - with io.open(label_path, 'r') as f: - labels_and_boxes = np.array( - [line.split(self.delimeter) for line in f.read().splitlines()]) - - if not len(labels_and_boxes): - return {} - - labels = labels_and_boxes[:, 0] - bboxes = labels_and_boxes[:, 1:] - - source_info.update({ - 'gt_bboxes': np.array(bboxes, dtype=np.float32), - 'gt_labels': labels.astype(np.int64) - }) - - return source_info - - def _build_sample_from_source_info(self, source_info): - if 'filename' not in source_info: - return {} - - result_dict = source_info - - img_info = self.load_image(source_info['filename']) - result_dict.update(img_info) - - result_dict.update({ - 'img_fields': ['img'], - 'bbox_fields': ['gt_bboxes'] - }) - # shape: h, w result_dict['gt_bboxes'] = batched_cxcywh2xyxy_with_shape( - result_dict['gt_bboxes'], shape=img_info['img_shape'][:2]) - + result_dict['gt_bboxes'], shape=result_dict['img_shape'][:2]) return result_dict diff --git a/easycv/datasets/detection/data_sources/voc.py b/easycv/datasets/detection/data_sources/voc.py index 28970d3a..e5465e34 100644 --- a/easycv/datasets/detection/data_sources/voc.py +++ b/easycv/datasets/detection/data_sources/voc.py @@ -1,24 +1,20 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import logging import os -import time import xml.etree.ElementTree as ET -from multiprocessing import Pool, cpu_count +from multiprocessing import cpu_count -import cv2 import numpy as np -from mmcv.runner.dist_utils import get_dist_info -from PIL import Image -from tqdm import tqdm +from easycv.datasets.detection.data_sources.base import DetSourceBase from easycv.datasets.registry import DATASOURCES from easycv.file import io -from easycv.utils.constant import MAX_READ_IMAGE_TRY_TIMES img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng'] -def parse_xml(xml_path, classes): +def parse_xml(source_item, classes): + img_path, xml_path = source_item with io.open(xml_path, 'r') as f: tree = ET.parse(f) root = tree.getroot() @@ -51,14 +47,15 @@ def parse_xml(xml_path, classes): img_info = { 'gt_bboxes': np.array(gt_bboxes, dtype=np.float32), - 'gt_labels': np.array(gt_labels, dtype=np.int64) + 'gt_labels': np.array(gt_labels, dtype=np.int64), + 'filename': img_path, } return img_info @DATASOURCES.register_module -class DetSourceVOC(object): +class DetSourceVOC(DetSourceBase): """ data dir is as follows: ``` @@ -98,6 +95,8 @@ class DetSourceVOC(object): cache_on_the_fly=False, img_suffix='.jpg', label_suffix='.xml', + parse_fn=parse_xml, + num_processes=int(cpu_count() / 2), **kwargs): """ Args: @@ -111,16 +110,24 @@ class DetSourceVOC(object): cache_on_the_fly: if set True, will cache in memroy during training img_suffix: suffix of image file label_suffix: suffix of label file + parse_fn: parse function to parse item of source iterator + num_processes: number of processes to parse samples """ - self.CLASSES = classes - self.rank, self.world_size = get_dist_info() + self.path = path self.img_root_path = img_root_path self.label_root_path = label_root_path - self.cache_at_init = cache_at_init - self.cache_on_the_fly = cache_on_the_fly + self.img_suffix = img_suffix + self.label_suffix = label_suffix + super(DetSourceVOC, self).__init__( + classes=classes, + cache_at_init=cache_at_init, + cache_on_the_fly=cache_on_the_fly, + parse_fn=parse_fn, + num_processes=num_processes) - if not img_root_path: + def get_source_iterator(self): + if not self.img_root_path: self.img_root_path = os.path.join( self.path.split('ImageSets/Main')[0], 'JPEGImages') if not self.label_root_path: @@ -134,128 +141,10 @@ class DetSourceVOC(object): for id_line in id_lines: img_id = id_line.strip().split(' ')[0] img_path = os.path.join(self.img_root_path, - img_id + img_suffix) + img_id + self.img_suffix) imgs_path_list.append(img_path) - label_path = os.path.join(self.label_root_path, - img_id + label_suffix) + img_id + self.label_suffix) labels_path_list.append(label_path) - # TODO: filter bad sample - self.samples_list = self.build_samples( - list(zip(imgs_path_list, labels_path_list))) - - def get_source_info(self, img_and_label): - img_path = img_and_label[0] - label_path = img_and_label[1] - source_info = parse_xml(label_path, self.CLASSES) - source_info.update({'filename': img_path}) - - return source_info - - def _build_sample_from_source_info(self, source_info): - if 'filename' not in source_info: - return {} - - result_dict = source_info - - img_info = self.load_image(source_info['filename']) - result_dict.update(img_info) - - result_dict.update({ - 'img_fields': ['img'], - 'bbox_fields': ['gt_bboxes'] - }) - - return result_dict - - def build_sample(self, data): - result_dict = self.get_source_info(data) - - if self.cache_at_init: - result_dict = self._build_sample_from_source_info(result_dict) - - return result_dict - - def build_samples(self, iterable): - samples_list = [] - proc_num = int(cpu_count() / 2) - with Pool(processes=proc_num) as p: - with tqdm(total=len(iterable), desc='Scanning images') as pbar: - for _, result_dict in enumerate( - p.imap_unordered(self.build_sample, iterable)): - if result_dict: - samples_list.append(result_dict) - pbar.update() - - return samples_list - - def load_image(self, img_path): - result = {} - try_cnt = 0 - img = None - while try_cnt < MAX_READ_IMAGE_TRY_TIMES: - try: - with io.open(img_path, 'rb') as infile: - # cv2.imdecode may corrupt when the img is broken - image = Image.open(infile) - img = cv2.cvtColor( - np.asarray(image, dtype=np.uint8), cv2.COLOR_RGB2BGR) - assert img is not None, 'Image load error, try %s : %s' % ( - try_cnt, img_path) - break - except: - time.sleep(2) - try_cnt += 1 - - if img is None: - raise ValueError('Read Image Times Out: ' + img_path) - - result['img'] = img.astype(np.float32) - result['img_shape'] = img.shape # h, w, c - result['ori_img_shape'] = img.shape - - return result - - def get_length(self): - return len(self.samples_list) - - def __len__(self): - return self.get_length() - - def get_ann_info(self, idx): - """ - Get raw annotation info, include bounding boxes, labels and so on. - `bboxes` format is as [x1, y1, x2, y2] without normalization. - """ - sample_info = self.samples_list[idx] - if sample_info.get('gt_labels', None) is None: - sample_info = self._build_sample_from_source_info(sample_info) - if self.cache_at_init or self.cache_on_the_fly: - self.samples_list[idx] = sample_info - - annotations = { - 'bboxes': sample_info['gt_bboxes'], - 'labels': sample_info['gt_labels'], - 'groundtruth_is_crowd': np.zeros_like(sample_info['gt_labels']) - } - - return annotations - - def get_sample(self, idx): - result_dict = self.samples_list[idx] - try: - if result_dict.get('img', None) is None: - result_dict = self._build_sample_from_source_info(result_dict) - if self.cache_at_init or self.cache_on_the_fly: - self.samples_list[idx] = result_dict - except Exception as e: - logging.warning(e) - - if not result_dict: - logging.warning( - 'Something wrong with current sample %s,Try load next sample...' - % result_dict.get('filename', '')) - result_dict = self.get_sample(idx + 1) - - return result_dict + return list(zip(imgs_path_list, labels_path_list)) diff --git a/easycv/datasets/shared/data_sources/concat.py b/easycv/datasets/shared/data_sources/concat.py index 2f1b95c6..a1eb77f6 100644 --- a/easycv/datasets/shared/data_sources/concat.py +++ b/easycv/datasets/shared/data_sources/concat.py @@ -25,6 +25,9 @@ class SourceConcat(object): def get_length(self): return self.cumsum_length_list[-1] + def __len__(self): + return self.get_length() + def get_sample(self, idx): dataset_idx = bisect.bisect_right(self.cumsum_length_list, idx) if dataset_idx == 0: diff --git a/tests/datasets/detection/data_sources/test_raw.py b/tests/datasets/detection/data_sources/test_raw.py index dcafda4b..1fae2851 100644 --- a/tests/datasets/detection/data_sources/test_raw.py +++ b/tests/datasets/detection/data_sources/test_raw.py @@ -42,7 +42,7 @@ class DetSourceRawTest(unittest.TestCase): data_source.samples_list[exclude_idx[i]]) length = data_source.get_length() - self.assertEqual(length, 126) + self.assertEqual(length, 128) exists = False for idx in range(length): diff --git a/tests/datasets/detection/data_sources/test_voc.py b/tests/datasets/detection/data_sources/test_voc.py index 484113a5..d8428cf9 100644 --- a/tests/datasets/detection/data_sources/test_voc.py +++ b/tests/datasets/detection/data_sources/test_voc.py @@ -90,6 +90,25 @@ class DetSourceVOCTest(unittest.TestCase): cache_on_the_fly=True) self._base_test(data_source) + def test_max_retry_num(self): + data_root = DET_DATA_SMALL_VOC_LOCAL + data_source = DetSourceVOC( + path=os.path.join(data_root, 'ImageSets/Main/train_20.txt'), + classes=VOC_CLASSES, + img_root_path=os.path.join(data_root, 'fault_path'), + label_root_path=os.path.join(data_root, 'Annotations')) + data_source._max_retry_num = 2 + num_samples = data_source.num_samples + with self.assertRaises(ValueError) as cm: + for idx in range(num_samples - 1, num_samples * 2): + _ = data_source.get_sample(idx) + + exception = cm.exception + + self.assertEqual(num_samples, 20) + self.assertEqual(data_source._retry_count, 2) + self.assertEqual(exception.args[0], 'All samples failed to load!') + if __name__ == '__main__': unittest.main() diff --git a/tools/train.py b/tools/train.py index 56b3a894..edf143b0 100644 --- a/tools/train.py +++ b/tools/train.py @@ -15,7 +15,7 @@ sys.path.append( osp.join(os.path.dirname(os.path.dirname(__file__)), '../'))) import time - +import cv2 import requests import torch from mmcv.runner import init_dist @@ -33,6 +33,9 @@ from easycv.utils.config_tools import traverse_replace from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO, mmcv_config_fromfile, rebuild_config) +# refer to: https://github.com/open-mmlab/mmdetection/pull/6867 +cv2.setNumThreads(0) + def parse_args(): parser = argparse.ArgumentParser(description='Train a model')