Hbsun/feature iss205 (#210)

* fix #205: remove act2fn

* fix pytest
pull/215/head
Hongbin Sun 2021-05-18 15:15:35 +08:00 committed by GitHub
parent 06b75780a0
commit 673aadc355
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 47 deletions

View File

@ -16,6 +16,7 @@ Welcome to MMOCR's documentation!
textdet_models.md textdet_models.md
textrecog_models.md textrecog_models.md
kie_models.md kie_models.md
ner_models.md
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2

View File

@ -25,7 +25,7 @@ class BertEncoder(nn.Module):
attention_probs_dropout_prob (float): The dropout probability attention_probs_dropout_prob (float): The dropout probability
of attention. of attention.
intermediate_size (int): The size of intermediate layer. intermediate_size (int): The size of intermediate layer.
hidden_act (str): Hidden layer activation. hidden_act_cfg (dict): Hidden layer activation.
""" """
def __init__(self, def __init__(self,
@ -42,7 +42,7 @@ class BertEncoder(nn.Module):
num_attention_heads=12, num_attention_heads=12,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
intermediate_size=3072, intermediate_size=3072,
hidden_act='gelu_new', hidden_act_cfg=dict(type='GeluNew'),
pretrained=None): pretrained=None):
super().__init__() super().__init__()
self.bert = BertModel( self.bert = BertModel(
@ -59,7 +59,7 @@ class BertEncoder(nn.Module):
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=hidden_act) hidden_act_cfg=hidden_act_cfg)
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
def forward(self, results): def forward(self, results):

View File

@ -1,4 +1,4 @@
from .activations import ACT2FN from .activations import GeluNew
from .bert import BertModel from .bert import BertModel
__all__ = ['BertModel', 'ACT2FN'] __all__ = ['BertModel', 'GeluNew']

View File

@ -6,34 +6,26 @@
import math import math
import torch import torch
from mmcv.cnn import Swish import torch.nn as nn
from mmcv.cnn import ACTIVATION_LAYERS
def gelu(x): @ACTIVATION_LAYERS.register_module()
"""Original Implementation of the gelu activation function in Google Bert class GeluNew(nn.Module):
repo when initially created. For information: OpenAI GPT's gelu is slightly
different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) *
(x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
"""Implementation of the gelu activation function currently in Google Bert """Implementation of the gelu activation function currently in Google Bert
repo (identical to OpenAI GPT). repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415 Also see https://arxiv.org/abs/1606.08415
""" """
return 0.5 * x * (1 + torch.tanh(
math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def forward(self, x):
"""Forward function.
ACT2FN = { Args:
'gelu': gelu, x (torch.Tensor): The input tensor.
'relu': torch.nn.functional.relu,
'swish': Swish, Returns:
'gelu_new': gelu_new torch.Tensor: Activated tensor.
} """
return 0.5 * x * (1 + torch.tanh(
math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

View File

@ -7,8 +7,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import build_activation_layer
from mmocr.models.ner.utils.activations import ACT2FN
class BertModel(nn.Module): class BertModel(nn.Module):
@ -31,7 +30,7 @@ class BertModel(nn.Module):
for the attention probabilities normalized from for the attention probabilities normalized from
the attention scores. the attention scores.
intermediate_size (int): The size of intermediate layer. intermediate_size (int): The size of intermediate layer.
hidden_act (str): hidden layer activation hidden_act_cfg (str): hidden layer activation
""" """
def __init__(self, def __init__(self,
@ -48,7 +47,7 @@ class BertModel(nn.Module):
num_attention_heads=12, num_attention_heads=12,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
intermediate_size=3072, intermediate_size=3072,
hidden_act='gelu_new'): hidden_act_cfg=dict(type='GeluNew')):
super().__init__() super().__init__()
self.embeddings = BertEmbeddings( self.embeddings = BertEmbeddings(
vocab_size=vocab_size, vocab_size=vocab_size,
@ -67,7 +66,7 @@ class BertModel(nn.Module):
layer_norm_eps=layer_norm_eps, layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob, hidden_dropout_prob=hidden_dropout_prob,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=hidden_act) hidden_act_cfg=hidden_act_cfg)
self.pooler = BertPooler(hidden_size=hidden_size) self.pooler = BertPooler(hidden_size=hidden_size)
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.initializer_range = initializer_range self.initializer_range = initializer_range
@ -205,7 +204,7 @@ class BertEncoder(nn.Module):
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
intermediate_size=3072, intermediate_size=3072,
hidden_act='gelu_new'): hidden_act_cfg=dict(type='GeluNew')):
super().__init__() super().__init__()
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states self.output_hidden_states = output_hidden_states
@ -218,7 +217,8 @@ class BertEncoder(nn.Module):
layer_norm_eps=layer_norm_eps, layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob, hidden_dropout_prob=hidden_dropout_prob,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=hidden_act) for _ in range(num_hidden_layers) hidden_act_cfg=hidden_act_cfg)
for _ in range(num_hidden_layers)
]) ])
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None):
@ -278,7 +278,7 @@ class BertLayer(nn.Module):
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
intermediate_size=3072, intermediate_size=3072,
hidden_act='gelu_new'): hidden_act_cfg=dict(type='GeluNew')):
super().__init__() super().__init__()
self.attention = BertAttention( self.attention = BertAttention(
hidden_size=hidden_size, hidden_size=hidden_size,
@ -290,7 +290,7 @@ class BertLayer(nn.Module):
self.intermediate = BertIntermediate( self.intermediate = BertIntermediate(
hidden_size=hidden_size, hidden_size=hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=hidden_act) hidden_act_cfg=hidden_act_cfg)
self.output = BertOutput( self.output = BertOutput(
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_size=hidden_size, hidden_size=hidden_size,
@ -448,13 +448,11 @@ class BertIntermediate(nn.Module):
def __init__(self, def __init__(self,
hidden_size=768, hidden_size=768,
intermediate_size=3072, intermediate_size=3072,
hidden_act='gelu_new'): hidden_act_cfg=dict(type='GeluNew')):
super().__init__() super().__init__()
self.dense = nn.Linear(hidden_size, intermediate_size) self.dense = nn.Linear(hidden_size, intermediate_size)
if isinstance(hidden_act, str): self.intermediate_act_fn = build_activation_layer(hidden_act_cfg)
self.intermediate_act_fn = ACT2FN[hidden_act]
else:
self.intermediate_act_fn = hidden_act
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)

View File

@ -6,7 +6,6 @@ import pytest
import torch import torch
from mmocr.models import build_detector from mmocr.models import build_detector
from mmocr.models.ner.utils.activations import gelu, gelu_new
def _create_dummy_vocab_file(vocab_file): def _create_dummy_vocab_file(vocab_file):
@ -35,7 +34,7 @@ def _get_detector_cfg(fname):
@pytest.mark.parametrize( @pytest.mark.parametrize(
'cfg_file', ['configs/ner/bert_softmax/bert_softmax_cluener_18e.py']) 'cfg_file', ['configs/ner/bert_softmax/bert_softmax_cluener_18e.py'])
def test_encoder_decoder_pipeline(cfg_file): def test_bert_softmax(cfg_file):
# prepare data # prepare data
texts = [''] * 47 texts = [''] * 47
img = [31] * 47 img = [31] * 47
@ -77,7 +76,3 @@ def test_encoder_decoder_pipeline(cfg_file):
batch_results = [] batch_results = []
result = detector.forward(None, img_metas, return_loss=False) result = detector.forward(None, img_metas, return_loss=False)
batch_results.append(result) batch_results.append(result)
# Test activations
gelu(torch.tensor(0.5))
gelu_new(torch.tensor(0.5))