mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
* Add custom detection and recognition model usage instructions in re * update * Add custom detection and recognition model usage instructions in re * add db net for benchmark * rename benckmark to PaddleOCR_benchmark * add addict to req * rename
107 lines
3.5 KiB
Python
107 lines
3.5 KiB
Python
# -*- coding: utf-8 -*-
|
||
# @Time : 2019/8/23 21:52
|
||
# @Author : zhoujun
|
||
import copy
|
||
|
||
import PIL
|
||
import numpy as np
|
||
import paddle
|
||
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
|
||
|
||
from paddle.vision import transforms
|
||
|
||
|
||
def get_dataset(data_path, module_name, transform, dataset_args):
|
||
"""
|
||
获取训练dataset
|
||
:param data_path: dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
|
||
:param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset
|
||
:param transform: 该数据集使用的transforms
|
||
:param dataset_args: module_name的参数
|
||
:return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None
|
||
"""
|
||
from . import dataset
|
||
s_dataset = getattr(dataset, module_name)(transform=transform,
|
||
data_path=data_path,
|
||
**dataset_args)
|
||
return s_dataset
|
||
|
||
|
||
def get_transforms(transforms_config):
|
||
tr_list = []
|
||
for item in transforms_config:
|
||
if 'args' not in item:
|
||
args = {}
|
||
else:
|
||
args = item['args']
|
||
cls = getattr(transforms, item['type'])(**args)
|
||
tr_list.append(cls)
|
||
tr_list = transforms.Compose(tr_list)
|
||
return tr_list
|
||
|
||
|
||
class ICDARCollectFN:
|
||
def __init__(self, *args, **kwargs):
|
||
pass
|
||
|
||
def __call__(self, batch):
|
||
data_dict = {}
|
||
to_tensor_keys = []
|
||
for sample in batch:
|
||
for k, v in sample.items():
|
||
if k not in data_dict:
|
||
data_dict[k] = []
|
||
if isinstance(v, (np.ndarray, paddle.Tensor, PIL.Image.Image)):
|
||
if k not in to_tensor_keys:
|
||
to_tensor_keys.append(k)
|
||
data_dict[k].append(v)
|
||
for k in to_tensor_keys:
|
||
data_dict[k] = paddle.stack(data_dict[k], 0)
|
||
return data_dict
|
||
|
||
|
||
def get_dataloader(module_config, distributed=False):
|
||
if module_config is None:
|
||
return None
|
||
config = copy.deepcopy(module_config)
|
||
dataset_args = config['dataset']['args']
|
||
if 'transforms' in dataset_args:
|
||
img_transfroms = get_transforms(dataset_args.pop('transforms'))
|
||
else:
|
||
img_transfroms = None
|
||
# 创建数据集
|
||
dataset_name = config['dataset']['type']
|
||
data_path = dataset_args.pop('data_path')
|
||
if data_path == None:
|
||
return None
|
||
|
||
data_path = [x for x in data_path if x is not None]
|
||
if len(data_path) == 0:
|
||
return None
|
||
if 'collate_fn' not in config['loader'] or config['loader'][
|
||
'collate_fn'] is None or len(config['loader']['collate_fn']) == 0:
|
||
config['loader']['collate_fn'] = None
|
||
else:
|
||
config['loader']['collate_fn'] = eval(config['loader']['collate_fn'])()
|
||
|
||
_dataset = get_dataset(
|
||
data_path=data_path,
|
||
module_name=dataset_name,
|
||
transform=img_transfroms,
|
||
dataset_args=dataset_args)
|
||
sampler = None
|
||
if distributed:
|
||
# 3)使用DistributedSampler
|
||
batch_sampler = DistributedBatchSampler(
|
||
dataset=_dataset,
|
||
batch_size=config['loader'].pop('batch_size'),
|
||
shuffle=config['loader'].pop('shuffle'))
|
||
else:
|
||
batch_sampler = BatchSampler(
|
||
dataset=_dataset,
|
||
batch_size=config['loader'].pop('batch_size'),
|
||
shuffle=config['loader'].pop('shuffle'))
|
||
loader = DataLoader(
|
||
dataset=_dataset, batch_sampler=batch_sampler, **config['loader'])
|
||
return loader
|