mirror of https://github.com/open-mmlab/mmcv.git
[Feature]: Register CrissCrossAttention into plugin layers (#1189)
parent
96c4b70ccb
commit
44e19ff68c
mmcv/ops
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from mmcv.cnn import Scale
|
||||
from mmcv.cnn import PLUGIN_LAYERS, Scale
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
|
@ -66,6 +66,7 @@ ca_weight = CAWeightFunction.apply
|
|||
ca_map = CAMapFunction.apply
|
||||
|
||||
|
||||
@PLUGIN_LAYERS.register_module()
|
||||
class CrissCrossAttention(nn.Module):
|
||||
"""Criss-Cross Attention Module."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue