88 lines
2.9 KiB
Python
88 lines
2.9 KiB
Python
|
# -*- coding: utf-8 -*-
|
|||
|
# @Time : 2019/12/4 13:12
|
|||
|
# @Author : zhoujun
|
|||
|
import copy
|
|||
|
from paddle.io import Dataset
|
|||
|
from data_loader.modules import *
|
|||
|
|
|||
|
|
|||
|
class BaseDataSet(Dataset):
|
|||
|
def __init__(self,
|
|||
|
data_path: str,
|
|||
|
img_mode,
|
|||
|
pre_processes,
|
|||
|
filter_keys,
|
|||
|
ignore_tags,
|
|||
|
transform=None,
|
|||
|
target_transform=None):
|
|||
|
assert img_mode in ['RGB', 'BRG', 'GRAY']
|
|||
|
self.ignore_tags = ignore_tags
|
|||
|
self.data_list = self.load_data(data_path)
|
|||
|
item_keys = [
|
|||
|
'img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags'
|
|||
|
]
|
|||
|
for item in item_keys:
|
|||
|
assert item in self.data_list[
|
|||
|
0], 'data_list from load_data must contains {}'.format(
|
|||
|
item_keys)
|
|||
|
self.img_mode = img_mode
|
|||
|
self.filter_keys = filter_keys
|
|||
|
self.transform = transform
|
|||
|
self.target_transform = target_transform
|
|||
|
self._init_pre_processes(pre_processes)
|
|||
|
|
|||
|
def _init_pre_processes(self, pre_processes):
|
|||
|
self.aug = []
|
|||
|
if pre_processes is not None:
|
|||
|
for aug in pre_processes:
|
|||
|
if 'args' not in aug:
|
|||
|
args = {}
|
|||
|
else:
|
|||
|
args = aug['args']
|
|||
|
if isinstance(args, dict):
|
|||
|
cls = eval(aug['type'])(**args)
|
|||
|
else:
|
|||
|
cls = eval(aug['type'])(args)
|
|||
|
self.aug.append(cls)
|
|||
|
|
|||
|
def load_data(self, data_path: str) -> list:
|
|||
|
"""
|
|||
|
把数据加载为一个list:
|
|||
|
:params data_path: 存储数据的文件夹或者文件
|
|||
|
return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags'
|
|||
|
"""
|
|||
|
raise NotImplementedError
|
|||
|
|
|||
|
def apply_pre_processes(self, data):
|
|||
|
for aug in self.aug:
|
|||
|
data = aug(data)
|
|||
|
return data
|
|||
|
|
|||
|
def __getitem__(self, index):
|
|||
|
try:
|
|||
|
data = copy.deepcopy(self.data_list[index])
|
|||
|
im = cv2.imread(data['img_path'], 1
|
|||
|
if self.img_mode != 'GRAY' else 0)
|
|||
|
if self.img_mode == 'RGB':
|
|||
|
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
|||
|
data['img'] = im
|
|||
|
data['shape'] = [im.shape[0], im.shape[1]]
|
|||
|
data = self.apply_pre_processes(data)
|
|||
|
|
|||
|
if self.transform:
|
|||
|
data['img'] = self.transform(data['img'])
|
|||
|
data['text_polys'] = data['text_polys'].tolist()
|
|||
|
if len(self.filter_keys):
|
|||
|
data_dict = {}
|
|||
|
for k, v in data.items():
|
|||
|
if k not in self.filter_keys:
|
|||
|
data_dict[k] = v
|
|||
|
return data_dict
|
|||
|
else:
|
|||
|
return data
|
|||
|
except:
|
|||
|
return self.__getitem__(np.random.randint(self.__len__()))
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.data_list)
|