diff --git a/demo/resources/log_analysis_demo.png b/demo/resources/log_analysis_demo.png new file mode 100644 index 00000000..fe795f33 Binary files /dev/null and b/demo/resources/log_analysis_demo.png differ diff --git a/docs/en/tools.md b/docs/en/tools.md index f42cef24..5402c6d1 100644 --- a/docs/en/tools.md +++ b/docs/en/tools.md @@ -30,3 +30,68 @@ For example, ```bash python tools/data/utils/txt2lmdb.py -i data/mixture/Syn90k/label.txt -o data/mixture/Syn90k/label.lmdb ``` + + +## Log Analysis + +You can use `tools/analyze_logs.py` to plot loss/hmean curves given a training log file. Run `pip install seaborn` first to install the dependency. + +![](../../demo/resources/log_analysis_demo.png) + + ```shell +python tools/analyze_logs.py plot_curve [--keys ${KEYS}] [--title ${TITLE}] [--legend ${LEGEND}] [--backend ${BACKEND}] [--style ${STYLE}] [--out ${OUT_FILE}] + ``` + +| Arguments | Type | Description | +| ----------- | ---- | --------------------------------------------------------------------------------------------------------------- | +| `--keys` | str | The metric that you want to plot. Defaults to `loss`. | +| `--title` | str | Title of figure. | +| `--legend` | str | Legend of each plot. | +| `--backend` | str | Backend of the plot. [more info](https://matplotlib.org/stable/users/explain/backends.html) | +| `--style` | str | Style of the plot. Defaults to `dark`. [more info](https://seaborn.pydata.org/generated/seaborn.set_style.html) | +| `--out` | str | Path of output figure. | + +**Examples:** + +Download the following DBNet and CRNN training logs to run demos. +```shell +wget https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.log.json -O DBNet_log.json + +wget https://download.openmmlab.com/mmocr/textrecog/crnn/20210326_111035.log.json -O CRNN_log.json +``` + +Please specify an output path if you are running the codes on systems without a GUI. + +- Plot loss metric. + + ```shell + python tools/analyze_logs.py plot_curve DBNet_log.json --keys loss --legend loss + ``` + +- Plot hmean-iou:hmean metric of text detection. + + ```shell + python tools/analyze_logs.py plot_curve DBNet_log.json --keys hmean-iou:hmean --legend hmean-iou:hmean + ``` + +- Plot 0_1-N.E.D metric of text recognition. + + ```shell + python tools/analyze_logs.py plot_curve CRNN_log.json --keys 0_1-N.E.D --legend 0_1-N.E.D + ``` + +- Compute the average training speed. + + ```shell + python tools/analyze_logs.py cal_train_time CRNN_log.json --include-outliers + ``` + + The output is expected to be like the following. + + ```text + -----Analyze train time of CRNN_log.json----- + slowest epoch 4, average time is 0.3464 + fastest epoch 5, average time is 0.2365 + time std over epochs is 0.0356 + average iter time: 0.2906 s/iter + ``` diff --git a/tools/analyze_logs.py b/tools/analyze_logs.py new file mode 100644 index 00000000..c0714c50 --- /dev/null +++ b/tools/analyze_logs.py @@ -0,0 +1,190 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/open- +mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" +import argparse +import json +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns + + +def cal_train_time(log_dicts, args): + for i, log_dict in enumerate(log_dicts): + print(f'{"-" * 5}Analyze train time of {args.json_logs[i]}{"-" * 5}') + all_times = [] + for epoch in log_dict.keys(): + if args.include_outliers: + all_times.append(log_dict[epoch]['time']) + else: + all_times.append(log_dict[epoch]['time'][1:]) + all_times = np.array(all_times) + epoch_ave_time = all_times.mean(-1) + slowest_epoch = epoch_ave_time.argmax() + fastest_epoch = epoch_ave_time.argmin() + std_over_epoch = epoch_ave_time.std() + print(f'slowest epoch {slowest_epoch + 1}, ' + f'average time is {epoch_ave_time[slowest_epoch]:.4f}') + print(f'fastest epoch {fastest_epoch + 1}, ' + f'average time is {epoch_ave_time[fastest_epoch]:.4f}') + print(f'time std over epochs is {std_over_epoch:.4f}') + print(f'average iter time: {np.mean(all_times):.4f} s/iter') + print() + + +def plot_curve(log_dicts, args): + if args.backend is not None: + plt.switch_backend(args.backend) + sns.set_style(args.style) + # if legend is None, use {filename}_{key} as legend + legend = args.legend + if legend is None: + legend = [] + for json_log in args.json_logs: + for metric in args.keys: + legend.append(f'{json_log}_{metric}') + assert len(legend) == (len(args.json_logs) * len(args.keys)) + metrics = args.keys + + num_metrics = len(metrics) + for i, log_dict in enumerate(log_dicts): + epochs = list(log_dict.keys()) + for j, metric in enumerate(metrics): + print(f'Plot curve of {args.json_logs[i]}, metric is {metric}') + + epoch_based_metrics = [ + 'hmean', 'word_acc', 'word_acc_ignore_case', + 'word_acc_ignore_case_symbol', 'char_recall', 'char_precision', + '1-N.E.D', 'macro_f1' + ] + if any(metric in m for m in epoch_based_metrics): + # determine whether it is a epoch-plotted metric + # e.g. hmean-iou:hmean, 0_word_acc + xs = [] + ys = [] + for epoch in epochs: + ys += log_dict[epoch][metric] + if 'val' in log_dict[epoch]['mode']: + xs.append(epoch) + plt.xlabel('epoch') + plt.plot(xs, ys, label=legend[i * num_metrics + j], marker='o') + else: + xs = [] + ys = [] + if log_dict[epochs[0]]['mode'][-1] == 'val': + num_iters_per_epoch = log_dict[epochs[0]]['iter'][-2] + else: + num_iters_per_epoch = log_dict[epochs[0]]['iter'][-1] + for epoch in epochs: + iters = log_dict[epoch]['iter'] + if log_dict[epoch]['mode'][-1] == 'val': + iters = iters[:-1] + xs.append( + np.array(iters) + (epoch - 1) * num_iters_per_epoch) + ys.append(np.array(log_dict[epoch][metric][:len(iters)])) + xs = np.concatenate(xs) + ys = np.concatenate(ys) + plt.xlabel('iter') + plt.plot( + xs, ys, label=legend[i * num_metrics + j], linewidth=0.5) + plt.grid() + plt.legend() + if args.title is not None: + plt.title(args.title) + if args.out is None: + plt.show() + else: + print(f'Save curve to: {args.out}') + plt.savefig(args.out) + plt.cla() + + +def add_plot_parser(subparsers): + parser_plt = subparsers.add_parser( + 'plot_curve', help='Parser for plotting curves') + parser_plt.add_argument( + 'json_logs', + type=str, + nargs='+', + help='Path of train log in json format') + parser_plt.add_argument( + '--keys', + type=str, + nargs='+', + default=['loss'], + help='The metric that you want to plot') + parser_plt.add_argument('--title', type=str, help='Title of figure') + parser_plt.add_argument( + '--legend', + type=str, + nargs='+', + default=None, + help='Legend of each plot') + parser_plt.add_argument( + '--backend', type=str, default=None, help='Backend of plt') + parser_plt.add_argument( + '--style', type=str, default='dark', help='Style of plt') + parser_plt.add_argument('--out', type=str, default=None) + + +def add_time_parser(subparsers): + parser_time = subparsers.add_parser( + 'cal_train_time', + help='Parser for computing the average time per training iteration') + parser_time.add_argument( + 'json_logs', + type=str, + nargs='+', + help='Path of train log in json format') + parser_time.add_argument( + '--include-outliers', + action='store_true', + help='Include the first value of every epoch when computing ' + 'the average time') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Analyze Json Log') + # currently only support plot curve and calculate average train time + subparsers = parser.add_subparsers(dest='task', help='task parser') + add_plot_parser(subparsers) + add_time_parser(subparsers) + args = parser.parse_args() + return args + + +def load_json_logs(json_logs): + # load and convert json_logs to log_dict, key is epoch, value is a sub dict + # keys of sub dict is different metrics, e.g. memory, loss + # value of sub dict is a list of corresponding values of all iterations + log_dicts = [dict() for _ in json_logs] + for json_log, log_dict in zip(json_logs, log_dicts): + with open(json_log, 'r') as log_file: + for line in log_file: + log = json.loads(line.strip()) + # skip lines without `epoch` field + if 'epoch' not in log: + continue + epoch = log.pop('epoch') + if epoch not in log_dict: + log_dict[epoch] = defaultdict(list) + for k, v in log.items(): + log_dict[epoch][k].append(v) + return log_dicts + + +def main(): + args = parse_args() + + json_logs = args.json_logs + for json_log in json_logs: + assert json_log.endswith('.json') + + log_dicts = load_json_logs(json_logs) + + eval(args.task)(log_dicts, args) + + +if __name__ == '__main__': + main()