mirror of https://github.com/JDAI-CV/fast-reid.git
fix: fix bugs about import of torch version
parent
f4551a128b
commit
d9d6e19b2c
|
@ -8,7 +8,8 @@ import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._six import container_abcs, string_classes, int_classes
|
from torch._six import string_classes
|
||||||
|
from collections import Mapping
|
||||||
|
|
||||||
from fastreid.config import configurable
|
from fastreid.config import configurable
|
||||||
from fastreid.utils import comm
|
from fastreid.utils import comm
|
||||||
|
@ -175,12 +176,12 @@ def fast_batch_collator(batched_inputs):
|
||||||
out[i] += tensor
|
out[i] += tensor
|
||||||
return out
|
return out
|
||||||
|
|
||||||
elif isinstance(elem, container_abcs.Mapping):
|
elif isinstance(elem, Mapping):
|
||||||
return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
|
return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
|
||||||
|
|
||||||
elif isinstance(elem, float):
|
elif isinstance(elem, float):
|
||||||
return torch.tensor(batched_inputs, dtype=torch.float64)
|
return torch.tensor(batched_inputs, dtype=torch.float64)
|
||||||
elif isinstance(elem, int_classes):
|
elif isinstance(elem, int):
|
||||||
return torch.tensor(batched_inputs)
|
return torch.tensor(batched_inputs)
|
||||||
elif isinstance(elem, string_classes):
|
elif isinstance(elem, string_classes):
|
||||||
return batched_inputs
|
return batched_inputs
|
||||||
|
|
Loading…
Reference in New Issue