91 lines
2.3 KiB
Python
91 lines
2.3 KiB
Python
|
from __future__ import absolute_import
|
||
|
from __future__ import print_function
|
||
|
from __future__ import division
|
||
|
|
||
|
__all__ = ['mkdir_if_missing', 'check_isfile', 'read_json', 'write_json',
|
||
|
'set_random_seed', 'download_url', 'read_image']
|
||
|
|
||
|
import sys
|
||
|
import os
|
||
|
import os.path as osp
|
||
|
import time
|
||
|
import errno
|
||
|
import json
|
||
|
from collections import OrderedDict
|
||
|
import warnings
|
||
|
import random
|
||
|
import numpy as np
|
||
|
from PIL import Image
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def mkdir_if_missing(directory):
|
||
|
if not osp.exists(directory):
|
||
|
try:
|
||
|
os.makedirs(directory)
|
||
|
except OSError as e:
|
||
|
if e.errno != errno.EEXIST:
|
||
|
raise
|
||
|
|
||
|
|
||
|
def check_isfile(path):
|
||
|
isfile = osp.isfile(path)
|
||
|
if not isfile:
|
||
|
warnings.warn('No file found at "{}"'.format(path))
|
||
|
return isfile
|
||
|
|
||
|
|
||
|
def read_json(fpath):
|
||
|
with open(fpath, 'r') as f:
|
||
|
obj = json.load(f)
|
||
|
return obj
|
||
|
|
||
|
|
||
|
def write_json(obj, fpath):
|
||
|
mkdir_if_missing(osp.dirname(fpath))
|
||
|
with open(fpath, 'w') as f:
|
||
|
json.dump(obj, f, indent=4, separators=(',', ': '))
|
||
|
|
||
|
|
||
|
def set_random_seed(seed):
|
||
|
random.seed(seed)
|
||
|
np.random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed_all(seed)
|
||
|
|
||
|
|
||
|
def download_url(url, dst):
|
||
|
from six.moves import urllib
|
||
|
print('* url="{}"'.format(url))
|
||
|
print('* destination="{}"'.format(dst))
|
||
|
|
||
|
def _reporthook(count, block_size, total_size):
|
||
|
global start_time
|
||
|
if count == 0:
|
||
|
start_time = time.time()
|
||
|
return
|
||
|
duration = time.time() - start_time
|
||
|
progress_size = int(count * block_size)
|
||
|
speed = int(progress_size / (1024 * duration))
|
||
|
percent = int(count * block_size * 100 / total_size)
|
||
|
sys.stdout.write('\r...%d%%, %d MB, %d KB/s, %d seconds passed' %
|
||
|
(percent, progress_size / (1024 * 1024), speed, duration))
|
||
|
sys.stdout.flush()
|
||
|
|
||
|
urllib.request.urlretrieve(url, dst, _reporthook)
|
||
|
sys.stdout.write('\n')
|
||
|
|
||
|
|
||
|
def read_image(path):
|
||
|
got_img = False
|
||
|
if not osp.exists(path):
|
||
|
raise IOError('"{}" does not exist'.format(path))
|
||
|
while not got_img:
|
||
|
try:
|
||
|
img = Image.open(path).convert('RGB')
|
||
|
got_img = True
|
||
|
except IOError:
|
||
|
print('IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'.format(img_path))
|
||
|
pass
|
||
|
return img
|