[Feature] Implement of Zero-Shot CLIP Classifier (#1737)
* zero-shot CLIP * modify zero-shot clip config * add in1k_sub_prompt(8 prompts) for improvement * add some annotations doc * clip base class & clip_zs sub-class * some modifications of details after review * convert into and use mmpretrain-vit * modify names of some files and directoriespull/1906/head
parent
845b462190
commit
bb59c9ad82
|
@ -0,0 +1,68 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='CIFAR100',
|
||||
data_root='data/cifar100',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='CLIPZeroShot',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
),
|
||||
projection=dict(type='CLIPProjection', in_channels=768, out_channels=512),
|
||||
text_backbone=dict(
|
||||
type='CLIPTransformer',
|
||||
width=512,
|
||||
layers=12,
|
||||
heads=8,
|
||||
attn_mask=True,
|
||||
),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='openai/clip-vit-base-patch16',
|
||||
use_fast=False),
|
||||
vocab_size=49408,
|
||||
transformer_width=512,
|
||||
proj_dim=512,
|
||||
text_prototype='cifar100',
|
||||
text_prompt='openai_cifar100',
|
||||
context_length=77,
|
||||
)
|
|
@ -0,0 +1,69 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='ImageNet',
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='CLIPZeroShot',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
),
|
||||
projection=dict(type='CLIPProjection', in_channels=768, out_channels=512),
|
||||
text_backbone=dict(
|
||||
type='CLIPTransformer',
|
||||
width=512,
|
||||
layers=12,
|
||||
heads=8,
|
||||
attn_mask=True,
|
||||
),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='openai/clip-vit-base-patch16',
|
||||
use_fast=False),
|
||||
vocab_size=49408,
|
||||
transformer_width=512,
|
||||
proj_dim=512,
|
||||
text_prototype='imagenet',
|
||||
text_prompt='openai_imagenet_sub', # openai_imagenet, openai_imagenet_sub
|
||||
context_length=77,
|
||||
)
|
|
@ -0,0 +1,68 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='CIFAR100',
|
||||
data_root='data/cifar100',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='CLIPZeroShot',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='large',
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
drop_rate=0.,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
),
|
||||
projection=dict(type='CLIPProjection', in_channels=1024, out_channels=768),
|
||||
text_backbone=dict(
|
||||
type='CLIPTransformer',
|
||||
width=768,
|
||||
layers=12,
|
||||
heads=12,
|
||||
attn_mask=True,
|
||||
),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='openai/clip-vit-large-patch14',
|
||||
use_fast=False),
|
||||
vocab_size=49408,
|
||||
transformer_width=768,
|
||||
proj_dim=768,
|
||||
text_prototype='cifar100',
|
||||
text_prompt='openai_cifar100',
|
||||
context_length=77,
|
||||
)
|
|
@ -0,0 +1,69 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='ImageNet',
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='CLIPZeroShot',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='large',
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
drop_rate=0.,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
),
|
||||
projection=dict(type='CLIPProjection', in_channels=1024, out_channels=768),
|
||||
text_backbone=dict(
|
||||
type='CLIPTransformer',
|
||||
width=768,
|
||||
layers=12,
|
||||
heads=12,
|
||||
attn_mask=True,
|
||||
),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='openai/clip-vit-large-patch14',
|
||||
use_fast=False),
|
||||
vocab_size=49408,
|
||||
transformer_width=768,
|
||||
proj_dim=768,
|
||||
text_prototype='imagenet',
|
||||
text_prompt='openai_imagenet_sub', # openai_imagenet, openai_imagenet_sub
|
||||
context_length=77,
|
||||
)
|
|
@ -1438,3 +1438,224 @@ CIFAR100_CATEGORIES_CN = (
|
|||
'海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒',
|
||||
'桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼',
|
||||
'柳树', '狼', '女人', '蠕虫')
|
||||
|
||||
IMAGENET_SIMPLE_CATEGORIES = (
|
||||
'tench', 'goldfish', 'great white shark', 'tiger shark',
|
||||
'hammerhead shark', 'electric ray', 'stingray', 'rooster', 'hen',
|
||||
'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco',
|
||||
'indigo bunting', 'American robin', 'bulbul', 'jay', 'magpie', 'chickadee',
|
||||
'American dipper', 'kite (bird of prey)', 'bald eagle', 'vulture',
|
||||
'great grey owl', 'fire salamander', 'smooth newt', 'newt',
|
||||
'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog',
|
||||
'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle',
|
||||
'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'green iguana',
|
||||
'Carolina anole', 'desert grassland whiptail lizard', 'agama',
|
||||
'frilled-necked lizard', 'alligator lizard', 'Gila monster',
|
||||
'European green lizard', 'chameleon', 'Komodo dragon', 'Nile crocodile',
|
||||
'American alligator', 'triceratops', 'worm snake', 'ring-necked snake',
|
||||
'eastern hog-nosed snake', 'smooth green snake', 'kingsnake',
|
||||
'garter snake', 'water snake', 'vine snake', 'night snake',
|
||||
'boa constrictor', 'African rock python', 'Indian cobra', 'green mamba',
|
||||
'sea snake', 'Saharan horned viper', 'eastern diamondback rattlesnake',
|
||||
'sidewinder rattlesnake', 'trilobite', 'harvestman', 'scorpion',
|
||||
'yellow garden spider', 'barn spider', 'European garden spider',
|
||||
'southern black widow', 'tarantula', 'wolf spider', 'tick', 'centipede',
|
||||
'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peafowl',
|
||||
'quail', 'partridge', 'african grey parrot', 'macaw',
|
||||
'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill',
|
||||
'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser',
|
||||
'goose', 'black swan', 'tusker', 'echidna', 'platypus', 'wallaby', 'koala',
|
||||
'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm',
|
||||
'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton',
|
||||
'chambered nautilus', 'Dungeness crab', 'rock crab', 'fiddler crab',
|
||||
'red king crab', 'American lobster', 'spiny lobster', 'crayfish',
|
||||
'hermit crab', 'isopod', 'white stork', 'black stork', 'spoonbill',
|
||||
'flamingo', 'little blue heron', 'great egret', 'bittern bird',
|
||||
'crane bird', 'limpkin', 'common gallinule', 'American coot', 'bustard',
|
||||
'ruddy turnstone', 'dunlin', 'common redshank', 'dowitcher',
|
||||
'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale',
|
||||
'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin',
|
||||
'Maltese', 'Pekingese', 'Shih Tzu', 'King Charles Spaniel', 'Papillon',
|
||||
'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound',
|
||||
'Beagle', 'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound',
|
||||
'Treeing Walker Coonhound', 'English foxhound', 'Redbone Coonhound',
|
||||
'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet',
|
||||
'Ibizan Hound', 'Norwegian Elkhound', 'Otterhound', 'Saluki',
|
||||
'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier',
|
||||
'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier',
|
||||
'Kerry Blue Terrier', 'Irish Terrier', 'Norfolk Terrier',
|
||||
'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier',
|
||||
'Lakeland Terrier', 'Sealyham Terrier', 'Airedale Terrier',
|
||||
'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier',
|
||||
'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer',
|
||||
'Standard Schnauzer', 'Scottish Terrier', 'Tibetan Terrier',
|
||||
'Australian Silky Terrier', 'Soft-coated Wheaten Terrier',
|
||||
'West Highland White Terrier', 'Lhasa Apso', 'Flat-Coated Retriever',
|
||||
'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever',
|
||||
'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla',
|
||||
'English Setter', 'Irish Setter', 'Gordon Setter', 'Brittany dog',
|
||||
'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel',
|
||||
'Cocker Spaniel', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz',
|
||||
'Schipperke', 'Groenendael dog', 'Malinois', 'Briard', 'Australian Kelpie',
|
||||
'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie',
|
||||
'Border Collie', 'Bouvier des Flandres dog', 'Rottweiler',
|
||||
'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher',
|
||||
'Greater Swiss Mountain Dog', 'Bernese Mountain Dog',
|
||||
'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer', 'Bullmastiff',
|
||||
'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky',
|
||||
'Alaskan Malamute', 'Siberian Husky', 'Dalmatian', 'Affenpinscher',
|
||||
'Basenji', 'pug', 'Leonberger', 'Newfoundland dog', 'Great Pyrenees dog',
|
||||
'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'brussels griffon',
|
||||
'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', 'Toy Poodle',
|
||||
'Miniature Poodle', 'Standard Poodle',
|
||||
'Mexican hairless dog (xoloitzcuintli)', 'grey wolf',
|
||||
'Alaskan tundra wolf', 'red wolf or maned wolf', 'coyote', 'dingo',
|
||||
'dhole', 'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox',
|
||||
'grey fox', 'tabby cat', 'tiger cat', 'Persian cat', 'Siamese cat',
|
||||
'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar',
|
||||
'lion', 'tiger', 'cheetah', 'brown bear', 'American black bear',
|
||||
'polar bear', 'sloth bear', 'mongoose', 'meerkat', 'tiger beetle',
|
||||
'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle',
|
||||
'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', 'ant',
|
||||
'grasshopper', 'cricket insect', 'stick insect', 'cockroach',
|
||||
'praying mantis', 'cicada', 'leafhopper', 'lacewing', 'dragonfly',
|
||||
'damselfly', 'red admiral butterfly', 'ringlet butterfly',
|
||||
'monarch butterfly', 'small white butterfly', 'sulphur butterfly',
|
||||
'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber',
|
||||
'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine',
|
||||
'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel horse',
|
||||
'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox',
|
||||
'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep',
|
||||
'Alpine ibex', 'hartebeest', 'impala (antelope)', 'gazelle',
|
||||
'arabian camel', 'llama', 'weasel', 'mink', 'European polecat',
|
||||
'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo',
|
||||
'three-toed sloth', 'orangutan', 'gorilla', 'chimpanzee', 'gibbon',
|
||||
'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur',
|
||||
'black-and-white colobus', 'proboscis monkey', 'marmoset',
|
||||
'white-headed capuchin', 'howler monkey', 'titi monkey',
|
||||
"Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur',
|
||||
'indri', 'Asian elephant', 'African bush elephant', 'red panda',
|
||||
'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish',
|
||||
'clownfish', 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus',
|
||||
'abaya', 'academic gown', 'accordion', 'acoustic guitar',
|
||||
'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance',
|
||||
'amphibious vehicle', 'analog clock', 'apiary', 'apron', 'trash can',
|
||||
'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon',
|
||||
'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail', 'barbell',
|
||||
'barber chair', 'barbershop', 'barn', 'barometer', 'barrel', 'wheelbarrow',
|
||||
'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap',
|
||||
'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker',
|
||||
'military hat (bearskin or shako)', 'beer bottle', 'beer glass',
|
||||
'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder',
|
||||
'binoculars', 'birdhouse', 'boathouse', 'bobsleigh', 'bolo tie',
|
||||
'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'hunting bow',
|
||||
'bow tie', 'brass memorial plaque', 'bra', 'breakwater', 'breastplate',
|
||||
'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train',
|
||||
'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe',
|
||||
'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit',
|
||||
'cardboard box / carton', 'car wheel', 'automated teller machine',
|
||||
'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', 'cello',
|
||||
'mobile phone', 'chain', 'chain-link fence', 'chain mail', 'chainsaw',
|
||||
'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet',
|
||||
'Christmas stocking', 'church', 'movie theater', 'cleaver',
|
||||
'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker', 'coffee mug',
|
||||
'coffeemaker', 'spiral or coil', 'combination lock', 'computer keyboard',
|
||||
'candy store', 'container ship', 'convertible', 'corkscrew', 'cornet',
|
||||
'cowboy boot', 'cowboy hat', 'cradle', 'construction crane',
|
||||
'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball',
|
||||
'crutch', 'cuirass', 'dam', 'desk', 'desktop computer',
|
||||
'rotary dial telephone', 'diaper', 'digital clock', 'digital watch',
|
||||
'dining table', 'dishcloth', 'dishwasher', 'disc brake', 'dock',
|
||||
'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', 'drumstick',
|
||||
'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar',
|
||||
'electric locomotive', 'entertainment center', 'envelope',
|
||||
'espresso machine', 'face powder', 'feather boa', 'filing cabinet',
|
||||
'fireboat', 'fire truck', 'fire screen', 'flagpole', 'flute',
|
||||
'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen',
|
||||
'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat',
|
||||
'garbage truck', 'gas mask or respirator', 'gas pump', 'goblet', 'go-kart',
|
||||
'golf ball', 'golf cart', 'gondola', 'gong', 'gown', 'grand piano',
|
||||
'greenhouse', 'radiator grille', 'grocery store', 'guillotine',
|
||||
'hair clip', 'hair spray', 'half-track', 'hammer', 'hamper', 'hair dryer',
|
||||
'hand-held computer', 'handkerchief', 'hard disk drive', 'harmonica',
|
||||
'harp', 'combine harvester', 'hatchet', 'holster', 'home theater',
|
||||
'honeycomb', 'hook', 'hoop skirt', 'gymnastic horizontal bar',
|
||||
'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron',
|
||||
'carved pumpkin', 'jeans', 'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw',
|
||||
'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade',
|
||||
'laptop computer', 'lawn mower', 'lens cap', 'letter opener', 'library',
|
||||
'lifeboat', 'lighter', 'limousine', 'ocean liner', 'lipstick',
|
||||
'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass',
|
||||
'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'tights',
|
||||
'one-piece bathing suit', 'manhole cover', 'maraca', 'marimba', 'mask',
|
||||
'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine cabinet',
|
||||
'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can',
|
||||
'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl',
|
||||
'mobile home', 'ford model t', 'modem', 'monastery', 'monitor', 'moped',
|
||||
'mortar and pestle', 'graduation cap', 'mosque', 'mosquito net', 'vespa',
|
||||
'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van',
|
||||
'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier',
|
||||
'notebook computer', 'obelisk', 'oboe', 'ocarina', 'odometer',
|
||||
'oil filter', 'pipe organ', 'oscilloscope', 'overskirt', 'bullock cart',
|
||||
'oxygen mask', 'product packet / packaging', 'paddle', 'paddle wheel',
|
||||
'padlock', 'paintbrush', 'pajamas', 'palace', 'pan flute', 'paper towel',
|
||||
'parachute', 'parallel bars', 'park bench', 'parking meter',
|
||||
'railroad car', 'patio', 'payphone', 'pedestal', 'pencil case',
|
||||
'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum',
|
||||
'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank',
|
||||
'pill bottle', 'pillow', 'ping-pong ball', 'pinwheel', 'pirate ship',
|
||||
'drink pitcher', 'block plane', 'planetarium', 'plastic bag', 'plate rack',
|
||||
'farm plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho',
|
||||
'pool table', 'soda bottle', 'plant pot', "potter's wheel", 'power drill',
|
||||
'prayer rug', 'printer', 'prison', 'missile', 'projector', 'hockey puck',
|
||||
'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket',
|
||||
'radiator', 'radio', 'radio telescope', 'rain barrel',
|
||||
'recreational vehicle', 'fishing casting reel', 'reflex camera',
|
||||
'refrigerator', 'remote control', 'restaurant', 'revolver', 'rifle',
|
||||
'rocking chair', 'rotisserie', 'eraser', 'rugby ball',
|
||||
'ruler measuring stick', 'sneaker', 'safe', 'safety pin', 'salt shaker',
|
||||
'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale',
|
||||
'school bus', 'schooner', 'scoreboard', 'CRT monitor', 'screw',
|
||||
'screwdriver', 'seat belt', 'sewing machine', 'shield', 'shoe store',
|
||||
'shoji screen / room divider', 'shopping basket', 'shopping cart',
|
||||
'shovel', 'shower cap', 'shower curtain', 'ski', 'balaclava ski mask',
|
||||
'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel',
|
||||
'snowmobile', 'snowplow', 'soap dispenser', 'soccer ball', 'sock',
|
||||
'solar thermal collector', 'sombrero', 'soup bowl', 'keyboard space bar',
|
||||
'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web',
|
||||
'spindle', 'sports car', 'spotlight', 'stage', 'steam locomotive',
|
||||
'through arch bridge', 'steel drum', 'stethoscope', 'scarf', 'stone wall',
|
||||
'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa',
|
||||
'submarine', 'suit', 'sundial', 'sunglasses', 'sunglasses', 'sunscreen',
|
||||
'suspension bridge', 'mop', 'sweatshirt', 'swim trunks / shorts', 'swing',
|
||||
'electrical switch', 'syringe', 'table lamp', 'tank', 'tape player',
|
||||
'teapot', 'teddy bear', 'television', 'tennis ball', 'thatched roof',
|
||||
'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof',
|
||||
'toaster', 'tobacco shop', 'toilet seat', 'torch', 'totem pole',
|
||||
'tow truck', 'toy store', 'tractor', 'semi-trailer truck', 'tray',
|
||||
'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch',
|
||||
'trolleybus', 'trombone', 'hot tub', 'turnstile', 'typewriter keyboard',
|
||||
'umbrella', 'unicycle', 'upright piano', 'vacuum cleaner', 'vase',
|
||||
'vaulted or arched ceiling', 'velvet fabric', 'vending machine',
|
||||
'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock',
|
||||
'wallet', 'wardrobe', 'military aircraft', 'sink', 'washing machine',
|
||||
'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle',
|
||||
'hair wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle',
|
||||
'airplane wing', 'wok', 'wooden spoon', 'wool', 'split-rail fence',
|
||||
'shipwreck', 'sailboat', 'yurt', 'website', 'comic book', 'crossword',
|
||||
'traffic or street sign', 'traffic light', 'dust jacket', 'menu', 'plate',
|
||||
'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream', 'popsicle',
|
||||
'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog',
|
||||
'mashed potatoes', 'cabbage', 'broccoli', 'cauliflower', 'zucchini',
|
||||
'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber',
|
||||
'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith apple',
|
||||
'strawberry', 'orange', 'lemon', 'fig', 'pineapple', 'banana', 'jackfruit',
|
||||
'cherimoya (custard apple)', 'pomegranate', 'hay', 'carbonara',
|
||||
'chocolate syrup', 'dough', 'meatloaf', 'pizza', 'pot pie', 'burrito',
|
||||
'red wine', 'espresso', 'tea cup', 'eggnog', 'mountain', 'bubble', 'cliff',
|
||||
'coral reef', 'geyser', 'lakeshore', 'promontory', 'sandbar', 'beach',
|
||||
'valley', 'volcano', 'baseball player', 'bridegroom', 'scuba diver',
|
||||
'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn', 'rose hip',
|
||||
'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra',
|
||||
'stinkhorn mushroom', 'earth star fungus', 'hen of the woods mushroom',
|
||||
'bolete', 'corn cob', 'toilet paper')
|
||||
|
|
|
@ -5,6 +5,7 @@ if WITH_MULTIMODAL:
|
|||
from .blip import * # noqa: F401,F403
|
||||
from .blip2 import * # noqa: F401,F403
|
||||
from .chinese_clip import * # noqa: F401, F403
|
||||
from .clip import * # noqa: F401, F403
|
||||
from .flamingo import * # noqa: F401, F403
|
||||
from .llava import * # noqa: F401, F403
|
||||
from .minigpt4 import * # noqa: F401, F403
|
||||
|
@ -17,5 +18,6 @@ else:
|
|||
register_multimodal_placeholder([
|
||||
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
|
||||
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
|
||||
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter'
|
||||
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP',
|
||||
'CLIPZeroShot'
|
||||
], MODELS)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from ..clip.clip import CLIP, CLIPZeroShot
|
||||
from ..clip.clip_transformer import CLIPProjection, CLIPTransformer
|
||||
|
||||
__all__ = ['CLIP', 'CLIPZeroShot', 'CLIPTransformer', 'CLIPProjection']
|
|
@ -0,0 +1,364 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModel
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.datasets.categories import (CIFAR100_CATEGORIES,
|
||||
IMAGENET_SIMPLE_CATEGORIES)
|
||||
from mmpretrain.registry import MODELS, TOKENIZER
|
||||
from mmpretrain.structures import DataSample
|
||||
from mmpretrain.utils import track_on_main_process
|
||||
from .utils import (OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT,
|
||||
OPENAI_IMAGENET_PROMPT_SUB)
|
||||
|
||||
CIFAR100_CATEGORIES = [' '.join(c.split('_')) for c in CIFAR100_CATEGORIES]
|
||||
PROTOTYPE_MAP = {
|
||||
'imagenet': IMAGENET_SIMPLE_CATEGORIES,
|
||||
'cifar100': CIFAR100_CATEGORIES,
|
||||
}
|
||||
PROMPT_MAP = {
|
||||
'openai_imagenet': OPENAI_IMAGENET_PROMPT,
|
||||
'openai_cifar100': OPENAI_CIFAR100_PROMPT,
|
||||
'vanilla': [lambda c: f'a photo of a {c}'],
|
||||
'openai_imagenet_sub': OPENAI_IMAGENET_PROMPT_SUB
|
||||
}
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function."""
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class CLIP(BaseModel):
|
||||
"""The implementation of `CLIP <https://arxiv.org/abs/2103.00020>`_.
|
||||
|
||||
Args:
|
||||
vision_backbone (dict): Config dict for vision backbone.
|
||||
text_backbone (dict): Config dict for text backbone.
|
||||
tokenizer (dict): Config dict for text tokenizer.
|
||||
proj_dim (int): Projection dimension for similarity computation.
|
||||
text_prototype (str): Text prototype, which can be a key in
|
||||
`PROTOTYPE_MAP` or list of text.
|
||||
text_prompt (str): The prompt for text prototype.
|
||||
Defaults to 'vanilla',which refers to "a photo of {cls}".
|
||||
context_length (int): The context length to use. Defaults to 77.
|
||||
data_preprocessor (Union[dict, nn.Module], optional): The config for
|
||||
preprocessing input data. If None or no specified type, it will use
|
||||
"MultiModalDataPreprocessor" as type.
|
||||
See :class:`MultiModalDataPreprocessor` for more details.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vision_backbone: dict,
|
||||
projection: dict,
|
||||
text_backbone: dict,
|
||||
tokenizer: dict,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
proj_dim: int,
|
||||
context_length: int = 77,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
if data_preprocessor is None:
|
||||
data_preprocessor = {}
|
||||
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
|
||||
data_preprocessor = MODELS.build(data_preprocessor)
|
||||
|
||||
super().__init__(
|
||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||
|
||||
self.context_length = context_length
|
||||
|
||||
# build the vision transformer
|
||||
self.visual = MODELS.build(vision_backbone)
|
||||
|
||||
# build the visual projection
|
||||
self.visual_proj = MODELS.build(projection)
|
||||
|
||||
# build attn_mask for casual-attn
|
||||
text_backbone['attn_mask'] = self.build_attention_mask()
|
||||
|
||||
# build the text transformer
|
||||
self.transformer = MODELS.build(text_backbone)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(
|
||||
torch.empty(transformer_width, proj_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
self.tokenizer = TOKENIZER.build(tokenizer)
|
||||
|
||||
self.tokenizer.vocab = self.tokenizer.get_vocab(
|
||||
) # CLIPTokenizer has no attribute named 'vocab', so manually
|
||||
|
||||
def initialize_parameters(self) -> None:
|
||||
"""Initialize the parameters.
|
||||
|
||||
The pretrained weight will override the initialized parameters by this
|
||||
function.
|
||||
"""
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
proj_std = (self.transformer.width**-0.5) * (
|
||||
(2 * self.transformer.layers)**-0.5)
|
||||
attn_std = self.transformer.width**-0.5
|
||||
fc_std = (2 * self.transformer.width)**-0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(
|
||||
self.text_projection, std=self.transformer.width**-0.5)
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask,
|
||||
# with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float('-inf'))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
data_samples: Optional[list] = None,
|
||||
mode: str = 'predict',
|
||||
**kwargs,
|
||||
):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
The method accepts the following modes:
|
||||
|
||||
- "predict": Forward and return a list of data samples contain the
|
||||
predict results.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): the preprocessed image tensor of shape
|
||||
``(N, C, H, W)``.
|
||||
data_samples (List[DataSample], optional): The annotation data
|
||||
of every samples. Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'predict'.
|
||||
"""
|
||||
if mode == 'predict':
|
||||
return self.predict(images, data_samples, **kwargs)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor:
|
||||
"""The function to extract image latent features."""
|
||||
return self.visual_proj(self.visual(images))[0]
|
||||
|
||||
def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor:
|
||||
"""The function to extract text latent features."""
|
||||
x = self.token_embedding(texts) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)[0]
|
||||
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding
|
||||
# (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]),
|
||||
texts.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
def extract_feat(
|
||||
self, images: torch.Tensor,
|
||||
texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
|
||||
"""The function to extract image and text latent features, the input
|
||||
image or text can not both be None."""
|
||||
|
||||
assert images is not None or texts is not None, \
|
||||
'text and image cannot both be None!'
|
||||
if images is None:
|
||||
return self.extract_text_feat(texts)
|
||||
elif texts is None:
|
||||
return self.extract_image_feat(images)
|
||||
|
||||
image_features = self.extract_image_feat(images)
|
||||
text_features = self.extract_text_feat(texts)
|
||||
|
||||
image_features = image_features / image_features.norm(
|
||||
dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(
|
||||
dim=-1, keepdim=True)
|
||||
|
||||
return image_features, text_features
|
||||
|
||||
def compute_similarity(self, images, texts):
|
||||
"""Extract images and texts features and compute cosine similarity."""
|
||||
image_features, text_features = self.extract_feat(
|
||||
images=images, texts=texts)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
# shape (N, N)
|
||||
return logits_per_image, logits_per_text
|
||||
|
||||
@abstractmethod
|
||||
def predict(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: DataSample = None) -> DataSample:
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor:
|
||||
"""Returns the tokenized representation of given input string(s)
|
||||
|
||||
Args:
|
||||
texts (Union[str, List[str]]): An input string or a list of input
|
||||
strings to tokenize
|
||||
context_length (int): The context length to use. Defaults to 52.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Resulting tokens.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
all_tokens = []
|
||||
for text in texts:
|
||||
# adapt the text to Chinese BERT vocab
|
||||
# text = text.lower().replace('“', "\"").replace('”', "\"")
|
||||
|
||||
# add special tokens
|
||||
all_tokens.append(
|
||||
[self.tokenizer.vocab['<|startoftext|>']
|
||||
] + # <|startoftext|>代表[CLS] token
|
||||
self.tokenizer.convert_tokens_to_ids(
|
||||
self.tokenizer.tokenize(text))[:self.context_length - 2] +
|
||||
[self.tokenizer.vocab['<|endoftext|>']])
|
||||
|
||||
result = torch.zeros(
|
||||
len(all_tokens), self.context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
assert len(tokens) <= self.context_length
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CLIPZeroShot(CLIP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone: dict,
|
||||
projection: dict,
|
||||
text_backbone: dict,
|
||||
tokenizer: dict,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
proj_dim: int,
|
||||
context_length: int = 77,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None,
|
||||
text_prototype: Union[str, List[str]] = 'imagenet',
|
||||
text_prompt: str = 'vanilla',
|
||||
):
|
||||
super(CLIPZeroShot,
|
||||
self).__init__(vision_backbone, projection, text_backbone,
|
||||
tokenizer, vocab_size, transformer_width,
|
||||
proj_dim, context_length, data_preprocessor,
|
||||
init_cfg)
|
||||
|
||||
# for zero-shot classification
|
||||
if isinstance(text_prototype,
|
||||
str) and text_prototype in PROTOTYPE_MAP.keys():
|
||||
self.prototype = PROTOTYPE_MAP[text_prototype]
|
||||
else:
|
||||
self.prototype = text_prototype
|
||||
self.text_prototype_embeds = None
|
||||
|
||||
self.prompt = PROMPT_MAP[text_prompt]
|
||||
|
||||
def predict(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: DataSample = None) -> DataSample:
|
||||
"""Predict the classes of the input images.
|
||||
|
||||
The prediction is for zero-shot classification and the text prototypes
|
||||
will be prepared in thisfunction.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): The input images.
|
||||
data_samples (DataSample): The data samples with information from
|
||||
dataset.
|
||||
|
||||
Returns:
|
||||
DataSample: The results of prediction.
|
||||
"""
|
||||
|
||||
if self.text_prototype_embeds is None:
|
||||
self.prepare_text_prototype(device=images.device)
|
||||
|
||||
image_features = self.extract_image_feat(images=images)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logits_per_image = image_features @ self.text_prototype_embeds.to(
|
||||
image_features.device) * self.logit_scale.exp()
|
||||
|
||||
pred_scores = F.softmax(logits_per_image, dim=1)
|
||||
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
|
||||
|
||||
out_data_samples = []
|
||||
if data_samples is None:
|
||||
data_samples = [None for _ in range(pred_scores.size(0))]
|
||||
|
||||
for data_sample, score, label in zip(data_samples, pred_scores,
|
||||
pred_labels):
|
||||
if data_sample is None:
|
||||
data_sample = DataSample()
|
||||
|
||||
data_sample.set_pred_score(score).set_pred_label(label)
|
||||
out_data_samples.append(data_sample)
|
||||
return out_data_samples
|
||||
|
||||
def prepare_text_prototype(self, device) -> None:
|
||||
"""The function to prepare text prototypes with prompt."""
|
||||
class_embeddings = []
|
||||
for classname in track_on_main_process(self.prototype,
|
||||
'Prepare text prototype...'):
|
||||
# format with class
|
||||
texts = [prompt(classname) for prompt in self.prompt]
|
||||
tokenized_texts = self.tokenize(texts)
|
||||
class_features = self.extract_text_feat(tokenized_texts.to(device))
|
||||
class_features /= class_features.norm(dim=-1, keepdim=True)
|
||||
class_feature = class_features.mean(dim=0)
|
||||
class_feature /= class_feature.norm()
|
||||
class_embeddings.append(class_feature)
|
||||
self.text_prototype_embeds = torch.stack(
|
||||
class_embeddings, dim=1).to(device)
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/zejiangh/MILAN
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.models.utils.clip_generator_helper import \
|
||||
ResidualAttentionBlock
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CLIPTransformer(nn.Module):
|
||||
"""Transformer.
|
||||
|
||||
Both visual and text branches use this transformer.
|
||||
|
||||
Args:
|
||||
width (int): The feature dimension.
|
||||
layers (int): The number of layers.
|
||||
heads (int): The number of attention heads.
|
||||
attn_mask (torch.Tensor, optional): The attention mask.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
attn_mask: Optional[torch.Tensor] = None) -> None:
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList()
|
||||
for _ in range(layers - 1):
|
||||
self.resblocks.append(
|
||||
ResidualAttentionBlock(width, heads, attn_mask))
|
||||
self.resblocks.append(
|
||||
ResidualAttentionBlock(
|
||||
width, heads, attn_mask, return_attention=True))
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward function."""
|
||||
z = []
|
||||
for idx, blk in enumerate(self.resblocks):
|
||||
if idx < self.layers - 1:
|
||||
x = blk(x)
|
||||
z.append(x.permute(1, 0, 2))
|
||||
else:
|
||||
x, attention = blk(x)
|
||||
z.append(x.permute(1, 0, 2))
|
||||
return x, attention, z
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CLIPProjection(BaseModule):
|
||||
"""Neck with CLIP Projection.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input.
|
||||
out_channels (int): Number of channels in the output.
|
||||
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super(CLIPProjection, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
scale = in_channels**-0.5
|
||||
self.proj = nn.Parameter(scale *
|
||||
torch.randn(in_channels, out_channels))
|
||||
|
||||
def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]:
|
||||
"""forward function.
|
||||
|
||||
Args:
|
||||
inputs (Tuple): The features extracted from
|
||||
the backbone. Multiple stage inputs are acceptable but only
|
||||
the last stage will be used.
|
||||
Returns:
|
||||
Tuple(torch.Tensor)): A tuple of reducted features.
|
||||
"""
|
||||
if isinstance(inputs, tuple):
|
||||
inputs = inputs[-1]
|
||||
out = inputs @ self.proj
|
||||
elif isinstance(inputs, torch.Tensor):
|
||||
out = inputs @ self.proj
|
||||
else:
|
||||
raise TypeError(
|
||||
'`CLIPProjection` neck inputs should be tuple or torch.tensor')
|
||||
return (out, )
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
OPENAI_CIFAR100_PROMPT = [
|
||||
lambda c: f'a photo of a {c}.',
|
||||
lambda c: f'a blurry photo of a {c}.',
|
||||
lambda c: f'a black and white photo of a {c}.',
|
||||
lambda c: f'a low contrast photo of a {c}.',
|
||||
lambda c: f'a high contrast photo of a {c}.',
|
||||
lambda c: f'a bad photo of a {c}.',
|
||||
lambda c: f'a good photo of a {c}.',
|
||||
lambda c: f'a photo of a small {c}.',
|
||||
lambda c: f'a photo of a big {c}.',
|
||||
lambda c: f'a photo of the {c}.',
|
||||
lambda c: f'a blurry photo of the {c}.',
|
||||
lambda c: f'a black and white photo of the {c}.',
|
||||
lambda c: f'a low contrast photo of the {c}.',
|
||||
lambda c: f'a high contrast photo of the {c}.',
|
||||
lambda c: f'a bad photo of the {c}.',
|
||||
lambda c: f'a good photo of the {c}.',
|
||||
lambda c: f'a photo of the small {c}.',
|
||||
lambda c: f'a photo of the big {c}.',
|
||||
]
|
||||
|
||||
OPENAI_IMAGENET_PROMPT_SUB = [
|
||||
lambda c: f'itap of a {c}.',
|
||||
lambda c: f'a bad photo of the {c}.',
|
||||
lambda c: f'a origami {c}.',
|
||||
lambda c: f'a photo of the large {c}.',
|
||||
lambda c: f'a {c} in a video game.',
|
||||
lambda c: f'art of the {c}.',
|
||||
lambda c: f'a photo of the small {c}.',
|
||||
]
|
||||
|
||||
OPENAI_IMAGENET_PROMPT = [
|
||||
lambda c: f'a bad photo of a {c}.',
|
||||
lambda c: f'a photo of many {c}.',
|
||||
lambda c: f'a sculpture of a {c}.',
|
||||
lambda c: f'a photo of the hard to see {c}.',
|
||||
lambda c: f'a low resolution photo of the {c}.',
|
||||
lambda c: f'a rendering of a {c}.',
|
||||
lambda c: f'graffiti of a {c}.',
|
||||
lambda c: f'a bad photo of the {c}.',
|
||||
lambda c: f'a cropped photo of the {c}.',
|
||||
lambda c: f'a tattoo of a {c}.',
|
||||
lambda c: f'the embroidered {c}.',
|
||||
lambda c: f'a photo of a hard to see {c}.',
|
||||
lambda c: f'a bright photo of a {c}.',
|
||||
lambda c: f'a photo of a clean {c}.',
|
||||
lambda c: f'a photo of a dirty {c}.',
|
||||
lambda c: f'a dark photo of the {c}.',
|
||||
lambda c: f'a drawing of a {c}.',
|
||||
lambda c: f'a photo of my {c}.',
|
||||
lambda c: f'the plastic {c}.',
|
||||
lambda c: f'a photo of the cool {c}.',
|
||||
lambda c: f'a close-up photo of a {c}.',
|
||||
lambda c: f'a black and white photo of the {c}.',
|
||||
lambda c: f'a painting of the {c}.',
|
||||
lambda c: f'a painting of a {c}.',
|
||||
lambda c: f'a pixelated photo of the {c}.',
|
||||
lambda c: f'a sculpture of the {c}.',
|
||||
lambda c: f'a bright photo of the {c}.',
|
||||
lambda c: f'a cropped photo of a {c}.',
|
||||
lambda c: f'a plastic {c}.',
|
||||
lambda c: f'a photo of the dirty {c}.',
|
||||
lambda c: f'a jpeg corrupted photo of a {c}.',
|
||||
lambda c: f'a blurry photo of the {c}.',
|
||||
lambda c: f'a photo of the {c}.',
|
||||
lambda c: f'a good photo of the {c}.',
|
||||
lambda c: f'a rendering of the {c}.',
|
||||
lambda c: f'a {c} in a video game.',
|
||||
lambda c: f'a photo of one {c}.',
|
||||
lambda c: f'a doodle of a {c}.',
|
||||
lambda c: f'a close-up photo of the {c}.',
|
||||
lambda c: f'a photo of a {c}.',
|
||||
lambda c: f'the origami {c}.',
|
||||
lambda c: f'the {c} in a video game.',
|
||||
lambda c: f'a sketch of a {c}.',
|
||||
lambda c: f'a doodle of the {c}.',
|
||||
lambda c: f'a origami {c}.',
|
||||
lambda c: f'a low resolution photo of a {c}.',
|
||||
lambda c: f'the toy {c}.',
|
||||
lambda c: f'a rendition of the {c}.',
|
||||
lambda c: f'a photo of the clean {c}.',
|
||||
lambda c: f'a photo of a large {c}.',
|
||||
lambda c: f'a rendition of a {c}.',
|
||||
lambda c: f'a photo of a nice {c}.',
|
||||
lambda c: f'a photo of a weird {c}.',
|
||||
lambda c: f'a blurry photo of a {c}.',
|
||||
lambda c: f'a cartoon {c}.',
|
||||
lambda c: f'art of a {c}.',
|
||||
lambda c: f'a sketch of the {c}.',
|
||||
lambda c: f'a embroidered {c}.',
|
||||
lambda c: f'a pixelated photo of a {c}.',
|
||||
lambda c: f'itap of the {c}.',
|
||||
lambda c: f'a jpeg corrupted photo of the {c}.',
|
||||
lambda c: f'a good photo of a {c}.',
|
||||
lambda c: f'a plushie {c}.',
|
||||
lambda c: f'a photo of the nice {c}.',
|
||||
lambda c: f'a photo of the small {c}.',
|
||||
lambda c: f'a photo of the weird {c}.',
|
||||
lambda c: f'the cartoon {c}.',
|
||||
lambda c: f'art of the {c}.',
|
||||
lambda c: f'a drawing of the {c}.',
|
||||
lambda c: f'a photo of the large {c}.',
|
||||
lambda c: f'a black and white photo of a {c}.',
|
||||
lambda c: f'the plushie {c}.',
|
||||
lambda c: f'a dark photo of a {c}.',
|
||||
lambda c: f'itap of a {c}.',
|
||||
lambda c: f'graffiti of the {c}.',
|
||||
lambda c: f'a toy {c}.',
|
||||
lambda c: f'itap of my {c}.',
|
||||
lambda c: f'a photo of a cool {c}.',
|
||||
lambda c: f'a photo of a small {c}.',
|
||||
lambda c: f'a tattoo of the {c}.',
|
||||
]
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_clip(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in list(ckpt.items()):
|
||||
new_v = v
|
||||
if k.startswith('visual.conv1'):
|
||||
new_k = k.replace('conv1', 'patch_embed.projection')
|
||||
elif k.startswith('visual.positional_embedding'):
|
||||
new_k = k.replace('positional_embedding', 'pos_embed')
|
||||
new_v = v.unsqueeze(dim=0)
|
||||
elif k.startswith('visual.class_embedding'):
|
||||
new_k = k.replace('class_embedding', 'cls_token')
|
||||
new_v = v.unsqueeze(dim=0).unsqueeze(dim=0)
|
||||
elif k.startswith('visual.ln_pre'):
|
||||
new_k = k.replace('ln_pre', 'pre_norm')
|
||||
elif k.startswith('visual.transformer.resblocks'):
|
||||
new_k = k.replace('transformer.resblocks', 'layers')
|
||||
if 'ln_1' in k:
|
||||
new_k = new_k.replace('ln_1', 'ln1')
|
||||
elif 'ln_2' in k:
|
||||
new_k = new_k.replace('ln_2', 'ln2')
|
||||
elif 'mlp.c_fc' in k:
|
||||
new_k = new_k.replace('mlp.c_fc', 'ffn.layers.0.0')
|
||||
elif 'mlp.c_proj' in k:
|
||||
new_k = new_k.replace('mlp.c_proj', 'ffn.layers.1')
|
||||
elif 'attn.in_proj_weight' in k:
|
||||
new_k = new_k.replace('in_proj_weight', 'qkv.weight')
|
||||
elif 'attn.in_proj_bias' in k:
|
||||
new_k = new_k.replace('in_proj_bias', 'qkv.bias')
|
||||
elif 'attn.out_proj' in k:
|
||||
new_k = new_k.replace('out_proj', 'proj')
|
||||
elif k.startswith('visual.ln_post'):
|
||||
new_k = k.replace('ln_post', 'ln1')
|
||||
elif k.startswith('visual.proj'):
|
||||
new_k = k.replace('visual.proj', 'visual_proj.proj')
|
||||
else:
|
||||
new_k = k
|
||||
|
||||
new_ckpt[new_k] = new_v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in pretrained clip '
|
||||
'models to mmpretrain style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
weight = convert_clip(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
print('Done!!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue