[Feature] Support fp16 training (#178)
* change mmcls fp16 to mmcv hook * support fp16 * clean unnessary stuffpull/183/head
parent
8eb845e718
commit
e76c5a368d
|
@ -0,0 +1,14 @@
|
||||||
|
# Mixed Precision Training
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
[OTHERS]
|
||||||
|
|
||||||
|
```latex
|
||||||
|
@article{micikevicius2017mixed,
|
||||||
|
title={Mixed precision training},
|
||||||
|
author={Micikevicius, Paulius and Narang, Sharan and Alben, Jonah and Diamos, Gregory and Elsen, Erich and Garcia, David and Ginsburg, Boris and Houston, Michael and Kuchaiev, Oleksii and Venkatesh, Ganesh and others},
|
||||||
|
journal={arXiv preprint arXiv:1710.03740},
|
||||||
|
year={2017}
|
||||||
|
}
|
||||||
|
```
|
|
@ -0,0 +1,4 @@
|
||||||
|
_base_ = ['../resnet/resnet50_b32x8_imagenet.py']
|
||||||
|
|
||||||
|
# fp16 settings
|
||||||
|
fp16 = dict(loss_scale='dynamic')
|
|
@ -0,0 +1,4 @@
|
||||||
|
_base_ = ['../resnet/resnet50_b32x8_imagenet.py']
|
||||||
|
|
||||||
|
# fp16 settings
|
||||||
|
fp16 = dict(loss_scale=512.)
|
|
@ -6,11 +6,18 @@ import torch
|
||||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||||
from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner
|
from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner
|
||||||
|
|
||||||
from mmcls.core import (DistEvalHook, DistOptimizerHook, EvalHook,
|
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
|
||||||
Fp16OptimizerHook)
|
|
||||||
from mmcls.datasets import build_dataloader, build_dataset
|
from mmcls.datasets import build_dataloader, build_dataset
|
||||||
from mmcls.utils import get_root_logger
|
from mmcls.utils import get_root_logger
|
||||||
|
|
||||||
|
# TODO import optimizer hook from mmcv and delete them from mmcls
|
||||||
|
try:
|
||||||
|
from mmcv.runner import Fp16OptimizerHook
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn('FP16OptimizerHook from mmcls will be deprecated.'
|
||||||
|
'Please install mmcv>=1.1.4.')
|
||||||
|
from mmcls.core import Fp16OptimizerHook
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed, deterministic=False):
|
def set_random_seed(seed, deterministic=False):
|
||||||
"""Set random seed.
|
"""Set random seed.
|
||||||
|
|
|
@ -10,12 +10,21 @@ import torch.nn as nn
|
||||||
from mmcv import color_val
|
from mmcv import color_val
|
||||||
from mmcv.utils import print_log
|
from mmcv.utils import print_log
|
||||||
|
|
||||||
|
# TODO import `auto_fp16` from mmcv and delete them from mmcls
|
||||||
|
try:
|
||||||
|
from mmcv.runner import auto_fp16
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn('auto_fp16 from mmcls will be deprecated.'
|
||||||
|
'Please install mmcv>=1.1.4.')
|
||||||
|
from mmcls.core import auto_fp16
|
||||||
|
|
||||||
|
|
||||||
class BaseClassifier(nn.Module, metaclass=ABCMeta):
|
class BaseClassifier(nn.Module, metaclass=ABCMeta):
|
||||||
"""Base class for classifiers"""
|
"""Base class for classifiers"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(BaseClassifier, self).__init__()
|
super(BaseClassifier, self).__init__()
|
||||||
|
self.fp16_enabled = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def with_neck(self):
|
def with_neck(self):
|
||||||
|
@ -70,6 +79,7 @@ class BaseClassifier(nn.Module, metaclass=ABCMeta):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('aug_test has not been implemented')
|
raise NotImplementedError('aug_test has not been implemented')
|
||||||
|
|
||||||
|
@auto_fp16(apply_to=('img', ))
|
||||||
def forward(self, img, return_loss=True, **kwargs):
|
def forward(self, img, return_loss=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Calls either forward_train or forward_test depending on whether
|
Calls either forward_train or forward_test depending on whether
|
||||||
|
|
|
@ -10,10 +10,17 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||||
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
||||||
|
|
||||||
from mmcls.apis import multi_gpu_test, single_gpu_test
|
from mmcls.apis import multi_gpu_test, single_gpu_test
|
||||||
from mmcls.core import wrap_fp16_model
|
|
||||||
from mmcls.datasets import build_dataloader, build_dataset
|
from mmcls.datasets import build_dataloader, build_dataset
|
||||||
from mmcls.models import build_classifier
|
from mmcls.models import build_classifier
|
||||||
|
|
||||||
|
# TODO import `wrap_fp16_model` from mmcv and delete them from mmcls
|
||||||
|
try:
|
||||||
|
from mmcv.runner import wrap_fp16_model
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn('wrap_fp16_model from mmcls will be deprecated.'
|
||||||
|
'Please install mmcv>=1.1.4.')
|
||||||
|
from mmcls.core import wrap_fp16_model
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='mmcls test model')
|
parser = argparse.ArgumentParser(description='mmcls test model')
|
||||||
|
|
Loading…
Reference in New Issue