24 lines
664 B
Python
24 lines
664 B
Python
from collections import abc
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def cast_tensor_type(inputs, src_type, dst_type):
|
|
if isinstance(inputs, torch.Tensor):
|
|
return inputs.to(dst_type)
|
|
elif isinstance(inputs, str):
|
|
return inputs
|
|
elif isinstance(inputs, np.ndarray):
|
|
return inputs
|
|
elif isinstance(inputs, abc.Mapping):
|
|
return type(inputs)({
|
|
k: cast_tensor_type(v, src_type, dst_type)
|
|
for k, v in inputs.items()
|
|
})
|
|
elif isinstance(inputs, abc.Iterable):
|
|
return type(inputs)(
|
|
cast_tensor_type(item, src_type, dst_type) for item in inputs)
|
|
else:
|
|
return inputs
|