mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
31 lines
1009 B
Python
31 lines
1009 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import mmcv
|
|
|
|
|
|
def extract_result_dict(results, key):
|
|
"""Extract and return the data corresponding to key in result dict.
|
|
|
|
``results`` is a dict output from `pipeline(input_dict)`, which is the
|
|
loaded data from ``Dataset`` class.
|
|
The data terms inside may be wrapped in list, tuple and DataContainer, so
|
|
this function essentially extracts data from these wrappers.
|
|
|
|
Args:
|
|
results (dict): Data loaded using pipeline.
|
|
key (str): Key of the desired data.
|
|
|
|
Returns:
|
|
np.ndarray | torch.Tensor: Data term.
|
|
"""
|
|
if key not in results.keys():
|
|
return None
|
|
# results[key] may be data or list[data] or tuple[data]
|
|
# data may be wrapped inside DataContainer
|
|
data = results[key]
|
|
if isinstance(data, (list, tuple)):
|
|
data = data[0]
|
|
if isinstance(data, mmcv.parallel.DataContainer):
|
|
data = data._data
|
|
return data
|