diff --git a/torchreid/engine/engine.py b/torchreid/engine/engine.py index bbc01e0..85872fa 100644 --- a/torchreid/engine/engine.py +++ b/torchreid/engine/engine.py @@ -11,10 +11,16 @@ from torch.utils.tensorboard import SummaryWriter from torchreid import metrics from torchreid.utils import ( MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint, - open_specified_layers, visualize_ranked_results + open_specified_layers, visualize_ranked_results, tools ) from torchreid.losses import DeepSupervision +import matplotlib.pyplot as plt +import numpy as np + +import matplotlib.pyplot as plt +import numpy as np + class Engine(object): r"""A generic base Engine class for both image- and video-reid. @@ -412,6 +418,9 @@ class Engine(object): g_camids, use_metric_cuhk03=use_metric_cuhk03 ) + + print('Plotting CMC ...') + tools.plot_cmc(cmc, max_rank=max(ranks), save_path=f"{save_dir}/cmc_curve.png") print('** Results **') print('mAP: {:.1%}'.format(mAP)) diff --git a/torchreid/utils/tools.py b/torchreid/utils/tools.py index 2ccf63a..a596146 100644 --- a/torchreid/utils/tools.py +++ b/torchreid/utils/tools.py @@ -5,6 +5,7 @@ import json import time import errno import numpy as np +import matplotlib.pyplot as plt import random import os.path as osp import warnings @@ -141,3 +142,26 @@ def listdir_nohidden(path, sort=False): if sort: items.sort() return items + +def plot_cmc(cmc, max_rank=50, save_path="curve.png"): + """Plots the CMC curve and saves it as an image. + + Args: + cmc (numpy.ndarray): CMC values computed from the evaluation. + max_rank (int): Maximum rank to display. + save_path (str): Path to save the CMC curve image. + """ + ranks = np.arange(1, len(cmc) + 1) + + plt.figure(figsize=(8, 6)) + plt.plot(ranks[:max_rank], cmc[:max_rank], marker='o', linestyle='-', color='b', label="CMC Curve") + plt.xlabel("Rank") + plt.ylabel("Matching Rate") + plt.title("Cumulative Matching Characteristics (CMC) Curve") + plt.legend() + plt.grid() + + # Save the plot + plt.savefig(save_path) + print(f"CMC curve saved to {save_path}") +