From 83d76abc7fe2e075e36313a7355d8759a38160a9 Mon Sep 17 00:00:00 2001 From: CescMessi Date: Sat, 6 May 2023 10:55:17 +0800 Subject: [PATCH] [Fix] Fix the incorrect device of inputs in get_model_complexity_info (#1130) --- mmengine/analysis/print_helper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mmengine/analysis/print_helper.py b/mmengine/analysis/print_helper.py index cd6092e9..3b87d423 100644 --- a/mmengine/analysis/print_helper.py +++ b/mmengine/analysis/print_helper.py @@ -725,14 +725,15 @@ def get_model_complexity_info( raise ValueError('"input_shape" and "inputs" cannot be both set.') if inputs is None: + device = next(model.parameters()).device if is_tuple_of(input_shape, int): # tuple of int, construct one tensor - inputs = (torch.randn(1, *input_shape), ) + inputs = (torch.randn(1, *input_shape).to(device), ) elif is_tuple_of(input_shape, tuple) and all([ is_tuple_of(one_input_shape, int) for one_input_shape in input_shape # type: ignore ]): # tuple of tuple of int, construct multiple tensors inputs = tuple([ - torch.randn(1, *one_input_shape) + torch.randn(1, *one_input_shape).to(device) for one_input_shape in input_shape # type: ignore ]) else: