Fix bug and optimize MNIST config (#98)
* add simple_test to ClsHead * optimize lenet training config * recover path settingpull/102/head
parent
c0e7512969
commit
44bbc71e14
|
@ -15,12 +15,13 @@ train_pipeline = [
|
|||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='Resize', size=32),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
|
@ -28,9 +29,9 @@ data = dict(
|
|||
train=dict(
|
||||
type=dataset_type, data_prefix='data/mnist', pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type, data_prefix='data/mnist', pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type, data_prefix='data/mnist', pipeline=test_pipeline))
|
||||
evaluation = dict(
|
||||
interval=5, metric='accuracy', metric_options={'topk': (1, )})
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
|
@ -40,14 +41,14 @@ lr_config = dict(policy='step', step=[15])
|
|||
checkpoint_config = dict(interval=1)
|
||||
# yapf:disable
|
||||
log_config = dict(
|
||||
interval=100,
|
||||
interval=150,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(type='TensorboardLoggerHook')
|
||||
])
|
||||
# yapf:enable
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=20)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=5)
|
||||
dist_params = dict(backend='nccl')
|
||||
log_level = 'INFO'
|
||||
work_dir = './work_dirs/mnist/'
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcls.models.losses import Accuracy
|
||||
from ..builder import HEADS, build_loss
|
||||
from .base_head import BaseHead
|
||||
|
@ -43,3 +46,13 @@ class ClsHead(BaseHead):
|
|||
def forward_train(self, cls_score, gt_label):
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
return losses
|
||||
|
||||
def simple_test(self, cls_score):
|
||||
"""Test without augmentation."""
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
|
Loading…
Reference in New Issue