mmpretrain/configs/_base_/models/resnet101_cifar.py

17 lines
408 B
Python

# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet_CIFAR',
depth=101,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))