From 9c45a94cee11be3e71dbc01ea18b71b6f29330d5 Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Wed, 20 Sep 2023 12:45:05 +0800 Subject: [PATCH] [Fix] fix import error raised by ldm (#3338) --- mmseg/models/backbones/vpd.py | 16 ++++++++++++++-- tools/analysis_tools/visualization_cam.py | 19 ++++++------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/mmseg/models/backbones/vpd.py b/mmseg/models/backbones/vpd.py index 8b57be39b..e0536d31c 100644 --- a/mmseg/models/backbones/vpd.py +++ b/mmseg/models/backbones/vpd.py @@ -10,14 +10,19 @@ from typing import List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F -from ldm.modules.diffusionmodules.util import timestep_embedding -from ldm.util import instantiate_from_config from mmengine.model import BaseModule from mmengine.runner import CheckpointLoader, load_checkpoint from mmseg.registry import MODELS from mmseg.utils import ConfigType, OptConfigType +try: + from ldm.modules.diffusionmodules.util import timestep_embedding + from ldm.util import instantiate_from_config + has_ldm = True +except ImportError: + has_ldm = False + def register_attention_control(model, controller): """Registers a control function to manage attention within a model. @@ -205,6 +210,10 @@ class UNetWrapper(nn.Module): max_attn_size=None, attn_selector='up_cross+down_cross'): super().__init__() + + assert has_ldm, 'To use UNetWrapper, please install required ' \ + 'packages via `pip install -r requirements/optional.txt`.' + self.unet = unet self.attention_store = AttentionStore( base_size=base_size // 8, max_size=max_attn_size) @@ -321,6 +330,9 @@ class VPD(BaseModule): super().__init__(init_cfg=init_cfg) + assert has_ldm, 'To use VPD model, please install required packages' \ + ' via `pip install -r requirements/optional.txt`.' + if pad_shape is not None: if not isinstance(pad_shape, (list, tuple)): pad_shape = (pad_shape, pad_shape) diff --git a/tools/analysis_tools/visualization_cam.py b/tools/analysis_tools/visualization_cam.py index 334de4adf..00cdb3e04 100644 --- a/tools/analysis_tools/visualization_cam.py +++ b/tools/analysis_tools/visualization_cam.py @@ -9,12 +9,12 @@ from argparse import ArgumentParser import numpy as np import torch import torch.nn.functional as F +from mmengine import Config from mmengine.model import revert_sync_batchnorm from PIL import Image -from pytorch_grad_cam import GradCAM, LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM +from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image -from mmengine import Config from mmseg.apis import inference_model, init_model, show_result_pyplot from mmseg.utils import register_all_modules @@ -56,21 +56,15 @@ def main(): default='prediction.png', help='Path to output prediction file') parser.add_argument( - '--cam-file', - default='vis_cam.png', - help='Path to output cam file') + '--cam-file', default='vis_cam.png', help='Path to output cam file') parser.add_argument( '--target-layers', default='backbone.layer4[2]', help='Target layers to visualize CAM') parser.add_argument( - '--category-index', - default='7', - help='Category to visualize CAM') + '--category-index', default='7', help='Category to visualize CAM') parser.add_argument( - '--device', - default='cuda:0', - help='Device used for inference') + '--device', default='cuda:0', help='Device used for inference') args = parser.parse_args() # build the model from a config file and a checkpoint file @@ -116,8 +110,7 @@ def main(): # Grad CAM(Class Activation Maps) # Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM targets = [ - SemanticSegmentationTarget(category, mask_float, - (height, width)) + SemanticSegmentationTarget(category, mask_float, (height, width)) ] with GradCAM( model=model,