[Feature] Support fp16 training (#178)
* change mmcls fp16 to mmcv hook * support fp16 * clean unnessary stuffpull/183/head
parent
8eb845e718
commit
e76c5a368d
|
@ -0,0 +1,14 @@
|
|||
# Mixed Precision Training
|
||||
|
||||
## Introduction
|
||||
|
||||
[OTHERS]
|
||||
|
||||
```latex
|
||||
@article{micikevicius2017mixed,
|
||||
title={Mixed precision training},
|
||||
author={Micikevicius, Paulius and Narang, Sharan and Alben, Jonah and Diamos, Gregory and Elsen, Erich and Garcia, David and Ginsburg, Boris and Houston, Michael and Kuchaiev, Oleksii and Venkatesh, Ganesh and others},
|
||||
journal={arXiv preprint arXiv:1710.03740},
|
||||
year={2017}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = ['../resnet/resnet50_b32x8_imagenet.py']
|
||||
|
||||
# fp16 settings
|
||||
fp16 = dict(loss_scale='dynamic')
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = ['../resnet/resnet50_b32x8_imagenet.py']
|
||||
|
||||
# fp16 settings
|
||||
fp16 = dict(loss_scale=512.)
|
|
@ -6,11 +6,18 @@ import torch
|
|||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner
|
||||
|
||||
from mmcls.core import (DistEvalHook, DistOptimizerHook, EvalHook,
|
||||
Fp16OptimizerHook)
|
||||
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
|
||||
from mmcls.datasets import build_dataloader, build_dataset
|
||||
from mmcls.utils import get_root_logger
|
||||
|
||||
# TODO import optimizer hook from mmcv and delete them from mmcls
|
||||
try:
|
||||
from mmcv.runner import Fp16OptimizerHook
|
||||
except ImportError:
|
||||
warnings.warn('FP16OptimizerHook from mmcls will be deprecated.'
|
||||
'Please install mmcv>=1.1.4.')
|
||||
from mmcls.core import Fp16OptimizerHook
|
||||
|
||||
|
||||
def set_random_seed(seed, deterministic=False):
|
||||
"""Set random seed.
|
||||
|
|
|
@ -10,12 +10,21 @@ import torch.nn as nn
|
|||
from mmcv import color_val
|
||||
from mmcv.utils import print_log
|
||||
|
||||
# TODO import `auto_fp16` from mmcv and delete them from mmcls
|
||||
try:
|
||||
from mmcv.runner import auto_fp16
|
||||
except ImportError:
|
||||
warnings.warn('auto_fp16 from mmcls will be deprecated.'
|
||||
'Please install mmcv>=1.1.4.')
|
||||
from mmcls.core import auto_fp16
|
||||
|
||||
|
||||
class BaseClassifier(nn.Module, metaclass=ABCMeta):
|
||||
"""Base class for classifiers"""
|
||||
|
||||
def __init__(self):
|
||||
super(BaseClassifier, self).__init__()
|
||||
self.fp16_enabled = False
|
||||
|
||||
@property
|
||||
def with_neck(self):
|
||||
|
@ -70,6 +79,7 @@ class BaseClassifier(nn.Module, metaclass=ABCMeta):
|
|||
else:
|
||||
raise NotImplementedError('aug_test has not been implemented')
|
||||
|
||||
@auto_fp16(apply_to=('img', ))
|
||||
def forward(self, img, return_loss=True, **kwargs):
|
||||
"""
|
||||
Calls either forward_train or forward_test depending on whether
|
||||
|
|
|
@ -10,10 +10,17 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
|||
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
||||
|
||||
from mmcls.apis import multi_gpu_test, single_gpu_test
|
||||
from mmcls.core import wrap_fp16_model
|
||||
from mmcls.datasets import build_dataloader, build_dataset
|
||||
from mmcls.models import build_classifier
|
||||
|
||||
# TODO import `wrap_fp16_model` from mmcv and delete them from mmcls
|
||||
try:
|
||||
from mmcv.runner import wrap_fp16_model
|
||||
except ImportError:
|
||||
warnings.warn('wrap_fp16_model from mmcls will be deprecated.'
|
||||
'Please install mmcv>=1.1.4.')
|
||||
from mmcls.core import wrap_fp16_model
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='mmcls test model')
|
||||
|
|
Loading…
Reference in New Issue