46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
_base_ = 'mmcls::_base_/default_runtime.py'
|
|
|
|
model = dict(
|
|
_scope_='mmcls',
|
|
type='ImageClassifier',
|
|
backbone=dict(
|
|
type='VisionTransformer',
|
|
arch='base',
|
|
img_size=224,
|
|
patch_size=16,
|
|
out_indices=-1,
|
|
drop_path_rate=0.1,
|
|
avg_token=False,
|
|
output_cls_token=False,
|
|
final_norm=False),
|
|
neck=None,
|
|
head=dict(
|
|
type='LinearClsHead',
|
|
num_classes=1000,
|
|
in_channels=768,
|
|
loss=dict(
|
|
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
|
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
|
|
)
|
|
|
|
dataset_type = 'mmcls.ImageNet'
|
|
data_root = 'data/imagenet/'
|
|
file_client_args = dict(backend='disk')
|
|
extract_pipeline = [
|
|
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
|
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
|
|
dict(type='CenterCrop', crop_size=224),
|
|
dict(type='mmcls.PackClsInputs'),
|
|
]
|
|
extract_dataloader = dict(
|
|
batch_size=8,
|
|
num_workers=4,
|
|
dataset=dict(
|
|
type=dataset_type,
|
|
data_root='data/imagenet',
|
|
ann_file='meta/val.txt',
|
|
data_prefix='val',
|
|
pipeline=extract_pipeline),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
)
|