mirror of https://github.com/alibaba/EasyCV.git
162 lines
5.1 KiB
Python
162 lines
5.1 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import json
|
|
import logging
|
|
from multiprocessing import cpu_count
|
|
|
|
import numpy as np
|
|
|
|
from easycv.datasets.detection.data_sources.base import DetSourceBase
|
|
from easycv.datasets.registry import DATASOURCES
|
|
from easycv.file import io
|
|
|
|
|
|
def get_prior_task_id(keys):
|
|
""""The task id ends with `check` is the highest priority.
|
|
"""
|
|
k_list = []
|
|
check_k_list = []
|
|
verify_k_list = []
|
|
for k in keys:
|
|
if k.startswith('label-'):
|
|
if k.endswith('check'):
|
|
check_k_list.append(k)
|
|
elif k.endswith('verify'):
|
|
verify_k_list.append(k)
|
|
else:
|
|
k_list.append(k)
|
|
|
|
if len(check_k_list):
|
|
return check_k_list
|
|
if len(k_list):
|
|
return k_list
|
|
if len(verify_k_list):
|
|
return verify_k_list
|
|
|
|
return []
|
|
|
|
|
|
def is_itag_v2(row):
|
|
"""
|
|
The keyword of the data source is `picUrl` in v1, but is `source` in v2
|
|
"""
|
|
if 'source' in row['data']:
|
|
return True
|
|
return False
|
|
|
|
|
|
def parser_manifest_row_str(row_str, classes):
|
|
row = json.loads(row_str.strip())
|
|
_is_itag_v2 = is_itag_v2(row)
|
|
|
|
parse_results = {}
|
|
|
|
# check img_url
|
|
img_url = row['data']['source']
|
|
if img_url.startswith(('http://', 'https://')):
|
|
logging.warning(
|
|
'Not support http url, only support `oss://`, skip the sample: %s!'
|
|
% img_url)
|
|
return parse_results
|
|
|
|
# check task ids
|
|
if _is_itag_v2:
|
|
task_ids = get_prior_task_id(row.keys())
|
|
else:
|
|
task_ids = [
|
|
row_k for row_k in row.keys() if row_k.startswith('label-')
|
|
]
|
|
if len(task_ids) > 1:
|
|
raise NotImplementedError('Not support multi label task ids: %s!' %
|
|
task_ids)
|
|
if not len(task_ids):
|
|
logging.warning('Not find label task id in sample: %s, skip it!' %
|
|
img_url)
|
|
return parse_results
|
|
|
|
ann_json = row[task_ids[0]]
|
|
if not ann_json:
|
|
return parse_results
|
|
|
|
bboxes, gt_labels = [], []
|
|
for result in ann_json['results']:
|
|
if result['type'] != 'image':
|
|
continue
|
|
|
|
objs_list = result['data']
|
|
|
|
for obj in objs_list:
|
|
if _is_itag_v2:
|
|
if obj['type'] != 'image/polygon':
|
|
logging.warning(
|
|
'Result type should be `image/polygon`, but get %s, skip object %s in %s'
|
|
% (obj['type'], obj, img_url))
|
|
continue
|
|
sort_points = sorted(obj['value'], key=sum)
|
|
(x0, y0, x1, y1) = np.concatenate(
|
|
(sort_points[0], sort_points[-1]), axis=0)
|
|
bboxes.append([x0, y0, x1, y1])
|
|
class_name = list(obj['labels'].values())
|
|
if len(class_name) > 1:
|
|
raise ValueError(
|
|
'Not support multi label, get class name %s!' %
|
|
class_name)
|
|
gt_labels.append(classes.index(class_name[0]))
|
|
else:
|
|
if obj['type'] != 'image/rectangleLabel':
|
|
logging.warning(
|
|
'result type [%s] in %s is not image/rectangleLabel, skip it!'
|
|
% (obj['type'], img_url))
|
|
continue
|
|
value = obj['value']
|
|
x, y, w, h = value['x'], value['y'], value['width'], value[
|
|
'height']
|
|
bnd = [x, y, x + w, y + h]
|
|
class_name = obj['labels'][0]
|
|
bboxes.append(bnd)
|
|
gt_labels.append(classes.index(class_name))
|
|
break
|
|
|
|
parse_results['filename'] = img_url
|
|
parse_results['gt_bboxes'] = np.array(bboxes, dtype=np.float32)
|
|
parse_results['gt_labels'] = np.array(gt_labels, dtype=np.int64)
|
|
|
|
return parse_results
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class DetSourcePAI(DetSourceBase):
|
|
"""
|
|
data format please refer to: https://help.aliyun.com/document_detail/311173.html
|
|
"""
|
|
|
|
def __init__(self,
|
|
path,
|
|
classes=[],
|
|
cache_at_init=False,
|
|
cache_on_the_fly=False,
|
|
parse_fn=parser_manifest_row_str,
|
|
num_processes=int(cpu_count() / 2),
|
|
**kwargs):
|
|
"""
|
|
Args:
|
|
path: Path of manifest path with pai label format
|
|
classes: classes list
|
|
cache_at_init: if set True, will cache in memory in __init__ for faster training
|
|
cache_on_the_fly: if set True, will cache in memroy during training
|
|
parse_fn: parse function to parse item of source iterator
|
|
num_processes: number of processes to parse samples
|
|
"""
|
|
|
|
self.manifest_path = path
|
|
super(DetSourcePAI, self).__init__(
|
|
classes=classes,
|
|
cache_at_init=cache_at_init,
|
|
cache_on_the_fly=cache_on_the_fly,
|
|
parse_fn=parse_fn,
|
|
num_processes=num_processes)
|
|
|
|
def get_source_iterator(self):
|
|
with io.open(self.manifest_path, 'r') as f:
|
|
rows = f.read().splitlines()
|
|
return rows
|