mirror of https://github.com/open-mmlab/mmocr.git
parent
06b75780a0
commit
673aadc355
|
@ -16,6 +16,7 @@ Welcome to MMOCR's documentation!
|
|||
textdet_models.md
|
||||
textrecog_models.md
|
||||
kie_models.md
|
||||
ner_models.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
|
|
@ -25,7 +25,7 @@ class BertEncoder(nn.Module):
|
|||
attention_probs_dropout_prob (float): The dropout probability
|
||||
of attention.
|
||||
intermediate_size (int): The size of intermediate layer.
|
||||
hidden_act (str): Hidden layer activation.
|
||||
hidden_act_cfg (dict): Hidden layer activation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -42,7 +42,7 @@ class BertEncoder(nn.Module):
|
|||
num_attention_heads=12,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu_new',
|
||||
hidden_act_cfg=dict(type='GeluNew'),
|
||||
pretrained=None):
|
||||
super().__init__()
|
||||
self.bert = BertModel(
|
||||
|
@ -59,7 +59,7 @@ class BertEncoder(nn.Module):
|
|||
num_attention_heads=num_attention_heads,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=hidden_act)
|
||||
hidden_act_cfg=hidden_act_cfg)
|
||||
self.init_weights(pretrained=pretrained)
|
||||
|
||||
def forward(self, results):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .activations import ACT2FN
|
||||
from .activations import GeluNew
|
||||
from .bert import BertModel
|
||||
|
||||
__all__ = ['BertModel', 'ACT2FN']
|
||||
__all__ = ['BertModel', 'GeluNew']
|
||||
|
|
|
@ -6,34 +6,26 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from mmcv.cnn import Swish
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ACTIVATION_LAYERS
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Original Implementation of the gelu activation function in Google Bert
|
||||
repo when initially created. For information: OpenAI GPT's gelu is slightly
|
||||
different (and gives slightly different results):
|
||||
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) *
|
||||
(x + 0.044715 * torch.pow(x, 3))))
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def gelu_new(x):
|
||||
@ACTIVATION_LAYERS.register_module()
|
||||
class GeluNew(nn.Module):
|
||||
"""Implementation of the gelu activation function currently in Google Bert
|
||||
repo (identical to OpenAI GPT).
|
||||
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
||||
ACT2FN = {
|
||||
'gelu': gelu,
|
||||
'relu': torch.nn.functional.relu,
|
||||
'swish': Swish,
|
||||
'gelu_new': gelu_new
|
||||
}
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Activated tensor.
|
||||
"""
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
|
|
@ -7,8 +7,7 @@ import math
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.ner.utils.activations import ACT2FN
|
||||
from mmcv.cnn import build_activation_layer
|
||||
|
||||
|
||||
class BertModel(nn.Module):
|
||||
|
@ -31,7 +30,7 @@ class BertModel(nn.Module):
|
|||
for the attention probabilities normalized from
|
||||
the attention scores.
|
||||
intermediate_size (int): The size of intermediate layer.
|
||||
hidden_act (str): hidden layer activation
|
||||
hidden_act_cfg (str): hidden layer activation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -48,7 +47,7 @@ class BertModel(nn.Module):
|
|||
num_attention_heads=12,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu_new'):
|
||||
hidden_act_cfg=dict(type='GeluNew')):
|
||||
super().__init__()
|
||||
self.embeddings = BertEmbeddings(
|
||||
vocab_size=vocab_size,
|
||||
|
@ -67,7 +66,7 @@ class BertModel(nn.Module):
|
|||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=hidden_act)
|
||||
hidden_act_cfg=hidden_act_cfg)
|
||||
self.pooler = BertPooler(hidden_size=hidden_size)
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.initializer_range = initializer_range
|
||||
|
@ -205,7 +204,7 @@ class BertEncoder(nn.Module):
|
|||
layer_norm_eps=1e-12,
|
||||
hidden_dropout_prob=0.1,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu_new'):
|
||||
hidden_act_cfg=dict(type='GeluNew')):
|
||||
super().__init__()
|
||||
self.output_attentions = output_attentions
|
||||
self.output_hidden_states = output_hidden_states
|
||||
|
@ -218,7 +217,8 @@ class BertEncoder(nn.Module):
|
|||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=hidden_act) for _ in range(num_hidden_layers)
|
||||
hidden_act_cfg=hidden_act_cfg)
|
||||
for _ in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
|
@ -278,7 +278,7 @@ class BertLayer(nn.Module):
|
|||
layer_norm_eps=1e-12,
|
||||
hidden_dropout_prob=0.1,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu_new'):
|
||||
hidden_act_cfg=dict(type='GeluNew')):
|
||||
super().__init__()
|
||||
self.attention = BertAttention(
|
||||
hidden_size=hidden_size,
|
||||
|
@ -290,7 +290,7 @@ class BertLayer(nn.Module):
|
|||
self.intermediate = BertIntermediate(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=hidden_act)
|
||||
hidden_act_cfg=hidden_act_cfg)
|
||||
self.output = BertOutput(
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_size=hidden_size,
|
||||
|
@ -448,13 +448,11 @@ class BertIntermediate(nn.Module):
|
|||
def __init__(self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu_new'):
|
||||
hidden_act_cfg=dict(type='GeluNew')):
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(hidden_size, intermediate_size)
|
||||
if isinstance(hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = hidden_act
|
||||
self.intermediate_act_fn = build_activation_layer(hidden_act_cfg)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
|
|
|
@ -6,7 +6,6 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.models.ner.utils.activations import gelu, gelu_new
|
||||
|
||||
|
||||
def _create_dummy_vocab_file(vocab_file):
|
||||
|
@ -35,7 +34,7 @@ def _get_detector_cfg(fname):
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
'cfg_file', ['configs/ner/bert_softmax/bert_softmax_cluener_18e.py'])
|
||||
def test_encoder_decoder_pipeline(cfg_file):
|
||||
def test_bert_softmax(cfg_file):
|
||||
# prepare data
|
||||
texts = ['中'] * 47
|
||||
img = [31] * 47
|
||||
|
@ -77,7 +76,3 @@ def test_encoder_decoder_pipeline(cfg_file):
|
|||
batch_results = []
|
||||
result = detector.forward(None, img_metas, return_loss=False)
|
||||
batch_results.append(result)
|
||||
|
||||
# Test activations
|
||||
gelu(torch.tensor(0.5))
|
||||
gelu_new(torch.tensor(0.5))
|
||||
|
|
Loading…
Reference in New Issue