diff --git a/mmcv/ops/cc_attention.py b/mmcv/ops/cc_attention.py index 6f59d29fd..53efff9e6 100644 --- a/mmcv/ops/cc_attention.py +++ b/mmcv/ops/cc_attention.py @@ -3,7 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.autograd.function import once_differentiable -from mmcv.cnn import Scale +from mmcv.cnn import PLUGIN_LAYERS, Scale from ..utils import ext_loader ext_module = ext_loader.load_ext( @@ -66,6 +66,7 @@ ca_weight = CAWeightFunction.apply ca_map = CAMapFunction.apply +@PLUGIN_LAYERS.register_module() class CrissCrossAttention(nn.Module): """Criss-Cross Attention Module."""