Merge fca41b23d5
into 566a56a2cb
commit
cd525bf2f1
|
@ -11,10 +11,16 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
from torchreid import metrics
|
||||
from torchreid.utils import (
|
||||
MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint,
|
||||
open_specified_layers, visualize_ranked_results
|
||||
open_specified_layers, visualize_ranked_results, tools
|
||||
)
|
||||
from torchreid.losses import DeepSupervision
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Engine(object):
|
||||
r"""A generic base Engine class for both image- and video-reid.
|
||||
|
@ -412,6 +418,9 @@ class Engine(object):
|
|||
g_camids,
|
||||
use_metric_cuhk03=use_metric_cuhk03
|
||||
)
|
||||
|
||||
print('Plotting CMC ...')
|
||||
tools.plot_cmc(cmc, max_rank=max(ranks), save_path=f"{save_dir}/cmc_curve.png")
|
||||
|
||||
print('** Results **')
|
||||
print('mAP: {:.1%}'.format(mAP))
|
||||
|
|
|
@ -5,6 +5,7 @@ import json
|
|||
import time
|
||||
import errno
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import random
|
||||
import os.path as osp
|
||||
import warnings
|
||||
|
@ -141,3 +142,26 @@ def listdir_nohidden(path, sort=False):
|
|||
if sort:
|
||||
items.sort()
|
||||
return items
|
||||
|
||||
def plot_cmc(cmc, max_rank=50, save_path="curve.png"):
|
||||
"""Plots the CMC curve and saves it as an image.
|
||||
|
||||
Args:
|
||||
cmc (numpy.ndarray): CMC values computed from the evaluation.
|
||||
max_rank (int): Maximum rank to display.
|
||||
save_path (str): Path to save the CMC curve image.
|
||||
"""
|
||||
ranks = np.arange(1, len(cmc) + 1)
|
||||
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.plot(ranks[:max_rank], cmc[:max_rank], marker='o', linestyle='-', color='b', label="CMC Curve")
|
||||
plt.xlabel("Rank")
|
||||
plt.ylabel("Matching Rate")
|
||||
plt.title("Cumulative Matching Characteristics (CMC) Curve")
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
|
||||
# Save the plot
|
||||
plt.savefig(save_path)
|
||||
print(f"CMC curve saved to {save_path}")
|
||||
|
||||
|
|
Loading…
Reference in New Issue