# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) Alibaba, Inc. and its affiliates. import mmcv def extract_result_dict(results, key): """Extract and return the data corresponding to key in result dict. ``results`` is a dict output from `pipeline(input_dict)`, which is the loaded data from ``Dataset`` class. The data terms inside may be wrapped in list, tuple and DataContainer, so this function essentially extracts data from these wrappers. Args: results (dict): Data loaded using pipeline. key (str): Key of the desired data. Returns: np.ndarray | torch.Tensor: Data term. """ if key not in results.keys(): return None # results[key] may be data or list[data] or tuple[data] # data may be wrapped inside DataContainer data = results[key] if isinstance(data, (list, tuple)): data = data[0] if isinstance(data, mmcv.parallel.DataContainer): data = data._data return data