mirror of
https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once.git
synced 2025-06-03 14:50:11 +08:00
35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
|
from importlib.metadata import requires
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from .registry import register_model
|
||
|
from .vlpencoder import LanguageEncoder
|
||
|
|
||
|
class FixLanguageEncoder(LanguageEncoder):
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
*args, **kwargs):
|
||
|
super(FixLanguageEncoder, self).__init__(*args, **kwargs)
|
||
|
self.logit_scale = nn.Parameter(torch.ones([]), requires_grad=False)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def get_text_embeddings(self, *args, **kwargs):
|
||
|
return super().get_text_embeddings(*args, **kwargs)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def get_text_token_embeddings(self, *args, **kwargs):
|
||
|
return super().get_text_token_embeddings(*args, **kwargs)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def forward_language(self, *args, **kwargs):
|
||
|
return super().forward_language(*args, **kwargs)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def forward_language_token(self, *args, **kwargs):
|
||
|
return super().forward_language_token(*args, **kwargs)
|
||
|
|
||
|
|
||
|
@register_model
|
||
|
def get_language_model(cfg, **kwargs):
|
||
|
return FixLanguageEncoder(cfg)
|