diff --git a/tools/visualize_actmap.py b/tools/visualize_actmap.py index 278d25e..e148050 100644 --- a/tools/visualize_actmap.py +++ b/tools/visualize_actmap.py @@ -133,6 +133,9 @@ def main(): num_classes=datamanager.num_train_pids, use_gpu=use_gpu ) + + if use_gpu: + model = model.cuda() if args.weights and check_isfile(args.weights): load_pretrained_weights(model, args.weights)