mirror of https://github.com/open-mmlab/mmcv.git
[Feature]: Register CrissCrossAttention into plugin layers (#1189)
parent
96c4b70ccb
commit
44e19ff68c
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.autograd.function import once_differentiable
|
from torch.autograd.function import once_differentiable
|
||||||
|
|
||||||
from mmcv.cnn import Scale
|
from mmcv.cnn import PLUGIN_LAYERS, Scale
|
||||||
from ..utils import ext_loader
|
from ..utils import ext_loader
|
||||||
|
|
||||||
ext_module = ext_loader.load_ext(
|
ext_module = ext_loader.load_ext(
|
||||||
|
@ -66,6 +66,7 @@ ca_weight = CAWeightFunction.apply
|
||||||
ca_map = CAMapFunction.apply
|
ca_map = CAMapFunction.apply
|
||||||
|
|
||||||
|
|
||||||
|
@PLUGIN_LAYERS.register_module()
|
||||||
class CrissCrossAttention(nn.Module):
|
class CrissCrossAttention(nn.Module):
|
||||||
"""Criss-Cross Attention Module."""
|
"""Criss-Cross Attention Module."""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue