diff --git a/tools/analysis_tools/analyze_logs.py b/tools/analysis_tools/analyze_logs.py index 64b91c23c..fb6de805a 100644 --- a/tools/analysis_tools/analyze_logs.py +++ b/tools/analysis_tools/analyze_logs.py @@ -57,11 +57,9 @@ def plot_curve(log_dicts, args): f'{args.json_logs[i]} does not contain metric {metric} ' f'in train mode') - if 'mAP' in metric: - xs = np.arange(1, max(epochs) + 1) - ys = [] - for epoch in epochs: - ys += log_dict[epoch][metric] + if any(m in metric for m in ('mAP', 'accuracy')): + xs = epochs + ys = [log_dict[e][metric] for e in xs] ax = plt.gca() ax.set_xticks(xs) plt.xlabel('epoch') @@ -74,6 +72,9 @@ def plot_curve(log_dicts, args): iters = log_dict[epoch]['iter'] if log_dict[epoch]['mode'][-1] == 'val': iters = iters[:-1] + assert len(iters) > 0, ( + 'The training log is empty, please try to reduce the ' + 'interval of log in config file.') xs.append( np.array(iters) + (epoch - 1) * num_iters_per_epoch) ys.append(np.array(log_dict[epoch][metric][:len(iters)]))