mmpretrain/configs/_base_/models/resnet34_cifar.py

17 lines
406 B
Python

# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet_CIFAR',
depth=34,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))