# Copyright (c) OpenMMLab. All rights reserved. import argparse import os import re from itertools import groupby import matplotlib.pyplot as plt import numpy as np from mmpretrain.utils import load_json_log def cal_train_time(log_dicts, args): """Compute the average time per training iteration.""" for i, log_dict in enumerate(log_dicts): print(f'{"-" * 5}Analyze train time of {args.json_logs[i]}{"-" * 5}') train_logs = log_dict['train'] if 'epoch' in train_logs[0]: epoch_ave_times = [] for _, logs in groupby(train_logs, lambda log: log['epoch']): if args.include_outliers: all_time = np.array([log['time'] for log in logs]) else: all_time = np.array([log['time'] for log in logs])[1:] epoch_ave_times.append(all_time.mean()) epoch_ave_times = np.array(epoch_ave_times) slowest_epoch = epoch_ave_times.argmax() fastest_epoch = epoch_ave_times.argmin() std_over_epoch = epoch_ave_times.std() print(f'slowest epoch {slowest_epoch + 1}, ' f'average time is {epoch_ave_times[slowest_epoch]:.4f}') print(f'fastest epoch {fastest_epoch + 1}, ' f'average time is {epoch_ave_times[fastest_epoch]:.4f}') print(f'time std over epochs is {std_over_epoch:.4f}') avg_iter_time = np.array([log['time'] for log in train_logs]).mean() print(f'average iter time: {avg_iter_time:.4f} s/iter') print() def get_legends(args): """if legend is None, use {filename}_{key} as legend.""" legend = args.legend if legend is None: legend = [] for json_log in args.json_logs: for metric in args.keys: # remove '.json' in the end of log names basename = os.path.basename(json_log)[:-5] if basename.endswith('.log'): basename = basename[:-4] legend.append(f'{basename}_{metric}') assert len(legend) == (len(args.json_logs) * len(args.keys)) return legend def plot_phase_train(metric, train_logs, curve_label): """plot phase of train curve.""" xs = np.array([log['step'] for log in train_logs]) ys = np.array([log[metric] for log in train_logs]) if 'epoch' in train_logs[0]: scale_factor = train_logs[-1]['step'] / train_logs[-1]['epoch'] xs = xs / scale_factor plt.xlabel('Epochs') else: plt.xlabel('Iters') plt.plot(xs, ys, label=curve_label, linewidth=0.75) def plot_phase_val(metric, val_logs, curve_label): """plot phase of val curve.""" xs = np.array([log['step'] for log in val_logs]) ys = np.array([log[metric] for log in val_logs]) plt.xlabel('Steps') plt.plot(xs, ys, label=curve_label, linewidth=0.75) def plot_curve_helper(log_dicts, metrics, args, legend): """plot curves from log_dicts by metrics.""" num_metrics = len(metrics) for i, log_dict in enumerate(log_dicts): for j, key in enumerate(metrics): json_log = args.json_logs[i] print(f'plot curve of {json_log}, metric is {key}') curve_label = legend[i * num_metrics + j] train_keys = {} if len(log_dict['train']) == 0 else set( log_dict['train'][0].keys()) - {'step', 'epoch'} val_keys = {} if len(log_dict['val']) == 0 else set( log_dict['val'][0].keys()) - {'step'} if key in val_keys: plot_phase_val(key, log_dict['val'], curve_label) elif key in train_keys: plot_phase_train(key, log_dict['train'], curve_label) else: raise ValueError( f'Invalid key "{key}", please choose from ' f'{set.union(set(train_keys), set(val_keys))}.') plt.legend() def plot_curve(log_dicts, args): """Plot train metric-iter graph.""" # set style try: import seaborn as sns sns.set_style(args.style) except ImportError: pass # set plot window size wind_w, wind_h = args.window_size.split('*') wind_w, wind_h = int(wind_w), int(wind_h) plt.figure(figsize=(wind_w, wind_h)) # get legends and metrics legends = get_legends(args) metrics = args.keys # plot curves from log_dicts by metrics plot_curve_helper(log_dicts, metrics, args, legends) # set title and show or save if args.title is not None: plt.title(args.title) if args.out is None: plt.show() else: print(f'save curve to: {args.out}') plt.savefig(args.out) plt.cla() def add_plot_parser(subparsers): parser_plt = subparsers.add_parser( 'plot_curve', help='parser for plotting curves') parser_plt.add_argument( 'json_logs', type=str, nargs='+', help='path of train log in json format') parser_plt.add_argument( '--keys', type=str, nargs='+', default=['loss'], help='the metric that you want to plot') parser_plt.add_argument('--title', type=str, help='title of figure') parser_plt.add_argument( '--legend', type=str, nargs='+', default=None, help='legend of each plot') parser_plt.add_argument( '--style', type=str, default='whitegrid', help='style of the figure, need `seaborn` package.') parser_plt.add_argument('--out', type=str, default=None) parser_plt.add_argument( '--window-size', default='12*7', help='size of the window to display images, in format of "$W*$H".') def add_time_parser(subparsers): parser_time = subparsers.add_parser( 'cal_train_time', help='parser for computing the average time per training iteration') parser_time.add_argument( 'json_logs', type=str, nargs='+', help='path of train log in json format') parser_time.add_argument( '--include-outliers', action='store_true', help='include the first value of every epoch when computing ' 'the average time') def parse_args(): parser = argparse.ArgumentParser(description='Analyze Json Log') # currently only support plot curve and calculate average train time subparsers = parser.add_subparsers(dest='task', help='task parser') add_plot_parser(subparsers) add_time_parser(subparsers) args = parser.parse_args() if hasattr(args, 'window_size') and args.window_size != '': assert re.match(r'\d+\*\d+', args.window_size), \ "'window-size' must be in format 'W*H'." return args def main(): args = parse_args() json_logs = args.json_logs for json_log in json_logs: assert json_log.endswith('.json') log_dicts = [load_json_log(json_log) for json_log in json_logs] if args.task == 'cal_train_time': cal_train_time(log_dicts, args) elif args.task == 'plot_curve': plot_curve(log_dicts, args) if __name__ == '__main__': main()