mirror of
https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once.git
synced 2025-06-03 14:50:11 +08:00
35 lines
1.0 KiB
Python
Executable File
35 lines
1.0 KiB
Python
Executable File
from importlib.metadata import requires
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .registry import register_model
|
|
from .vlpencoder import LanguageEncoder
|
|
|
|
class FixLanguageEncoder(LanguageEncoder):
|
|
|
|
def __init__(
|
|
self,
|
|
*args, **kwargs):
|
|
super(FixLanguageEncoder, self).__init__(*args, **kwargs)
|
|
self.logit_scale = nn.Parameter(torch.ones([]), requires_grad=False)
|
|
|
|
@torch.no_grad()
|
|
def get_text_embeddings(self, *args, **kwargs):
|
|
return super().get_text_embeddings(*args, **kwargs)
|
|
|
|
@torch.no_grad()
|
|
def get_text_token_embeddings(self, *args, **kwargs):
|
|
return super().get_text_token_embeddings(*args, **kwargs)
|
|
|
|
@torch.no_grad()
|
|
def forward_language(self, *args, **kwargs):
|
|
return super().forward_language(*args, **kwargs)
|
|
|
|
@torch.no_grad()
|
|
def forward_language_token(self, *args, **kwargs):
|
|
return super().forward_language_token(*args, **kwargs)
|
|
|
|
|
|
@register_model
|
|
def get_language_model(cfg, **kwargs):
|
|
return FixLanguageEncoder(cfg) |