[Refactor] Unify the `--out` and `--dump` in `tools/test.py`. (#1307)
parent
58cefa5c0f
commit
7ec6062415
|
@ -141,8 +141,8 @@ CUDA_VISIBLE_DEVICES=-1 python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [
|
|||
| `CONFIG_FILE` | The path to the config file. |
|
||||
| `CHECKPOINT_FILE` | The path to the checkpoint file (It can be a http link, and you can find checkpoints [here](https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html)). |
|
||||
| `--work-dir WORK_DIR` | The directory to save the file containing evaluation metrics. |
|
||||
| `--out OUT` | The path to save the file containing evaluation metrics. |
|
||||
| `--dump DUMP` | The path to dump all outputs of the model for offline evaluation. |
|
||||
| `--out OUT` | The path to save the file containing test results. |
|
||||
| `--out-item OUT_ITEM` | To specify the content of the test results file, and it can be "pred" or "metrics". If "pred", save the outputs of the model for offline evaluation. If "metrics", save the evaluation metrics. Defaults to "pred". |
|
||||
| `--cfg-options CFG_OPTIONS` | Override some settings in the used config, the key-value pair in xxx=yyy format will be merged into the config file. If the value to be overwritten is a list, it should be of the form of either `key="[a,b]"` or `key=a,b`. The argument also allows nested list/tuple values, e.g. `key="[(a,b),(c,d)]"`. Note that the quotation marks are necessary and that no white space is allowed. |
|
||||
| `--show-dir SHOW_DIR` | The directory to save the result visualization images. |
|
||||
| `--show` | Visualize the prediction result in a window. |
|
||||
|
|
|
@ -139,8 +139,8 @@ CUDA_VISIBLE_DEVICES=-1 python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [
|
|||
| `CONFIG_FILE` | 配置文件的路径。 |
|
||||
| `CHECKPOINT_FILE` | 权重文件路径(支持 http 链接,你可以在[这里](https://mmclassification.readthedocs.io/en/1.x/modelzoo_statistics.html)寻找需要的权重文件)。 |
|
||||
| `--work-dir WORK_DIR` | 用来保存测试指标结果的文件夹。 |
|
||||
| `--out OUT` | 用来保存测试指标结果的文件。 |
|
||||
| `--dump DUMP` | 用来保存所有模型输出的文件,这些数据可以用于离线测评。 |
|
||||
| `--out OUT` | 用来保存测试输出的文件。 |
|
||||
| `--out-item OUT_ITEM` | 指定测试输出文件的内容,可以为 "pred" 或 "metrics",其中 "pred" 表示保存所有模型输出,这些数据可以用于离线测评;"metrics" 表示输出测试指标。默认为 "pred"。 |
|
||||
| `--cfg-options CFG_OPTIONS` | 重载配置文件中的一些设置。使用类似 `xxx=yyy` 的键值对形式指定,这些设置会被融合入从配置文件读取的配置。你可以使用 `key="[a,b]"` 或者 `key=a,b` 的格式来指定列表格式的值,且支持嵌套,例如 \`key="[(a,b),(c,d)]",这里的引号是不可省略的。另外每个重载项内部不可出现空格。 |
|
||||
| `--show-dir SHOW_DIR` | 用于保存可视化预测结果图像的文件夹。 |
|
||||
| `--show` | 在窗口中显示预测结果图像。 |
|
||||
|
|
|
@ -6,7 +6,7 @@ from copy import deepcopy
|
|||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict, DictAction
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.evaluator import DumpResults
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
|
@ -18,11 +18,12 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
'--work-dir',
|
||||
help='the directory to save the file containing evaluation metrics')
|
||||
parser.add_argument('--out', help='the file to save metric results.')
|
||||
parser.add_argument('--out', help='the file to output results.')
|
||||
parser.add_argument(
|
||||
'--dump',
|
||||
type=str,
|
||||
help='dump predictions to a pickle file for offline evaluation')
|
||||
'--out-item',
|
||||
choices=['metrics', 'pred'],
|
||||
help='To output whether metrics or predictions. '
|
||||
'Defaults to output predictions.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
|
@ -100,17 +101,6 @@ def merge_args(cfg, args):
|
|||
cfg.default_hooks.visualization.out_dir = args.show_dir
|
||||
cfg.default_hooks.visualization.interval = args.interval
|
||||
|
||||
# -------------------- Dump predictions --------------------
|
||||
if args.dump is not None:
|
||||
assert args.dump.endswith(('.pkl', '.pickle')), \
|
||||
'The dump file must be a pkl file.'
|
||||
dump_metric = dict(type='DumpResults', out_file_path=args.dump)
|
||||
if isinstance(cfg.test_evaluator, (list, tuple)):
|
||||
cfg.test_evaluator = list(cfg.test_evaluator)
|
||||
cfg.test_evaluator.append(dump_metric)
|
||||
else:
|
||||
cfg.test_evaluator = [cfg.test_evaluator, dump_metric]
|
||||
|
||||
# -------------------- TTA related args --------------------
|
||||
if args.tta:
|
||||
if 'tta_model' not in cfg:
|
||||
|
@ -157,6 +147,10 @@ def merge_args(cfg, args):
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.out is None and args.out_item is not None:
|
||||
raise ValueError('Please use `--out` argument to specify the '
|
||||
'path of the output file before using `--out-item`.')
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
|
@ -166,18 +160,15 @@ def main():
|
|||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
if args.out:
|
||||
|
||||
class SaveMetricHook(Hook):
|
||||
|
||||
def after_test_epoch(self, _, metrics=None):
|
||||
if metrics is not None:
|
||||
mmengine.dump(metrics, args.out)
|
||||
|
||||
runner.register_hook(SaveMetricHook(), 'LOWEST')
|
||||
if args.out and args.out_item in ['pred', None]:
|
||||
runner.test_evaluator.metrics.append(
|
||||
DumpResults(out_file_path=args.out))
|
||||
|
||||
# start testing
|
||||
runner.test()
|
||||
metrics = runner.test()
|
||||
|
||||
if args.out and args.out_item == 'metrics':
|
||||
mmengine.dump(metrics, args.out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue