mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Support depth estimation algorithm [VPD](https://github.com/wl-zhao/VPD) ## Modification 1. add VPD backbone 2. add VPD decoder head for depth estimation 3. add a new segmentor `DepthEstimator` based on `EncoderDecoder` for depth estimation 4. add an integrated metric that calculate common metrics in depth estimation 5. add SiLog loss for depth estimation 6. add config for VPD ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 7. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 8. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 9. The documentation has been modified accordingly, like docstring or example tutorials.
384 lines
14 KiB
Python
384 lines
14 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# ------------------------------------------------------------------------------
|
|
# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py
|
|
# Original licence: MIT License
|
|
# ------------------------------------------------------------------------------
|
|
|
|
import math
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from ldm.modules.diffusionmodules.util import timestep_embedding
|
|
from ldm.util import instantiate_from_config
|
|
from mmengine.model import BaseModule
|
|
from mmengine.runner import CheckpointLoader, load_checkpoint
|
|
|
|
from mmseg.registry import MODELS
|
|
from mmseg.utils import ConfigType, OptConfigType
|
|
|
|
|
|
def register_attention_control(model, controller):
|
|
"""Registers a control function to manage attention within a model.
|
|
|
|
Args:
|
|
model: The model to which attention is to be registered.
|
|
controller: The control function responsible for managing attention.
|
|
"""
|
|
|
|
def ca_forward(self, place_in_unet):
|
|
"""Custom forward method for attention.
|
|
|
|
Args:
|
|
self: Reference to the current object.
|
|
place_in_unet: The location in UNet (down/mid/up).
|
|
|
|
Returns:
|
|
The modified forward method.
|
|
"""
|
|
|
|
def forward(x, context=None, mask=None):
|
|
h = self.heads
|
|
is_cross = context is not None
|
|
context = context or x # if context is None, use x
|
|
|
|
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
|
q, k, v = (
|
|
tensor.view(tensor.shape[0] * h, tensor.shape[1],
|
|
tensor.shape[2] // h) for tensor in [q, k, v])
|
|
|
|
sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
|
|
|
if mask is not None:
|
|
mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1)
|
|
max_neg_value = -torch.finfo(sim.dtype).max
|
|
sim.masked_fill_(~mask, max_neg_value)
|
|
|
|
attn = sim.softmax(dim=-1)
|
|
attn_mean = attn.view(h, attn.shape[0] // h,
|
|
*attn.shape[1:]).mean(0)
|
|
controller(attn_mean, is_cross, place_in_unet)
|
|
|
|
out = torch.matmul(attn, v)
|
|
out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h)
|
|
return self.to_out(out)
|
|
|
|
return forward
|
|
|
|
def register_recr(net_, count, place_in_unet):
|
|
"""Recursive function to register the custom forward method to all
|
|
CrossAttention layers.
|
|
|
|
Args:
|
|
net_: The network layer currently being processed.
|
|
count: The current count of layers processed.
|
|
place_in_unet: The location in UNet (down/mid/up).
|
|
|
|
Returns:
|
|
The updated count of layers processed.
|
|
"""
|
|
if net_.__class__.__name__ == 'CrossAttention':
|
|
net_.forward = ca_forward(net_, place_in_unet)
|
|
return count + 1
|
|
if hasattr(net_, 'children'):
|
|
return sum(
|
|
register_recr(child, 0, place_in_unet)
|
|
for child in net_.children())
|
|
return count
|
|
|
|
cross_att_count = sum(
|
|
register_recr(net[1], 0, place) for net, place in [
|
|
(child, 'down') if 'input_blocks' in name else (
|
|
child, 'up') if 'output_blocks' in name else
|
|
(child,
|
|
'mid') if 'middle_block' in name else (None, None) # Default case
|
|
for name, child in model.diffusion_model.named_children()
|
|
] if net is not None)
|
|
|
|
controller.num_att_layers = cross_att_count
|
|
|
|
|
|
class AttentionStore:
|
|
"""A class for storing attention information in the UNet model.
|
|
|
|
Attributes:
|
|
base_size (int): Base size for storing attention information.
|
|
max_size (int): Maximum size for storing attention information.
|
|
"""
|
|
|
|
def __init__(self, base_size=64, max_size=None):
|
|
"""Initialize AttentionStore with default or custom sizes."""
|
|
self.reset()
|
|
self.base_size = base_size
|
|
self.max_size = max_size or (base_size // 2)
|
|
self.num_att_layers = -1
|
|
|
|
@staticmethod
|
|
def get_empty_store():
|
|
"""Returns an empty store for holding attention values."""
|
|
return {
|
|
key: []
|
|
for key in [
|
|
'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self',
|
|
'up_self'
|
|
]
|
|
}
|
|
|
|
def reset(self):
|
|
"""Resets the step and attention stores to their initial states."""
|
|
self.cur_step = 0
|
|
self.cur_att_layer = 0
|
|
self.step_store = self.get_empty_store()
|
|
self.attention_store = {}
|
|
|
|
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
|
"""Processes a single forward step, storing the attention.
|
|
|
|
Args:
|
|
attn: The attention tensor.
|
|
is_cross (bool): Whether it's cross attention.
|
|
place_in_unet (str): The location in UNet (down/mid/up).
|
|
|
|
Returns:
|
|
The unmodified attention tensor.
|
|
"""
|
|
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
|
if attn.shape[1] <= (self.max_size)**2:
|
|
self.step_store[key].append(attn)
|
|
return attn
|
|
|
|
def between_steps(self):
|
|
"""Processes and stores attention information between steps."""
|
|
if not self.attention_store:
|
|
self.attention_store = self.step_store
|
|
else:
|
|
for key in self.attention_store:
|
|
self.attention_store[key] = [
|
|
stored + step for stored, step in zip(
|
|
self.attention_store[key], self.step_store[key])
|
|
]
|
|
self.step_store = self.get_empty_store()
|
|
|
|
def get_average_attention(self):
|
|
"""Calculates and returns the average attention across all steps."""
|
|
return {
|
|
key: [item for item in self.step_store[key]]
|
|
for key in self.step_store
|
|
}
|
|
|
|
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
|
"""Allows the class instance to be callable."""
|
|
return self.forward(attn, is_cross, place_in_unet)
|
|
|
|
@property
|
|
def num_uncond_att_layers(self):
|
|
"""Returns the number of unconditional attention layers (default is
|
|
0)."""
|
|
return 0
|
|
|
|
def step_callback(self, x_t):
|
|
"""A placeholder for a step callback.
|
|
|
|
Returns the input unchanged.
|
|
"""
|
|
return x_t
|
|
|
|
|
|
class UNetWrapper(nn.Module):
|
|
"""A wrapper for UNet with optional attention mechanisms.
|
|
|
|
Args:
|
|
unet (nn.Module): The UNet model to wrap
|
|
use_attn (bool): Whether to use attention. Defaults to True
|
|
base_size (int): Base size for the attention store. Defaults to 512
|
|
max_attn_size (int, optional): Maximum size for the attention store.
|
|
Defaults to None
|
|
attn_selector (str): The types of attention to use.
|
|
Defaults to 'up_cross+down_cross'
|
|
"""
|
|
|
|
def __init__(self,
|
|
unet,
|
|
use_attn=True,
|
|
base_size=512,
|
|
max_attn_size=None,
|
|
attn_selector='up_cross+down_cross'):
|
|
super().__init__()
|
|
self.unet = unet
|
|
self.attention_store = AttentionStore(
|
|
base_size=base_size // 8, max_size=max_attn_size)
|
|
self.attn_selector = attn_selector.split('+')
|
|
self.use_attn = use_attn
|
|
self.init_sizes(base_size)
|
|
if self.use_attn:
|
|
register_attention_control(unet, self.attention_store)
|
|
|
|
def init_sizes(self, base_size):
|
|
"""Initialize sizes based on the base size."""
|
|
self.size16 = base_size // 32
|
|
self.size32 = base_size // 16
|
|
self.size64 = base_size // 8
|
|
|
|
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
|
"""Forward pass through the model."""
|
|
diffusion_model = self.unet.diffusion_model
|
|
if self.use_attn:
|
|
self.attention_store.reset()
|
|
hs, emb, out_list = self._unet_forward(x, timesteps, context, y,
|
|
diffusion_model)
|
|
if self.use_attn:
|
|
self._append_attn_to_output(out_list)
|
|
return out_list[::-1]
|
|
|
|
def _unet_forward(self, x, timesteps, context, y, diffusion_model):
|
|
hs = []
|
|
t_emb = timestep_embedding(
|
|
timesteps, diffusion_model.model_channels, repeat_only=False)
|
|
emb = diffusion_model.time_embed(t_emb)
|
|
h = x.type(diffusion_model.dtype)
|
|
for module in diffusion_model.input_blocks:
|
|
h = module(h, emb, context)
|
|
hs.append(h)
|
|
h = diffusion_model.middle_block(h, emb, context)
|
|
out_list = []
|
|
for i_out, module in enumerate(diffusion_model.output_blocks):
|
|
h = torch.cat([h, hs.pop()], dim=1)
|
|
h = module(h, emb, context)
|
|
if i_out in [1, 4, 7]:
|
|
out_list.append(h)
|
|
h = h.type(x.dtype)
|
|
out_list.append(h)
|
|
return hs, emb, out_list
|
|
|
|
def _append_attn_to_output(self, out_list):
|
|
avg_attn = self.attention_store.get_average_attention()
|
|
attns = {self.size16: [], self.size32: [], self.size64: []}
|
|
for k in self.attn_selector:
|
|
for up_attn in avg_attn[k]:
|
|
size = int(math.sqrt(up_attn.shape[1]))
|
|
up_attn = up_attn.transpose(-1, -2).reshape(
|
|
*up_attn.shape[:2], size, -1)
|
|
attns[size].append(up_attn)
|
|
attn16 = torch.stack(attns[self.size16]).mean(0)
|
|
attn32 = torch.stack(attns[self.size32]).mean(0)
|
|
attn64 = torch.stack(attns[self.size64]).mean(0) if len(
|
|
attns[self.size64]) > 0 else None
|
|
out_list[1] = torch.cat([out_list[1], attn16], dim=1)
|
|
out_list[2] = torch.cat([out_list[2], attn32], dim=1)
|
|
if attn64 is not None:
|
|
out_list[3] = torch.cat([out_list[3], attn64], dim=1)
|
|
|
|
|
|
class TextAdapter(nn.Module):
|
|
"""A PyTorch Module that serves as a text adapter.
|
|
|
|
This module takes text embeddings and adjusts them based on a scaling
|
|
factor gamma.
|
|
"""
|
|
|
|
def __init__(self, text_dim=768):
|
|
super().__init__()
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(text_dim, text_dim), nn.GELU(),
|
|
nn.Linear(text_dim, text_dim))
|
|
|
|
def forward(self, texts, gamma):
|
|
texts_after = self.fc(texts)
|
|
texts = texts + gamma * texts_after
|
|
return texts
|
|
|
|
|
|
@MODELS.register_module()
|
|
class VPD(BaseModule):
|
|
"""VPD (Visual Perception Diffusion) model.
|
|
|
|
.. _`VPD`: https://arxiv.org/abs/2303.02153
|
|
|
|
Args:
|
|
diffusion_cfg (dict): Configuration for diffusion model.
|
|
class_embed_path (str): Path for class embeddings.
|
|
unet_cfg (dict, optional): Configuration for U-Net.
|
|
gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4.
|
|
class_embed_select (bool, optional): If True, enables class embedding
|
|
selection. Defaults to False.
|
|
pad_shape (Optional[Union[int, List[int]]], optional): Padding shape.
|
|
Defaults to None.
|
|
pad_val (Union[int, List[int]], optional): Padding value.
|
|
Defaults to 0.
|
|
init_cfg (dict, optional): Configuration for network initialization.
|
|
"""
|
|
|
|
def __init__(self,
|
|
diffusion_cfg: ConfigType,
|
|
class_embed_path: str,
|
|
unet_cfg: OptConfigType = dict(),
|
|
gamma: float = 1e-4,
|
|
class_embed_select=False,
|
|
pad_shape: Optional[Union[int, List[int]]] = None,
|
|
pad_val: Union[int, List[int]] = 0,
|
|
init_cfg: OptConfigType = None):
|
|
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
if pad_shape is not None:
|
|
if not isinstance(pad_shape, (list, tuple)):
|
|
pad_shape = (pad_shape, pad_shape)
|
|
|
|
self.pad_shape = pad_shape
|
|
self.pad_val = pad_val
|
|
|
|
# diffusion model
|
|
diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None)
|
|
sd_model = instantiate_from_config(diffusion_cfg)
|
|
if diffusion_checkpoint is not None:
|
|
load_checkpoint(sd_model, diffusion_checkpoint, strict=False)
|
|
|
|
self.encoder_vq = sd_model.first_stage_model
|
|
self.unet = UNetWrapper(sd_model.model, **unet_cfg)
|
|
|
|
# class embeddings & text adapter
|
|
class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path)
|
|
text_dim = class_embeddings.size(-1)
|
|
self.text_adapter = TextAdapter(text_dim=text_dim)
|
|
self.class_embed_select = class_embed_select
|
|
if class_embed_select:
|
|
class_embeddings = torch.cat(
|
|
(class_embeddings, class_embeddings.mean(dim=0,
|
|
keepdims=True)),
|
|
dim=0)
|
|
self.register_buffer('class_embeddings', class_embeddings)
|
|
self.gamma = nn.Parameter(torch.ones(text_dim) * gamma)
|
|
|
|
def forward(self, x):
|
|
"""Extract features from images."""
|
|
|
|
# calculate cross-attn map
|
|
if self.class_embed_select:
|
|
if isinstance(x, (tuple, list)):
|
|
x, class_ids = x[:2]
|
|
class_ids = class_ids.tolist()
|
|
else:
|
|
class_ids = [-1] * x.size(0)
|
|
class_embeddings = self.class_embeddings[class_ids]
|
|
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
|
|
c_crossattn = c_crossattn.unsqueeze(1)
|
|
else:
|
|
class_embeddings = self.class_embeddings
|
|
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
|
|
c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1)
|
|
|
|
# pad to required input shape for pretrained diffusion model
|
|
if self.pad_shape is not None:
|
|
pad_width = max(0, self.pad_shape[1] - x.shape[-1])
|
|
pad_height = max(0, self.pad_shape[0] - x.shape[-2])
|
|
x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val)
|
|
|
|
# forward the denoising model
|
|
with torch.no_grad():
|
|
latents = self.encoder_vq.encode(x).mode().detach()
|
|
t = torch.ones((x.shape[0], ), device=x.device).long()
|
|
outs = self.unet(latents, t, context=c_crossattn)
|
|
|
|
return outs
|