mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
384 lines
14 KiB
Python
384 lines
14 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
# ------------------------------------------------------------------------------
|
||
|
# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py
|
||
|
# Original licence: MIT License
|
||
|
# ------------------------------------------------------------------------------
|
||
|
|
||
|
import math
|
||
|
from typing import List, Optional, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from ldm.modules.diffusionmodules.util import timestep_embedding
|
||
|
from ldm.util import instantiate_from_config
|
||
|
from mmengine.model import BaseModule
|
||
|
from mmengine.runner import CheckpointLoader, load_checkpoint
|
||
|
|
||
|
from mmseg.registry import MODELS
|
||
|
from mmseg.utils import ConfigType, OptConfigType
|
||
|
|
||
|
|
||
|
def register_attention_control(model, controller):
|
||
|
"""Registers a control function to manage attention within a model.
|
||
|
|
||
|
Args:
|
||
|
model: The model to which attention is to be registered.
|
||
|
controller: The control function responsible for managing attention.
|
||
|
"""
|
||
|
|
||
|
def ca_forward(self, place_in_unet):
|
||
|
"""Custom forward method for attention.
|
||
|
|
||
|
Args:
|
||
|
self: Reference to the current object.
|
||
|
place_in_unet: The location in UNet (down/mid/up).
|
||
|
|
||
|
Returns:
|
||
|
The modified forward method.
|
||
|
"""
|
||
|
|
||
|
def forward(x, context=None, mask=None):
|
||
|
h = self.heads
|
||
|
is_cross = context is not None
|
||
|
context = context or x # if context is None, use x
|
||
|
|
||
|
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||
|
q, k, v = (
|
||
|
tensor.view(tensor.shape[0] * h, tensor.shape[1],
|
||
|
tensor.shape[2] // h) for tensor in [q, k, v])
|
||
|
|
||
|
sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||
|
|
||
|
if mask is not None:
|
||
|
mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1)
|
||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||
|
sim.masked_fill_(~mask, max_neg_value)
|
||
|
|
||
|
attn = sim.softmax(dim=-1)
|
||
|
attn_mean = attn.view(h, attn.shape[0] // h,
|
||
|
*attn.shape[1:]).mean(0)
|
||
|
controller(attn_mean, is_cross, place_in_unet)
|
||
|
|
||
|
out = torch.matmul(attn, v)
|
||
|
out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h)
|
||
|
return self.to_out(out)
|
||
|
|
||
|
return forward
|
||
|
|
||
|
def register_recr(net_, count, place_in_unet):
|
||
|
"""Recursive function to register the custom forward method to all
|
||
|
CrossAttention layers.
|
||
|
|
||
|
Args:
|
||
|
net_: The network layer currently being processed.
|
||
|
count: The current count of layers processed.
|
||
|
place_in_unet: The location in UNet (down/mid/up).
|
||
|
|
||
|
Returns:
|
||
|
The updated count of layers processed.
|
||
|
"""
|
||
|
if net_.__class__.__name__ == 'CrossAttention':
|
||
|
net_.forward = ca_forward(net_, place_in_unet)
|
||
|
return count + 1
|
||
|
if hasattr(net_, 'children'):
|
||
|
return sum(
|
||
|
register_recr(child, 0, place_in_unet)
|
||
|
for child in net_.children())
|
||
|
return count
|
||
|
|
||
|
cross_att_count = sum(
|
||
|
register_recr(net[1], 0, place) for net, place in [
|
||
|
(child, 'down') if 'input_blocks' in name else (
|
||
|
child, 'up') if 'output_blocks' in name else
|
||
|
(child,
|
||
|
'mid') if 'middle_block' in name else (None, None) # Default case
|
||
|
for name, child in model.diffusion_model.named_children()
|
||
|
] if net is not None)
|
||
|
|
||
|
controller.num_att_layers = cross_att_count
|
||
|
|
||
|
|
||
|
class AttentionStore:
|
||
|
"""A class for storing attention information in the UNet model.
|
||
|
|
||
|
Attributes:
|
||
|
base_size (int): Base size for storing attention information.
|
||
|
max_size (int): Maximum size for storing attention information.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, base_size=64, max_size=None):
|
||
|
"""Initialize AttentionStore with default or custom sizes."""
|
||
|
self.reset()
|
||
|
self.base_size = base_size
|
||
|
self.max_size = max_size or (base_size // 2)
|
||
|
self.num_att_layers = -1
|
||
|
|
||
|
@staticmethod
|
||
|
def get_empty_store():
|
||
|
"""Returns an empty store for holding attention values."""
|
||
|
return {
|
||
|
key: []
|
||
|
for key in [
|
||
|
'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self',
|
||
|
'up_self'
|
||
|
]
|
||
|
}
|
||
|
|
||
|
def reset(self):
|
||
|
"""Resets the step and attention stores to their initial states."""
|
||
|
self.cur_step = 0
|
||
|
self.cur_att_layer = 0
|
||
|
self.step_store = self.get_empty_store()
|
||
|
self.attention_store = {}
|
||
|
|
||
|
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
||
|
"""Processes a single forward step, storing the attention.
|
||
|
|
||
|
Args:
|
||
|
attn: The attention tensor.
|
||
|
is_cross (bool): Whether it's cross attention.
|
||
|
place_in_unet (str): The location in UNet (down/mid/up).
|
||
|
|
||
|
Returns:
|
||
|
The unmodified attention tensor.
|
||
|
"""
|
||
|
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
||
|
if attn.shape[1] <= (self.max_size)**2:
|
||
|
self.step_store[key].append(attn)
|
||
|
return attn
|
||
|
|
||
|
def between_steps(self):
|
||
|
"""Processes and stores attention information between steps."""
|
||
|
if not self.attention_store:
|
||
|
self.attention_store = self.step_store
|
||
|
else:
|
||
|
for key in self.attention_store:
|
||
|
self.attention_store[key] = [
|
||
|
stored + step for stored, step in zip(
|
||
|
self.attention_store[key], self.step_store[key])
|
||
|
]
|
||
|
self.step_store = self.get_empty_store()
|
||
|
|
||
|
def get_average_attention(self):
|
||
|
"""Calculates and returns the average attention across all steps."""
|
||
|
return {
|
||
|
key: [item for item in self.step_store[key]]
|
||
|
for key in self.step_store
|
||
|
}
|
||
|
|
||
|
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
||
|
"""Allows the class instance to be callable."""
|
||
|
return self.forward(attn, is_cross, place_in_unet)
|
||
|
|
||
|
@property
|
||
|
def num_uncond_att_layers(self):
|
||
|
"""Returns the number of unconditional attention layers (default is
|
||
|
0)."""
|
||
|
return 0
|
||
|
|
||
|
def step_callback(self, x_t):
|
||
|
"""A placeholder for a step callback.
|
||
|
|
||
|
Returns the input unchanged.
|
||
|
"""
|
||
|
return x_t
|
||
|
|
||
|
|
||
|
class UNetWrapper(nn.Module):
|
||
|
"""A wrapper for UNet with optional attention mechanisms.
|
||
|
|
||
|
Args:
|
||
|
unet (nn.Module): The UNet model to wrap
|
||
|
use_attn (bool): Whether to use attention. Defaults to True
|
||
|
base_size (int): Base size for the attention store. Defaults to 512
|
||
|
max_attn_size (int, optional): Maximum size for the attention store.
|
||
|
Defaults to None
|
||
|
attn_selector (str): The types of attention to use.
|
||
|
Defaults to 'up_cross+down_cross'
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
unet,
|
||
|
use_attn=True,
|
||
|
base_size=512,
|
||
|
max_attn_size=None,
|
||
|
attn_selector='up_cross+down_cross'):
|
||
|
super().__init__()
|
||
|
self.unet = unet
|
||
|
self.attention_store = AttentionStore(
|
||
|
base_size=base_size // 8, max_size=max_attn_size)
|
||
|
self.attn_selector = attn_selector.split('+')
|
||
|
self.use_attn = use_attn
|
||
|
self.init_sizes(base_size)
|
||
|
if self.use_attn:
|
||
|
register_attention_control(unet, self.attention_store)
|
||
|
|
||
|
def init_sizes(self, base_size):
|
||
|
"""Initialize sizes based on the base size."""
|
||
|
self.size16 = base_size // 32
|
||
|
self.size32 = base_size // 16
|
||
|
self.size64 = base_size // 8
|
||
|
|
||
|
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||
|
"""Forward pass through the model."""
|
||
|
diffusion_model = self.unet.diffusion_model
|
||
|
if self.use_attn:
|
||
|
self.attention_store.reset()
|
||
|
hs, emb, out_list = self._unet_forward(x, timesteps, context, y,
|
||
|
diffusion_model)
|
||
|
if self.use_attn:
|
||
|
self._append_attn_to_output(out_list)
|
||
|
return out_list[::-1]
|
||
|
|
||
|
def _unet_forward(self, x, timesteps, context, y, diffusion_model):
|
||
|
hs = []
|
||
|
t_emb = timestep_embedding(
|
||
|
timesteps, diffusion_model.model_channels, repeat_only=False)
|
||
|
emb = diffusion_model.time_embed(t_emb)
|
||
|
h = x.type(diffusion_model.dtype)
|
||
|
for module in diffusion_model.input_blocks:
|
||
|
h = module(h, emb, context)
|
||
|
hs.append(h)
|
||
|
h = diffusion_model.middle_block(h, emb, context)
|
||
|
out_list = []
|
||
|
for i_out, module in enumerate(diffusion_model.output_blocks):
|
||
|
h = torch.cat([h, hs.pop()], dim=1)
|
||
|
h = module(h, emb, context)
|
||
|
if i_out in [1, 4, 7]:
|
||
|
out_list.append(h)
|
||
|
h = h.type(x.dtype)
|
||
|
out_list.append(h)
|
||
|
return hs, emb, out_list
|
||
|
|
||
|
def _append_attn_to_output(self, out_list):
|
||
|
avg_attn = self.attention_store.get_average_attention()
|
||
|
attns = {self.size16: [], self.size32: [], self.size64: []}
|
||
|
for k in self.attn_selector:
|
||
|
for up_attn in avg_attn[k]:
|
||
|
size = int(math.sqrt(up_attn.shape[1]))
|
||
|
up_attn = up_attn.transpose(-1, -2).reshape(
|
||
|
*up_attn.shape[:2], size, -1)
|
||
|
attns[size].append(up_attn)
|
||
|
attn16 = torch.stack(attns[self.size16]).mean(0)
|
||
|
attn32 = torch.stack(attns[self.size32]).mean(0)
|
||
|
attn64 = torch.stack(attns[self.size64]).mean(0) if len(
|
||
|
attns[self.size64]) > 0 else None
|
||
|
out_list[1] = torch.cat([out_list[1], attn16], dim=1)
|
||
|
out_list[2] = torch.cat([out_list[2], attn32], dim=1)
|
||
|
if attn64 is not None:
|
||
|
out_list[3] = torch.cat([out_list[3], attn64], dim=1)
|
||
|
|
||
|
|
||
|
class TextAdapter(nn.Module):
|
||
|
"""A PyTorch Module that serves as a text adapter.
|
||
|
|
||
|
This module takes text embeddings and adjusts them based on a scaling
|
||
|
factor gamma.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, text_dim=768):
|
||
|
super().__init__()
|
||
|
self.fc = nn.Sequential(
|
||
|
nn.Linear(text_dim, text_dim), nn.GELU(),
|
||
|
nn.Linear(text_dim, text_dim))
|
||
|
|
||
|
def forward(self, texts, gamma):
|
||
|
texts_after = self.fc(texts)
|
||
|
texts = texts + gamma * texts_after
|
||
|
return texts
|
||
|
|
||
|
|
||
|
@MODELS.register_module()
|
||
|
class VPD(BaseModule):
|
||
|
"""VPD (Visual Perception Diffusion) model.
|
||
|
|
||
|
.. _`VPD`: https://arxiv.org/abs/2303.02153
|
||
|
|
||
|
Args:
|
||
|
diffusion_cfg (dict): Configuration for diffusion model.
|
||
|
class_embed_path (str): Path for class embeddings.
|
||
|
unet_cfg (dict, optional): Configuration for U-Net.
|
||
|
gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4.
|
||
|
class_embed_select (bool, optional): If True, enables class embedding
|
||
|
selection. Defaults to False.
|
||
|
pad_shape (Optional[Union[int, List[int]]], optional): Padding shape.
|
||
|
Defaults to None.
|
||
|
pad_val (Union[int, List[int]], optional): Padding value.
|
||
|
Defaults to 0.
|
||
|
init_cfg (dict, optional): Configuration for network initialization.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
diffusion_cfg: ConfigType,
|
||
|
class_embed_path: str,
|
||
|
unet_cfg: OptConfigType = dict(),
|
||
|
gamma: float = 1e-4,
|
||
|
class_embed_select=False,
|
||
|
pad_shape: Optional[Union[int, List[int]]] = None,
|
||
|
pad_val: Union[int, List[int]] = 0,
|
||
|
init_cfg: OptConfigType = None):
|
||
|
|
||
|
super().__init__(init_cfg=init_cfg)
|
||
|
|
||
|
if pad_shape is not None:
|
||
|
if not isinstance(pad_shape, (list, tuple)):
|
||
|
pad_shape = (pad_shape, pad_shape)
|
||
|
|
||
|
self.pad_shape = pad_shape
|
||
|
self.pad_val = pad_val
|
||
|
|
||
|
# diffusion model
|
||
|
diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None)
|
||
|
sd_model = instantiate_from_config(diffusion_cfg)
|
||
|
if diffusion_checkpoint is not None:
|
||
|
load_checkpoint(sd_model, diffusion_checkpoint, strict=False)
|
||
|
|
||
|
self.encoder_vq = sd_model.first_stage_model
|
||
|
self.unet = UNetWrapper(sd_model.model, **unet_cfg)
|
||
|
|
||
|
# class embeddings & text adapter
|
||
|
class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path)
|
||
|
text_dim = class_embeddings.size(-1)
|
||
|
self.text_adapter = TextAdapter(text_dim=text_dim)
|
||
|
self.class_embed_select = class_embed_select
|
||
|
if class_embed_select:
|
||
|
class_embeddings = torch.cat(
|
||
|
(class_embeddings, class_embeddings.mean(dim=0,
|
||
|
keepdims=True)),
|
||
|
dim=0)
|
||
|
self.register_buffer('class_embeddings', class_embeddings)
|
||
|
self.gamma = nn.Parameter(torch.ones(text_dim) * gamma)
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""Extract features from images."""
|
||
|
|
||
|
# calculate cross-attn map
|
||
|
if self.class_embed_select:
|
||
|
if isinstance(x, (tuple, list)):
|
||
|
x, class_ids = x[:2]
|
||
|
class_ids = class_ids.tolist()
|
||
|
else:
|
||
|
class_ids = [-1] * x.size(0)
|
||
|
class_embeddings = self.class_embeddings[class_ids]
|
||
|
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
|
||
|
c_crossattn = c_crossattn.unsqueeze(1)
|
||
|
else:
|
||
|
class_embeddings = self.class_embeddings
|
||
|
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
|
||
|
c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1)
|
||
|
|
||
|
# pad to required input shape for pretrained diffusion model
|
||
|
if self.pad_shape is not None:
|
||
|
pad_width = max(0, self.pad_shape[1] - x.shape[-1])
|
||
|
pad_height = max(0, self.pad_shape[0] - x.shape[-2])
|
||
|
x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val)
|
||
|
|
||
|
# forward the denoising model
|
||
|
with torch.no_grad():
|
||
|
latents = self.encoder_vq.encode(x).mode().detach()
|
||
|
t = torch.ones((x.shape[0], ), device=x.device).long()
|
||
|
outs = self.unet(latents, t, context=c_crossattn)
|
||
|
|
||
|
return outs
|