144 lines
3.4 KiB
Python
144 lines
3.4 KiB
Python
from __future__ import division, print_function, absolute_import
|
|
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
import errno
|
|
import numpy as np
|
|
import random
|
|
import os.path as osp
|
|
import warnings
|
|
import PIL
|
|
import torch
|
|
from PIL import Image
|
|
|
|
__all__ = [
|
|
'mkdir_if_missing', 'check_isfile', 'read_json', 'write_json',
|
|
'set_random_seed', 'download_url', 'read_image', 'collect_env_info',
|
|
'listdir_nohidden'
|
|
]
|
|
|
|
|
|
def mkdir_if_missing(dirname):
|
|
"""Creates dirname if it is missing."""
|
|
if not osp.exists(dirname):
|
|
try:
|
|
os.makedirs(dirname)
|
|
except OSError as e:
|
|
if e.errno != errno.EEXIST:
|
|
raise
|
|
|
|
|
|
def check_isfile(fpath):
|
|
"""Checks if the given path is a file.
|
|
|
|
Args:
|
|
fpath (str): file path.
|
|
|
|
Returns:
|
|
bool
|
|
"""
|
|
isfile = osp.isfile(fpath)
|
|
if not isfile:
|
|
warnings.warn('No file found at "{}"'.format(fpath))
|
|
return isfile
|
|
|
|
|
|
def read_json(fpath):
|
|
"""Reads json file from a path."""
|
|
with open(fpath, 'r') as f:
|
|
obj = json.load(f)
|
|
return obj
|
|
|
|
|
|
def write_json(obj, fpath):
|
|
"""Writes to a json file."""
|
|
mkdir_if_missing(osp.dirname(fpath))
|
|
with open(fpath, 'w') as f:
|
|
json.dump(obj, f, indent=4, separators=(',', ': '))
|
|
|
|
|
|
def set_random_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def download_url(url, dst):
|
|
"""Downloads file from a url to a destination.
|
|
|
|
Args:
|
|
url (str): url to download file.
|
|
dst (str): destination path.
|
|
"""
|
|
from six.moves import urllib
|
|
print('* url="{}"'.format(url))
|
|
print('* destination="{}"'.format(dst))
|
|
|
|
def _reporthook(count, block_size, total_size):
|
|
global start_time
|
|
if count == 0:
|
|
start_time = time.time()
|
|
return
|
|
duration = time.time() - start_time
|
|
progress_size = int(count * block_size)
|
|
speed = int(progress_size / (1024*duration))
|
|
percent = int(count * block_size * 100 / total_size)
|
|
sys.stdout.write(
|
|
'\r...%d%%, %d MB, %d KB/s, %d seconds passed' %
|
|
(percent, progress_size / (1024*1024), speed, duration)
|
|
)
|
|
sys.stdout.flush()
|
|
|
|
urllib.request.urlretrieve(url, dst, _reporthook)
|
|
sys.stdout.write('\n')
|
|
|
|
|
|
def read_image(path):
|
|
"""Reads image from path using ``PIL.Image``.
|
|
|
|
Args:
|
|
path (str): path to an image.
|
|
|
|
Returns:
|
|
PIL image
|
|
"""
|
|
got_img = False
|
|
if not osp.exists(path):
|
|
raise IOError('"{}" does not exist'.format(path))
|
|
while not got_img:
|
|
try:
|
|
img = Image.open(path).convert('RGB')
|
|
got_img = True
|
|
except IOError:
|
|
print(
|
|
'IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'
|
|
.format(path)
|
|
)
|
|
return img
|
|
|
|
|
|
def collect_env_info():
|
|
"""Returns env info as a string.
|
|
|
|
Code source: github.com/facebookresearch/maskrcnn-benchmark
|
|
"""
|
|
from torch.utils.collect_env import get_pretty_env_info
|
|
env_str = get_pretty_env_info()
|
|
env_str += '\n Pillow ({})'.format(PIL.__version__)
|
|
return env_str
|
|
|
|
|
|
def listdir_nohidden(path, sort=False):
|
|
"""List non-hidden items in a directory.
|
|
|
|
Args:
|
|
path (str): directory path.
|
|
sort (bool): sort the items.
|
|
"""
|
|
items = [f for f in os.listdir(path) if not f.startswith('.')]
|
|
if sort:
|
|
items.sort()
|
|
return items
|