EasyCV/easycv/datasets/shared/odps_reader.py

247 lines
8.4 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import base64
import os
import time
from random import randint
import numpy as np
import requests
from mmcv.runner import get_dist_info
from PIL import Image, ImageFile
from easycv.datasets.registry import DATASOURCES
from easycv.file import io
ImageFile.LOAD_TRUNCATED_IMAGES = True
data_cache = {}
SUPPORT_IMAGE_TYPE = ['url', 'base64']
DATALOADER_WORKID = -1
DATALOADER_WORKNUM = 1
def set_dataloader_workid(value):
global DATALOADER_WORKID
DATALOADER_WORKID = value
def set_dataloader_worknum(value):
global DATALOADER_WORKNUM
DATALOADER_WORKNUM = value
def get_dist_image(img_url, max_try=10):
img = None
try_idx = 0
while try_idx < max_try:
try:
# http url
if img_url.startswith('http'):
img = Image.open(requests.get(img_url,
stream=True).raw).convert('RGB')
# oss url
else:
img = Image.open(io.open(img_url, 'rb')).convert('RGB')
except:
print('Try read file fault, %s' % img_url)
time.sleep(1)
img = None
try_idx += 1
if img is not None:
break
return img
@DATASOURCES.register_module
class OdpsReader(object):
def __init__(self,
table_name,
selected_cols=[],
excluded_cols=[],
random_start=False,
odps_io_config=None,
image_col=['url_image'],
image_type=['url']):
"""Init odps reader and datasource set to load data from odps table
Args:
table_name (str): odps table to load
selected_cols (list(str)): select column
excluded_cols (list(str)): exclude column
random_start (bool): random start for odps table
odps_io_config (dict): odps config contains access_id, access_key, endpoint
image_col (list(str)): image column names
image_type (list(str)): image column types support url/base64, must be same length with image type or 0
Returns :
None
"""
assert (odps_io_config
is not None), 'odps_io_config should be set for OdpsReader !'
# set odps config
if odps_io_config is not None:
assert 'access_id' in odps_io_config.keys(
), 'odps_io_config should contains access_id'
assert 'access_key' in odps_io_config.keys(
), 'odps_io_config should contains access_key'
assert 'end_point' in odps_io_config.keys(
), 'odps_io_config should contains end_point'
# distributed env, especially on PAI-Studio.
if not os.path.exists('.odps_io_config'):
write_idx = 0
while not os.path.exists('.odps_io_config') and write_idx < 10:
write_idx += 1
try:
with open('.odps_io_config', 'w') as f:
f.write('access_id=%s\n' %
(odps_io_config['access_id']))
f.write('access_key=%s\n' %
(odps_io_config['access_key']))
f.write('end_point=%s\n' %
(odps_io_config['end_point']))
except:
pass
os.environ['ODPS_CONFIG_FILE_PATH'] = '.odps_config'
# set distribute read
rank, world_size = get_dist_info()
# there are two multi process world for dataset, first multi-gpu worker, secord multi process for per GPU
self.dataloader_init = False
# keep input args
assert (
type(image_type) == list and type(image_col) == list
), 'image_col, image_type for OdpsReader must be set as list of (column name), list of (image type)'
assert (len(image_type) == len(image_col))
self.selected_cols = selected_cols
self.excluded_cols = excluded_cols
self.rank = rank
self.ddp_world_size = world_size
self.table_name = table_name
self.random_start = random_start
# init for reader
import common_io
self.reader = common_io.table.TableReader(
self.table_name,
slice_id=self.rank,
slice_count=self.ddp_world_size,
selected_cols=','.join(self.selected_cols),
excluded_cols=','.join(self.excluded_cols),
)
self.length = self.reader.get_row_count()
self.world_size = self.ddp_world_size
if self.random_start:
self.idx = randint(0, self.length)
self.reader.seek(self.idx)
else:
self.idx = 0
# init for find image
self.schema = self.reader.get_schema()
self.schema_name = [i[0] for i in self.schema]
# find base64 image in odps schema
self.base64_image_idx = []
self.url_image_idx = []
for idx, s in enumerate(self.schema):
if s[0] in image_col:
assert (
s[1] == 'string'
), 'ODPS image column must be string type, %s is %s !' % (s[0],
s[1])
idx_type = image_type[image_col.index(s[0])]
assert (
idx_type in SUPPORT_IMAGE_TYPE
), 'image_type must set in support image type : url / base64'
if idx_type == 'url':
self.url_image_idx.append(idx)
if idx_type == 'base64':
self.base64_image_idx.append(idx)
delattr(self, 'reader')
return
def __len__(self):
return self.length * self.world_size
def reset_reader(self, dataloader_workid, dataloader_worknum):
import common_io
self.reader = common_io.table.TableReader(
self.table_name,
slice_id=self.rank * dataloader_worknum + dataloader_workid,
slice_count=self.ddp_world_size * dataloader_worknum,
selected_cols=','.join(self.selected_cols),
excluded_cols=','.join(self.excluded_cols),
)
self.length = self.reader.get_row_count()
self.world_size = self.ddp_world_size * dataloader_worknum
if self.random_start:
self.idx = randint(0, self.length)
self.reader.seek(self.idx)
else:
self.idx = 0
def __getitem__(self, idx):
global DATALOADER_WORKID
global DATALOADER_WORKNUM
# we must del reader before init to support pytorch dataloader multi-process
if not hasattr(self, 'reader'):
import common_io
self.reader = common_io.table.TableReader(
self.table_name,
slice_id=self.rank,
slice_count=self.ddp_world_size,
selected_cols=','.join(self.selected_cols),
excluded_cols=','.join(self.excluded_cols),
)
if not self.dataloader_init:
# num_per_gpu = 1 means we should not split reader
if DATALOADER_WORKNUM == 1:
self.dataloader_init = True
elif DATALOADER_WORKNUM < 1:
print('num_per_gpu for OdpsReader should >= 1')
else:
# if DATALOADER_WORKID == -1:
assert (
DATALOADER_WORKID > -1
), "num_per_gpu for OdpsReader > 1, but DATALOADER_WORKNUM didn't be set by work_fn, False"
self.reset_reader(DATALOADER_WORKID, DATALOADER_WORKNUM)
self.dataloader_init = True
self.idx += 1
t = self.reader.read()[0]
if self.idx == self.length:
self.reader.seek(-self.length)
self.idx = 0
return_dict = {}
# need set oss_io before
for idx, m in enumerate(t):
if idx in self.base64_image_idx:
return_dict[self.schema_name[idx]] = Image.fromarray(
np.frombuffer(self.b64_decode(m)))
elif idx in self.url_image_idx:
return_dict[self.schema_name[idx]] = get_dist_image(m, 5)
else:
return_dict[self.schema_name[idx]] = m
return return_dict
def b64_decode(string):
return base64.decodebytes(string.encode())