mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] fix import error raised by ldm (#3338)
This commit is contained in:
parent
56a40d78ab
commit
9c45a94cee
@ -10,14 +10,19 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from ldm.util import instantiate_from_config
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner import CheckpointLoader, load_checkpoint
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import ConfigType, OptConfigType
|
||||
|
||||
try:
|
||||
from ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from ldm.util import instantiate_from_config
|
||||
has_ldm = True
|
||||
except ImportError:
|
||||
has_ldm = False
|
||||
|
||||
|
||||
def register_attention_control(model, controller):
|
||||
"""Registers a control function to manage attention within a model.
|
||||
@ -205,6 +210,10 @@ class UNetWrapper(nn.Module):
|
||||
max_attn_size=None,
|
||||
attn_selector='up_cross+down_cross'):
|
||||
super().__init__()
|
||||
|
||||
assert has_ldm, 'To use UNetWrapper, please install required ' \
|
||||
'packages via `pip install -r requirements/optional.txt`.'
|
||||
|
||||
self.unet = unet
|
||||
self.attention_store = AttentionStore(
|
||||
base_size=base_size // 8, max_size=max_attn_size)
|
||||
@ -321,6 +330,9 @@ class VPD(BaseModule):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert has_ldm, 'To use VPD model, please install required packages' \
|
||||
' via `pip install -r requirements/optional.txt`.'
|
||||
|
||||
if pad_shape is not None:
|
||||
if not isinstance(pad_shape, (list, tuple)):
|
||||
pad_shape = (pad_shape, pad_shape)
|
||||
|
@ -9,12 +9,12 @@ from argparse import ArgumentParser
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine import Config
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from PIL import Image
|
||||
from pytorch_grad_cam import GradCAM, LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
|
||||
from pytorch_grad_cam import GradCAM
|
||||
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
|
||||
|
||||
from mmengine import Config
|
||||
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
@ -56,21 +56,15 @@ def main():
|
||||
default='prediction.png',
|
||||
help='Path to output prediction file')
|
||||
parser.add_argument(
|
||||
'--cam-file',
|
||||
default='vis_cam.png',
|
||||
help='Path to output cam file')
|
||||
'--cam-file', default='vis_cam.png', help='Path to output cam file')
|
||||
parser.add_argument(
|
||||
'--target-layers',
|
||||
default='backbone.layer4[2]',
|
||||
help='Target layers to visualize CAM')
|
||||
parser.add_argument(
|
||||
'--category-index',
|
||||
default='7',
|
||||
help='Category to visualize CAM')
|
||||
'--category-index', default='7', help='Category to visualize CAM')
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
default='cuda:0',
|
||||
help='Device used for inference')
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
args = parser.parse_args()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
@ -116,8 +110,7 @@ def main():
|
||||
# Grad CAM(Class Activation Maps)
|
||||
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
|
||||
targets = [
|
||||
SemanticSegmentationTarget(category, mask_float,
|
||||
(height, width))
|
||||
SemanticSegmentationTarget(category, mask_float, (height, width))
|
||||
]
|
||||
with GradCAM(
|
||||
model=model,
|
||||
|
Loading…
x
Reference in New Issue
Block a user