mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] fix import error raised by ldm (#3338)
This commit is contained in:
parent
56a40d78ab
commit
9c45a94cee
@ -10,14 +10,19 @@ from typing import List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from ldm.modules.diffusionmodules.util import timestep_embedding
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
from mmengine.model import BaseModule
|
from mmengine.model import BaseModule
|
||||||
from mmengine.runner import CheckpointLoader, load_checkpoint
|
from mmengine.runner import CheckpointLoader, load_checkpoint
|
||||||
|
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
from mmseg.utils import ConfigType, OptConfigType
|
from mmseg.utils import ConfigType, OptConfigType
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
has_ldm = True
|
||||||
|
except ImportError:
|
||||||
|
has_ldm = False
|
||||||
|
|
||||||
|
|
||||||
def register_attention_control(model, controller):
|
def register_attention_control(model, controller):
|
||||||
"""Registers a control function to manage attention within a model.
|
"""Registers a control function to manage attention within a model.
|
||||||
@ -205,6 +210,10 @@ class UNetWrapper(nn.Module):
|
|||||||
max_attn_size=None,
|
max_attn_size=None,
|
||||||
attn_selector='up_cross+down_cross'):
|
attn_selector='up_cross+down_cross'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
assert has_ldm, 'To use UNetWrapper, please install required ' \
|
||||||
|
'packages via `pip install -r requirements/optional.txt`.'
|
||||||
|
|
||||||
self.unet = unet
|
self.unet = unet
|
||||||
self.attention_store = AttentionStore(
|
self.attention_store = AttentionStore(
|
||||||
base_size=base_size // 8, max_size=max_attn_size)
|
base_size=base_size // 8, max_size=max_attn_size)
|
||||||
@ -321,6 +330,9 @@ class VPD(BaseModule):
|
|||||||
|
|
||||||
super().__init__(init_cfg=init_cfg)
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
|
assert has_ldm, 'To use VPD model, please install required packages' \
|
||||||
|
' via `pip install -r requirements/optional.txt`.'
|
||||||
|
|
||||||
if pad_shape is not None:
|
if pad_shape is not None:
|
||||||
if not isinstance(pad_shape, (list, tuple)):
|
if not isinstance(pad_shape, (list, tuple)):
|
||||||
pad_shape = (pad_shape, pad_shape)
|
pad_shape = (pad_shape, pad_shape)
|
||||||
|
@ -9,12 +9,12 @@ from argparse import ArgumentParser
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from mmengine import Config
|
||||||
from mmengine.model import revert_sync_batchnorm
|
from mmengine.model import revert_sync_batchnorm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch_grad_cam import GradCAM, LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
|
from pytorch_grad_cam import GradCAM
|
||||||
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
|
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
|
||||||
|
|
||||||
from mmengine import Config
|
|
||||||
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
||||||
from mmseg.utils import register_all_modules
|
from mmseg.utils import register_all_modules
|
||||||
|
|
||||||
@ -56,21 +56,15 @@ def main():
|
|||||||
default='prediction.png',
|
default='prediction.png',
|
||||||
help='Path to output prediction file')
|
help='Path to output prediction file')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--cam-file',
|
'--cam-file', default='vis_cam.png', help='Path to output cam file')
|
||||||
default='vis_cam.png',
|
|
||||||
help='Path to output cam file')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--target-layers',
|
'--target-layers',
|
||||||
default='backbone.layer4[2]',
|
default='backbone.layer4[2]',
|
||||||
help='Target layers to visualize CAM')
|
help='Target layers to visualize CAM')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--category-index',
|
'--category-index', default='7', help='Category to visualize CAM')
|
||||||
default='7',
|
|
||||||
help='Category to visualize CAM')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device',
|
'--device', default='cuda:0', help='Device used for inference')
|
||||||
default='cuda:0',
|
|
||||||
help='Device used for inference')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# build the model from a config file and a checkpoint file
|
# build the model from a config file and a checkpoint file
|
||||||
@ -116,8 +110,7 @@ def main():
|
|||||||
# Grad CAM(Class Activation Maps)
|
# Grad CAM(Class Activation Maps)
|
||||||
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
|
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
|
||||||
targets = [
|
targets = [
|
||||||
SemanticSegmentationTarget(category, mask_float,
|
SemanticSegmentationTarget(category, mask_float, (height, width))
|
||||||
(height, width))
|
|
||||||
]
|
]
|
||||||
with GradCAM(
|
with GradCAM(
|
||||||
model=model,
|
model=model,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user