From 3e5a9513be323f33ee20831299c9cdcd969f1d0c Mon Sep 17 00:00:00 2001 From: LXXXXR <73265258+LXXXXR@users.noreply.github.com> Date: Mon, 25 Jan 2021 20:29:28 +0800 Subject: [PATCH] [Bug] Fix bug in Collect (#149) * fix bug in Collect * add metakeys --- mmcls/datasets/pipelines/formating.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/mmcls/datasets/pipelines/formating.py b/mmcls/datasets/pipelines/formating.py index e0d78f220..63559c03b 100644 --- a/mmcls/datasets/pipelines/formating.py +++ b/mmcls/datasets/pipelines/formating.py @@ -3,6 +3,7 @@ from collections.abc import Sequence import mmcv import numpy as np import torch +from mmcv.parallel import DataContainer as DC from PIL import Image from ..builder import PIPELINES @@ -110,13 +111,34 @@ class Collect(object): This is usually the last stage of the data loader pipeline. Typically keys is set to some subset of "img" and "gt_label". + + Args: + keys (Sequence[str]): Keys of results to be collected in ``data``. + meta_keys (Sequence[str], optional): Meta keys to be converted to + ``mmcv.DataContainer`` and collected in ``data[img_metas]``. + Default: ``('filename', 'ori_shape', 'img_shape', 'flip', + 'flip_direction', 'img_norm_cfg')`` + + Returns: + dict: The result dict contains the following keys + - keys in``self.keys`` + - ``img_metas`` if avaliable """ - def __init__(self, keys): + def __init__(self, + keys, + meta_keys=('filename', 'ori_shape', 'img_shape', 'flip', + 'flip_direction', 'img_norm_cfg')): self.keys = keys + self.meta_keys = meta_keys def __call__(self, results): data = {} + img_meta = {} + for key in self.meta_keys: + if key in results: + img_meta[key] = results[key] + data['img_metas'] = DC(img_meta, cpu_only=True) for key in self.keys: data[key] = results[key] return data