diff --git a/tools/analyze_logs.py b/tools/analyze_logs.py index e2127d4d6..abf998297 100644 --- a/tools/analyze_logs.py +++ b/tools/analyze_logs.py @@ -43,14 +43,13 @@ def plot_curve(log_dicts, args): plot_values.append(epoch_logs[metric][0]) else: for idx in range(len(epoch_logs[metric])): - if epoch_logs['mode'][idx] == 'train': - plot_iters.append(epoch_logs['iter'][idx]) - plot_values.append(epoch_logs[metric][idx]) + plot_iters.append(epoch_logs['step'][idx]) + plot_values.append(epoch_logs[metric][idx]) ax = plt.gca() label = legend[i * num_metrics + j] if metric in ['mIoU', 'mAcc', 'aAcc']: ax.set_xticks(plot_epochs) - plt.xlabel('epoch') + plt.xlabel('step') plt.plot(plot_epochs, plot_values, label=label, marker='o') else: plt.xlabel('iter') @@ -96,22 +95,25 @@ def parse_args(): def load_json_logs(json_logs): - # load and convert json_logs to log_dict, key is epoch, value is a sub dict + # load and convert json_logs to log_dict, key is step, value is a sub dict # keys of sub dict is different metrics # value of sub dict is a list of corresponding values of all iterations log_dicts = [dict() for _ in json_logs] + prev_step = 0 for json_log, log_dict in zip(json_logs, log_dicts): with open(json_log, 'r') as log_file: for line in log_file: log = json.loads(line.strip()) - # skip lines without `epoch` field - if 'epoch' not in log: - continue - epoch = log.pop('epoch') - if epoch not in log_dict: - log_dict[epoch] = defaultdict(list) + # the final step in json file is 0. + if 'step' in log and log['step'] != 0: + step = log['step'] + prev_step = step + else: + step = prev_step + if step not in log: + log_dict[step] = defaultdict(list) for k, v in log.items(): - log_dict[epoch][k].append(v) + log_dict[step][k].append(v) return log_dicts