mmpretrain/mmcls/core/fp16/utils.py

24 lines
664 B
Python

from collections import abc
import numpy as np
import torch
def cast_tensor_type(inputs, src_type, dst_type):
if isinstance(inputs, torch.Tensor):
return inputs.to(dst_type)
elif isinstance(inputs, str):
return inputs
elif isinstance(inputs, np.ndarray):
return inputs
elif isinstance(inputs, abc.Mapping):
return type(inputs)({
k: cast_tensor_type(v, src_type, dst_type)
for k, v in inputs.items()
})
elif isinstance(inputs, abc.Iterable):
return type(inputs)(
cast_tensor_type(item, src_type, dst_type) for item in inputs)
else:
return inputs