mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix the incorrect device of inputs in get_model_complexity_info (#1130)
This commit is contained in:
parent
2085046d22
commit
83d76abc7f
@ -725,14 +725,15 @@ def get_model_complexity_info(
|
||||
raise ValueError('"input_shape" and "inputs" cannot be both set.')
|
||||
|
||||
if inputs is None:
|
||||
device = next(model.parameters()).device
|
||||
if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
|
||||
inputs = (torch.randn(1, *input_shape), )
|
||||
inputs = (torch.randn(1, *input_shape).to(device), )
|
||||
elif is_tuple_of(input_shape, tuple) and all([
|
||||
is_tuple_of(one_input_shape, int)
|
||||
for one_input_shape in input_shape # type: ignore
|
||||
]): # tuple of tuple of int, construct multiple tensors
|
||||
inputs = tuple([
|
||||
torch.randn(1, *one_input_shape)
|
||||
torch.randn(1, *one_input_shape).to(device)
|
||||
for one_input_shape in input_shape # type: ignore
|
||||
])
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user