added utility to plot cmc curve

pull/589/head
Deepankar Sharma 2025-02-06 11:37:48 +05:30 committed by GitHub
parent 566a56a2cb
commit fca41b23d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 1 deletions

View File

@ -11,10 +11,16 @@ from torch.utils.tensorboard import SummaryWriter
from torchreid import metrics
from torchreid.utils import (
MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint,
open_specified_layers, visualize_ranked_results
open_specified_layers, visualize_ranked_results, tools
)
from torchreid.losses import DeepSupervision
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
class Engine(object):
r"""A generic base Engine class for both image- and video-reid.
@ -412,6 +418,9 @@ class Engine(object):
g_camids,
use_metric_cuhk03=use_metric_cuhk03
)
print('Plotting CMC ...')
tools.plot_cmc(cmc, max_rank=max(ranks), save_path=f"{save_dir}/cmc_curve.png")
print('** Results **')
print('mAP: {:.1%}'.format(mAP))

View File

@ -5,6 +5,7 @@ import json
import time
import errno
import numpy as np
import matplotlib.pyplot as plt
import random
import os.path as osp
import warnings
@ -141,3 +142,26 @@ def listdir_nohidden(path, sort=False):
if sort:
items.sort()
return items
def plot_cmc(cmc, max_rank=50, save_path="curve.png"):
"""Plots the CMC curve and saves it as an image.
Args:
cmc (numpy.ndarray): CMC values computed from the evaluation.
max_rank (int): Maximum rank to display.
save_path (str): Path to save the CMC curve image.
"""
ranks = np.arange(1, len(cmc) + 1)
plt.figure(figsize=(8, 6))
plt.plot(ranks[:max_rank], cmc[:max_rank], marker='o', linestyle='-', color='b', label="CMC Curve")
plt.xlabel("Rank")
plt.ylabel("Matching Rate")
plt.title("Cumulative Matching Characteristics (CMC) Curve")
plt.legend()
plt.grid()
# Save the plot
plt.savefig(save_path)
print(f"CMC curve saved to {save_path}")