diff --git a/tools/analyze_logs.py b/tools/analyze_logs.py index c3a468b55..fb017efaa 100644 --- a/tools/analyze_logs.py +++ b/tools/analyze_logs.py @@ -30,6 +30,9 @@ def plot_curve(log_dicts, args): plot_epochs = [] plot_iters = [] plot_values = [] + # In some log files, iters number is not correct, `pre_iter` is + # used to prevent generate wrong lines. + pre_iter = -1 for epoch in epochs: epoch_logs = log_dict[epoch] if metric not in epoch_logs.keys(): @@ -39,6 +42,9 @@ def plot_curve(log_dicts, args): plot_values.append(epoch_logs[metric][0]) else: for idx in range(len(epoch_logs[metric])): + if pre_iter > epoch_logs['iter'][idx]: + continue + pre_iter = epoch_logs['iter'][idx] plot_iters.append(epoch_logs['iter'][idx]) plot_values.append(epoch_logs[metric][idx]) ax = plt.gca()