[Fix] fix import error raised by ldm (#3338)

This commit is contained in:
Peng Lu 2023-09-20 12:45:05 +08:00 committed by GitHub
parent 56a40d78ab
commit 9c45a94cee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 15 deletions

View File

@ -10,14 +10,19 @@ from typing import List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ldm.modules.diffusionmodules.util import timestep_embedding
from ldm.util import instantiate_from_config
from mmengine.model import BaseModule from mmengine.model import BaseModule
from mmengine.runner import CheckpointLoader, load_checkpoint from mmengine.runner import CheckpointLoader, load_checkpoint
from mmseg.registry import MODELS from mmseg.registry import MODELS
from mmseg.utils import ConfigType, OptConfigType from mmseg.utils import ConfigType, OptConfigType
try:
from ldm.modules.diffusionmodules.util import timestep_embedding
from ldm.util import instantiate_from_config
has_ldm = True
except ImportError:
has_ldm = False
def register_attention_control(model, controller): def register_attention_control(model, controller):
"""Registers a control function to manage attention within a model. """Registers a control function to manage attention within a model.
@ -205,6 +210,10 @@ class UNetWrapper(nn.Module):
max_attn_size=None, max_attn_size=None,
attn_selector='up_cross+down_cross'): attn_selector='up_cross+down_cross'):
super().__init__() super().__init__()
assert has_ldm, 'To use UNetWrapper, please install required ' \
'packages via `pip install -r requirements/optional.txt`.'
self.unet = unet self.unet = unet
self.attention_store = AttentionStore( self.attention_store = AttentionStore(
base_size=base_size // 8, max_size=max_attn_size) base_size=base_size // 8, max_size=max_attn_size)
@ -321,6 +330,9 @@ class VPD(BaseModule):
super().__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
assert has_ldm, 'To use VPD model, please install required packages' \
' via `pip install -r requirements/optional.txt`.'
if pad_shape is not None: if pad_shape is not None:
if not isinstance(pad_shape, (list, tuple)): if not isinstance(pad_shape, (list, tuple)):
pad_shape = (pad_shape, pad_shape) pad_shape = (pad_shape, pad_shape)

View File

@ -9,12 +9,12 @@ from argparse import ArgumentParser
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from mmengine import Config
from mmengine.model import revert_sync_batchnorm from mmengine.model import revert_sync_batchnorm
from PIL import Image from PIL import Image
from pytorch_grad_cam import GradCAM, LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
from mmengine import Config
from mmseg.apis import inference_model, init_model, show_result_pyplot from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import register_all_modules from mmseg.utils import register_all_modules
@ -56,21 +56,15 @@ def main():
default='prediction.png', default='prediction.png',
help='Path to output prediction file') help='Path to output prediction file')
parser.add_argument( parser.add_argument(
'--cam-file', '--cam-file', default='vis_cam.png', help='Path to output cam file')
default='vis_cam.png',
help='Path to output cam file')
parser.add_argument( parser.add_argument(
'--target-layers', '--target-layers',
default='backbone.layer4[2]', default='backbone.layer4[2]',
help='Target layers to visualize CAM') help='Target layers to visualize CAM')
parser.add_argument( parser.add_argument(
'--category-index', '--category-index', default='7', help='Category to visualize CAM')
default='7',
help='Category to visualize CAM')
parser.add_argument( parser.add_argument(
'--device', '--device', default='cuda:0', help='Device used for inference')
default='cuda:0',
help='Device used for inference')
args = parser.parse_args() args = parser.parse_args()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
@ -116,8 +110,7 @@ def main():
# Grad CAM(Class Activation Maps) # Grad CAM(Class Activation Maps)
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM # Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
targets = [ targets = [
SemanticSegmentationTarget(category, mask_float, SemanticSegmentationTarget(category, mask_float, (height, width))
(height, width))
] ]
with GradCAM( with GradCAM(
model=model, model=model,