ram init commit

pull/1802/head
HustQBW 2023-09-25 01:42:49 +08:00
parent bb59c9ad82
commit 51a2a15f1e
16 changed files with 2203 additions and 4 deletions

View File

@ -11,13 +11,29 @@ if WITH_MULTIMODAL:
from .minigpt4 import * # noqa: F401, F403
from .ofa import * # noqa: F401, F403
from .otter import * # noqa: F401, F403
from .ram import * # noqa: F401, F403
else:
from mmpretrain.registry import MODELS
from mmpretrain.utils.dependency import register_multimodal_placeholder
register_multimodal_placeholder([
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP',
'CLIPZeroShot'
'Blip2Caption',
'Blip2Retrieval',
'Blip2VQA',
'BlipCaption',
'BlipNLVR',
'BlipRetrieval',
'BlipGrounding',
'BlipVQA',
'Flamingo',
'OFA',
'ChineseCLIP',
'MiniGPT4',
'Llava',
'Otter',
'CLIP',
'CLIPZeroShot',
'RAM',
'RAMNormal',
'RAMOpenset',
], MODELS)

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ram import RAM, RAMNormal, RAMOpenset
__all__ = ['RAM', 'RAMNormal', 'RAMOpenset']

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
# data settings
test_transforms_cfg = [
dict(type='Resize', scale=(384, 384), interpolation='bicubic'),
dict(
type='mmpretrain.PackInputs',
algorithm_keys=['text'],
meta_keys=['image_id', 'scale_factor'],
),
]
def get_ram_cfg(mode='normal'):
assert mode in ['normal', 'openset'], 'mode must "normal" or "openset"'
model_type = 'RAMNormal' if mode == 'normal' else 'RAMOpenset'
model_cfg = dict(
type=model_type,
tokenizer=dict(
type='BertTokenizer',
name_or_path='/public/DATA/qbw/ckpt/bert-base-uncased',
use_fast=False),
vision_backbone=dict(
type='SwinTransformer',
arch='large',
img_size=384,
window_size=12,
),
tag_encoder={
'architectures': ['BertModel'],
'attention_probs_dropout_prob': 0.1,
'hidden_act': 'gelu',
'hidden_dropout_prob': 0.1,
'hidden_size': 768,
'initializer_range': 0.02,
'intermediate_size': 3072,
'layer_norm_eps': 1e-12,
'max_position_embeddings': 512,
'model_type': 'bert',
'num_attention_heads': 12,
'num_hidden_layers': 12,
'pad_token_id': 0,
'type_vocab_size': 2,
'vocab_size': 30524,
'encoder_width': 512,
'add_cross_attention': True
},
text_decoder={
'architectures': ['BertModel'],
'attention_probs_dropout_prob': 0.1,
'hidden_act': 'gelu',
'hidden_dropout_prob': 0.1,
'hidden_size': 768,
'initializer_range': 0.02,
'intermediate_size': 3072,
'layer_norm_eps': 1e-12,
'max_position_embeddings': 512,
'model_type': 'bert',
'num_attention_heads': 12,
'num_hidden_layers': 12,
'pad_token_id': 0,
'type_vocab_size': 2,
'vocab_size': 30524,
'encoder_width': 768,
'add_cross_attention': True
},
tagging_head={
'architectures': ['BertModel'],
'attention_probs_dropout_prob': 0.1,
'hidden_act': 'gelu',
'hidden_dropout_prob': 0.1,
'hidden_size': 768,
'initializer_range': 0.02,
'intermediate_size': 3072,
'layer_norm_eps': 1e-12,
'max_position_embeddings': 512,
'model_type': 'bert',
'num_attention_heads': 4,
'num_hidden_layers': 2,
'pad_token_id': 0,
'type_vocab_size': 2,
'vocab_size': 30522,
'encoder_width': 512,
'add_cross_attention': True,
'add_tag_cross_attention': False
},
data_preprocessor=dict(
type='MultiModalDataPreprocessor',
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=False,
),
)
return model_cfg

View File

@ -0,0 +1,109 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import gradio as gr
import torch
from mmpretrain.registry import MODELS, TRANSFORMS
from .config.ram_swin_large_14m import get_ram_cfg, test_transforms_cfg
from .run.inference import inference
parser = argparse.ArgumentParser(
description='RAM(Recognize Anything Model) demo')
parser.add_argument(
'ram_ckpt', type=str, help='pretrained file for ram (absolute path)')
parser.add_argument(
'clip_ckpt',
type=str,
help='clip vit-base-p16 pretrained file (absolute path)')
args = parser.parse_args()
if torch.cuda.is_available():
devices = [
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
]
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
devices = [torch.device('mps')]
else:
devices = [torch.device('cpu')]
def get_free_device():
if hasattr(torch.cuda, 'mem_get_info'):
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices]
select = max(zip(free, range(len(free))))[1]
else:
import random
select = random.randint(0, len(devices) - 1)
return devices[select]
device = get_free_device()
def ram_inference(image, tag_list, mode, threshold):
test_transforms = TRANSFORMS.get('Compose')(transforms=test_transforms_cfg)
model = MODELS.build(get_ram_cfg(mode=mode))
model.load_state_dict(torch.load(args.ram_ckpt))
model.device = device
if mode == 'openset':
categories = tag_list
if categories != '':
categories = categories.strip().split()
else:
categories = None
model.set_openset(
categories=categories,
clip_ckpt=args.clip_ckpt,
threshold=threshold)
sample = dict(img=image)
result = inference(sample, model, test_transforms, mode=mode)
tag, tag_chinese, logits = \
result.get('tag_output')[0][0], result.get('tag_output')[1][0],\
result.get('logits_output')[0]
def wrap(tags, logits):
if tags is None:
return 'Openset mode has no tag_en'
tag_lst = tags.split('|')
rt_lst = []
for i, tag in enumerate(tag_lst):
tag = tag.strip()
rt_lst.append(tag + f': {logits[i]:.2f}')
return ' | '.join(rt_lst)
return [wrap(tag, logits), wrap(tag_chinese, logits)]
def build_gradio():
inputs = [
gr.components.Image(label='image'),
gr.components.Textbox(
lines=2,
label='tag_list',
placeholder=
'please input the categories split by keyboard "blank": ',
value=''),
gr.components.Radio(['normal', 'openset'],
label='mode',
value='normal'),
gr.components.Slider(
minimum=0, maximum=1, value=0.68, step=0.01, label='threshold')
]
return gr.Interface(
fn=ram_inference,
inputs=inputs,
outputs=[
gr.components.Textbox(),
gr.components.Textbox(info="it's translated from the english tags")
])
def main():
build_gradio().launch()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,212 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmpretrain.registry import MODELS
def article(name):
return 'an' if name[0] in 'aeiou' else 'a'
def processed_name(name, rm_dot=False):
# _ for lvis
# / for obj365
res = name.replace('_', ' ').replace('/', ' or ').lower()
if rm_dot:
res = res.rstrip('.')
return res
single_template = ['a photo of a {}.']
multiple_templates = [
'There is {article} {} in the scene.',
'There is the {} in the scene.',
'a photo of {article} {} in the scene.',
'a photo of the {} in the scene.',
'a photo of one {} in the scene.',
'itap of {article} {}.',
'itap of my {}.', # itap: I took a picture of
'itap of the {}.',
'a photo of {article} {}.',
'a photo of my {}.',
'a photo of the {}.',
'a photo of one {}.',
'a photo of many {}.',
'a good photo of {article} {}.',
'a good photo of the {}.',
'a bad photo of {article} {}.',
'a bad photo of the {}.',
'a photo of a nice {}.',
'a photo of the nice {}.',
'a photo of a cool {}.',
'a photo of the cool {}.',
'a photo of a weird {}.',
'a photo of the weird {}.',
'a photo of a small {}.',
'a photo of the small {}.',
'a photo of a large {}.',
'a photo of the large {}.',
'a photo of a clean {}.',
'a photo of the clean {}.',
'a photo of a dirty {}.',
'a photo of the dirty {}.',
'a bright photo of {article} {}.',
'a bright photo of the {}.',
'a dark photo of {article} {}.',
'a dark photo of the {}.',
'a photo of a hard to see {}.',
'a photo of the hard to see {}.',
'a low resolution photo of {article} {}.',
'a low resolution photo of the {}.',
'a cropped photo of {article} {}.',
'a cropped photo of the {}.',
'a close-up photo of {article} {}.',
'a close-up photo of the {}.',
'a jpeg corrupted photo of {article} {}.',
'a jpeg corrupted photo of the {}.',
'a blurry photo of {article} {}.',
'a blurry photo of the {}.',
'a pixelated photo of {article} {}.',
'a pixelated photo of the {}.',
'a black and white photo of the {}.',
'a black and white photo of {article} {}.',
'a plastic {}.',
'the plastic {}.',
'a toy {}.',
'the toy {}.',
'a plushie {}.',
'the plushie {}.',
'a cartoon {}.',
'the cartoon {}.',
'an embroidered {}.',
'the embroidered {}.',
'a painting of the {}.',
'a painting of a {}.',
]
openimages_rare_unseen = [
'Aerial photography', 'Aircraft engine', 'Ale', 'Aloe', 'Amphibian',
'Angling', 'Anole', 'Antique car', 'Arcade game', 'Arthropod',
'Assault rifle', 'Athletic shoe', 'Auto racing', 'Backlighting',
'Bagpipes', 'Ball game', 'Barbecue chicken', 'Barechested', 'Barquentine',
'Beef tenderloin', 'Billiard room', 'Billiards', 'Bird of prey',
'Black swan', 'Black-and-white', 'Blond', 'Boating', 'Bonbon',
'Bottled water', 'Bouldering', 'Bovine', 'Bratwurst', 'Breadboard',
'Briefs', 'Brisket', 'Brochette', 'Calabaza', 'Camera operator', 'Canola',
'Childbirth', 'Chordophone', 'Church bell', 'Classical sculpture',
'Close-up', 'Cobblestone', 'Coca-cola', 'Combat sport', 'Comics',
'Compact car', 'Computer speaker', 'Cookies and crackers',
'Coral reef fish', 'Corn on the cob', 'Cosmetics', 'Crocodilia',
'Digital camera', 'Dishware', 'Divemaster', 'Dobermann', 'Dog walking',
'Domestic rabbit', 'Domestic short-haired cat', 'Double-decker bus',
'Drums', 'Electric guitar', 'Electric piano', 'Electronic instrument',
'Equestrianism', 'Equitation', 'Erinaceidae', 'Extreme sport', 'Falafel',
'Figure skating', 'Filling station', 'Fire apparatus', 'Firearm',
'Flatbread', 'Floristry', 'Forklift truck', 'Freight transport',
'Fried food', 'Fried noodles', 'Frigate', 'Frozen yogurt', 'Frying',
'Full moon', 'Galleon', 'Glacial landform', 'Gliding', 'Go-kart', 'Goats',
'Grappling', 'Great white shark', 'Gumbo', 'Gun turret', 'Hair coloring',
'Halter', 'Headphones', 'Heavy cruiser', 'Herding', 'High-speed rail',
'Holding hands', 'Horse and buggy', 'Horse racing', 'Hound',
'Hunting knife', 'Hurdling', 'Inflatable', 'Jackfruit', 'Jeans', 'Jiaozi',
'Junk food', 'Khinkali', 'Kitesurfing', 'Lawn game', 'Leaf vegetable',
'Lechon', 'Lifebuoy', 'Locust', 'Lumpia', 'Luxury vehicle', 'Machine tool',
'Medical imaging', 'Melee weapon', 'Microcontroller', 'Middle ages',
'Military person', 'Military vehicle', 'Milky way', 'Miniature Poodle',
'Modern dance', 'Molluscs', 'Monoplane', 'Motorcycling', 'Musical theatre',
'Narcissus', 'Nest box', 'Newsagent\'s shop', 'Nile crocodile',
'Nordic skiing', 'Nuclear power plant', 'Orator', 'Outdoor shoe',
'Parachuting', 'Pasta salad', 'Peafowl', 'Pelmeni', 'Perching bird',
'Performance car', 'Personal water craft', 'Pit bull', 'Plant stem',
'Pork chop', 'Portrait photography', 'Primate', 'Procyonidae',
'Prosciutto', 'Public speaking', 'Racewalking', 'Ramen',
'Rear-view mirror', 'Residential area', 'Ribs', 'Rice ball',
'Road cycling', 'Roller skating', 'Roman temple', 'Rowing', 'Rural area',
'Sailboat racing', 'Scaled reptile', 'Scuba diving', 'Senior citizen',
'Shallot', 'Shinto shrine', 'Shooting range', 'Siberian husky', 'Sledding',
'Soba', 'Solar energy', 'Sport climbing', 'Sport utility vehicle',
'Steamed rice', 'Stemware', 'Sumo', 'Surfing Equipment', 'Team sport',
'Touring car', 'Toy block', 'Trampolining', 'Underwater diving',
'Vegetarian food', 'Wallaby', 'Water polo', 'Watercolor paint', 'Whiskers',
'Wind wave', 'Woodwind instrument', 'Yakitori', 'Zeppelin'
]
def get_clip_model():
model = dict(
type='CLIPZeroShot',
vision_backbone=dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
drop_rate=0.,
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
pre_norm=True,
),
projection=dict(
type='CLIPProjection', in_channels=768, out_channels=512),
text_backbone=dict(
type='CLIPTransformer',
width=512,
layers=12,
heads=8,
attn_mask=True,
),
tokenizer=dict(
type='AutoTokenizer',
name_or_path='openai/clip-vit-base-patch16',
use_fast=False),
vocab_size=49408,
transformer_width=512,
proj_dim=512,
context_length=77,
data_preprocessor=dict(
type='MultiModalDataPreprocessor',
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=False,
),
)
return MODELS.build(model)
def build_openset_label_embedding(categories=None, clip_ckpt_path=''):
if categories is None:
print('Categories is None, so using rare_unseen categories')
categories = openimages_rare_unseen
model = get_clip_model()
model.load_state_dict(torch.load(clip_ckpt_path))
templates = multiple_templates
run_on_gpu = torch.cuda.is_available()
with torch.no_grad():
openset_label_embedding = []
for category in categories:
texts = [
template.format(
processed_name(category, rm_dot=True),
article=article(category)) for template in templates
]
texts = [
'This is ' + text
if text.startswith('a') or text.startswith('the') else text
for text in texts
]
texts = model.tokenize(texts) # tokenize
if run_on_gpu:
texts = texts.cuda()
model = model.cuda()
text_embeddings = model.extract_text_feat(texts)
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
text_embedding = text_embeddings.mean(dim=0)
text_embedding /= text_embedding.norm()
openset_label_embedding.append(text_embedding)
openset_label_embedding = torch.stack(openset_label_embedding, dim=1)
if run_on_gpu:
openset_label_embedding = openset_label_embedding.cuda()
openset_label_embedding = openset_label_embedding.t()
return openset_label_embedding, categories

View File

@ -0,0 +1,332 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import pickle
from abc import abstractmethod
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
from mmengine.model import BaseModel
from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample
from .bert import BertConfig, BertLMHeadModel, BertModel
from .openset_utils import build_openset_label_embedding
from .utils import tie_encoder_decoder_weights
def get_path(path):
file_path = os.path.abspath(os.path.dirname(__file__))
if not os.path.isabs(path):
return os.path.join(file_path, path)
class RAM(BaseModel):
"""The implementation of `RAM <https://arxiv.org/abs/2306.03514>`_."""
def __init__(self,
tokenizer: dict,
vision_backbone: dict,
tag_encoder: dict,
tagging_head: dict,
text_decoder: dict,
device: str = 'cpu',
vision_width: int = 1536,
prompt='a picture of ',
threshold=0.68,
delete_tag_index=[],
tag_list='./data/ram_tag_list.pickle',
tag_list_chinese='./data/ram_tag_list_chinese.pickle',
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
if data_preprocessor is None:
data_preprocessor = {}
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
data_preprocessor = MODELS.build(data_preprocessor)
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.device = device
# build the visual encoder
self.visual_encoder = MODELS.build(vision_backbone)
# build the tokenizer
self.tokenizer = TOKENIZER.build(tokenizer)
self.tokenizer.add_special_tokens({'bos_token': '[DEC]'})
self.tokenizer.add_special_tokens(
{'additional_special_tokens': ['[ENC]']})
self.tokenizer.enc_token_id = \
self.tokenizer.additional_special_tokens_ids[0]
# build the tag encoder
# encoder_config = BertConfig.from_json_file(med_config)
# encoder_config.encoder_width = 512
encoder_config = BertConfig.from_dict(tag_encoder)
self.tag_encoder = BertModel(
config=encoder_config, add_pooling_layer=False)
# build image-tag-text decoder
# decoder_config = BertConfig.from_json_file(med_config)
decoder_config = BertConfig.from_dict(text_decoder)
self.text_decoder = BertLMHeadModel(config=decoder_config)
self.delete_tag_index = delete_tag_index
self.prompt = prompt
self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
# load tag list
self.tag_list = self.load_tag_list(get_path(tag_list))
self.tag_list_chinese = self.load_tag_list(get_path(tag_list_chinese))
# create image-tag recognition decoder
self.threshold = threshold
self.num_class = len(self.tag_list)
# q2l_config = \
# BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
# q2l_config.encoder_width = 512
q2l_config = BertConfig.from_dict(tagging_head)
self.tagging_head = BertModel(
config=q2l_config, add_pooling_layer=False)
self.tagging_head.resize_token_embeddings(len(self.tokenizer))
self.label_embed = nn.Parameter(
torch.zeros(self.num_class, q2l_config.encoder_width))
if q2l_config.hidden_size != 512:
self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size)
else:
self.wordvec_proj = nn.Identity()
self.fc = nn.Linear(q2l_config.hidden_size, 1)
self.del_selfattention()
# share weights of the lowest 2-layer of
# "image-tag interaction encoder" with
# the "image-tag recogntion decoder"
tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
' ')
self.image_proj = nn.Linear(vision_width, 512)
# self.label_embed = nn.Parameter(torch.load(
# f'{CONFIG_PATH}/data/textual_label_embedding.pth',
# map_location='cpu').float())
# adjust thresholds for some tags
self.class_threshold = torch.ones(self.num_class) * self.threshold
ram_class_threshold_path = get_path(
'./data/ram_tag_list_threshold.pickle')
with open(ram_class_threshold_path, 'rb') as f:
ram_class_threshold = pickle.load(f)
for key, value in enumerate(ram_class_threshold):
self.class_threshold[key] = value
def load_tag_list(self, tag_list_file):
with open(tag_list_file, 'rb') as f:
tag_list = pickle.load(f)
tag_list = np.array(tag_list)
return tag_list
# delete self-attention layer of image-tag recognition decoder
# to reduce computation, follower Query2Label
def del_selfattention(self):
del self.tagging_head.embeddings
for layer in self.tagging_head.encoder.layer:
del layer.attention
def get_label_embed(self):
return torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
def extract_visual_feature(self, images):
image_embeds = self.visual_encoder(images)[0]
image_embeds = image_embeds.flatten(2, 3)
attn_pool = nn.AdaptiveAvgPool1d(1)
cls_token = attn_pool(image_embeds).permute(0, 2, 1).contiguous()
image_embeds = image_embeds.permute(0, 2, 1).contiguous()
image_embeds = torch.cat([cls_token, image_embeds], dim=1)
image_embeds = self.image_proj(image_embeds)
image_atts = torch.ones(
image_embeds.size()[:-1], dtype=torch.long).to(images.device)
return image_embeds, image_atts
def image2tag(self, label_embed, image_embeds, image_atts):
# recognized image tags using image-tag recogntiion decoder
# image_cls_embeds = image_embeds[:, 0, :]
image_spatial_embeds = image_embeds[:, 1:, :]
bs = image_spatial_embeds.shape[0]
label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
tagging_embed = self.tagging_head(
encoder_embeds=label_embed,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=False,
mode='tagging',
)
logits = self.fc(tagging_embed[0]).squeeze(-1)
return logits
def forward(
self,
images: torch.Tensor,
data_samples: Optional[list] = None,
mode: str = 'predict',
**kwargs,
):
if mode == 'predict':
return self.predict(images, data_samples, **kwargs)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
@abstractmethod
def predict(self,
images: torch.Tensor,
data_samples: DataSample = None) -> DataSample:
raise NotImplementedError
@MODELS.register_module()
class RAMNormal(RAM):
def __init__(self,
tokenizer: dict,
vision_backbone: dict,
tag_encoder: dict,
tagging_head: dict,
text_decoder: dict,
device: str = 'cpu',
vision_width: int = 1536,
prompt='a picture of ',
threshold=0.68,
delete_tag_index=[],
tag_list='./data/ram_tag_list.pickle',
tag_list_chinese='./data/ram_tag_list_chinese.pickle',
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super().__init__(
tokenizer,
vision_backbone,
tag_encoder,
tagging_head,
text_decoder,
device,
vision_width,
prompt,
threshold,
delete_tag_index,
tag_list,
tag_list_chinese,
data_preprocessor,
init_cfg,
)
def tag_process(self, logits):
targets = torch.where(
torch.sigmoid(logits) > self.class_threshold.to(logits.device),
torch.tensor(1.0).to(logits.device),
torch.zeros(self.num_class).to(logits.device))
tag = targets.cpu().numpy()
tag[:, self.delete_tag_index] = 0
tag_output = []
tag_output_chinese = []
logits_output = []
bs = logits.shape[0]
for b in range(bs):
index = np.argwhere(tag[b] == 1)
token = self.tag_list[index].squeeze(axis=1)
logits_output.append(
torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy())
tag_output.append(' | '.join(token))
token_chinese = self.tag_list_chinese[index].squeeze(axis=1)
tag_output_chinese.append(' | '.join(token_chinese))
return [(tag_output, tag_output_chinese), logits_output]
def predict(self,
images: torch.Tensor,
data_samples: DataSample = None) -> DataSample:
self.eval()
self.to(self.device)
images = images.to(self.device)
label_embed = self.get_label_embed()
image_embeds, image_atts = self.extract_visual_feature(images)
logits = self.image2tag(label_embed, image_embeds, image_atts)
tag_output, logits_output = self.tag_process(logits)
data_samples.set_field(logits_output, 'logits_output')
data_samples.set_field(tag_output, 'tag_output')
return data_samples
@MODELS.register_module()
class RAMOpenset(RAMNormal):
def __init__(self,
tokenizer: dict,
vision_backbone: dict,
tag_encoder: dict,
tagging_head: dict,
text_decoder: dict,
device: str = 'cpu',
vision_width: int = 1536,
prompt='a picture of ',
threshold=0.68,
delete_tag_index=[],
tag_list='./data/ram_tag_list.pickle',
tag_list_chinese='./data/ram_tag_list_chinese.pickle',
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super().__init__(
tokenizer,
vision_backbone,
tag_encoder,
tagging_head,
text_decoder,
device,
vision_width,
prompt,
threshold,
delete_tag_index,
tag_list,
tag_list_chinese,
data_preprocessor,
init_cfg,
)
def set_openset(self,
categories: List[str] = None,
clip_ckpt: str = '',
threshold: float = 0.68):
openset_label_embedding, openset_categories = \
build_openset_label_embedding(
categories, clip_ckpt
)
self.tag_list = np.array(openset_categories)
self.label_embed = nn.Parameter(openset_label_embedding.float())
self.num_class = len(openset_categories)
# the threshold for unseen categories is often lower
self.class_threshold = torch.ones(self.num_class) * threshold
def tag_process(self, logits):
targets = torch.where(
torch.sigmoid(logits) > self.class_threshold.to(logits.device),
torch.tensor(1.0).to(logits.device),
torch.zeros(self.num_class).to(logits.device))
tag = targets.cpu().numpy()
tag[:, self.delete_tag_index] = 0
bs = logits.shape[0]
tag_output = []
logits_output = []
for b in range(bs):
index = np.argwhere(tag[b] == 1)
token = self.tag_list[index].squeeze(axis=1)
logits_output.append(
torch.sigmoid(logits)[b][index[:, 0]].cpu().numpy())
tag_output.append(' | '.join(token))
return [(tag_output, [None]), logits_output]

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def inference_ram(sample, model):
with torch.no_grad():
result = model.test_step(sample)
return result
def inference_ram_openset(sample, model):
with torch.no_grad():
result = model.test_step(sample)
return result
def inference(sample, model, transforms, mode='normal'):
sample = transforms(sample)
if sample['inputs'].ndim == 3:
sample['inputs'] = sample['inputs'].unsqueeze(dim=0)
assert mode in ['normal', 'openset'
], 'mode of inference must be "normal" or "openset"'
if mode == 'normal':
return inference_ram(sample, model)
else:
return inference_ram_openset(sample, model)

View File

@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from torch import nn
def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module,
base_model_prefix: str, skip_key: str):
uninitialized_encoder_weights: List[str] = []
if decoder.__class__ != encoder.__class__:
print(f'''{decoder.__class__} and {encoder.__class__} are not equal.
In this case make sure that
all encoder weights are correctly initialized.''')
def tie_encoder_to_decoder_recursively(
decoder_pointer: nn.Module,
encoder_pointer: nn.Module,
module_name: str,
uninitialized_encoder_weights: List[str],
skip_key: str,
depth=0,
):
assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module
), f'{decoder_pointer} and {encoder_pointer}' + \
'have to be of type torch.nn.Module'
if hasattr(decoder_pointer, 'weight') and skip_key not in module_name:
assert hasattr(encoder_pointer, 'weight')
encoder_pointer.weight = decoder_pointer.weight
if hasattr(decoder_pointer, 'bias'):
assert hasattr(encoder_pointer, 'bias')
encoder_pointer.bias = decoder_pointer.bias
print(module_name + ' is tied')
return
encoder_modules = encoder_pointer._modules
decoder_modules = decoder_pointer._modules
if len(decoder_modules) > 0:
assert (len(encoder_modules) >
0), f'''Encoder module {encoder_pointer}
does not match decoder module {decoder_pointer}'''
all_encoder_weights = set([
module_name + '/' + sub_name
for sub_name in encoder_modules.keys()
])
encoder_layer_pos = 0
for name, module in decoder_modules.items():
if name.isdigit():
encoder_name = str(int(name) + encoder_layer_pos)
decoder_name = name
if not isinstance(
decoder_modules[decoder_name],
type(encoder_modules[encoder_name])) and len(
encoder_modules) != len(decoder_modules):
# this can happen if the name corresponds to
# the position in a list module list of layers
# in this case the decoder has added a
# cross-attention that the encoder doesn't have
# thus skip this step and
# subtract one layer pos from encoder
encoder_layer_pos -= 1
continue
elif name not in encoder_modules:
continue
elif depth > 500:
raise ValueError(
'''Max depth of recursive function `tie_encoder_to_decoder` reached.
It seems that there is a circular dependency
between two or more `nn.Modules` of your model.''')
else:
decoder_name = encoder_name = name
tie_encoder_to_decoder_recursively(
decoder_modules[decoder_name],
encoder_modules[encoder_name],
module_name + '/' + name,
uninitialized_encoder_weights,
skip_key,
depth=depth + 1,
)
all_encoder_weights.remove(module_name + '/' + encoder_name)
uninitialized_encoder_weights += list(all_encoder_weights)
# tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix,
uninitialized_encoder_weights, skip_key)

View File

@ -12,6 +12,7 @@ from .huggingface import register_hf_tokenizer
register_hf_tokenizer(AutoTokenizer)
register_hf_tokenizer(LlamaTokenizer)
register_hf_tokenizer(BertTokenizer)
@register_hf_tokenizer()

View File

@ -0,0 +1,117 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
from copy import deepcopy
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_swin(ckpt):
new_ckpt = OrderedDict()
convert_mapping = dict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if 'attn_mask' in k:
continue
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
elif k.startswith('norm'):
new_v = v
new_k = k.replace('norm', 'norm3')
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
convert_mapping[k] = new_k
return new_ckpt, convert_mapping
def main():
parser = argparse.ArgumentParser(
description='Convert keys in official pretrained RAM models to'
'MMPretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
visual_ckpt = OrderedDict()
for key in state_dict:
if key.startswith('visual_encoder.'):
new_key = key.replace('visual_encoder.', '')
visual_ckpt[new_key] = state_dict[key]
new_visual_ckpt, convert_mapping = convert_swin(visual_ckpt)
new_ckpt = deepcopy(state_dict)
for key in state_dict:
if key.startswith('visual_encoder.'):
if 'attn_mask' in key:
del new_ckpt[key]
continue
del new_ckpt[key]
old_key = key.replace('visual_encoder.', '')
new_ckpt[key.replace(old_key,
convert_mapping[old_key])] = deepcopy(
new_visual_ckpt[key.replace(
old_key,
convert_mapping[old_key]).replace(
'visual_encoder.', '')])
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(new_ckpt, args.dst)
if __name__ == '__main__':
main()