Hbsun/feature iss205 (#210)

* fix #205: remove act2fn

* fix pytest
pull/215/head
Hongbin Sun 2021-05-18 15:15:35 +08:00 committed by GitHub
parent 06b75780a0
commit 673aadc355
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 47 deletions

View File

@ -16,6 +16,7 @@ Welcome to MMOCR's documentation!
textdet_models.md
textrecog_models.md
kie_models.md
ner_models.md
.. toctree::
:maxdepth: 2

View File

@ -25,7 +25,7 @@ class BertEncoder(nn.Module):
attention_probs_dropout_prob (float): The dropout probability
of attention.
intermediate_size (int): The size of intermediate layer.
hidden_act (str): Hidden layer activation.
hidden_act_cfg (dict): Hidden layer activation.
"""
def __init__(self,
@ -42,7 +42,7 @@ class BertEncoder(nn.Module):
num_attention_heads=12,
attention_probs_dropout_prob=0.1,
intermediate_size=3072,
hidden_act='gelu_new',
hidden_act_cfg=dict(type='GeluNew'),
pretrained=None):
super().__init__()
self.bert = BertModel(
@ -59,7 +59,7 @@ class BertEncoder(nn.Module):
num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob,
intermediate_size=intermediate_size,
hidden_act=hidden_act)
hidden_act_cfg=hidden_act_cfg)
self.init_weights(pretrained=pretrained)
def forward(self, results):

View File

@ -1,4 +1,4 @@
from .activations import ACT2FN
from .activations import GeluNew
from .bert import BertModel
__all__ = ['BertModel', 'ACT2FN']
__all__ = ['BertModel', 'GeluNew']

View File

@ -6,34 +6,26 @@
import math
import torch
from mmcv.cnn import Swish
import torch.nn as nn
from mmcv.cnn import ACTIVATION_LAYERS
def gelu(x):
"""Original Implementation of the gelu activation function in Google Bert
repo when initially created. For information: OpenAI GPT's gelu is slightly
different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) *
(x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
@ACTIVATION_LAYERS.register_module()
class GeluNew(nn.Module):
"""Implementation of the gelu activation function currently in Google Bert
repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(
math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def forward(self, x):
"""Forward function.
ACT2FN = {
'gelu': gelu,
'relu': torch.nn.functional.relu,
'swish': Swish,
'gelu_new': gelu_new
}
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: Activated tensor.
"""
return 0.5 * x * (1 + torch.tanh(
math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

View File

@ -7,8 +7,7 @@ import math
import torch
import torch.nn as nn
from mmocr.models.ner.utils.activations import ACT2FN
from mmcv.cnn import build_activation_layer
class BertModel(nn.Module):
@ -31,7 +30,7 @@ class BertModel(nn.Module):
for the attention probabilities normalized from
the attention scores.
intermediate_size (int): The size of intermediate layer.
hidden_act (str): hidden layer activation
hidden_act_cfg (str): hidden layer activation
"""
def __init__(self,
@ -48,7 +47,7 @@ class BertModel(nn.Module):
num_attention_heads=12,
attention_probs_dropout_prob=0.1,
intermediate_size=3072,
hidden_act='gelu_new'):
hidden_act_cfg=dict(type='GeluNew')):
super().__init__()
self.embeddings = BertEmbeddings(
vocab_size=vocab_size,
@ -67,7 +66,7 @@ class BertModel(nn.Module):
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob,
intermediate_size=intermediate_size,
hidden_act=hidden_act)
hidden_act_cfg=hidden_act_cfg)
self.pooler = BertPooler(hidden_size=hidden_size)
self.num_hidden_layers = num_hidden_layers
self.initializer_range = initializer_range
@ -205,7 +204,7 @@ class BertEncoder(nn.Module):
layer_norm_eps=1e-12,
hidden_dropout_prob=0.1,
intermediate_size=3072,
hidden_act='gelu_new'):
hidden_act_cfg=dict(type='GeluNew')):
super().__init__()
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
@ -218,7 +217,8 @@ class BertEncoder(nn.Module):
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob,
intermediate_size=intermediate_size,
hidden_act=hidden_act) for _ in range(num_hidden_layers)
hidden_act_cfg=hidden_act_cfg)
for _ in range(num_hidden_layers)
])
def forward(self, hidden_states, attention_mask=None, head_mask=None):
@ -278,7 +278,7 @@ class BertLayer(nn.Module):
layer_norm_eps=1e-12,
hidden_dropout_prob=0.1,
intermediate_size=3072,
hidden_act='gelu_new'):
hidden_act_cfg=dict(type='GeluNew')):
super().__init__()
self.attention = BertAttention(
hidden_size=hidden_size,
@ -290,7 +290,7 @@ class BertLayer(nn.Module):
self.intermediate = BertIntermediate(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act)
hidden_act_cfg=hidden_act_cfg)
self.output = BertOutput(
intermediate_size=intermediate_size,
hidden_size=hidden_size,
@ -448,13 +448,11 @@ class BertIntermediate(nn.Module):
def __init__(self,
hidden_size=768,
intermediate_size=3072,
hidden_act='gelu_new'):
hidden_act_cfg=dict(type='GeluNew')):
super().__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
if isinstance(hidden_act, str):
self.intermediate_act_fn = ACT2FN[hidden_act]
else:
self.intermediate_act_fn = hidden_act
self.intermediate_act_fn = build_activation_layer(hidden_act_cfg)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)

View File

@ -6,7 +6,6 @@ import pytest
import torch
from mmocr.models import build_detector
from mmocr.models.ner.utils.activations import gelu, gelu_new
def _create_dummy_vocab_file(vocab_file):
@ -35,7 +34,7 @@ def _get_detector_cfg(fname):
@pytest.mark.parametrize(
'cfg_file', ['configs/ner/bert_softmax/bert_softmax_cluener_18e.py'])
def test_encoder_decoder_pipeline(cfg_file):
def test_bert_softmax(cfg_file):
# prepare data
texts = [''] * 47
img = [31] * 47
@ -77,7 +76,3 @@ def test_encoder_decoder_pipeline(cfg_file):
batch_results = []
result = detector.forward(None, img_metas, return_loss=False)
batch_results.append(result)
# Test activations
gelu(torch.tensor(0.5))
gelu_new(torch.tensor(0.5))