Modify test tools and add some new tools (#322)
* Refactor tools folder structure. * Modify tools/test.py and add eval_metric.py to analysis test output. * Add new tools `analyze_logs.py` and `print_config.py`. * Add comment for analysis_tools functions.pull/338/head
parent
bee0ac6b56
commit
aad796ae6f
|
@ -14,6 +14,6 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = mmcls
|
||||
known_third_party = PIL,cv2,matplotlib,mmcv,numpy,onnxruntime,pytest,torch,torchvision,ts
|
||||
known_third_party = PIL,cv2,matplotlib,mmcv,numpy,onnxruntime,pytest,seaborn,torch,torchvision,ts
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
import argparse
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
|
||||
|
||||
def cal_train_time(log_dicts, args):
|
||||
"""Compute the average time per training iteration."""
|
||||
for i, log_dict in enumerate(log_dicts):
|
||||
print(f'{"-" * 5}Analyze train time of {args.json_logs[i]}{"-" * 5}')
|
||||
all_times = []
|
||||
for epoch in log_dict.keys():
|
||||
if args.include_outliers:
|
||||
all_times.append(log_dict[epoch]['time'])
|
||||
else:
|
||||
all_times.append(log_dict[epoch]['time'][1:])
|
||||
all_times = np.array(all_times)
|
||||
epoch_ave_time = all_times.mean(-1)
|
||||
slowest_epoch = epoch_ave_time.argmax()
|
||||
fastest_epoch = epoch_ave_time.argmin()
|
||||
std_over_epoch = epoch_ave_time.std()
|
||||
print(f'slowest epoch {slowest_epoch + 1}, '
|
||||
f'average time is {epoch_ave_time[slowest_epoch]:.4f}')
|
||||
print(f'fastest epoch {fastest_epoch + 1}, '
|
||||
f'average time is {epoch_ave_time[fastest_epoch]:.4f}')
|
||||
print(f'time std over epochs is {std_over_epoch:.4f}')
|
||||
print(f'average iter time: {np.mean(all_times):.4f} s/iter')
|
||||
print()
|
||||
|
||||
|
||||
def plot_curve(log_dicts, args):
|
||||
"""Plot train metric-iter graph."""
|
||||
if args.backend is not None:
|
||||
plt.switch_backend(args.backend)
|
||||
sns.set_style(args.style)
|
||||
# if legend is None, use {filename}_{key} as legend
|
||||
legend = args.legend
|
||||
if legend is None:
|
||||
legend = []
|
||||
for json_log in args.json_logs:
|
||||
for metric in args.keys:
|
||||
legend.append(f'{json_log}_{metric}')
|
||||
assert len(legend) == (len(args.json_logs) * len(args.keys))
|
||||
metrics = args.keys
|
||||
|
||||
num_metrics = len(metrics)
|
||||
for i, log_dict in enumerate(log_dicts):
|
||||
epochs = list(log_dict.keys())
|
||||
for j, metric in enumerate(metrics):
|
||||
print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
|
||||
if metric not in log_dict[epochs[0]]:
|
||||
raise KeyError(
|
||||
f'{args.json_logs[i]} does not contain metric {metric} '
|
||||
f'in train mode')
|
||||
|
||||
if 'mAP' in metric:
|
||||
xs = np.arange(1, max(epochs) + 1)
|
||||
ys = []
|
||||
for epoch in epochs:
|
||||
ys += log_dict[epoch][metric]
|
||||
ax = plt.gca()
|
||||
ax.set_xticks(xs)
|
||||
plt.xlabel('epoch')
|
||||
plt.plot(xs, ys, label=legend[i * num_metrics + j], marker='o')
|
||||
else:
|
||||
xs = []
|
||||
ys = []
|
||||
num_iters_per_epoch = log_dict[epochs[0]]['iter'][-1]
|
||||
for epoch in epochs:
|
||||
iters = log_dict[epoch]['iter']
|
||||
if log_dict[epoch]['mode'][-1] == 'val':
|
||||
iters = iters[:-1]
|
||||
xs.append(
|
||||
np.array(iters) + (epoch - 1) * num_iters_per_epoch)
|
||||
ys.append(np.array(log_dict[epoch][metric][:len(iters)]))
|
||||
xs = np.concatenate(xs)
|
||||
ys = np.concatenate(ys)
|
||||
plt.xlabel('iter')
|
||||
plt.plot(
|
||||
xs, ys, label=legend[i * num_metrics + j], linewidth=0.5)
|
||||
plt.legend()
|
||||
if args.title is not None:
|
||||
plt.title(args.title)
|
||||
if args.out is None:
|
||||
plt.show()
|
||||
else:
|
||||
print(f'save curve to: {args.out}')
|
||||
plt.savefig(args.out)
|
||||
plt.cla()
|
||||
|
||||
|
||||
def add_plot_parser(subparsers):
|
||||
parser_plt = subparsers.add_parser(
|
||||
'plot_curve', help='parser for plotting curves')
|
||||
parser_plt.add_argument(
|
||||
'json_logs',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='path of train log in json format')
|
||||
parser_plt.add_argument(
|
||||
'--keys',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=['loss'],
|
||||
help='the metric that you want to plot')
|
||||
parser_plt.add_argument('--title', type=str, help='title of figure')
|
||||
parser_plt.add_argument(
|
||||
'--legend',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='legend of each plot')
|
||||
parser_plt.add_argument(
|
||||
'--backend', type=str, default=None, help='backend of plt')
|
||||
parser_plt.add_argument(
|
||||
'--style', type=str, default='dark', help='style of plt')
|
||||
parser_plt.add_argument('--out', type=str, default=None)
|
||||
|
||||
|
||||
def add_time_parser(subparsers):
|
||||
parser_time = subparsers.add_parser(
|
||||
'cal_train_time',
|
||||
help='parser for computing the average time per training iteration')
|
||||
parser_time.add_argument(
|
||||
'json_logs',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='path of train log in json format')
|
||||
parser_time.add_argument(
|
||||
'--include-outliers',
|
||||
action='store_true',
|
||||
help='include the first value of every epoch when computing '
|
||||
'the average time')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Analyze Json Log')
|
||||
# currently only support plot curve and calculate average train time
|
||||
subparsers = parser.add_subparsers(dest='task', help='task parser')
|
||||
add_plot_parser(subparsers)
|
||||
add_time_parser(subparsers)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def load_json_logs(json_logs):
|
||||
# load and convert json_logs to log_dict, key is epoch, value is a sub dict
|
||||
# keys of sub dict is different metrics, e.g. memory, bbox_mAP
|
||||
# value of sub dict is a list of corresponding values of all iterations
|
||||
log_dicts = [dict() for _ in json_logs]
|
||||
for json_log, log_dict in zip(json_logs, log_dicts):
|
||||
with open(json_log, 'r') as log_file:
|
||||
for line in log_file:
|
||||
log = json.loads(line.strip())
|
||||
# skip lines without `epoch` field
|
||||
if 'epoch' not in log:
|
||||
continue
|
||||
epoch = log.pop('epoch')
|
||||
if epoch not in log_dict:
|
||||
log_dict[epoch] = defaultdict(list)
|
||||
for k, v in log.items():
|
||||
log_dict[epoch][k].append(v)
|
||||
return log_dicts
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
json_logs = args.json_logs
|
||||
for json_log in json_logs:
|
||||
assert json_log.endswith('.json')
|
||||
|
||||
log_dicts = load_json_logs(json_logs)
|
||||
|
||||
eval(args.task)(log_dicts, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,71 @@
|
|||
import argparse
|
||||
|
||||
import mmcv
|
||||
from mmcv import Config, DictAction
|
||||
|
||||
from mmcls.datasets import build_dataset
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Evaluate metric of the '
|
||||
'results saved in pkl format')
|
||||
parser.add_argument('config', help='Config of the model')
|
||||
parser.add_argument('pkl_results', help='Results in pickle format')
|
||||
parser.add_argument(
|
||||
'--metrics',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='Evaluation metrics, which depends on the dataset, e.g., '
|
||||
'"accuracy", "precision", "recall" and "support".')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--eval-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='custom options for evaluation, the key-value pair in xxx=yyy '
|
||||
'format will be kwargs for dataset.evaluate() function')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
assert args.metrics, (
|
||||
'Please specify at least one metric the argument "--metrics".')
|
||||
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
# import modules from string list.
|
||||
if cfg.get('custom_imports', None):
|
||||
from mmcv.utils import import_modules_from_strings
|
||||
import_modules_from_strings(**cfg['custom_imports'])
|
||||
cfg.data.test.test_mode = True
|
||||
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
outputs = mmcv.load(args.pkl_results)
|
||||
pred_score = outputs['class_scores']
|
||||
|
||||
kwargs = {} if args.eval_options is None else args.eval_options
|
||||
eval_kwargs = cfg.get('evaluation', {}).copy()
|
||||
# hard-code way to remove EvalHook args
|
||||
for key in [
|
||||
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule'
|
||||
]:
|
||||
eval_kwargs.pop(key, None)
|
||||
eval_kwargs.update(dict(metric=args.metrics, **kwargs))
|
||||
print(dataset.evaluate(pred_score, **eval_kwargs))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,54 @@
|
|||
import argparse
|
||||
import warnings
|
||||
|
||||
from mmcv import Config, DictAction
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Print the whole config')
|
||||
parser.add_argument('config', help='config file path')
|
||||
parser.add_argument(
|
||||
'--options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file (deprecate), '
|
||||
'change to --cfg-options instead.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.options and args.cfg_options:
|
||||
raise ValueError(
|
||||
'--options and --cfg-options cannot be both '
|
||||
'specified, --options is deprecated in favor of --cfg-options')
|
||||
if args.options:
|
||||
warnings.warn('--options is deprecated in favor of --cfg-options')
|
||||
args.cfg_options = args.options
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
# import modules from string list.
|
||||
if cfg.get('custom_imports', None):
|
||||
from mmcv.utils import import_modules_from_strings
|
||||
import_modules_from_strings(**cfg['custom_imports'])
|
||||
print(f'Config:\n{cfg.pretty_text}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -92,6 +92,9 @@ def main():
|
|||
cfg.model.pretrained = None
|
||||
cfg.data.test.test_mode = True
|
||||
|
||||
assert args.metrics or args.out, \
|
||||
'Please specify at least one of output path and evaluation metrics.'
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if args.launcher == 'none':
|
||||
distributed = False
|
||||
|
@ -145,31 +148,26 @@ def main():
|
|||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
results = {}
|
||||
if args.metrics:
|
||||
results = dataset.evaluate(outputs, args.metrics,
|
||||
args.metric_options)
|
||||
for k, v in results.items():
|
||||
eval_results = dataset.evaluate(outputs, args.metrics,
|
||||
args.metric_options)
|
||||
results.update(eval_results)
|
||||
for k, v in eval_results.items():
|
||||
print(f'\n{k} : {v:.2f}')
|
||||
else:
|
||||
warnings.warn('Evaluation metrics are not specified.')
|
||||
if args.out:
|
||||
scores = np.vstack(outputs)
|
||||
pred_score = np.max(scores, axis=1)
|
||||
pred_label = np.argmax(scores, axis=1)
|
||||
pred_class = [CLASSES[lb] for lb in pred_label]
|
||||
results = {
|
||||
results.update({
|
||||
'class_scores': scores,
|
||||
'pred_score': pred_score,
|
||||
'pred_label': pred_label,
|
||||
'pred_class': pred_class
|
||||
}
|
||||
if not args.out:
|
||||
print('\nthe predicted result for the first element is '
|
||||
f'pred_score = {pred_score[0]:.2f}, '
|
||||
f'pred_label = {pred_label[0]} '
|
||||
f'and pred_class = {pred_class[0]}. '
|
||||
'Specify --out to save all results to files.')
|
||||
if args.out and rank == 0:
|
||||
print(f'\nwriting results to {args.out}')
|
||||
mmcv.dump(results, args.out)
|
||||
})
|
||||
print(f'\ndumping results to {args.out}')
|
||||
mmcv.dump(results, args.out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue