<Feat>: use local transformer model ()

<Detail>:

<Footer>:
SkalskiP-bump-supervision-version
Liu, Hao 2023-05-22 15:10:04 +08:00 committed by GitHub
parent 39b1472457
commit 427aebd59a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 2 deletions
groundingdino/util

View File

@ -1,5 +1,5 @@
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
import os
def get_tokenlizer(text_encoder_type):
if not isinstance(text_encoder_type, str):
@ -8,6 +8,8 @@ def get_tokenlizer(text_encoder_type):
text_encoder_type = text_encoder_type.text_encoder_type
elif text_encoder_type.get("text_encoder_type", False):
text_encoder_type = text_encoder_type.get("text_encoder_type")
elif os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type):
pass
else:
raise ValueError(
"Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
@ -19,8 +21,9 @@ def get_tokenlizer(text_encoder_type):
def get_pretrained_language_model(text_encoder_type):
if text_encoder_type == "bert-base-uncased":
if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)):
return BertModel.from_pretrained(text_encoder_type)
if text_encoder_type == "roberta-base":
return RobertaModel.from_pretrained(text_encoder_type)
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))