Source code for torchreid.utils.tools

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

__all__ = ['mkdir_if_missing', 'check_isfile', 'read_json', 'write_json',
           'set_random_seed', 'download_url', 'read_image', 'collect_env_info']

import sys
import os
import os.path as osp
import time
import errno
import json
import warnings
import random
import numpy as np
import PIL
from PIL import Image

import torch


[docs]def mkdir_if_missing(dirname): """Creates dirname if it is missing.""" if not osp.exists(dirname): try: os.makedirs(dirname) except OSError as e: if e.errno != errno.EEXIST: raise
[docs]def check_isfile(fpath): """Checks if the given path is a file. Args: fpath (str): file path. Returns: bool """ isfile = osp.isfile(fpath) if not isfile: warnings.warn('No file found at "{}"'.format(fpath)) return isfile
[docs]def read_json(fpath): """Reads json file from a path.""" with open(fpath, 'r') as f: obj = json.load(f) return obj
[docs]def write_json(obj, fpath): """Writes to a json file.""" mkdir_if_missing(osp.dirname(fpath)) with open(fpath, 'w') as f: json.dump(obj, f, indent=4, separators=(',', ': '))
def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed)
[docs]def download_url(url, dst): """Downloads file from a url to a destination. Args: url (str): url to download file. dst (str): destination path. """ from six.moves import urllib print('* url="{}"'.format(url)) print('* destination="{}"'.format(dst)) def _reporthook(count, block_size, total_size): global start_time if count == 0: start_time = time.time() return duration = time.time() - start_time progress_size = int(count * block_size) speed = int(progress_size / (1024 * duration)) percent = int(count * block_size * 100 / total_size) sys.stdout.write('\r...%d%%, %d MB, %d KB/s, %d seconds passed' % (percent, progress_size / (1024 * 1024), speed, duration)) sys.stdout.flush() urllib.request.urlretrieve(url, dst, _reporthook) sys.stdout.write('\n')
[docs]def read_image(path): """Reads image from path using ``PIL.Image``. Args: path (str): path to an image. Returns: PIL image """ got_img = False if not osp.exists(path): raise IOError('"{}" does not exist'.format(path)) while not got_img: try: img = Image.open(path).convert('RGB') got_img = True except IOError: print('IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'.format(path)) return img
[docs]def collect_env_info(): """Returns env info as a string. Code source: github.com/facebookresearch/maskrcnn-benchmark """ from torch.utils.collect_env import get_pretty_env_info env_str = get_pretty_env_info() env_str += '\n Pillow ({})'.format(PIL.__version__) return env_str