diff --git a/mmcv/parallel/collate.py b/mmcv/parallel/collate.py index 2e3a861fe..1bcab66c4 100644 --- a/mmcv/parallel/collate.py +++ b/mmcv/parallel/collate.py @@ -1,5 +1,5 @@ # Copyright (c) Open-MMLab. All rights reserved. -import collections +from collections.abc import Mapping, Sequence import torch import torch.nn.functional as F @@ -20,7 +20,7 @@ def collate(batch, samples_per_gpu=1): 3. cpu_only = False, stack = False, e.g., gt bboxes """ - if not isinstance(batch, collections.Sequence): + if not isinstance(batch, Sequence): raise TypeError(f'{batch.dtype} is not supported.') if isinstance(batch[0], DataContainer): @@ -73,10 +73,10 @@ def collate(batch, samples_per_gpu=1): stacked.append( [sample.data for sample in batch[i:i + samples_per_gpu]]) return DataContainer(stacked, batch[0].stack, batch[0].padding_value) - elif isinstance(batch[0], collections.Sequence): + elif isinstance(batch[0], Sequence): transposed = zip(*batch) return [collate(samples, samples_per_gpu) for samples in transposed] - elif isinstance(batch[0], collections.Mapping): + elif isinstance(batch[0], Mapping): return { key: collate([d[key] for d in batch], samples_per_gpu) for key in batch[0]