[Fix] Fix ndarray metainfo check in ConcatDataset (#1333)

This commit is contained in:
lizuoxin-nreal 2023-09-01 16:40:26 +08:00 committed by GitHub
parent 8a7e80e9e0
commit 762c9a25b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,6 +6,7 @@ import math
from collections import defaultdict
from typing import List, Sequence, Tuple, Union
import numpy as np
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from mmengine.logging import print_log
@ -73,7 +74,17 @@ class ConcatDataset(_ConcatDataset):
raise ValueError(
f'{key} does not in the meta information of '
f'the {i}-th dataset')
if self._metainfo[key] != dataset.metainfo[key]:
first_type = type(self._metainfo[key])
cur_type = type(dataset.metainfo[key])
if first_type is not cur_type: # type: ignore
raise TypeError(
f'The type {cur_type} of {key} in the {i}-th dataset '
'should be the same with the first dataset '
f'{first_type}')
if (isinstance(self._metainfo[key], np.ndarray)
and not np.array_equal(self._metainfo[key],
dataset.metainfo[key])
or self._metainfo[key] != dataset.metainfo[key]):
raise ValueError(
f'The meta information of the {i}-th dataset does not '
'match meta information of the first dataset')