added utility to plot cmc curve

pull/589/head
Deepankar Sharma 2025-02-06 11:37:48 +05:30 committed by GitHub
parent 566a56a2cb
commit fca41b23d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 1 deletions

View File

@ -11,10 +11,16 @@ from torch.utils.tensorboard import SummaryWriter
from torchreid import metrics from torchreid import metrics
from torchreid.utils import ( from torchreid.utils import (
MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint, MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint,
open_specified_layers, visualize_ranked_results open_specified_layers, visualize_ranked_results, tools
) )
from torchreid.losses import DeepSupervision from torchreid.losses import DeepSupervision
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
class Engine(object): class Engine(object):
r"""A generic base Engine class for both image- and video-reid. r"""A generic base Engine class for both image- and video-reid.
@ -413,6 +419,9 @@ class Engine(object):
use_metric_cuhk03=use_metric_cuhk03 use_metric_cuhk03=use_metric_cuhk03
) )
print('Plotting CMC ...')
tools.plot_cmc(cmc, max_rank=max(ranks), save_path=f"{save_dir}/cmc_curve.png")
print('** Results **') print('** Results **')
print('mAP: {:.1%}'.format(mAP)) print('mAP: {:.1%}'.format(mAP))
print('CMC curve') print('CMC curve')

View File

@ -5,6 +5,7 @@ import json
import time import time
import errno import errno
import numpy as np import numpy as np
import matplotlib.pyplot as plt
import random import random
import os.path as osp import os.path as osp
import warnings import warnings
@ -141,3 +142,26 @@ def listdir_nohidden(path, sort=False):
if sort: if sort:
items.sort() items.sort()
return items return items
def plot_cmc(cmc, max_rank=50, save_path="curve.png"):
"""Plots the CMC curve and saves it as an image.
Args:
cmc (numpy.ndarray): CMC values computed from the evaluation.
max_rank (int): Maximum rank to display.
save_path (str): Path to save the CMC curve image.
"""
ranks = np.arange(1, len(cmc) + 1)
plt.figure(figsize=(8, 6))
plt.plot(ranks[:max_rank], cmc[:max_rank], marker='o', linestyle='-', color='b', label="CMC Curve")
plt.xlabel("Rank")
plt.ylabel("Matching Rate")
plt.title("Cumulative Matching Characteristics (CMC) Curve")
plt.legend()
plt.grid()
# Save the plot
plt.savefig(save_path)
print(f"CMC curve saved to {save_path}")