[Bug] Fix bug in Collect (#149)

* fix bug in Collect

* add metakeys
pull/152/head
LXXXXR 2021-01-25 20:29:28 +08:00 committed by GitHub
parent fc82c31b2f
commit 3e5a9513be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 1 deletions

View File

@ -3,6 +3,7 @@ from collections.abc import Sequence
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from PIL import Image
from ..builder import PIPELINES
@ -110,13 +111,34 @@ class Collect(object):
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img" and "gt_label".
Args:
keys (Sequence[str]): Keys of results to be collected in ``data``.
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ``('filename', 'ori_shape', 'img_shape', 'flip',
'flip_direction', 'img_norm_cfg')``
Returns:
dict: The result dict contains the following keys
- keys in``self.keys``
- ``img_metas`` if avaliable
"""
def __init__(self, keys):
def __init__(self,
keys,
meta_keys=('filename', 'ori_shape', 'img_shape', 'flip',
'flip_direction', 'img_norm_cfg')):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
data = {}
img_meta = {}
for key in self.meta_keys:
if key in results:
img_meta[key] = results[key]
data['img_metas'] = DC(img_meta, cpu_only=True)
for key in self.keys:
data[key] = results[key]
return data