[Fix] Fix the incorrect device of inputs in get_model_complexity_info (#1130)

This commit is contained in:
CescMessi 2023-05-06 10:55:17 +08:00 committed by GitHub
parent 2085046d22
commit 83d76abc7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -725,14 +725,15 @@ def get_model_complexity_info(
raise ValueError('"input_shape" and "inputs" cannot be both set.')
if inputs is None:
device = next(model.parameters()).device
if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
inputs = (torch.randn(1, *input_shape), )
inputs = (torch.randn(1, *input_shape).to(device), )
elif is_tuple_of(input_shape, tuple) and all([
is_tuple_of(one_input_shape, int)
for one_input_shape in input_shape # type: ignore
]): # tuple of tuple of int, construct multiple tensors
inputs = tuple([
torch.randn(1, *one_input_shape)
torch.randn(1, *one_input_shape).to(device)
for one_input_shape in input_shape # type: ignore
])
else: