Merge remote-tracking branch 'origin/main' into dev
commit
d35c778a6f
|
@ -0,0 +1,68 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='CIFAR100',
|
||||
data_root='data/cifar100',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='CLIPZeroShot',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
),
|
||||
projection=dict(type='CLIPProjection', in_channels=768, out_channels=512),
|
||||
text_backbone=dict(
|
||||
type='CLIPTransformer',
|
||||
width=512,
|
||||
layers=12,
|
||||
heads=8,
|
||||
attn_mask=True,
|
||||
),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='openai/clip-vit-base-patch16',
|
||||
use_fast=False),
|
||||
vocab_size=49408,
|
||||
transformer_width=512,
|
||||
proj_dim=512,
|
||||
text_prototype='cifar100',
|
||||
text_prompt='openai_cifar100',
|
||||
context_length=77,
|
||||
)
|
|
@ -0,0 +1,69 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='ImageNet',
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='CLIPZeroShot',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_rate=0.,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
),
|
||||
projection=dict(type='CLIPProjection', in_channels=768, out_channels=512),
|
||||
text_backbone=dict(
|
||||
type='CLIPTransformer',
|
||||
width=512,
|
||||
layers=12,
|
||||
heads=8,
|
||||
attn_mask=True,
|
||||
),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='openai/clip-vit-base-patch16',
|
||||
use_fast=False),
|
||||
vocab_size=49408,
|
||||
transformer_width=512,
|
||||
proj_dim=512,
|
||||
text_prototype='imagenet',
|
||||
text_prompt='openai_imagenet_sub', # openai_imagenet, openai_imagenet_sub
|
||||
context_length=77,
|
||||
)
|
|
@ -0,0 +1,68 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='CIFAR100',
|
||||
data_root='data/cifar100',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='CLIPZeroShot',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='large',
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
drop_rate=0.,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
),
|
||||
projection=dict(type='CLIPProjection', in_channels=1024, out_channels=768),
|
||||
text_backbone=dict(
|
||||
type='CLIPTransformer',
|
||||
width=768,
|
||||
layers=12,
|
||||
heads=12,
|
||||
attn_mask=True,
|
||||
),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='openai/clip-vit-large-patch14',
|
||||
use_fast=False),
|
||||
vocab_size=49408,
|
||||
transformer_width=768,
|
||||
proj_dim=768,
|
||||
text_prototype='cifar100',
|
||||
text_prompt='openai_cifar100',
|
||||
context_length=77,
|
||||
)
|
|
@ -0,0 +1,69 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(224, 224), interpolation='bicubic'),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['text'],
|
||||
meta_keys=['image_id', 'scale_factor'],
|
||||
),
|
||||
]
|
||||
|
||||
train_dataloader = None
|
||||
test_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=8,
|
||||
dataset=dict(
|
||||
type='ImageNet',
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
test_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# schedule settings
|
||||
train_cfg = None
|
||||
val_cfg = None
|
||||
test_cfg = dict()
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='CLIPZeroShot',
|
||||
vision_backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='large',
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
drop_rate=0.,
|
||||
layer_cfgs=dict(act_cfg=dict(type='QuickGELU')),
|
||||
pre_norm=True,
|
||||
),
|
||||
projection=dict(type='CLIPProjection', in_channels=1024, out_channels=768),
|
||||
text_backbone=dict(
|
||||
type='CLIPTransformer',
|
||||
width=768,
|
||||
layers=12,
|
||||
heads=12,
|
||||
attn_mask=True,
|
||||
),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='openai/clip-vit-large-patch14',
|
||||
use_fast=False),
|
||||
vocab_size=49408,
|
||||
transformer_width=768,
|
||||
proj_dim=768,
|
||||
text_prototype='imagenet',
|
||||
text_prompt='openai_imagenet_sub', # openai_imagenet, openai_imagenet_sub
|
||||
context_length=77,
|
||||
)
|
|
@ -108,6 +108,7 @@ class ImageRetrievalInferencer(BaseInferencer):
|
|||
# A config of dataset
|
||||
from mmpretrain.registry import DATASETS
|
||||
test_pipeline = [dict(type='LoadImageFromFile'), self.pipeline]
|
||||
prototype.setdefault('pipeline', test_pipeline)
|
||||
dataset = DATASETS.build(prototype)
|
||||
dataloader = build_dataloader(dataset)
|
||||
elif isinstance(prototype, DataLoader):
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.dataset import DefaultSampler
|
||||
|
||||
from mmpretrain.datasets import CIFAR10, PackInputs, RandomCrop, RandomFlip
|
||||
from mmpretrain.evaluation import Accuracy
|
||||
|
||||
# dataset settings
|
||||
dataset_type = CIFAR10
|
||||
data_preprocessor = dict(
|
||||
num_classes=10,
|
||||
# RGB format normalization parameters
|
||||
mean=[125.307, 122.961, 113.8575],
|
||||
std=[51.5865, 50.847, 51.255],
|
||||
# loaded images are already RGB format
|
||||
to_rgb=False)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=RandomCrop, crop_size=32, padding=4),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/cifar10',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/cifar10/',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type=Accuracy, topk=(1, ))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.dataset import DefaultSampler
|
||||
|
||||
from mmpretrain.datasets import (AutoAugment, CenterCrop, ImageNet,
|
||||
LoadImageFromFile, PackInputs, RandomErasing,
|
||||
RandomFlip, RandomResizedCrop, ResizeEdge)
|
||||
from mmpretrain.evaluation import Accuracy
|
||||
|
||||
# dataset settings
|
||||
dataset_type = ImageNet
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=RandomResizedCrop, scale=224, backend='pillow'),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type=AutoAugment,
|
||||
policies='imagenet',
|
||||
hparams=dict(pad_val=[round(x) for x in bgr_mean])),
|
||||
dict(
|
||||
type=RandomErasing,
|
||||
erase_prob=0.2,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'),
|
||||
dict(type=CenterCrop, crop_size=224),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type=Accuracy, topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.dataset import DefaultSampler
|
||||
|
||||
from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile,
|
||||
PackInputs, RandomFlip, RandomResizedCrop,
|
||||
ResizeEdge)
|
||||
from mmpretrain.evaluation import Accuracy
|
||||
|
||||
# dataset settings
|
||||
dataset_type = ImageNet
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=RandomResizedCrop, scale=224, backend='pillow'),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'),
|
||||
dict(type=CenterCrop, crop_size=224),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type=Accuracy, topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
|
||||
ImageClassifier, LinearClsHead, MobileNetV2)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type=ImageClassifier,
|
||||
backbone=dict(type=MobileNetV2, widen_factor=1.0),
|
||||
neck=dict(type=GlobalAveragePooling),
|
||||
head=dict(
|
||||
type=LinearClsHead,
|
||||
num_classes=1000,
|
||||
in_channels=1280,
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.model.weight_init import NormalInit
|
||||
from torch.nn.modules.activation import Hardswish
|
||||
|
||||
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
|
||||
ImageClassifier, MobileNetV3,
|
||||
StackedLinearClsHead)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type=ImageClassifier,
|
||||
backbone=dict(type=MobileNetV3, arch='small'),
|
||||
neck=dict(type=GlobalAveragePooling),
|
||||
head=dict(
|
||||
type=StackedLinearClsHead,
|
||||
num_classes=1000,
|
||||
in_channels=576,
|
||||
mid_channels=[1024],
|
||||
dropout_rate=0.2,
|
||||
act_cfg=dict(type=Hardswish),
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
init_cfg=dict(
|
||||
type=NormalInit, layer='Linear', mean=0., std=0.01, bias=0.),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.optim import MultiStepLR
|
||||
from torch.optim import SGD
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001))
|
||||
# learning policy
|
||||
param_scheduler = dict(
|
||||
type=MultiStepLR, by_epoch=True, milestones=[100, 150], gamma=0.1)
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=128)
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.optim import StepLR
|
||||
from torch.optim import SGD
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(type=SGD, lr=0.045, momentum=0.9, weight_decay=0.00004))
|
||||
|
||||
# learning policy
|
||||
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=1, gamma=0.98)
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR,
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=256)
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs32_pil_resize import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.mobilenet_v2_1x import *
|
||||
from .._base_.schedules.imagenet_bs256_epochstep import *
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
|
||||
# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.models.mobilenet_v3_small import *
|
||||
from .._base_.datasets.imagenet_bs128_mbv3 import *
|
||||
from .._base_.default_runtime import *
|
||||
|
||||
from mmengine.optim import StepLR
|
||||
from torch.optim import RMSprop
|
||||
|
||||
# model settings
|
||||
model.merge(
|
||||
dict(
|
||||
backbone=dict(arch='large'),
|
||||
head=dict(in_channels=960, mid_channels=[1280]),
|
||||
))
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type=RMSprop,
|
||||
lr=0.064,
|
||||
alpha=0.9,
|
||||
momentum=0.9,
|
||||
eps=0.0316,
|
||||
weight_decay=1e-5))
|
||||
|
||||
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973)
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
# base_batch_size = (8 GPUs) x (128 samples per GPU)
|
||||
auto_scale_lr = dict(base_batch_size=1024)
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
|
||||
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.models.mobilenet_v3_small import *
|
||||
from .._base_.datasets.imagenet_bs128_mbv3 import *
|
||||
from .._base_.default_runtime import *
|
||||
|
||||
from mmengine.optim import StepLR
|
||||
from torch.nn.modules.batchnorm import BatchNorm2d
|
||||
from torch.optim import RMSprop
|
||||
|
||||
# model settings
|
||||
model.merge(
|
||||
dict(
|
||||
backbone=dict(
|
||||
arch='small_050',
|
||||
norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)),
|
||||
head=dict(in_channels=288),
|
||||
))
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(
|
||||
type=RandomResizedCrop,
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type=AutoAugment,
|
||||
policies='imagenet',
|
||||
hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])),
|
||||
dict(
|
||||
type=RandomErasing,
|
||||
erase_prob=0.2,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=[103.53, 116.28, 123.675],
|
||||
fill_std=[57.375, 57.12, 58.395]),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(
|
||||
type=ResizeEdge,
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type=CenterCrop, crop_size=224),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline)))
|
||||
|
||||
val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline)))
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type=RMSprop,
|
||||
lr=0.064,
|
||||
alpha=0.9,
|
||||
momentum=0.9,
|
||||
eps=0.0316,
|
||||
weight_decay=1e-5))
|
||||
|
||||
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973)
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
# base_batch_size = (8 GPUs) x (128 samples per GPU)
|
||||
auto_scale_lr = dict(base_batch_size=1024)
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
|
||||
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.models.mobilenet_v3_small import *
|
||||
from .._base_.datasets.imagenet_bs128_mbv3 import *
|
||||
from .._base_.default_runtime import *
|
||||
|
||||
from mmengine.optim import StepLR
|
||||
from torch.nn.modules.batchnorm import BatchNorm2d
|
||||
from torch.optim import RMSprop
|
||||
|
||||
# model settings
|
||||
model.merge(
|
||||
dict(
|
||||
backbone=dict(
|
||||
arch='small_075',
|
||||
norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)),
|
||||
head=dict(in_channels=432),
|
||||
))
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(
|
||||
type=RandomResizedCrop,
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type=AutoAugment,
|
||||
policies='imagenet',
|
||||
hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])),
|
||||
dict(
|
||||
type=RandomErasing,
|
||||
erase_prob=0.2,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=[103.53, 116.28, 123.675],
|
||||
fill_std=[57.375, 57.12, 58.395]),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(
|
||||
type=ResizeEdge,
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type=CenterCrop, crop_size=224),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline)))
|
||||
val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline)))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type=RMSprop,
|
||||
lr=0.064,
|
||||
alpha=0.9,
|
||||
momentum=0.9,
|
||||
eps=0.0316,
|
||||
weight_decay=1e-5))
|
||||
|
||||
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973)
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
# base_batch_size = (8 GPUs) x (128 samples per GPU)
|
||||
auto_scale_lr = dict(base_batch_size=1024)
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
|
||||
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.models.mobilenet_v3_small import *
|
||||
from .._base_.datasets.imagenet_bs128_mbv3 import *
|
||||
from .._base_.default_runtime import *
|
||||
|
||||
from mmengine.optim import StepLR
|
||||
from torch.optim import RMSprop
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type=RMSprop,
|
||||
lr=0.064,
|
||||
alpha=0.9,
|
||||
momentum=0.9,
|
||||
eps=0.0316,
|
||||
weight_decay=1e-5))
|
||||
|
||||
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973)
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
# base_batch_size = (8 GPUs) x (128 samples per GPU)
|
||||
auto_scale_lr = dict(base_batch_size=1024)
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.models.mobilenet_v3_small import *
|
||||
from .._base_.datasets.cifar10_bs16 import *
|
||||
from .._base_.schedules.cifar10_bs128 import *
|
||||
from .._base_.default_runtime import *
|
||||
|
||||
from mmengine.optim import MultiStepLR
|
||||
|
||||
# model settings
|
||||
model.merge(
|
||||
dict(
|
||||
head=dict(
|
||||
_delete_=True,
|
||||
type=StackedLinearClsHead,
|
||||
num_classes=10,
|
||||
in_channels=576,
|
||||
mid_channels=[1280],
|
||||
act_cfg=dict(type=Hardswish),
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5))))
|
||||
# schedule settings
|
||||
param_scheduler.merge(
|
||||
dict(
|
||||
type=MultiStepLR,
|
||||
by_epoch=True,
|
||||
milestones=[120, 170],
|
||||
gamma=0.1,
|
||||
))
|
||||
|
||||
train_cfg.merge(dict(by_epoch=True, max_epochs=200))
|
|
@ -1438,3 +1438,224 @@ CIFAR100_CATEGORIES_CN = (
|
|||
'海豹', '鲨鱼', '尖嘴小鼠', '臭鼬', '摩天大楼', '蜗牛', '蛇', '蜘蛛', '松鼠', '电车', '向日葵', '甜椒',
|
||||
'桌子', '坦克', '电话', '电视', '老虎', '拖拉机', '火车', '鳟鱼', '郁金香', '乌龟', '衣柜', '鲸鱼',
|
||||
'柳树', '狼', '女人', '蠕虫')
|
||||
|
||||
IMAGENET_SIMPLE_CATEGORIES = (
|
||||
'tench', 'goldfish', 'great white shark', 'tiger shark',
|
||||
'hammerhead shark', 'electric ray', 'stingray', 'rooster', 'hen',
|
||||
'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco',
|
||||
'indigo bunting', 'American robin', 'bulbul', 'jay', 'magpie', 'chickadee',
|
||||
'American dipper', 'kite (bird of prey)', 'bald eagle', 'vulture',
|
||||
'great grey owl', 'fire salamander', 'smooth newt', 'newt',
|
||||
'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog',
|
||||
'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle',
|
||||
'mud turtle', 'terrapin', 'box turtle', 'banded gecko', 'green iguana',
|
||||
'Carolina anole', 'desert grassland whiptail lizard', 'agama',
|
||||
'frilled-necked lizard', 'alligator lizard', 'Gila monster',
|
||||
'European green lizard', 'chameleon', 'Komodo dragon', 'Nile crocodile',
|
||||
'American alligator', 'triceratops', 'worm snake', 'ring-necked snake',
|
||||
'eastern hog-nosed snake', 'smooth green snake', 'kingsnake',
|
||||
'garter snake', 'water snake', 'vine snake', 'night snake',
|
||||
'boa constrictor', 'African rock python', 'Indian cobra', 'green mamba',
|
||||
'sea snake', 'Saharan horned viper', 'eastern diamondback rattlesnake',
|
||||
'sidewinder rattlesnake', 'trilobite', 'harvestman', 'scorpion',
|
||||
'yellow garden spider', 'barn spider', 'European garden spider',
|
||||
'southern black widow', 'tarantula', 'wolf spider', 'tick', 'centipede',
|
||||
'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peafowl',
|
||||
'quail', 'partridge', 'african grey parrot', 'macaw',
|
||||
'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill',
|
||||
'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser',
|
||||
'goose', 'black swan', 'tusker', 'echidna', 'platypus', 'wallaby', 'koala',
|
||||
'wombat', 'jellyfish', 'sea anemone', 'brain coral', 'flatworm',
|
||||
'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton',
|
||||
'chambered nautilus', 'Dungeness crab', 'rock crab', 'fiddler crab',
|
||||
'red king crab', 'American lobster', 'spiny lobster', 'crayfish',
|
||||
'hermit crab', 'isopod', 'white stork', 'black stork', 'spoonbill',
|
||||
'flamingo', 'little blue heron', 'great egret', 'bittern bird',
|
||||
'crane bird', 'limpkin', 'common gallinule', 'American coot', 'bustard',
|
||||
'ruddy turnstone', 'dunlin', 'common redshank', 'dowitcher',
|
||||
'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale',
|
||||
'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin',
|
||||
'Maltese', 'Pekingese', 'Shih Tzu', 'King Charles Spaniel', 'Papillon',
|
||||
'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound',
|
||||
'Beagle', 'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound',
|
||||
'Treeing Walker Coonhound', 'English foxhound', 'Redbone Coonhound',
|
||||
'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet',
|
||||
'Ibizan Hound', 'Norwegian Elkhound', 'Otterhound', 'Saluki',
|
||||
'Scottish Deerhound', 'Weimaraner', 'Staffordshire Bull Terrier',
|
||||
'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier',
|
||||
'Kerry Blue Terrier', 'Irish Terrier', 'Norfolk Terrier',
|
||||
'Norwich Terrier', 'Yorkshire Terrier', 'Wire Fox Terrier',
|
||||
'Lakeland Terrier', 'Sealyham Terrier', 'Airedale Terrier',
|
||||
'Cairn Terrier', 'Australian Terrier', 'Dandie Dinmont Terrier',
|
||||
'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer',
|
||||
'Standard Schnauzer', 'Scottish Terrier', 'Tibetan Terrier',
|
||||
'Australian Silky Terrier', 'Soft-coated Wheaten Terrier',
|
||||
'West Highland White Terrier', 'Lhasa Apso', 'Flat-Coated Retriever',
|
||||
'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever',
|
||||
'Chesapeake Bay Retriever', 'German Shorthaired Pointer', 'Vizsla',
|
||||
'English Setter', 'Irish Setter', 'Gordon Setter', 'Brittany dog',
|
||||
'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel',
|
||||
'Cocker Spaniel', 'Sussex Spaniel', 'Irish Water Spaniel', 'Kuvasz',
|
||||
'Schipperke', 'Groenendael dog', 'Malinois', 'Briard', 'Australian Kelpie',
|
||||
'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie',
|
||||
'Border Collie', 'Bouvier des Flandres dog', 'Rottweiler',
|
||||
'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher',
|
||||
'Greater Swiss Mountain Dog', 'Bernese Mountain Dog',
|
||||
'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer', 'Bullmastiff',
|
||||
'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky',
|
||||
'Alaskan Malamute', 'Siberian Husky', 'Dalmatian', 'Affenpinscher',
|
||||
'Basenji', 'pug', 'Leonberger', 'Newfoundland dog', 'Great Pyrenees dog',
|
||||
'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'brussels griffon',
|
||||
'Pembroke Welsh Corgi', 'Cardigan Welsh Corgi', 'Toy Poodle',
|
||||
'Miniature Poodle', 'Standard Poodle',
|
||||
'Mexican hairless dog (xoloitzcuintli)', 'grey wolf',
|
||||
'Alaskan tundra wolf', 'red wolf or maned wolf', 'coyote', 'dingo',
|
||||
'dhole', 'African wild dog', 'hyena', 'red fox', 'kit fox', 'Arctic fox',
|
||||
'grey fox', 'tabby cat', 'tiger cat', 'Persian cat', 'Siamese cat',
|
||||
'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar',
|
||||
'lion', 'tiger', 'cheetah', 'brown bear', 'American black bear',
|
||||
'polar bear', 'sloth bear', 'mongoose', 'meerkat', 'tiger beetle',
|
||||
'ladybug', 'ground beetle', 'longhorn beetle', 'leaf beetle',
|
||||
'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', 'ant',
|
||||
'grasshopper', 'cricket insect', 'stick insect', 'cockroach',
|
||||
'praying mantis', 'cicada', 'leafhopper', 'lacewing', 'dragonfly',
|
||||
'damselfly', 'red admiral butterfly', 'ringlet butterfly',
|
||||
'monarch butterfly', 'small white butterfly', 'sulphur butterfly',
|
||||
'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber',
|
||||
'cottontail rabbit', 'hare', 'Angora rabbit', 'hamster', 'porcupine',
|
||||
'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel horse',
|
||||
'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox',
|
||||
'water buffalo', 'bison', 'ram (adult male sheep)', 'bighorn sheep',
|
||||
'Alpine ibex', 'hartebeest', 'impala (antelope)', 'gazelle',
|
||||
'arabian camel', 'llama', 'weasel', 'mink', 'European polecat',
|
||||
'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo',
|
||||
'three-toed sloth', 'orangutan', 'gorilla', 'chimpanzee', 'gibbon',
|
||||
'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur',
|
||||
'black-and-white colobus', 'proboscis monkey', 'marmoset',
|
||||
'white-headed capuchin', 'howler monkey', 'titi monkey',
|
||||
"Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur',
|
||||
'indri', 'Asian elephant', 'African bush elephant', 'red panda',
|
||||
'giant panda', 'snoek fish', 'eel', 'silver salmon', 'rock beauty fish',
|
||||
'clownfish', 'sturgeon', 'gar fish', 'lionfish', 'pufferfish', 'abacus',
|
||||
'abaya', 'academic gown', 'accordion', 'acoustic guitar',
|
||||
'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance',
|
||||
'amphibious vehicle', 'analog clock', 'apiary', 'apron', 'trash can',
|
||||
'assault rifle', 'backpack', 'bakery', 'balance beam', 'balloon',
|
||||
'ballpoint pen', 'Band-Aid', 'banjo', 'baluster / handrail', 'barbell',
|
||||
'barber chair', 'barbershop', 'barn', 'barometer', 'barrel', 'wheelbarrow',
|
||||
'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap',
|
||||
'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker',
|
||||
'military hat (bearskin or shako)', 'beer bottle', 'beer glass',
|
||||
'bell tower', 'baby bib', 'tandem bicycle', 'bikini', 'ring binder',
|
||||
'binoculars', 'birdhouse', 'boathouse', 'bobsleigh', 'bolo tie',
|
||||
'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'hunting bow',
|
||||
'bow tie', 'brass memorial plaque', 'bra', 'breakwater', 'breastplate',
|
||||
'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train',
|
||||
'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe',
|
||||
'can opener', 'cardigan', 'car mirror', 'carousel', 'tool kit',
|
||||
'cardboard box / carton', 'car wheel', 'automated teller machine',
|
||||
'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', 'cello',
|
||||
'mobile phone', 'chain', 'chain-link fence', 'chain mail', 'chainsaw',
|
||||
'storage chest', 'chiffonier', 'bell or wind chime', 'china cabinet',
|
||||
'Christmas stocking', 'church', 'movie theater', 'cleaver',
|
||||
'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker', 'coffee mug',
|
||||
'coffeemaker', 'spiral or coil', 'combination lock', 'computer keyboard',
|
||||
'candy store', 'container ship', 'convertible', 'corkscrew', 'cornet',
|
||||
'cowboy boot', 'cowboy hat', 'cradle', 'construction crane',
|
||||
'crash helmet', 'crate', 'infant bed', 'Crock Pot', 'croquet ball',
|
||||
'crutch', 'cuirass', 'dam', 'desk', 'desktop computer',
|
||||
'rotary dial telephone', 'diaper', 'digital clock', 'digital watch',
|
||||
'dining table', 'dishcloth', 'dishwasher', 'disc brake', 'dock',
|
||||
'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', 'drumstick',
|
||||
'dumbbell', 'Dutch oven', 'electric fan', 'electric guitar',
|
||||
'electric locomotive', 'entertainment center', 'envelope',
|
||||
'espresso machine', 'face powder', 'feather boa', 'filing cabinet',
|
||||
'fireboat', 'fire truck', 'fire screen', 'flagpole', 'flute',
|
||||
'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen',
|
||||
'four-poster bed', 'freight car', 'French horn', 'frying pan', 'fur coat',
|
||||
'garbage truck', 'gas mask or respirator', 'gas pump', 'goblet', 'go-kart',
|
||||
'golf ball', 'golf cart', 'gondola', 'gong', 'gown', 'grand piano',
|
||||
'greenhouse', 'radiator grille', 'grocery store', 'guillotine',
|
||||
'hair clip', 'hair spray', 'half-track', 'hammer', 'hamper', 'hair dryer',
|
||||
'hand-held computer', 'handkerchief', 'hard disk drive', 'harmonica',
|
||||
'harp', 'combine harvester', 'hatchet', 'holster', 'home theater',
|
||||
'honeycomb', 'hook', 'hoop skirt', 'gymnastic horizontal bar',
|
||||
'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron',
|
||||
'carved pumpkin', 'jeans', 'jeep', 'T-shirt', 'jigsaw puzzle', 'rickshaw',
|
||||
'joystick', 'kimono', 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade',
|
||||
'laptop computer', 'lawn mower', 'lens cap', 'letter opener', 'library',
|
||||
'lifeboat', 'lighter', 'limousine', 'ocean liner', 'lipstick',
|
||||
'slip-on shoe', 'lotion', 'music speaker', 'loupe magnifying glass',
|
||||
'sawmill', 'magnetic compass', 'messenger bag', 'mailbox', 'tights',
|
||||
'one-piece bathing suit', 'manhole cover', 'maraca', 'marimba', 'mask',
|
||||
'matchstick', 'maypole', 'maze', 'measuring cup', 'medicine cabinet',
|
||||
'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can',
|
||||
'minibus', 'miniskirt', 'minivan', 'missile', 'mitten', 'mixing bowl',
|
||||
'mobile home', 'ford model t', 'modem', 'monastery', 'monitor', 'moped',
|
||||
'mortar and pestle', 'graduation cap', 'mosque', 'mosquito net', 'vespa',
|
||||
'mountain bike', 'tent', 'computer mouse', 'mousetrap', 'moving van',
|
||||
'muzzle', 'metal nail', 'neck brace', 'necklace', 'baby pacifier',
|
||||
'notebook computer', 'obelisk', 'oboe', 'ocarina', 'odometer',
|
||||
'oil filter', 'pipe organ', 'oscilloscope', 'overskirt', 'bullock cart',
|
||||
'oxygen mask', 'product packet / packaging', 'paddle', 'paddle wheel',
|
||||
'padlock', 'paintbrush', 'pajamas', 'palace', 'pan flute', 'paper towel',
|
||||
'parachute', 'parallel bars', 'park bench', 'parking meter',
|
||||
'railroad car', 'patio', 'payphone', 'pedestal', 'pencil case',
|
||||
'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum',
|
||||
'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank',
|
||||
'pill bottle', 'pillow', 'ping-pong ball', 'pinwheel', 'pirate ship',
|
||||
'drink pitcher', 'block plane', 'planetarium', 'plastic bag', 'plate rack',
|
||||
'farm plow', 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho',
|
||||
'pool table', 'soda bottle', 'plant pot', "potter's wheel", 'power drill',
|
||||
'prayer rug', 'printer', 'prison', 'missile', 'projector', 'hockey puck',
|
||||
'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket',
|
||||
'radiator', 'radio', 'radio telescope', 'rain barrel',
|
||||
'recreational vehicle', 'fishing casting reel', 'reflex camera',
|
||||
'refrigerator', 'remote control', 'restaurant', 'revolver', 'rifle',
|
||||
'rocking chair', 'rotisserie', 'eraser', 'rugby ball',
|
||||
'ruler measuring stick', 'sneaker', 'safe', 'safety pin', 'salt shaker',
|
||||
'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale',
|
||||
'school bus', 'schooner', 'scoreboard', 'CRT monitor', 'screw',
|
||||
'screwdriver', 'seat belt', 'sewing machine', 'shield', 'shoe store',
|
||||
'shoji screen / room divider', 'shopping basket', 'shopping cart',
|
||||
'shovel', 'shower cap', 'shower curtain', 'ski', 'balaclava ski mask',
|
||||
'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel',
|
||||
'snowmobile', 'snowplow', 'soap dispenser', 'soccer ball', 'sock',
|
||||
'solar thermal collector', 'sombrero', 'soup bowl', 'keyboard space bar',
|
||||
'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web',
|
||||
'spindle', 'sports car', 'spotlight', 'stage', 'steam locomotive',
|
||||
'through arch bridge', 'steel drum', 'stethoscope', 'scarf', 'stone wall',
|
||||
'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa',
|
||||
'submarine', 'suit', 'sundial', 'sunglasses', 'sunglasses', 'sunscreen',
|
||||
'suspension bridge', 'mop', 'sweatshirt', 'swim trunks / shorts', 'swing',
|
||||
'electrical switch', 'syringe', 'table lamp', 'tank', 'tape player',
|
||||
'teapot', 'teddy bear', 'television', 'tennis ball', 'thatched roof',
|
||||
'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof',
|
||||
'toaster', 'tobacco shop', 'toilet seat', 'torch', 'totem pole',
|
||||
'tow truck', 'toy store', 'tractor', 'semi-trailer truck', 'tray',
|
||||
'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch',
|
||||
'trolleybus', 'trombone', 'hot tub', 'turnstile', 'typewriter keyboard',
|
||||
'umbrella', 'unicycle', 'upright piano', 'vacuum cleaner', 'vase',
|
||||
'vaulted or arched ceiling', 'velvet fabric', 'vending machine',
|
||||
'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock',
|
||||
'wallet', 'wardrobe', 'military aircraft', 'sink', 'washing machine',
|
||||
'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle',
|
||||
'hair wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle',
|
||||
'airplane wing', 'wok', 'wooden spoon', 'wool', 'split-rail fence',
|
||||
'shipwreck', 'sailboat', 'yurt', 'website', 'comic book', 'crossword',
|
||||
'traffic or street sign', 'traffic light', 'dust jacket', 'menu', 'plate',
|
||||
'guacamole', 'consomme', 'hot pot', 'trifle', 'ice cream', 'popsicle',
|
||||
'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog',
|
||||
'mashed potatoes', 'cabbage', 'broccoli', 'cauliflower', 'zucchini',
|
||||
'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber',
|
||||
'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith apple',
|
||||
'strawberry', 'orange', 'lemon', 'fig', 'pineapple', 'banana', 'jackfruit',
|
||||
'cherimoya (custard apple)', 'pomegranate', 'hay', 'carbonara',
|
||||
'chocolate syrup', 'dough', 'meatloaf', 'pizza', 'pot pie', 'burrito',
|
||||
'red wine', 'espresso', 'tea cup', 'eggnog', 'mountain', 'bubble', 'cliff',
|
||||
'coral reef', 'geyser', 'lakeshore', 'promontory', 'sandbar', 'beach',
|
||||
'valley', 'volcano', 'baseball player', 'bridegroom', 'scuba diver',
|
||||
'rapeseed', 'daisy', "yellow lady's slipper", 'corn', 'acorn', 'rose hip',
|
||||
'horse chestnut seed', 'coral fungus', 'agaric', 'gyromitra',
|
||||
'stinkhorn mushroom', 'earth star fungus', 'hen of the woods mushroom',
|
||||
'bolete', 'corn cob', 'toilet paper')
|
||||
|
|
|
@ -1,18 +1,45 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
from os import PathLike
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from mmengine import get_file_backend
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
from mmpretrain.registry import DATASETS, TRANSFORMS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
def expanduser(data_prefix):
|
||||
if isinstance(data_prefix, (str, PathLike)):
|
||||
return osp.expanduser(data_prefix)
|
||||
else:
|
||||
return data_prefix
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class COCORetrieval(BaseDataset):
|
||||
"""COCO Retrieval dataset.
|
||||
|
||||
COCO (Common Objects in Context): The COCO dataset contains more than
|
||||
330K images,each of which has approximately 5 descriptive annotations.
|
||||
This dataset was releasedin collaboration between Microsoft and Carnegie
|
||||
Mellon University
|
||||
|
||||
COCO_2014 dataset directory: ::
|
||||
|
||||
COCO_2014
|
||||
├── val2014
|
||||
├── train2014
|
||||
├── annotations
|
||||
├── instances_train2014.json
|
||||
├── instances_val2014.json
|
||||
├── person_keypoints_train2014.json
|
||||
├── person_keypoints_val2014.json
|
||||
├── captions_train2014.json
|
||||
├── captions_val2014.json
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
test_mode (bool): Whether dataset is used for evaluation. This will
|
||||
|
@ -23,8 +50,52 @@ class COCORetrieval(BaseDataset):
|
|||
data_prefix (str | dict): Prefix for training data. Defaults to ''.
|
||||
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
|
||||
Examples:
|
||||
>>> from mmpretrain.datasets import COCORetrieval
|
||||
>>> train_dataset=COCORetrieval(data_root='coco2014/')
|
||||
>>> train_dataset
|
||||
Dataset COCORetrieval
|
||||
Number of samples: 414113
|
||||
Annotation file: /coco2014/annotations/captions_train2014.json
|
||||
Prefix of images: /coco2014/
|
||||
>>> from mmpretrain.datasets import COCORetrieval
|
||||
>>> val_dataset = COCORetrieval(data_root='coco2014/')
|
||||
>>> val_dataset
|
||||
Dataset COCORetrieval
|
||||
Number of samples: 202654
|
||||
Annotation file: /coco2014/annotations/captions_val2014.json
|
||||
Prefix of images: /coco2014/
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str,
|
||||
test_mode: bool = False,
|
||||
data_prefix: Union[str, dict] = '',
|
||||
data_root: str = '',
|
||||
pipeline: Sequence = (),
|
||||
**kwargs):
|
||||
|
||||
if isinstance(data_prefix, str):
|
||||
data_prefix = dict(img_path=expanduser(data_prefix))
|
||||
|
||||
ann_file = expanduser(ann_file)
|
||||
transforms = []
|
||||
for transform in pipeline:
|
||||
if isinstance(transform, dict):
|
||||
transforms.append(TRANSFORMS.build(transform))
|
||||
else:
|
||||
transforms.append(transform)
|
||||
|
||||
super().__init__(
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
test_mode=test_mode,
|
||||
pipeline=transforms,
|
||||
ann_file=ann_file,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load data list."""
|
||||
# get file backend
|
||||
|
|
|
@ -5,6 +5,7 @@ if WITH_MULTIMODAL:
|
|||
from .blip import * # noqa: F401,F403
|
||||
from .blip2 import * # noqa: F401,F403
|
||||
from .chinese_clip import * # noqa: F401, F403
|
||||
from .clip import * # noqa: F401, F403
|
||||
from .flamingo import * # noqa: F401, F403
|
||||
from .llava import * # noqa: F401, F403
|
||||
from .minigpt4 import * # noqa: F401, F403
|
||||
|
@ -17,5 +18,6 @@ else:
|
|||
register_multimodal_placeholder([
|
||||
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
|
||||
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
|
||||
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter'
|
||||
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter', 'CLIP',
|
||||
'CLIPZeroShot'
|
||||
], MODELS)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from ..clip.clip import CLIP, CLIPZeroShot
|
||||
from ..clip.clip_transformer import CLIPProjection, CLIPTransformer
|
||||
|
||||
__all__ = ['CLIP', 'CLIPZeroShot', 'CLIPTransformer', 'CLIPProjection']
|
|
@ -0,0 +1,364 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModel
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.datasets.categories import (CIFAR100_CATEGORIES,
|
||||
IMAGENET_SIMPLE_CATEGORIES)
|
||||
from mmpretrain.registry import MODELS, TOKENIZER
|
||||
from mmpretrain.structures import DataSample
|
||||
from mmpretrain.utils import track_on_main_process
|
||||
from .utils import (OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT,
|
||||
OPENAI_IMAGENET_PROMPT_SUB)
|
||||
|
||||
CIFAR100_CATEGORIES = [' '.join(c.split('_')) for c in CIFAR100_CATEGORIES]
|
||||
PROTOTYPE_MAP = {
|
||||
'imagenet': IMAGENET_SIMPLE_CATEGORIES,
|
||||
'cifar100': CIFAR100_CATEGORIES,
|
||||
}
|
||||
PROMPT_MAP = {
|
||||
'openai_imagenet': OPENAI_IMAGENET_PROMPT,
|
||||
'openai_cifar100': OPENAI_CIFAR100_PROMPT,
|
||||
'vanilla': [lambda c: f'a photo of a {c}'],
|
||||
'openai_imagenet_sub': OPENAI_IMAGENET_PROMPT_SUB
|
||||
}
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function."""
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class CLIP(BaseModel):
|
||||
"""The implementation of `CLIP <https://arxiv.org/abs/2103.00020>`_.
|
||||
|
||||
Args:
|
||||
vision_backbone (dict): Config dict for vision backbone.
|
||||
text_backbone (dict): Config dict for text backbone.
|
||||
tokenizer (dict): Config dict for text tokenizer.
|
||||
proj_dim (int): Projection dimension for similarity computation.
|
||||
text_prototype (str): Text prototype, which can be a key in
|
||||
`PROTOTYPE_MAP` or list of text.
|
||||
text_prompt (str): The prompt for text prototype.
|
||||
Defaults to 'vanilla',which refers to "a photo of {cls}".
|
||||
context_length (int): The context length to use. Defaults to 77.
|
||||
data_preprocessor (Union[dict, nn.Module], optional): The config for
|
||||
preprocessing input data. If None or no specified type, it will use
|
||||
"MultiModalDataPreprocessor" as type.
|
||||
See :class:`MultiModalDataPreprocessor` for more details.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): The config to control the initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vision_backbone: dict,
|
||||
projection: dict,
|
||||
text_backbone: dict,
|
||||
tokenizer: dict,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
proj_dim: int,
|
||||
context_length: int = 77,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
if data_preprocessor is None:
|
||||
data_preprocessor = {}
|
||||
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
|
||||
data_preprocessor = MODELS.build(data_preprocessor)
|
||||
|
||||
super().__init__(
|
||||
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
|
||||
|
||||
self.context_length = context_length
|
||||
|
||||
# build the vision transformer
|
||||
self.visual = MODELS.build(vision_backbone)
|
||||
|
||||
# build the visual projection
|
||||
self.visual_proj = MODELS.build(projection)
|
||||
|
||||
# build attn_mask for casual-attn
|
||||
text_backbone['attn_mask'] = self.build_attention_mask()
|
||||
|
||||
# build the text transformer
|
||||
self.transformer = MODELS.build(text_backbone)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(
|
||||
torch.empty(transformer_width, proj_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
self.tokenizer = TOKENIZER.build(tokenizer)
|
||||
|
||||
self.tokenizer.vocab = self.tokenizer.get_vocab(
|
||||
) # CLIPTokenizer has no attribute named 'vocab', so manually
|
||||
|
||||
def initialize_parameters(self) -> None:
|
||||
"""Initialize the parameters.
|
||||
|
||||
The pretrained weight will override the initialized parameters by this
|
||||
function.
|
||||
"""
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
proj_std = (self.transformer.width**-0.5) * (
|
||||
(2 * self.transformer.layers)**-0.5)
|
||||
attn_std = self.transformer.width**-0.5
|
||||
fc_std = (2 * self.transformer.width)**-0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(
|
||||
self.text_projection, std=self.transformer.width**-0.5)
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask,
|
||||
# with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float('-inf'))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
data_samples: Optional[list] = None,
|
||||
mode: str = 'predict',
|
||||
**kwargs,
|
||||
):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
The method accepts the following modes:
|
||||
|
||||
- "predict": Forward and return a list of data samples contain the
|
||||
predict results.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): the preprocessed image tensor of shape
|
||||
``(N, C, H, W)``.
|
||||
data_samples (List[DataSample], optional): The annotation data
|
||||
of every samples. Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'predict'.
|
||||
"""
|
||||
if mode == 'predict':
|
||||
return self.predict(images, data_samples, **kwargs)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor:
|
||||
"""The function to extract image latent features."""
|
||||
return self.visual_proj(self.visual(images))[0]
|
||||
|
||||
def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor:
|
||||
"""The function to extract text latent features."""
|
||||
x = self.token_embedding(texts) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)[0]
|
||||
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding
|
||||
# (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]),
|
||||
texts.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
def extract_feat(
|
||||
self, images: torch.Tensor,
|
||||
texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
|
||||
"""The function to extract image and text latent features, the input
|
||||
image or text can not both be None."""
|
||||
|
||||
assert images is not None or texts is not None, \
|
||||
'text and image cannot both be None!'
|
||||
if images is None:
|
||||
return self.extract_text_feat(texts)
|
||||
elif texts is None:
|
||||
return self.extract_image_feat(images)
|
||||
|
||||
image_features = self.extract_image_feat(images)
|
||||
text_features = self.extract_text_feat(texts)
|
||||
|
||||
image_features = image_features / image_features.norm(
|
||||
dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(
|
||||
dim=-1, keepdim=True)
|
||||
|
||||
return image_features, text_features
|
||||
|
||||
def compute_similarity(self, images, texts):
|
||||
"""Extract images and texts features and compute cosine similarity."""
|
||||
image_features, text_features = self.extract_feat(
|
||||
images=images, texts=texts)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
# shape (N, N)
|
||||
return logits_per_image, logits_per_text
|
||||
|
||||
@abstractmethod
|
||||
def predict(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: DataSample = None) -> DataSample:
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor:
|
||||
"""Returns the tokenized representation of given input string(s)
|
||||
|
||||
Args:
|
||||
texts (Union[str, List[str]]): An input string or a list of input
|
||||
strings to tokenize
|
||||
context_length (int): The context length to use. Defaults to 52.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Resulting tokens.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
all_tokens = []
|
||||
for text in texts:
|
||||
# adapt the text to Chinese BERT vocab
|
||||
# text = text.lower().replace('“', "\"").replace('”', "\"")
|
||||
|
||||
# add special tokens
|
||||
all_tokens.append(
|
||||
[self.tokenizer.vocab['<|startoftext|>']
|
||||
] + # <|startoftext|>代表[CLS] token
|
||||
self.tokenizer.convert_tokens_to_ids(
|
||||
self.tokenizer.tokenize(text))[:self.context_length - 2] +
|
||||
[self.tokenizer.vocab['<|endoftext|>']])
|
||||
|
||||
result = torch.zeros(
|
||||
len(all_tokens), self.context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
assert len(tokens) <= self.context_length
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CLIPZeroShot(CLIP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone: dict,
|
||||
projection: dict,
|
||||
text_backbone: dict,
|
||||
tokenizer: dict,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
proj_dim: int,
|
||||
context_length: int = 77,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None,
|
||||
text_prototype: Union[str, List[str]] = 'imagenet',
|
||||
text_prompt: str = 'vanilla',
|
||||
):
|
||||
super(CLIPZeroShot,
|
||||
self).__init__(vision_backbone, projection, text_backbone,
|
||||
tokenizer, vocab_size, transformer_width,
|
||||
proj_dim, context_length, data_preprocessor,
|
||||
init_cfg)
|
||||
|
||||
# for zero-shot classification
|
||||
if isinstance(text_prototype,
|
||||
str) and text_prototype in PROTOTYPE_MAP.keys():
|
||||
self.prototype = PROTOTYPE_MAP[text_prototype]
|
||||
else:
|
||||
self.prototype = text_prototype
|
||||
self.text_prototype_embeds = None
|
||||
|
||||
self.prompt = PROMPT_MAP[text_prompt]
|
||||
|
||||
def predict(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: DataSample = None) -> DataSample:
|
||||
"""Predict the classes of the input images.
|
||||
|
||||
The prediction is for zero-shot classification and the text prototypes
|
||||
will be prepared in thisfunction.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): The input images.
|
||||
data_samples (DataSample): The data samples with information from
|
||||
dataset.
|
||||
|
||||
Returns:
|
||||
DataSample: The results of prediction.
|
||||
"""
|
||||
|
||||
if self.text_prototype_embeds is None:
|
||||
self.prepare_text_prototype(device=images.device)
|
||||
|
||||
image_features = self.extract_image_feat(images=images)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logits_per_image = image_features @ self.text_prototype_embeds.to(
|
||||
image_features.device) * self.logit_scale.exp()
|
||||
|
||||
pred_scores = F.softmax(logits_per_image, dim=1)
|
||||
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
|
||||
|
||||
out_data_samples = []
|
||||
if data_samples is None:
|
||||
data_samples = [None for _ in range(pred_scores.size(0))]
|
||||
|
||||
for data_sample, score, label in zip(data_samples, pred_scores,
|
||||
pred_labels):
|
||||
if data_sample is None:
|
||||
data_sample = DataSample()
|
||||
|
||||
data_sample.set_pred_score(score).set_pred_label(label)
|
||||
out_data_samples.append(data_sample)
|
||||
return out_data_samples
|
||||
|
||||
def prepare_text_prototype(self, device) -> None:
|
||||
"""The function to prepare text prototypes with prompt."""
|
||||
class_embeddings = []
|
||||
for classname in track_on_main_process(self.prototype,
|
||||
'Prepare text prototype...'):
|
||||
# format with class
|
||||
texts = [prompt(classname) for prompt in self.prompt]
|
||||
tokenized_texts = self.tokenize(texts)
|
||||
class_features = self.extract_text_feat(tokenized_texts.to(device))
|
||||
class_features /= class_features.norm(dim=-1, keepdim=True)
|
||||
class_feature = class_features.mean(dim=0)
|
||||
class_feature /= class_feature.norm()
|
||||
class_embeddings.append(class_feature)
|
||||
self.text_prototype_embeds = torch.stack(
|
||||
class_embeddings, dim=1).to(device)
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/zejiangh/MILAN
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.models.utils.clip_generator_helper import \
|
||||
ResidualAttentionBlock
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CLIPTransformer(nn.Module):
|
||||
"""Transformer.
|
||||
|
||||
Both visual and text branches use this transformer.
|
||||
|
||||
Args:
|
||||
width (int): The feature dimension.
|
||||
layers (int): The number of layers.
|
||||
heads (int): The number of attention heads.
|
||||
attn_mask (torch.Tensor, optional): The attention mask.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
attn_mask: Optional[torch.Tensor] = None) -> None:
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList()
|
||||
for _ in range(layers - 1):
|
||||
self.resblocks.append(
|
||||
ResidualAttentionBlock(width, heads, attn_mask))
|
||||
self.resblocks.append(
|
||||
ResidualAttentionBlock(
|
||||
width, heads, attn_mask, return_attention=True))
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward function."""
|
||||
z = []
|
||||
for idx, blk in enumerate(self.resblocks):
|
||||
if idx < self.layers - 1:
|
||||
x = blk(x)
|
||||
z.append(x.permute(1, 0, 2))
|
||||
else:
|
||||
x, attention = blk(x)
|
||||
z.append(x.permute(1, 0, 2))
|
||||
return x, attention, z
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CLIPProjection(BaseModule):
|
||||
"""Neck with CLIP Projection.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input.
|
||||
out_channels (int): Number of channels in the output.
|
||||
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super(CLIPProjection, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
scale = in_channels**-0.5
|
||||
self.proj = nn.Parameter(scale *
|
||||
torch.randn(in_channels, out_channels))
|
||||
|
||||
def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]:
|
||||
"""forward function.
|
||||
|
||||
Args:
|
||||
inputs (Tuple): The features extracted from
|
||||
the backbone. Multiple stage inputs are acceptable but only
|
||||
the last stage will be used.
|
||||
Returns:
|
||||
Tuple(torch.Tensor)): A tuple of reducted features.
|
||||
"""
|
||||
if isinstance(inputs, tuple):
|
||||
inputs = inputs[-1]
|
||||
out = inputs @ self.proj
|
||||
elif isinstance(inputs, torch.Tensor):
|
||||
out = inputs @ self.proj
|
||||
else:
|
||||
raise TypeError(
|
||||
'`CLIPProjection` neck inputs should be tuple or torch.tensor')
|
||||
return (out, )
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
OPENAI_CIFAR100_PROMPT = [
|
||||
lambda c: f'a photo of a {c}.',
|
||||
lambda c: f'a blurry photo of a {c}.',
|
||||
lambda c: f'a black and white photo of a {c}.',
|
||||
lambda c: f'a low contrast photo of a {c}.',
|
||||
lambda c: f'a high contrast photo of a {c}.',
|
||||
lambda c: f'a bad photo of a {c}.',
|
||||
lambda c: f'a good photo of a {c}.',
|
||||
lambda c: f'a photo of a small {c}.',
|
||||
lambda c: f'a photo of a big {c}.',
|
||||
lambda c: f'a photo of the {c}.',
|
||||
lambda c: f'a blurry photo of the {c}.',
|
||||
lambda c: f'a black and white photo of the {c}.',
|
||||
lambda c: f'a low contrast photo of the {c}.',
|
||||
lambda c: f'a high contrast photo of the {c}.',
|
||||
lambda c: f'a bad photo of the {c}.',
|
||||
lambda c: f'a good photo of the {c}.',
|
||||
lambda c: f'a photo of the small {c}.',
|
||||
lambda c: f'a photo of the big {c}.',
|
||||
]
|
||||
|
||||
OPENAI_IMAGENET_PROMPT_SUB = [
|
||||
lambda c: f'itap of a {c}.',
|
||||
lambda c: f'a bad photo of the {c}.',
|
||||
lambda c: f'a origami {c}.',
|
||||
lambda c: f'a photo of the large {c}.',
|
||||
lambda c: f'a {c} in a video game.',
|
||||
lambda c: f'art of the {c}.',
|
||||
lambda c: f'a photo of the small {c}.',
|
||||
]
|
||||
|
||||
OPENAI_IMAGENET_PROMPT = [
|
||||
lambda c: f'a bad photo of a {c}.',
|
||||
lambda c: f'a photo of many {c}.',
|
||||
lambda c: f'a sculpture of a {c}.',
|
||||
lambda c: f'a photo of the hard to see {c}.',
|
||||
lambda c: f'a low resolution photo of the {c}.',
|
||||
lambda c: f'a rendering of a {c}.',
|
||||
lambda c: f'graffiti of a {c}.',
|
||||
lambda c: f'a bad photo of the {c}.',
|
||||
lambda c: f'a cropped photo of the {c}.',
|
||||
lambda c: f'a tattoo of a {c}.',
|
||||
lambda c: f'the embroidered {c}.',
|
||||
lambda c: f'a photo of a hard to see {c}.',
|
||||
lambda c: f'a bright photo of a {c}.',
|
||||
lambda c: f'a photo of a clean {c}.',
|
||||
lambda c: f'a photo of a dirty {c}.',
|
||||
lambda c: f'a dark photo of the {c}.',
|
||||
lambda c: f'a drawing of a {c}.',
|
||||
lambda c: f'a photo of my {c}.',
|
||||
lambda c: f'the plastic {c}.',
|
||||
lambda c: f'a photo of the cool {c}.',
|
||||
lambda c: f'a close-up photo of a {c}.',
|
||||
lambda c: f'a black and white photo of the {c}.',
|
||||
lambda c: f'a painting of the {c}.',
|
||||
lambda c: f'a painting of a {c}.',
|
||||
lambda c: f'a pixelated photo of the {c}.',
|
||||
lambda c: f'a sculpture of the {c}.',
|
||||
lambda c: f'a bright photo of the {c}.',
|
||||
lambda c: f'a cropped photo of a {c}.',
|
||||
lambda c: f'a plastic {c}.',
|
||||
lambda c: f'a photo of the dirty {c}.',
|
||||
lambda c: f'a jpeg corrupted photo of a {c}.',
|
||||
lambda c: f'a blurry photo of the {c}.',
|
||||
lambda c: f'a photo of the {c}.',
|
||||
lambda c: f'a good photo of the {c}.',
|
||||
lambda c: f'a rendering of the {c}.',
|
||||
lambda c: f'a {c} in a video game.',
|
||||
lambda c: f'a photo of one {c}.',
|
||||
lambda c: f'a doodle of a {c}.',
|
||||
lambda c: f'a close-up photo of the {c}.',
|
||||
lambda c: f'a photo of a {c}.',
|
||||
lambda c: f'the origami {c}.',
|
||||
lambda c: f'the {c} in a video game.',
|
||||
lambda c: f'a sketch of a {c}.',
|
||||
lambda c: f'a doodle of the {c}.',
|
||||
lambda c: f'a origami {c}.',
|
||||
lambda c: f'a low resolution photo of a {c}.',
|
||||
lambda c: f'the toy {c}.',
|
||||
lambda c: f'a rendition of the {c}.',
|
||||
lambda c: f'a photo of the clean {c}.',
|
||||
lambda c: f'a photo of a large {c}.',
|
||||
lambda c: f'a rendition of a {c}.',
|
||||
lambda c: f'a photo of a nice {c}.',
|
||||
lambda c: f'a photo of a weird {c}.',
|
||||
lambda c: f'a blurry photo of a {c}.',
|
||||
lambda c: f'a cartoon {c}.',
|
||||
lambda c: f'art of a {c}.',
|
||||
lambda c: f'a sketch of the {c}.',
|
||||
lambda c: f'a embroidered {c}.',
|
||||
lambda c: f'a pixelated photo of a {c}.',
|
||||
lambda c: f'itap of the {c}.',
|
||||
lambda c: f'a jpeg corrupted photo of the {c}.',
|
||||
lambda c: f'a good photo of a {c}.',
|
||||
lambda c: f'a plushie {c}.',
|
||||
lambda c: f'a photo of the nice {c}.',
|
||||
lambda c: f'a photo of the small {c}.',
|
||||
lambda c: f'a photo of the weird {c}.',
|
||||
lambda c: f'the cartoon {c}.',
|
||||
lambda c: f'art of the {c}.',
|
||||
lambda c: f'a drawing of the {c}.',
|
||||
lambda c: f'a photo of the large {c}.',
|
||||
lambda c: f'a black and white photo of a {c}.',
|
||||
lambda c: f'the plushie {c}.',
|
||||
lambda c: f'a dark photo of a {c}.',
|
||||
lambda c: f'itap of a {c}.',
|
||||
lambda c: f'graffiti of the {c}.',
|
||||
lambda c: f'a toy {c}.',
|
||||
lambda c: f'itap of my {c}.',
|
||||
lambda c: f'a photo of a cool {c}.',
|
||||
lambda c: f'a photo of a small {c}.',
|
||||
lambda c: f'a tattoo of the {c}.',
|
||||
]
|
|
@ -1301,6 +1301,7 @@ class OFAEncoderDecoder(BaseModule, GenerationMixin):
|
|||
Defaults to an empty dict.
|
||||
init_cfg (dict, optional): The initialization config. Defaults to None.
|
||||
"""
|
||||
base_model_prefix = ''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Implementation for DINO
|
||||
|
||||
**NOTE**: We only guarantee correctness of the forward pass, not responsible for full reimplementation.
|
||||
|
||||
First, ensure you are in the root directory of MMPretrain, then you have two choices
|
||||
to play with DINO in MMPretrain:
|
||||
|
||||
## Slurm
|
||||
|
||||
If you are using a cluster managed by Slurm, you can use the following command to
|
||||
start your job:
|
||||
|
||||
```shell
|
||||
GPUS_PER_NODE=8 GPUS=8 CPUS_PER_TASK=16 bash projects/dino/tools/slurm_train.sh mm_model dino projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py --amp
|
||||
```
|
||||
|
||||
The above command will pre-train the model on a single node with 8 GPUs.
|
||||
|
||||
## PyTorch
|
||||
|
||||
If you are using a single machine, without any cluster management software, you can use the following command
|
||||
|
||||
```shell
|
||||
NNODES=1 bash projects/dino/tools/dist_train.sh projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py 8
|
||||
--amp
|
||||
```
|
|
@ -0,0 +1,104 @@
|
|||
model = dict(
|
||||
type='DINO',
|
||||
data_preprocessor=dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='mmpretrain.VisionTransformer', arch='b', patch_size=16),
|
||||
neck=dict(
|
||||
type='DINONeck',
|
||||
in_channels=768,
|
||||
out_channels=65536,
|
||||
hidden_channels=2048,
|
||||
bottleneck_channels=256),
|
||||
head=dict(
|
||||
type='DINOHead',
|
||||
out_channels=65536,
|
||||
num_crops=10,
|
||||
student_temp=0.1,
|
||||
center_momentum=0.9))
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='DINOMultiCrop',
|
||||
global_crops_scale=(0.4, 1.0),
|
||||
local_crops_scale=(0.05, 0.4),
|
||||
local_crops_number=8),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=16,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type='mmpretrain.ImageNet',
|
||||
data_root='/data/imagenet/',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix=dict(img_path='train/'),
|
||||
pipeline=train_pipeline,
|
||||
))
|
||||
optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05)
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
ln=dict(decay_mult=0.0),
|
||||
bias=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0),
|
||||
mask_token=dict(decay_mult=0.0),
|
||||
cls_token=dict(decay_mult=0.0))),
|
||||
loss_scale='dynamic')
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-09,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=10,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=90,
|
||||
by_epoch=True,
|
||||
begin=10,
|
||||
end=100,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100)
|
||||
default_scope = 'mmpretrain'
|
||||
default_hooks = dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=100),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'))
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=False,
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'))
|
||||
log_processor = dict(
|
||||
window_size=10,
|
||||
custom_cfg=[dict(data_src='', method='mean', window_size='global')])
|
||||
vis_backends = [dict(type='LocalVisBackend')]
|
||||
visualizer = dict(
|
||||
type='UniversalVisualizer',
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
name='visualizer')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = True
|
||||
randomness = dict(seed=2, diff_rank_seed=True)
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='DINOTeacherTempWarmupHook',
|
||||
warmup_teacher_temp=0.04,
|
||||
teacher_temp=0.04,
|
||||
teacher_temp_warmup_epochs=0,
|
||||
max_epochs=100)
|
||||
]
|
|
@ -0,0 +1 @@
|
|||
from .transform import * # noqa: F401,F403
|
|
@ -0,0 +1,3 @@
|
|||
from .processing import DINOMultiCrop
|
||||
|
||||
__all__ = ['DINOMultiCrop']
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
|
||||
from mmcv.transforms import RandomApply # noqa: E501
|
||||
from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale
|
||||
|
||||
from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur,
|
||||
RandomResizedCrop, Solarize)
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class DINOMultiCrop(BaseTransform):
|
||||
"""Multi-crop transform for DINO.
|
||||
|
||||
This module applies the multi-crop transform for DINO.
|
||||
|
||||
Args:
|
||||
global_crops_scale (int): Scale of global crops.
|
||||
local_crops_scale (int): Scale of local crops.
|
||||
local_crops_number (int): Number of local crops.
|
||||
"""
|
||||
|
||||
def __init__(self, global_crops_scale: int, local_crops_scale: int,
|
||||
local_crops_number: int) -> None:
|
||||
super().__init__()
|
||||
self.global_crops_scale = global_crops_scale
|
||||
self.local_crops_scale = local_crops_scale
|
||||
|
||||
flip_and_color_jitter = Compose([
|
||||
RandomFlip(prob=0.5, direction='horizontal'),
|
||||
RandomApply([
|
||||
ColorJitter(
|
||||
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
|
||||
],
|
||||
prob=0.8),
|
||||
RandomGrayscale(
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989),
|
||||
)
|
||||
])
|
||||
|
||||
self.global_transform_1 = Compose([
|
||||
RandomResizedCrop(
|
||||
224,
|
||||
crop_ratio_range=global_crops_scale,
|
||||
interpolation='bicubic'),
|
||||
flip_and_color_jitter,
|
||||
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
|
||||
])
|
||||
|
||||
self.global_transform_2 = Compose([
|
||||
RandomResizedCrop(
|
||||
224,
|
||||
crop_ratio_range=global_crops_scale,
|
||||
interpolation='bicubic'),
|
||||
flip_and_color_jitter,
|
||||
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
|
||||
Solarize(thr=128, prob=0.2),
|
||||
])
|
||||
|
||||
self.local_crops_number = local_crops_number
|
||||
self.local_transform = Compose([
|
||||
RandomResizedCrop(
|
||||
96,
|
||||
crop_ratio_range=local_crops_scale,
|
||||
interpolation='bicubic'),
|
||||
flip_and_color_jitter,
|
||||
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
|
||||
])
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
ori_img = results['img']
|
||||
crops = []
|
||||
results['img'] = ori_img
|
||||
crops.append(self.global_transform_1(results)['img'])
|
||||
results['img'] = ori_img
|
||||
crops.append(self.global_transform_2(results)['img'])
|
||||
for _ in range(self.local_crops_number):
|
||||
results['img'] = ori_img
|
||||
crops.append(self.local_transform(results)['img'])
|
||||
results['img'] = crops
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(global_crops_scale = {self.global_crops_scale}, '
|
||||
repr_str += f'local_crops_scale = {self.local_crops_scale}, '
|
||||
repr_str += f'local_crop_number = {self.local_crops_number})'
|
||||
return repr_str
|
|
@ -0,0 +1 @@
|
|||
from .hooks import * # noqa
|
|
@ -0,0 +1,3 @@
|
|||
from .dino_teacher_temp_warmup_hook import DINOTeacherTempWarmupHook
|
||||
|
||||
__all__ = ['DINOTeacherTempWarmupHook']
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
from mmengine.hooks import Hook
|
||||
|
||||
from mmpretrain.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class DINOTeacherTempWarmupHook(Hook):
|
||||
"""Warmup teacher temperature for DINO.
|
||||
|
||||
This hook warmups the temperature for teacher to stabilize the training
|
||||
process.
|
||||
|
||||
Args:
|
||||
warmup_teacher_temp (float): Warmup temperature for teacher.
|
||||
teacher_temp (float): Temperature for teacher.
|
||||
teacher_temp_warmup_epochs (int): Warmup epochs for teacher
|
||||
temperature.
|
||||
max_epochs (int): Maximum epochs for training.
|
||||
"""
|
||||
|
||||
def __init__(self, warmup_teacher_temp: float, teacher_temp: float,
|
||||
teacher_temp_warmup_epochs: int, max_epochs: int) -> None:
|
||||
super().__init__()
|
||||
self.teacher_temps = np.concatenate(
|
||||
(np.linspace(warmup_teacher_temp, teacher_temp,
|
||||
teacher_temp_warmup_epochs),
|
||||
np.ones(max_epochs - teacher_temp_warmup_epochs) * teacher_temp))
|
||||
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
runner.model.module.head.teacher_temp = self.teacher_temps[
|
||||
runner.epoch]
|
|
@ -0,0 +1,3 @@
|
|||
from .algorithm import * # noqa
|
||||
from .head import * # noqa
|
||||
from .neck import * # noqa
|
|
@ -0,0 +1,3 @@
|
|||
from .dino import DINO
|
||||
|
||||
__all__ = ['DINO']
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.models import BaseSelfSupervisor, CosineEMA
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DINO(BaseSelfSupervisor):
|
||||
"""Implementation for DINO.
|
||||
|
||||
This module is proposed in `DINO: Emerging Properties in Self-Supervised
|
||||
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
|
||||
|
||||
Args:
|
||||
backbone (dict): Config for backbone.
|
||||
neck (dict): Config for neck.
|
||||
head (dict): Config for head.
|
||||
pretrained (str, optional): Path for pretrained model.
|
||||
Defaults to None.
|
||||
base_momentum (float, optional): Base momentum for momentum update.
|
||||
Defaults to 0.99.
|
||||
data_preprocessor (dict, optional): Config for data preprocessor.
|
||||
Defaults to None.
|
||||
init_cfg (list[dict] | dict, optional): Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: dict,
|
||||
neck: dict,
|
||||
head: dict,
|
||||
pretrained: Optional[str] = None,
|
||||
base_momentum: float = 0.99,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
pretrained=pretrained,
|
||||
data_preprocessor=data_preprocessor,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
# create momentum model
|
||||
self.teacher = CosineEMA(
|
||||
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
|
||||
# weight normalization layer
|
||||
self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer)
|
||||
self.neck.last_layer.weight_g.data.fill_(1)
|
||||
self.neck.last_layer.weight_g.requires_grad = False
|
||||
self.teacher.module[1].last_layer = nn.utils.weight_norm(
|
||||
self.teacher.module[1].last_layer)
|
||||
self.teacher.module[1].last_layer.weight_g.data.fill_(1)
|
||||
self.teacher.module[1].last_layer.weight_g.requires_grad = False
|
||||
|
||||
def loss(self, inputs: torch.Tensor,
|
||||
data_samples: List[DataSample]) -> dict:
|
||||
global_crops = torch.cat(inputs[:2])
|
||||
local_crops = torch.cat(inputs[2:])
|
||||
# teacher forward
|
||||
teacher_output = self.teacher(global_crops)
|
||||
|
||||
# student forward global
|
||||
student_output_global = self.backbone(global_crops)
|
||||
student_output_global = self.neck(student_output_global)
|
||||
|
||||
# student forward local
|
||||
student_output_local = self.backbone(local_crops)
|
||||
student_output_local = self.neck(student_output_local)
|
||||
|
||||
student_output = torch.cat(
|
||||
(student_output_global, student_output_local))
|
||||
|
||||
# compute loss
|
||||
loss = self.head(student_output, teacher_output)
|
||||
|
||||
return dict(loss=loss)
|
|
@ -0,0 +1,3 @@
|
|||
from .dino_head import DINOHead
|
||||
|
||||
__all__ = ['DINOHead']
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.dist import all_reduce, get_world_size
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DINOHead(BaseModule):
|
||||
"""Implementation for DINO head.
|
||||
|
||||
This module is proposed in `DINO: Emerging Properties in Self-Supervised
|
||||
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
|
||||
|
||||
Args:
|
||||
out_channels (int): Output channels of the head.
|
||||
num_crops (int): Number of crops.
|
||||
student_temp (float): Temperature for student output.
|
||||
center_momentum (float): Momentum for center update.
|
||||
"""
|
||||
|
||||
def __init__(self, out_channels: int, num_crops: int, student_temp: float,
|
||||
center_momentum: float) -> None:
|
||||
super().__init__()
|
||||
self.student_temp = student_temp
|
||||
self.teacher_temp = 0
|
||||
self.center_momentum = center_momentum
|
||||
self.num_crops = num_crops
|
||||
self.register_buffer('center', torch.zeros(1, out_channels))
|
||||
|
||||
def forward(self, student_output: torch.Tensor,
|
||||
teacher_output: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
current_teacher_output = teacher_output
|
||||
student_output = student_output / self.student_temp
|
||||
student_output = student_output.chunk(self.num_crops, dim=0)
|
||||
|
||||
# teacher centering and sharpening
|
||||
teacher_output = F.softmax(
|
||||
(teacher_output - self.center) / self.teacher_temp, dim=-1)
|
||||
teacher_output = teacher_output.detach().chunk(2, dim=0)
|
||||
|
||||
total_loss = 0
|
||||
n_loss_terms = 0
|
||||
|
||||
for i in range(len(teacher_output)):
|
||||
for j in range(len(student_output)):
|
||||
if i == j:
|
||||
continue
|
||||
total_loss += (-teacher_output[i] *
|
||||
student_output[j].log_softmax(dim=-1)).sum(
|
||||
dim=-1).mean()
|
||||
n_loss_terms += 1
|
||||
total_loss /= n_loss_terms
|
||||
self.update_center(current_teacher_output)
|
||||
return total_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def update_center(self, teacher_output: torch.Tensor) -> None:
|
||||
|
||||
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
|
||||
all_reduce(batch_center)
|
||||
batch_center = batch_center / (len(teacher_output) * get_world_size())
|
||||
|
||||
# ema update batch center
|
||||
self.center = self.center * self.center_momentum + batch_center * (
|
||||
1 - self.center_momentum)
|
|
@ -0,0 +1,3 @@
|
|||
from .dino_neck import DINONeck
|
||||
|
||||
__all__ = ['DINONeck']
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DINONeck(BaseModule):
|
||||
"""Implementation for DINO neck.
|
||||
|
||||
This module is proposed in `DINO: Emerging Properties in Self-Supervised
|
||||
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
hidden_channels (int): Hidden channels.
|
||||
out_channels (int): Output channels.
|
||||
bottleneck_channels (int): Bottleneck channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, hidden_channels: int,
|
||||
out_channels: int, bottleneck_channels: int) -> None:
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(*[
|
||||
nn.Linear(in_channels, hidden_channels),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_channels, hidden_channels),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_channels, bottleneck_channels),
|
||||
])
|
||||
|
||||
self.last_layer = nn.Linear(
|
||||
bottleneck_channels, out_channels, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.mlp(x[0])
|
||||
x = nn.functional.normalize(x, dim=-1, p=2)
|
||||
x = self.last_layer(x)
|
||||
return x
|
|
@ -0,0 +1,19 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
CONFIG=$1
|
||||
GPUS=$2
|
||||
NNODES=${NNODES:-1}
|
||||
NODE_RANK=${NODE_RANK:-0}
|
||||
PORT=${PORT:-29500}
|
||||
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
python -m torch.distributed.launch \
|
||||
--nnodes=$NNODES \
|
||||
--node_rank=$NODE_RANK \
|
||||
--master_addr=$MASTER_ADDR \
|
||||
--nproc_per_node=$GPUS \
|
||||
--master_port=$PORT \
|
||||
$(dirname "$0")/train.py \
|
||||
$CONFIG \
|
||||
--launcher pytorch ${@:3}
|
|
@ -0,0 +1,23 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -x
|
||||
|
||||
PARTITION=$1
|
||||
JOB_NAME=$2
|
||||
CONFIG=$3
|
||||
GPUS=${GPUS:-8}
|
||||
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
||||
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
||||
SRUN_ARGS=${SRUN_ARGS:-""}
|
||||
PY_ARGS=${@:4}
|
||||
|
||||
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
||||
srun -p ${PARTITION} \
|
||||
--job-name=${JOB_NAME} \
|
||||
--gres=gpu:${GPUS_PER_NODE} \
|
||||
--ntasks=${GPUS} \
|
||||
--ntasks-per-node=${GPUS_PER_NODE} \
|
||||
--cpus-per-task=${CPUS_PER_TASK} \
|
||||
--kill-on-bad-exit=1 \
|
||||
${SRUN_ARGS} \
|
||||
python -u projects/dino/tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from dataset import * # noqa: F401,F403
|
||||
from engine import * # noqa: F401,F403
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.runner import Runner
|
||||
from models.algorithm import * # noqa: F401,F403
|
||||
from models.head import * # noqa: F401,F403
|
||||
from models.neck import * # noqa: F401,F403
|
||||
|
||||
from mmpretrain.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train a model')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--resume',
|
||||
nargs='?',
|
||||
type=str,
|
||||
const='auto',
|
||||
help='If specify checkpint path, resume from it, while if not '
|
||||
'specify, try to auto resume from the latest checkpoint '
|
||||
'in the work directory.')
|
||||
parser.add_argument(
|
||||
'--amp',
|
||||
action='store_true',
|
||||
help='enable automatic-mixed-precision training')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmpretrain into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
# update configs according to CLI args if args.work_dir is not None
|
||||
cfg.work_dir = args.work_dir
|
||||
elif cfg.get('work_dir', None) is None:
|
||||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
work_type = args.config.split('/')[1]
|
||||
cfg.work_dir = osp.join('./work_dirs', work_type,
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
# enable automatic-mixed-precision training
|
||||
if args.amp is True:
|
||||
optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
|
||||
assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
|
||||
'`--amp` is not supported custom optimizer wrapper type ' \
|
||||
f'`{optim_wrapper}.'
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
|
||||
|
||||
# resume training
|
||||
if args.resume == 'auto':
|
||||
cfg.resume = True
|
||||
cfg.load_from = None
|
||||
elif args.resume is not None:
|
||||
cfg.resume = True
|
||||
cfg.load_from = args.resume
|
||||
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# start training
|
||||
runner.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_clip(ckpt):
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in list(ckpt.items()):
|
||||
new_v = v
|
||||
if k.startswith('visual.conv1'):
|
||||
new_k = k.replace('conv1', 'patch_embed.projection')
|
||||
elif k.startswith('visual.positional_embedding'):
|
||||
new_k = k.replace('positional_embedding', 'pos_embed')
|
||||
new_v = v.unsqueeze(dim=0)
|
||||
elif k.startswith('visual.class_embedding'):
|
||||
new_k = k.replace('class_embedding', 'cls_token')
|
||||
new_v = v.unsqueeze(dim=0).unsqueeze(dim=0)
|
||||
elif k.startswith('visual.ln_pre'):
|
||||
new_k = k.replace('ln_pre', 'pre_norm')
|
||||
elif k.startswith('visual.transformer.resblocks'):
|
||||
new_k = k.replace('transformer.resblocks', 'layers')
|
||||
if 'ln_1' in k:
|
||||
new_k = new_k.replace('ln_1', 'ln1')
|
||||
elif 'ln_2' in k:
|
||||
new_k = new_k.replace('ln_2', 'ln2')
|
||||
elif 'mlp.c_fc' in k:
|
||||
new_k = new_k.replace('mlp.c_fc', 'ffn.layers.0.0')
|
||||
elif 'mlp.c_proj' in k:
|
||||
new_k = new_k.replace('mlp.c_proj', 'ffn.layers.1')
|
||||
elif 'attn.in_proj_weight' in k:
|
||||
new_k = new_k.replace('in_proj_weight', 'qkv.weight')
|
||||
elif 'attn.in_proj_bias' in k:
|
||||
new_k = new_k.replace('in_proj_bias', 'qkv.bias')
|
||||
elif 'attn.out_proj' in k:
|
||||
new_k = new_k.replace('out_proj', 'proj')
|
||||
elif k.startswith('visual.ln_post'):
|
||||
new_k = k.replace('ln_post', 'ln1')
|
||||
elif k.startswith('visual.proj'):
|
||||
new_k = k.replace('visual.proj', 'visual_proj.proj')
|
||||
else:
|
||||
new_k = k
|
||||
|
||||
new_ckpt[new_k] = new_v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in pretrained clip '
|
||||
'models to mmpretrain style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
weight = convert_clip(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
print('Done!!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -91,10 +91,6 @@ def merge_args(cfg, args):
|
|||
|
||||
# enable automatic-mixed-precision training
|
||||
if args.amp is True:
|
||||
optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
|
||||
assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
|
||||
'`--amp` is not supported custom optimizer wrapper type ' \
|
||||
f'`{optim_wrapper}.'
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
|
||||
|
||||
|
|
Loading…
Reference in New Issue