diff --git a/mmengine/analysis/print_helper.py b/mmengine/analysis/print_helper.py index cd6092e9..3b87d423 100644 --- a/mmengine/analysis/print_helper.py +++ b/mmengine/analysis/print_helper.py @@ -725,14 +725,15 @@ def get_model_complexity_info( raise ValueError('"input_shape" and "inputs" cannot be both set.') if inputs is None: + device = next(model.parameters()).device if is_tuple_of(input_shape, int): # tuple of int, construct one tensor - inputs = (torch.randn(1, *input_shape), ) + inputs = (torch.randn(1, *input_shape).to(device), ) elif is_tuple_of(input_shape, tuple) and all([ is_tuple_of(one_input_shape, int) for one_input_shape in input_shape # type: ignore ]): # tuple of tuple of int, construct multiple tensors inputs = tuple([ - torch.randn(1, *one_input_shape) + torch.randn(1, *one_input_shape).to(device) for one_input_shape in input_shape # type: ignore ]) else: