88 lines
2.9 KiB
Python
88 lines
2.9 KiB
Python
# -*- coding: utf-8 -*-
|
||
# @Time : 2019/12/4 13:12
|
||
# @Author : zhoujun
|
||
import copy
|
||
from paddle.io import Dataset
|
||
from data_loader.modules import *
|
||
|
||
|
||
class BaseDataSet(Dataset):
|
||
def __init__(self,
|
||
data_path: str,
|
||
img_mode,
|
||
pre_processes,
|
||
filter_keys,
|
||
ignore_tags,
|
||
transform=None,
|
||
target_transform=None):
|
||
assert img_mode in ['RGB', 'BRG', 'GRAY']
|
||
self.ignore_tags = ignore_tags
|
||
self.data_list = self.load_data(data_path)
|
||
item_keys = [
|
||
'img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags'
|
||
]
|
||
for item in item_keys:
|
||
assert item in self.data_list[
|
||
0], 'data_list from load_data must contains {}'.format(
|
||
item_keys)
|
||
self.img_mode = img_mode
|
||
self.filter_keys = filter_keys
|
||
self.transform = transform
|
||
self.target_transform = target_transform
|
||
self._init_pre_processes(pre_processes)
|
||
|
||
def _init_pre_processes(self, pre_processes):
|
||
self.aug = []
|
||
if pre_processes is not None:
|
||
for aug in pre_processes:
|
||
if 'args' not in aug:
|
||
args = {}
|
||
else:
|
||
args = aug['args']
|
||
if isinstance(args, dict):
|
||
cls = eval(aug['type'])(**args)
|
||
else:
|
||
cls = eval(aug['type'])(args)
|
||
self.aug.append(cls)
|
||
|
||
def load_data(self, data_path: str) -> list:
|
||
"""
|
||
把数据加载为一个list:
|
||
:params data_path: 存储数据的文件夹或者文件
|
||
return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags'
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def apply_pre_processes(self, data):
|
||
for aug in self.aug:
|
||
data = aug(data)
|
||
return data
|
||
|
||
def __getitem__(self, index):
|
||
try:
|
||
data = copy.deepcopy(self.data_list[index])
|
||
im = cv2.imread(data['img_path'], 1
|
||
if self.img_mode != 'GRAY' else 0)
|
||
if self.img_mode == 'RGB':
|
||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||
data['img'] = im
|
||
data['shape'] = [im.shape[0], im.shape[1]]
|
||
data = self.apply_pre_processes(data)
|
||
|
||
if self.transform:
|
||
data['img'] = self.transform(data['img'])
|
||
data['text_polys'] = data['text_polys'].tolist()
|
||
if len(self.filter_keys):
|
||
data_dict = {}
|
||
for k, v in data.items():
|
||
if k not in self.filter_keys:
|
||
data_dict[k] = v
|
||
return data_dict
|
||
else:
|
||
return data
|
||
except:
|
||
return self.__getitem__(np.random.randint(self.__len__()))
|
||
|
||
def __len__(self):
|
||
return len(self.data_list)
|