PaddleOCR/ppocr/utils/network.py

154 lines
5.2 KiB
Python

# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import shutil
import tarfile
import requests
import os.path as osp
import paddle.distributed as dist
from tqdm import tqdm
from ppocr.utils.logging import get_logger
MODELS_DIR = os.path.expanduser("~/.paddleocr/models/")
DOWNLOAD_RETRY_LIMIT = 3
def download_with_progressbar(url, save_path):
logger = get_logger()
if save_path and os.path.exists(save_path):
logger.info(f"Path {save_path} already exists. Skipping...")
return
else:
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different nodes will download
# data, and the same node will only download data once.
if dist.get_rank() == 0:
_download(url, save_path)
else:
while not os.path.exists(save_path):
time.sleep(1)
def _download(url, save_path):
"""
Download from url, save to path.
url (str): download url
save_path (str): download to given path
"""
logger = get_logger()
fname = osp.split(url)[-1]
retry_cnt = 0
while not osp.exists(save_path):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError(
"Download from {} failed. " "Retry limit reached".format(url)
)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info(
"Downloading {} from {} failed {} times with exception {}".format(
fname, url, retry_cnt + 1, str(e)
)
)
time.sleep(1)
continue
if req.status_code != 200:
raise RuntimeError(
"Downloading from {} failed with code "
"{}!".format(url, req.status_code)
)
# For protecting download interupted, download to
# tmp_file firstly, move tmp_file to save_path
# after download finished
tmp_file = save_path + ".tmp"
total_size = req.headers.get("content-length")
with open(tmp_file, "wb") as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_file, save_path)
return save_path
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [".pdiparams", ".pdiparams.info", ".pdmodel"]
if not os.path.exists(
os.path.join(model_storage_directory, "inference.pdiparams")
) or not os.path.exists(os.path.join(model_storage_directory, "inference.pdmodel")):
assert url.endswith(".tar"), "Only supports tar compressed package"
tmp_path = os.path.join(model_storage_directory, url.split("/")[-1])
print("download {} to {}".format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, "r") as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if member.name.endswith(tar_file_name):
filename = "inference" + tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(os.path.join(model_storage_directory, filename), "wb") as f:
f.write(file.read())
os.remove(tmp_path)
def maybe_download_params(model_path):
if os.path.exists(model_path) or not is_link(model_path):
return model_path
else:
url = model_path
tmp_path = os.path.join(MODELS_DIR, url.split("/")[-1])
print("download {} to {}".format(url, tmp_path))
os.makedirs(MODELS_DIR, exist_ok=True)
download_with_progressbar(url, tmp_path)
return tmp_path
def is_link(s):
return s is not None and s.startswith("http")
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
url = default_url
if model_dir is None or is_link(model_dir):
if is_link(model_dir):
url = model_dir
file_name = url.split("/")[-1][:-4]
model_dir = default_model_dir
model_dir = os.path.join(model_dir, file_name)
return model_dir, url