Source code for torchreid.utils.tools

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

__all__ = ['mkdir_if_missing', 'check_isfile', 'read_json', 'write_json',
           'set_random_seed', 'download_url', 'read_image']

import sys
import os
import os.path as osp
import time
import errno
import json
from collections import OrderedDict
import warnings
import random
import numpy as np
from PIL import Image

import torch


[docs]def mkdir_if_missing(dirname): """Creates dirname if it is missing.""" if not osp.exists(dirname): try: os.makedirs(dirname) except OSError as e: if e.errno != errno.EEXIST: raise
[docs]def check_isfile(fpath): """Checks if the given path is a file. Args: fpath (str): file path. Returns: bool """ isfile = osp.isfile(fpath) if not isfile: warnings.warn('No file found at "{}"'.format(fpath)) return isfile
[docs]def read_json(fpath): """Reads json file from a path.""" with open(fpath, 'r') as f: obj = json.load(f) return obj
[docs]def write_json(obj, fpath): """Writes to a json file.""" mkdir_if_missing(osp.dirname(fpath)) with open(fpath, 'w') as f: json.dump(obj, f, indent=4, separators=(',', ': '))
def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed)
[docs]def download_url(url, dst): """Downloads file from a url to a destination. Args: url (str): url to download file. dst (str): destination path. """ from six.moves import urllib print('* url="{}"'.format(url)) print('* destination="{}"'.format(dst)) def _reporthook(count, block_size, total_size): global start_time if count == 0: start_time = time.time() return duration = time.time() - start_time progress_size = int(count * block_size) speed = int(progress_size / (1024 * duration)) percent = int(count * block_size * 100 / total_size) sys.stdout.write('\r...%d%%, %d MB, %d KB/s, %d seconds passed' % (percent, progress_size / (1024 * 1024), speed, duration)) sys.stdout.flush() urllib.request.urlretrieve(url, dst, _reporthook) sys.stdout.write('\n')
[docs]def read_image(path): """Reads image from path using ``PIL.Image``. Args: path (str): path to an image. Returns: PIL image """ got_img = False if not osp.exists(path): raise IOError('"{}" does not exist'.format(path)) while not got_img: try: img = Image.open(path).convert('RGB') got_img = True except IOError: print('IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'.format(img_path)) pass return img