mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix the incorrect device of inputs in get_model_complexity_info (#1130)
This commit is contained in:
parent
2085046d22
commit
83d76abc7f
@ -725,14 +725,15 @@ def get_model_complexity_info(
|
|||||||
raise ValueError('"input_shape" and "inputs" cannot be both set.')
|
raise ValueError('"input_shape" and "inputs" cannot be both set.')
|
||||||
|
|
||||||
if inputs is None:
|
if inputs is None:
|
||||||
|
device = next(model.parameters()).device
|
||||||
if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
|
if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
|
||||||
inputs = (torch.randn(1, *input_shape), )
|
inputs = (torch.randn(1, *input_shape).to(device), )
|
||||||
elif is_tuple_of(input_shape, tuple) and all([
|
elif is_tuple_of(input_shape, tuple) and all([
|
||||||
is_tuple_of(one_input_shape, int)
|
is_tuple_of(one_input_shape, int)
|
||||||
for one_input_shape in input_shape # type: ignore
|
for one_input_shape in input_shape # type: ignore
|
||||||
]): # tuple of tuple of int, construct multiple tensors
|
]): # tuple of tuple of int, construct multiple tensors
|
||||||
inputs = tuple([
|
inputs = tuple([
|
||||||
torch.randn(1, *one_input_shape)
|
torch.randn(1, *one_input_shape).to(device)
|
||||||
for one_input_shape in input_shape # type: ignore
|
for one_input_shape in input_shape # type: ignore
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user