PaddleOCR/benchmark/PaddleOCR_DBNet/base/base_dataset.py

87 lines
2.8 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
# @Time : 2019/12/4 13:12
# @Author : zhoujun
import copy
from paddle.io import Dataset
from data_loader.modules import *
class BaseDataSet(Dataset):
def __init__(
self,
data_path: str,
img_mode,
pre_processes,
filter_keys,
ignore_tags,
transform=None,
target_transform=None,
):
assert img_mode in ["RGB", "BRG", "GRAY"]
self.ignore_tags = ignore_tags
self.data_list = self.load_data(data_path)
item_keys = ["img_path", "img_name", "text_polys", "texts", "ignore_tags"]
for item in item_keys:
assert (
item in self.data_list[0]
), "data_list from load_data must contains {}".format(item_keys)
self.img_mode = img_mode
self.filter_keys = filter_keys
self.transform = transform
self.target_transform = target_transform
self._init_pre_processes(pre_processes)
def _init_pre_processes(self, pre_processes):
self.aug = []
if pre_processes is not None:
for aug in pre_processes:
if "args" not in aug:
args = {}
else:
args = aug["args"]
if isinstance(args, dict):
cls = eval(aug["type"])(**args)
else:
cls = eval(aug["type"])(args)
self.aug.append(cls)
def load_data(self, data_path: str) -> list:
"""
把数据加载为一个list
:params data_path: 存储数据的文件夹或者文件
return a dict ,包含了'img_path','img_name','text_polys','texts','ignore_tags'
"""
raise NotImplementedError
def apply_pre_processes(self, data):
for aug in self.aug:
data = aug(data)
return data
def __getitem__(self, index):
try:
data = copy.deepcopy(self.data_list[index])
im = cv2.imread(data["img_path"], 1 if self.img_mode != "GRAY" else 0)
if self.img_mode == "RGB":
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
data["img"] = im
data["shape"] = [im.shape[0], im.shape[1]]
data = self.apply_pre_processes(data)
if self.transform:
data["img"] = self.transform(data["img"])
data["text_polys"] = data["text_polys"].tolist()
if len(self.filter_keys):
data_dict = {}
for k, v in data.items():
if k not in self.filter_keys:
data_dict[k] = v
return data_dict
else:
return data
except:
return self.__getitem__(np.random.randint(self.__len__()))
def __len__(self):
return len(self.data_list)