[Feature]: Register CrissCrossAttention into plugin layers ()

pull/1199/head
Ye Liu 2021-07-13 14:08:57 +08:00 committed by GitHub
parent 96c4b70ccb
commit 44e19ff68c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import once_differentiable
from mmcv.cnn import Scale
from mmcv.cnn import PLUGIN_LAYERS, Scale
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
@ -66,6 +66,7 @@ ca_weight = CAWeightFunction.apply
ca_map = CAMapFunction.apply
@PLUGIN_LAYERS.register_module()
class CrissCrossAttention(nn.Module):
"""Criss-Cross Attention Module."""