from collections import abc import numpy as np import torch def cast_tensor_type(inputs, src_type, dst_type): if isinstance(inputs, torch.Tensor): return inputs.to(dst_type) elif isinstance(inputs, str): return inputs elif isinstance(inputs, np.ndarray): return inputs elif isinstance(inputs, abc.Mapping): return type(inputs)({ k: cast_tensor_type(v, src_type, dst_type) for k, v in inputs.items() }) elif isinstance(inputs, abc.Iterable): return type(inputs)( cast_tensor_type(item, src_type, dst_type) for item in inputs) else: return inputs