deep-person-reid/torchreid/utils/tools.py

130 lines
3.1 KiB
Python
Raw Normal View History

2019-12-01 10:35:44 +08:00
from __future__ import division, print_function, absolute_import
2019-03-20 01:26:08 +08:00
import os
2019-12-01 10:35:44 +08:00
import sys
import json
2019-03-20 01:26:08 +08:00
import time
import errno
import numpy as np
2019-12-01 10:35:44 +08:00
import random
import os.path as osp
import warnings
2019-05-24 22:34:27 +08:00
import PIL
2019-12-01 10:35:44 +08:00
import torch
2019-03-20 01:26:08 +08:00
from PIL import Image
2019-12-01 10:35:44 +08:00
__all__ = [
'mkdir_if_missing', 'check_isfile', 'read_json', 'write_json',
'set_random_seed', 'download_url', 'read_image', 'collect_env_info'
]
2019-03-20 01:26:08 +08:00
def mkdir_if_missing(dirname):
"""Creates dirname if it is missing."""
if not osp.exists(dirname):
2019-03-20 01:26:08 +08:00
try:
os.makedirs(dirname)
2019-03-20 01:26:08 +08:00
except OSError as e:
if e.errno != errno.EEXIST:
raise
def check_isfile(fpath):
"""Checks if the given path is a file.
Args:
fpath (str): file path.
Returns:
bool
"""
isfile = osp.isfile(fpath)
2019-03-20 01:26:08 +08:00
if not isfile:
warnings.warn('No file found at "{}"'.format(fpath))
2019-03-20 01:26:08 +08:00
return isfile
def read_json(fpath):
"""Reads json file from a path."""
2019-03-20 01:26:08 +08:00
with open(fpath, 'r') as f:
obj = json.load(f)
return obj
def write_json(obj, fpath):
"""Writes to a json file."""
2019-03-20 01:26:08 +08:00
mkdir_if_missing(osp.dirname(fpath))
with open(fpath, 'w') as f:
json.dump(obj, f, indent=4, separators=(',', ': '))
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def download_url(url, dst):
"""Downloads file from a url to a destination.
Args:
url (str): url to download file.
dst (str): destination path.
"""
2019-03-20 01:26:08 +08:00
from six.moves import urllib
print('* url="{}"'.format(url))
print('* destination="{}"'.format(dst))
def _reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
2019-12-01 10:35:44 +08:00
speed = int(progress_size / (1024*duration))
2019-03-20 01:26:08 +08:00
percent = int(count * block_size * 100 / total_size)
2019-12-01 10:35:44 +08:00
sys.stdout.write(
'\r...%d%%, %d MB, %d KB/s, %d seconds passed' %
(percent, progress_size / (1024*1024), speed, duration)
)
2019-03-20 01:26:08 +08:00
sys.stdout.flush()
urllib.request.urlretrieve(url, dst, _reporthook)
sys.stdout.write('\n')
def read_image(path):
"""Reads image from path using ``PIL.Image``.
Args:
path (str): path to an image.
Returns:
PIL image
"""
2019-03-20 01:26:08 +08:00
got_img = False
if not osp.exists(path):
raise IOError('"{}" does not exist'.format(path))
while not got_img:
try:
img = Image.open(path).convert('RGB')
got_img = True
except IOError:
2019-12-01 10:35:44 +08:00
print(
'IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'
.format(path)
)
2019-05-24 22:34:27 +08:00
return img
def collect_env_info():
"""Returns env info as a string.
Code source: github.com/facebookresearch/maskrcnn-benchmark
"""
from torch.utils.collect_env import get_pretty_env_info
env_str = get_pretty_env_info()
env_str += '\n Pillow ({})'.format(PIL.__version__)
return env_str