fix: fix bugs about import of torch version

pull/583/head
l1aoxingyu 2021-09-21 21:42:00 +08:00
parent f4551a128b
commit d9d6e19b2c
1 changed files with 4 additions and 3 deletions

View File

@ -8,7 +8,8 @@ import logging
import os import os
import torch import torch
from torch._six import container_abcs, string_classes, int_classes from torch._six import string_classes
from collections import Mapping
from fastreid.config import configurable from fastreid.config import configurable
from fastreid.utils import comm from fastreid.utils import comm
@ -175,12 +176,12 @@ def fast_batch_collator(batched_inputs):
out[i] += tensor out[i] += tensor
return out return out
elif isinstance(elem, container_abcs.Mapping): elif isinstance(elem, Mapping):
return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem} return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
elif isinstance(elem, float): elif isinstance(elem, float):
return torch.tensor(batched_inputs, dtype=torch.float64) return torch.tensor(batched_inputs, dtype=torch.float64)
elif isinstance(elem, int_classes): elif isinstance(elem, int):
return torch.tensor(batched_inputs) return torch.tensor(batched_inputs)
elif isinstance(elem, string_classes): elif isinstance(elem, string_classes):
return batched_inputs return batched_inputs