deep-person-reid/torchreid/utils/tools.py

168 lines
4.2 KiB
Python
Raw Normal View History

2019-12-01 10:35:44 +08:00
from __future__ import division, print_function, absolute_import
2019-03-20 01:26:08 +08:00
import os
2019-12-01 10:35:44 +08:00
import sys
import json
2019-03-20 01:26:08 +08:00
import time
import errno
import numpy as np
2025-02-06 14:07:48 +08:00
import matplotlib.pyplot as plt
2019-12-01 10:35:44 +08:00
import random
import os.path as osp
import warnings
2019-05-24 22:34:27 +08:00
import PIL
2019-12-01 10:35:44 +08:00
import torch
2019-03-20 01:26:08 +08:00
from PIL import Image
2019-12-01 10:35:44 +08:00
__all__ = [
'mkdir_if_missing', 'check_isfile', 'read_json', 'write_json',
2021-04-28 16:49:35 +08:00
'set_random_seed', 'download_url', 'read_image', 'collect_env_info',
'listdir_nohidden'
2019-12-01 10:35:44 +08:00
]
2019-03-20 01:26:08 +08:00
def mkdir_if_missing(dirname):
"""Creates dirname if it is missing."""
if not osp.exists(dirname):
2019-03-20 01:26:08 +08:00
try:
os.makedirs(dirname)
2019-03-20 01:26:08 +08:00
except OSError as e:
if e.errno != errno.EEXIST:
raise
def check_isfile(fpath):
"""Checks if the given path is a file.
Args:
fpath (str): file path.
Returns:
bool
"""
isfile = osp.isfile(fpath)
2019-03-20 01:26:08 +08:00
if not isfile:
warnings.warn('No file found at "{}"'.format(fpath))
2019-03-20 01:26:08 +08:00
return isfile
def read_json(fpath):
"""Reads json file from a path."""
2019-03-20 01:26:08 +08:00
with open(fpath, 'r') as f:
obj = json.load(f)
return obj
def write_json(obj, fpath):
"""Writes to a json file."""
2019-03-20 01:26:08 +08:00
mkdir_if_missing(osp.dirname(fpath))
with open(fpath, 'w') as f:
json.dump(obj, f, indent=4, separators=(',', ': '))
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def download_url(url, dst):
"""Downloads file from a url to a destination.
Args:
url (str): url to download file.
dst (str): destination path.
"""
2019-03-20 01:26:08 +08:00
from six.moves import urllib
print('* url="{}"'.format(url))
print('* destination="{}"'.format(dst))
def _reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = 0 if duration == 0 else int(progress_size / (1024*duration))
2019-03-20 01:26:08 +08:00
percent = int(count * block_size * 100 / total_size)
2019-12-01 10:35:44 +08:00
sys.stdout.write(
'\r...%d%%, %d MB, %d KB/s, %d seconds passed' %
(percent, progress_size / (1024*1024), speed, duration)
)
2019-03-20 01:26:08 +08:00
sys.stdout.flush()
urllib.request.urlretrieve(url, dst, _reporthook)
sys.stdout.write('\n')
def read_image(path):
"""Reads image from path using ``PIL.Image``.
Args:
path (str): path to an image.
Returns:
PIL image
"""
2019-03-20 01:26:08 +08:00
got_img = False
if not osp.exists(path):
raise IOError('"{}" does not exist'.format(path))
while not got_img:
try:
img = Image.open(path).convert('RGB')
got_img = True
except IOError:
2019-12-01 10:35:44 +08:00
print(
'IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'
.format(path)
)
2019-05-24 22:34:27 +08:00
return img
def collect_env_info():
"""Returns env info as a string.
Code source: github.com/facebookresearch/maskrcnn-benchmark
"""
from torch.utils.collect_env import get_pretty_env_info
env_str = get_pretty_env_info()
env_str += '\n Pillow ({})'.format(PIL.__version__)
return env_str
2021-04-28 16:49:35 +08:00
def listdir_nohidden(path, sort=False):
"""List non-hidden items in a directory.
Args:
path (str): directory path.
sort (bool): sort the items.
"""
items = [f for f in os.listdir(path) if not f.startswith('.')]
if sort:
items.sort()
return items
2025-02-06 14:07:48 +08:00
def plot_cmc(cmc, max_rank=50, save_path="curve.png"):
"""Plots the CMC curve and saves it as an image.
Args:
cmc (numpy.ndarray): CMC values computed from the evaluation.
max_rank (int): Maximum rank to display.
save_path (str): Path to save the CMC curve image.
"""
ranks = np.arange(1, len(cmc) + 1)
plt.figure(figsize=(8, 6))
plt.plot(ranks[:max_rank], cmc[:max_rank], marker='o', linestyle='-', color='b', label="CMC Curve")
plt.xlabel("Rank")
plt.ylabel("Matching Rate")
plt.title("Cumulative Matching Characteristics (CMC) Curve")
plt.legend()
plt.grid()
# Save the plot
plt.savefig(save_path)
print(f"CMC curve saved to {save_path}")