1
0
mirror of https://github.com/KaiyangZhou/deep-person-reid.git synced 2025-06-03 14:53:23 +08:00

Merge fca41b23d51dcffe212148fefdd8c60b5b4509f4 into 566a56a2cb255f59ba75aa817032621784df546a

This commit is contained in:
Deepankar Sharma 2025-02-06 11:40:04 +05:30 committed by GitHub
commit cd525bf2f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 1 deletions
torchreid

@ -11,10 +11,16 @@ from torch.utils.tensorboard import SummaryWriter
from torchreid import metrics from torchreid import metrics
from torchreid.utils import ( from torchreid.utils import (
MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint, MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint,
open_specified_layers, visualize_ranked_results open_specified_layers, visualize_ranked_results, tools
) )
from torchreid.losses import DeepSupervision from torchreid.losses import DeepSupervision
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
class Engine(object): class Engine(object):
r"""A generic base Engine class for both image- and video-reid. r"""A generic base Engine class for both image- and video-reid.
@ -413,6 +419,9 @@ class Engine(object):
use_metric_cuhk03=use_metric_cuhk03 use_metric_cuhk03=use_metric_cuhk03
) )
print('Plotting CMC ...')
tools.plot_cmc(cmc, max_rank=max(ranks), save_path=f"{save_dir}/cmc_curve.png")
print('** Results **') print('** Results **')
print('mAP: {:.1%}'.format(mAP)) print('mAP: {:.1%}'.format(mAP))
print('CMC curve') print('CMC curve')

@ -5,6 +5,7 @@ import json
import time import time
import errno import errno
import numpy as np import numpy as np
import matplotlib.pyplot as plt
import random import random
import os.path as osp import os.path as osp
import warnings import warnings
@ -141,3 +142,26 @@ def listdir_nohidden(path, sort=False):
if sort: if sort:
items.sort() items.sort()
return items return items
def plot_cmc(cmc, max_rank=50, save_path="curve.png"):
"""Plots the CMC curve and saves it as an image.
Args:
cmc (numpy.ndarray): CMC values computed from the evaluation.
max_rank (int): Maximum rank to display.
save_path (str): Path to save the CMC curve image.
"""
ranks = np.arange(1, len(cmc) + 1)
plt.figure(figsize=(8, 6))
plt.plot(ranks[:max_rank], cmc[:max_rank], marker='o', linestyle='-', color='b', label="CMC Curve")
plt.xlabel("Rank")
plt.ylabel("Matching Rate")
plt.title("Cumulative Matching Characteristics (CMC) Curve")
plt.legend()
plt.grid()
# Save the plot
plt.savefig(save_path)
print(f"CMC curve saved to {save_path}")