[Fix]fix classification configs (#488)
parent
14afb8e302
commit
2dacc14d4a
|
@ -1,49 +0,0 @@
|
|||
# dataset settings
|
||||
data_source = 'CIFAR10'
|
||||
dataset_type = 'SingleViewDataset'
|
||||
img_norm_cfg = dict(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.201])
|
||||
train_pipeline = [
|
||||
dict(type='RandomCrop', size=32, padding=4),
|
||||
dict(type='RandomHorizontalFlip'),
|
||||
]
|
||||
test_pipeline = []
|
||||
|
||||
# prefetch
|
||||
prefetch = False
|
||||
if not prefetch:
|
||||
train_pipeline.extend(
|
||||
[dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg)])
|
||||
test_pipeline.extend(
|
||||
[dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg)])
|
||||
|
||||
# dataset summary
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
type=data_source,
|
||||
data_prefix='data/cifar10',
|
||||
),
|
||||
pipeline=train_pipeline,
|
||||
prefetch=prefetch),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
type=data_source,
|
||||
data_prefix='data/cifar10',
|
||||
),
|
||||
pipeline=test_pipeline,
|
||||
prefetch=prefetch),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
type=data_source,
|
||||
data_prefix='data/cifar10',
|
||||
),
|
||||
pipeline=test_pipeline,
|
||||
prefetch=prefetch))
|
||||
evaluation = dict(interval=10, topk=(1, 5))
|
|
@ -1,5 +1,6 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
preprocess_cfg = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
|
@ -57,7 +58,7 @@ train_dataloader = dict(
|
|||
num_workers=8,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
data_root=data_root,
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
|
@ -70,7 +71,7 @@ val_dataloader = dict(
|
|||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
data_root=data_root,
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
|
|
|
@ -21,7 +21,7 @@ train_dataloader = dict(
|
|||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
data_root=data_root,
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
|
@ -34,7 +34,7 @@ val_dataloader = dict(
|
|||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
data_root=data_root,
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
|
|
|
@ -57,7 +57,7 @@ train_dataloader = dict(
|
|||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
data_root=data_root,
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
|
@ -70,7 +70,7 @@ val_dataloader = dict(
|
|||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
data_root=data_root,
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
_base_ = [
|
||||
'../_base_/models/resnet50.py',
|
||||
'../_base_/datasets/cifar10.py',
|
||||
'mmcls::_base_/datasets/cifar10_bs16.py',
|
||||
'../_base_/schedules/sgd_steplr-100e.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(batch_size=128)
|
||||
val_dataloader = dict(batch_size=128)
|
||||
|
||||
# model settings
|
||||
model = dict(head=dict(num_classes=10))
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=5e-4)
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
|
@ -17,5 +22,6 @@ param_scheduler = [
|
|||
]
|
||||
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=350)
|
||||
checkpoint_config = dict(interval=50)
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=350)
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(type='CheckpointHook', interval=50, max_keep_ckpts=3))
|
||||
|
|
Loading…
Reference in New Issue