import copy import warnings from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn.bricks import Linear from mmcv.cnn.bricks.drop import build_dropout from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.utils import ConfigDict from torch import Tensor from easycv.framework.errors import RuntimeError from easycv.models.builder import (build_attention, build_feedforward_network, build_transformer_layer) from easycv.models.registry import (ATTENTION, FEEDFORWARD_NETWORK, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) from easycv.models.utils.activation import build_activation_layer from easycv.models.utils.norm import build_norm_layer class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x class Mlp(nn.Module): """ Multilayer perceptron. Parameters: act_layer: Specify the activate function, default use nn.GELU. """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class ConvMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0], ) + (1, ) * ( x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand( shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self): return 'p={}'.format(self.drop_prob) class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None, d_model=256, query_scale_type=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.query_scale_type = query_scale_type if query_scale_type == 'cond_elewise': self.query_scale = MLP(d_model, d_model, d_model, 2) self.norm = norm def forward(self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None): output = src for layer_id, layer in enumerate(self.layers): # rescale the content and pos sim if self.query_scale_type == 'cond_elewise': pos_scales = self.query_scale(output) else: pos_scales = 1 output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos * pos_scales) if self.norm is not None: output = self.norm(output) return output class TransformerEncoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu', normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward(self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None): q = k = self.with_pos_embed(src, pos) src2 = self.self_attn( q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def _get_clones(module, N, layer_share=False): if layer_share: return nn.ModuleList([module for i in range(N)]) else: return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == 'relu': return F.relu if activation == 'gelu': return F.gelu if activation == 'glu': return F.glu if activation == 'prelu': return nn.PReLU() if activation == 'selu': return F.selu raise RuntimeError(F'activation should be relu/gelu, not {activation}.') @FEEDFORWARD_NETWORK.register_module() class FFN(BaseModule): """Implements feed-forward networks (FFNs) with identity connection. Args: embed_dims (int): The feature dimension. Same as `MultiheadAttention`. Defaults: 256. feedforward_channels (int): The hidden dimension of FFNs. Defaults: 1024. num_fcs (int, optional): The number of fully-connected layers in FFNs. Default: 2. act_cfg (dict, optional): The activation config for FFNs. Default: dict(type='ReLU') ffn_drop (float, optional): Probability of an element to be zeroed in FFN. Default 0.0. add_identity (bool, optional): Whether to add the identity connection. Default: `True`. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims=256, feedforward_channels=1024, num_fcs=2, act_cfg=dict(type='ReLU', inplace=True), ffn_drop=0., dropout_layer=None, add_identity=True, init_cfg=None, **kwargs): super().__init__(init_cfg) assert num_fcs >= 2, 'num_fcs should be no less ' \ f'than 2. got {num_fcs}.' self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.num_fcs = num_fcs self.act_cfg = act_cfg self.activate = build_activation_layer(act_cfg) layers = [] in_channels = embed_dims for _ in range(num_fcs - 1): layers.append( Sequential( Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop))) in_channels = feedforward_channels layers.append(Linear(feedforward_channels, embed_dims)) layers.append(nn.Dropout(ffn_drop)) self.layers = Sequential(*layers) self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else torch.nn.Identity() self.add_identity = add_identity def forward(self, x, identity=None): """Forward function for `FFN`. The function would add x to the output tensor if residue is None. """ out = self.layers(x) if not self.add_identity: return self.dropout_layer(out) if identity is None: identity = x return identity + self.dropout_layer(out) @TRANSFORMER_LAYER.register_module() class BaseTransformerLayer(BaseModule): """Base `TransformerLayer` for vision transformer. It can be built from `mmcv.ConfigDict` and support more flexible customization, for example, using any number of `FFN or LN ` and use different kinds of `attention` by specifying a list of `ConfigDict` named `attn_cfgs`. It is worth mentioning that it supports `prenorm` when you specifying `norm` as the first element of `operation_order`. More details about the `prenorm`: `On Layer Normalization in the Transformer Architecture `_ . Args: attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for `self_attention` or `cross_attention` modules, The order of the configs in the list should be consistent with corresponding attentions in operation_order. If it is a dict, all of the attention modules in operation_order will be built with this config. Default: None. ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for FFN, The order of the configs in the list should be consistent with corresponding ffn in operation_order. If it is a dict, all of the attention modules in operation_order will be built with this config. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). Support `prenorm` when you specifying first element as `norm`. Default:None. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default to False. """ def __init__(self, attn_cfgs=None, ffn_cfgs=dict( type='FFN', embed_dims=256, feedforward_channels=1024, num_fcs=2, ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True), ), operation_order=None, norm_cfg=dict(type='LN'), init_cfg=None, batch_first=False, **kwargs): deprecated_args = dict( feedforward_channels='feedforward_channels', ffn_dropout='ffn_drop', ffn_num_fcs='num_fcs') for ori_name, new_name in deprecated_args.items(): if ori_name in kwargs: warnings.warn( f'The arguments `{ori_name}` in BaseTransformerLayer ' f'has been deprecated, now you should set `{new_name}` ' f'and other FFN related arguments ' f'to a dict named `ffn_cfgs`. ', DeprecationWarning) ffn_cfgs[new_name] = kwargs[ori_name] super().__init__(init_cfg) self.batch_first = batch_first assert set(operation_order) & { 'self_attn', 'norm', 'ffn', 'cross_attn'} == \ set(operation_order), f'The operation_order of' \ f' {self.__class__.__name__} should ' \ f'contains all four operation type ' \ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}" num_attn = operation_order.count('self_attn') + operation_order.count( 'cross_attn') if isinstance(attn_cfgs, dict): attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] else: assert num_attn == len(attn_cfgs), f'The length ' \ f'of attn_cfg {num_attn} is ' \ f'not consistent with the number of attention' \ f'in operation_order {operation_order}.' self.num_attn = num_attn self.operation_order = operation_order self.norm_cfg = norm_cfg self.pre_norm = operation_order[0] == 'norm' self.attentions = ModuleList() index = 0 for operation_name in operation_order: if operation_name in ['self_attn', 'cross_attn']: if 'batch_first' in attn_cfgs[index]: assert self.batch_first == attn_cfgs[index]['batch_first'] else: attn_cfgs[index]['batch_first'] = self.batch_first attention = build_attention(attn_cfgs[index]) # Some custom attentions used as `self_attn` # or `cross_attn` can have different behavior. attention.operation_name = operation_name self.attentions.append(attention) index += 1 self.embed_dims = self.attentions[0].embed_dims self.ffns = ModuleList() num_ffns = operation_order.count('ffn') if isinstance(ffn_cfgs, dict): ffn_cfgs = ConfigDict(ffn_cfgs) if isinstance(ffn_cfgs, dict): ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] assert len(ffn_cfgs) == num_ffns for ffn_index in range(num_ffns): if 'embed_dims' not in ffn_cfgs[ffn_index]: ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims else: assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims self.ffns.append( build_feedforward_network(ffn_cfgs[ffn_index], dict(type='FFN'))) self.norms = ModuleList() num_norms = operation_order.count('norm') for _ in range(num_norms): self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) def forward(self, query, key=None, value=None, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, **kwargs): """Forward function for `TransformerDecoderLayer`. **kwargs contains some specific arguments of attentions. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . value (Tensor): The value tensor with same shape as `key`. query_pos (Tensor): The positional encoding for `query`. Default: None. key_pos (Tensor): The positional encoding for `key`. Default: None. attn_masks (List[Tensor] | None): 2D Tensor used in calculation of corresponding attention. The length of it should equal to the number of `attention` in `operation_order`. Default: None. query_key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_queries]. Only used in `self_attn` layer. Defaults to None. key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_keys]. Default: None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims]. """ norm_index = 0 attn_index = 0 ffn_index = 0 identity = query if attn_masks is None: attn_masks = [None for _ in range(self.num_attn)] elif isinstance(attn_masks, torch.Tensor): attn_masks = [ copy.deepcopy(attn_masks) for _ in range(self.num_attn) ] warnings.warn(f'Use same attn_mask in all attentions in ' f'{self.__class__.__name__} ') else: assert len(attn_masks) == self.num_attn, f'The length of ' \ f'attn_masks {len(attn_masks)} must be equal ' \ f'to the number of attention in ' \ f'operation_order {self.num_attn}' for layer in self.operation_order: if layer == 'self_attn': temp_key = temp_value = query query = self.attentions[attn_index]( query, temp_key, temp_value, identity if self.pre_norm else None, query_pos=query_pos, key_pos=query_pos, attn_mask=attn_masks[attn_index], key_padding_mask=query_key_padding_mask, **kwargs) attn_index += 1 identity = query elif layer == 'norm': query = self.norms[norm_index](query) norm_index += 1 elif layer == 'cross_attn': query = self.attentions[attn_index]( query=query, key=key, value=value, identity=identity if self.pre_norm else None, query_pos=query_pos, key_pos=key_pos, attn_mask=attn_masks[attn_index], key_padding_mask=key_padding_mask, **kwargs) attn_index += 1 identity = query elif layer == 'ffn': query = self.ffns[ffn_index]( query, identity if self.pre_norm else None) ffn_index += 1 return query @TRANSFORMER_LAYER_SEQUENCE.register_module() class TransformerLayerSequence(BaseModule): """Base class for TransformerEncoder and TransformerDecoder in vision transformer. As base-class of Encoder and Decoder in vision transformer. Support customization such as specifying different kind of `transformer_layer` in `transformer_coder`. Args: transformerlayer (list[obj:`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`): Config of transformerlayer in TransformerCoder. If it is obj:`mmcv.ConfigDict`, it would be repeated `num_layer` times to a list[`mmcv.ConfigDict`]. Default: None. num_layers (int): The number of `TransformerLayer`. Default: None. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): super().__init__(init_cfg) if isinstance(transformerlayers, dict): transformerlayers = [ copy.deepcopy(transformerlayers) for _ in range(num_layers) ] else: assert isinstance(transformerlayers, list) and \ len(transformerlayers) == num_layers self.num_layers = num_layers self.layers = ModuleList() for i in range(num_layers): self.layers.append(build_transformer_layer(transformerlayers[i])) self.embed_dims = self.layers[0].embed_dims self.pre_norm = self.layers[0].pre_norm def forward(self, query, key, value, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, **kwargs): """Forward function for `TransformerCoder`. Args: query (Tensor): Input query with shape `(num_queries, bs, embed_dims)`. key (Tensor): The key tensor with shape `(num_keys, bs, embed_dims)`. value (Tensor): The value tensor with shape `(num_keys, bs, embed_dims)`. query_pos (Tensor): The positional encoding for `query`. Default: None. key_pos (Tensor): The positional encoding for `key`. Default: None. attn_masks (List[Tensor], optional): Each element is 2D Tensor which is used in calculation of corresponding attention in operation_order. Default: None. query_key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_queries]. Only used in self-attention Default: None. key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_keys]. Default: None. Returns: Tensor: results with shape [num_queries, bs, embed_dims]. """ for layer in self.layers: query = layer( query, key, value, query_pos=query_pos, key_pos=key_pos, attn_masks=attn_masks, query_key_padding_mask=query_key_padding_mask, key_padding_mask=key_padding_mask, **kwargs) return query @ATTENTION.register_module() class MultiheadAttention(BaseModule): """A wrapper for ``torch.nn.MultiheadAttention``. This module implements MultiheadAttention with identity connection, and positional encoding is also passed as input. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. attn_drop (float): A Dropout layer on attn_output_weights. Default: 0.0. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. Default: 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. batch_first (bool): When it is True, Key, Query and Value are shape of (batch, n, embed_dim), otherwise (n, batch, embed_dim). Default to False. """ def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., dropout_layer=dict(type='Dropout', drop_prob=0.), init_cfg=None, batch_first=False, **kwargs): super().__init__(init_cfg) if 'dropout' in kwargs: warnings.warn( 'The arguments `dropout` in MultiheadAttention ' 'has been deprecated, now you can separately ' 'set `attn_drop`(float), proj_drop(float), ' 'and `dropout_layer`(dict) ', DeprecationWarning) attn_drop = kwargs['dropout'] dropout_layer['drop_prob'] = kwargs.pop('dropout') self.embed_dims = embed_dims self.num_heads = num_heads self.batch_first = batch_first self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, **kwargs) self.proj_drop = nn.Dropout(proj_drop) self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else nn.Identity() def forward(self, query, key=None, value=None, identity=None, query_pos=None, key_pos=None, attn_mask=None, key_padding_mask=None, **kwargs): """Forward function for `MultiheadAttention`. **kwargs allow passing a more general data flow when combining with other operations in `transformerlayer`. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . If None, the ``query`` will be used. Defaults to None. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Defaults to None. If None, the `key` will be used. identity (Tensor): This tensor, with the same shape as x, will be used for the identity link. If None, `x` will be used. Defaults to None. query_pos (Tensor): The positional encoding for query, with the same shape as `x`. If not None, it will be added to `x` before forward function. Defaults to None. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Defaults to None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. attn_mask (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Defaults to None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. """ if key is None: key = query if value is None: value = key if identity is None: identity = query if key_pos is None: if query_pos is not None: # use query_pos if key_pos is not available if query_pos.shape == key.shape: key_pos = query_pos else: warnings.warn(f'position encoding of key is' f'missing in {self.__class__.__name__}.') if query_pos is not None: query = query + query_pos if key_pos is not None: key = key + key_pos # Because the dataflow('key', 'query', 'value') of # ``torch.nn.MultiheadAttention`` is (num_query, batch, # embed_dims), We should adjust the shape of dataflow from # batch_first (batch, num_query, embed_dims) to num_query_first # (num_query ,batch, embed_dims), and recover ``attn_output`` # from num_query_first to batch_first. if self.batch_first: query = query.transpose(0, 1) key = key.transpose(0, 1) value = value.transpose(0, 1) out = self.attn( query=query, key=key, value=value, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] if self.batch_first: out = out.transpose(0, 1) return identity + self.dropout_layer(self.proj_drop(out))