Merge branch 'xiexinchen/fix_analyze_log_script' into 'refactor_dev'

[Refactor] Fix analyze log script

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!66
pull/1801/head
zhengmiao 2022-07-14 07:28:44 +00:00
commit 4717c01fc2
1 changed files with 14 additions and 12 deletions

View File

@ -43,14 +43,13 @@ def plot_curve(log_dicts, args):
plot_values.append(epoch_logs[metric][0]) plot_values.append(epoch_logs[metric][0])
else: else:
for idx in range(len(epoch_logs[metric])): for idx in range(len(epoch_logs[metric])):
if epoch_logs['mode'][idx] == 'train': plot_iters.append(epoch_logs['step'][idx])
plot_iters.append(epoch_logs['iter'][idx]) plot_values.append(epoch_logs[metric][idx])
plot_values.append(epoch_logs[metric][idx])
ax = plt.gca() ax = plt.gca()
label = legend[i * num_metrics + j] label = legend[i * num_metrics + j]
if metric in ['mIoU', 'mAcc', 'aAcc']: if metric in ['mIoU', 'mAcc', 'aAcc']:
ax.set_xticks(plot_epochs) ax.set_xticks(plot_epochs)
plt.xlabel('epoch') plt.xlabel('step')
plt.plot(plot_epochs, plot_values, label=label, marker='o') plt.plot(plot_epochs, plot_values, label=label, marker='o')
else: else:
plt.xlabel('iter') plt.xlabel('iter')
@ -96,22 +95,25 @@ def parse_args():
def load_json_logs(json_logs): def load_json_logs(json_logs):
# load and convert json_logs to log_dict, key is epoch, value is a sub dict # load and convert json_logs to log_dict, key is step, value is a sub dict
# keys of sub dict is different metrics # keys of sub dict is different metrics
# value of sub dict is a list of corresponding values of all iterations # value of sub dict is a list of corresponding values of all iterations
log_dicts = [dict() for _ in json_logs] log_dicts = [dict() for _ in json_logs]
prev_step = 0
for json_log, log_dict in zip(json_logs, log_dicts): for json_log, log_dict in zip(json_logs, log_dicts):
with open(json_log, 'r') as log_file: with open(json_log, 'r') as log_file:
for line in log_file: for line in log_file:
log = json.loads(line.strip()) log = json.loads(line.strip())
# skip lines without `epoch` field # the final step in json file is 0.
if 'epoch' not in log: if 'step' in log and log['step'] != 0:
continue step = log['step']
epoch = log.pop('epoch') prev_step = step
if epoch not in log_dict: else:
log_dict[epoch] = defaultdict(list) step = prev_step
if step not in log:
log_dict[step] = defaultdict(list)
for k, v in log.items(): for k, v in log.items():
log_dict[epoch][k].append(v) log_dict[step][k].append(v)
return log_dicts return log_dicts