Fix bug and optimize MNIST config (#98)

* add simple_test to ClsHead

* optimize lenet training config

* recover path setting
pull/102/head
WRH 2020-11-26 15:27:04 +08:00 committed by GitHub
parent c0e7512969
commit 44bbc71e14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 5 deletions

View File

@ -15,12 +15,13 @@ train_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label']),
]
test_pipeline = [
dict(type='Resize', size=32),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img']),
]
data = dict(
samples_per_gpu=128,
@ -28,9 +29,9 @@ data = dict(
train=dict(
type=dataset_type, data_prefix='data/mnist', pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/mnist', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/mnist', pipeline=test_pipeline))
evaluation = dict(
interval=5, metric='accuracy', metric_options={'topk': (1, )})
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
@ -40,14 +41,14 @@ lr_config = dict(policy='step', step=[15])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=100,
interval=150,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=20)
runner = dict(type='EpochBasedRunner', max_epochs=5)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/mnist/'

View File

@ -1,3 +1,6 @@
import torch
import torch.nn.functional as F
from mmcls.models.losses import Accuracy
from ..builder import HEADS, build_loss
from .base_head import BaseHead
@ -43,3 +46,13 @@ class ClsHead(BaseHead):
def forward_train(self, cls_score, gt_label):
losses = self.loss(cls_score, gt_label)
return losses
def simple_test(self, cls_score):
"""Test without augmentation."""
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
return pred
pred = list(pred.detach().cpu().numpy())
return pred