gulou 36a3c45efa
add more data source for auto download (#229)
* add caltech, flower, mnist data source

* add det lvis data source

* add pose crowdPose data source

* add pose of OC Human data source

* add pose of mpii data source

* add Seg of voc data source

* add Seg of coco data source

* add Det of wider person datasource

* add Det of african wildlife datasource

* add Det of fruit datasource

* add Det of pet datasource

* add Det of artaxor and tiny person datasource

* add Det of wider face datasource

* add Det of crowd human datasource

* add Det of object365 datasource

* add Seg of coco stuff 10k and 164k datasource

Co-authored-by: Cathy0908 <30484308+Cathy0908@users.noreply.github.com>
2022-12-02 10:57:23 +08:00

78 lines
2.4 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os
import xml.etree.ElementTree as ET
from multiprocessing import cpu_count
from easycv.datasets.registry import DATASOURCES
from .base import DetSourceBase
from .voc import parse_xml
@DATASOURCES.register_module
class DetSourceFruit(DetSourceBase):
"""
data dir is as follows:
```
|- data
|-banana_2.jpg
|-banana_2.xml
|-...
```
Example1:
data_source = DetSourceFruit(
path='/your/data/',
classes=${CLASSES},
"""
CLASSES = ['apple', 'banana', 'orange']
def __init__(self,
path,
classes=CLASSES,
cache_at_init=False,
cache_on_the_fly=False,
img_suffix='.jpg',
label_suffix='.xml',
parse_fn=parse_xml,
num_processes=int(cpu_count() / 2),
**kwargs):
"""
Args:
path: path of img id list file in ImageSets/Main/
classes: classes list
cache_at_init: if set True, will cache in memory in __init__ for faster training
cache_on_the_fly: if set True, will cache in memroy during training
img_suffix: suffix of image file
label_suffix: suffix of label file
parse_fn: parse function to parse item of source iterator
num_processes: number of processes to parse samples
"""
self.path = path
self.img_suffix = img_suffix
self.label_suffix = label_suffix
super(DetSourceFruit, self).__init__(
classes=classes,
cache_at_init=cache_at_init,
cache_on_the_fly=cache_on_the_fly,
parse_fn=parse_fn,
num_processes=num_processes)
def get_source_iterator(self):
assert os.path.exists(self.path), f'{self.path} is not exists'
imgs_path_list = []
labels_path_list = []
img_list = glob.glob(os.path.join(self.path, '*' + self.img_suffix))
for img in img_list:
label_path = img.replace(self.img_suffix, self.label_suffix)
assert os.path.exists(label_path), f'{label_path} is not exists'
imgs_path_list.append(img)
labels_path_list.append(label_path)
return list(zip(imgs_path_list, labels_path_list))