added utility to plot cmc curve
parent
566a56a2cb
commit
fca41b23d5
|
@ -11,10 +11,16 @@ from torch.utils.tensorboard import SummaryWriter
|
||||||
from torchreid import metrics
|
from torchreid import metrics
|
||||||
from torchreid.utils import (
|
from torchreid.utils import (
|
||||||
MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint,
|
MetricMeter, AverageMeter, re_ranking, open_all_layers, save_checkpoint,
|
||||||
open_specified_layers, visualize_ranked_results
|
open_specified_layers, visualize_ranked_results, tools
|
||||||
)
|
)
|
||||||
from torchreid.losses import DeepSupervision
|
from torchreid.losses import DeepSupervision
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Engine(object):
|
class Engine(object):
|
||||||
r"""A generic base Engine class for both image- and video-reid.
|
r"""A generic base Engine class for both image- and video-reid.
|
||||||
|
@ -413,6 +419,9 @@ class Engine(object):
|
||||||
use_metric_cuhk03=use_metric_cuhk03
|
use_metric_cuhk03=use_metric_cuhk03
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print('Plotting CMC ...')
|
||||||
|
tools.plot_cmc(cmc, max_rank=max(ranks), save_path=f"{save_dir}/cmc_curve.png")
|
||||||
|
|
||||||
print('** Results **')
|
print('** Results **')
|
||||||
print('mAP: {:.1%}'.format(mAP))
|
print('mAP: {:.1%}'.format(mAP))
|
||||||
print('CMC curve')
|
print('CMC curve')
|
||||||
|
|
|
@ -5,6 +5,7 @@ import json
|
||||||
import time
|
import time
|
||||||
import errno
|
import errno
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import random
|
import random
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -141,3 +142,26 @@ def listdir_nohidden(path, sort=False):
|
||||||
if sort:
|
if sort:
|
||||||
items.sort()
|
items.sort()
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
def plot_cmc(cmc, max_rank=50, save_path="curve.png"):
|
||||||
|
"""Plots the CMC curve and saves it as an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cmc (numpy.ndarray): CMC values computed from the evaluation.
|
||||||
|
max_rank (int): Maximum rank to display.
|
||||||
|
save_path (str): Path to save the CMC curve image.
|
||||||
|
"""
|
||||||
|
ranks = np.arange(1, len(cmc) + 1)
|
||||||
|
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
plt.plot(ranks[:max_rank], cmc[:max_rank], marker='o', linestyle='-', color='b', label="CMC Curve")
|
||||||
|
plt.xlabel("Rank")
|
||||||
|
plt.ylabel("Matching Rate")
|
||||||
|
plt.title("Cumulative Matching Characteristics (CMC) Curve")
|
||||||
|
plt.legend()
|
||||||
|
plt.grid()
|
||||||
|
|
||||||
|
# Save the plot
|
||||||
|
plt.savefig(save_path)
|
||||||
|
print(f"CMC curve saved to {save_path}")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue