[Feature] Support fp16 training (#178)

* change mmcls fp16 to mmcv hook

* support fp16

* clean unnessary stuff
pull/183/head
LXXXXR 2021-03-17 15:53:55 +08:00 committed by GitHub
parent 8eb845e718
commit e76c5a368d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 49 additions and 3 deletions

View File

@ -0,0 +1,14 @@
# Mixed Precision Training
## Introduction
[OTHERS]
```latex
@article{micikevicius2017mixed,
title={Mixed precision training},
author={Micikevicius, Paulius and Narang, Sharan and Alben, Jonah and Diamos, Gregory and Elsen, Erich and Garcia, David and Ginsburg, Boris and Houston, Michael and Kuchaiev, Oleksii and Venkatesh, Ganesh and others},
journal={arXiv preprint arXiv:1710.03740},
year={2017}
}
```

View File

@ -0,0 +1,4 @@
_base_ = ['../resnet/resnet50_b32x8_imagenet.py']
# fp16 settings
fp16 = dict(loss_scale='dynamic')

View File

@ -0,0 +1,4 @@
_base_ = ['../resnet/resnet50_b32x8_imagenet.py']
# fp16 settings
fp16 = dict(loss_scale=512.)

View File

@ -6,11 +6,18 @@ import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner
from mmcls.core import (DistEvalHook, DistOptimizerHook, EvalHook,
Fp16OptimizerHook)
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import get_root_logger
# TODO import optimizer hook from mmcv and delete them from mmcls
try:
from mmcv.runner import Fp16OptimizerHook
except ImportError:
warnings.warn('FP16OptimizerHook from mmcls will be deprecated.'
'Please install mmcv>=1.1.4.')
from mmcls.core import Fp16OptimizerHook
def set_random_seed(seed, deterministic=False):
"""Set random seed.

View File

@ -10,12 +10,21 @@ import torch.nn as nn
from mmcv import color_val
from mmcv.utils import print_log
# TODO import `auto_fp16` from mmcv and delete them from mmcls
try:
from mmcv.runner import auto_fp16
except ImportError:
warnings.warn('auto_fp16 from mmcls will be deprecated.'
'Please install mmcv>=1.1.4.')
from mmcls.core import auto_fp16
class BaseClassifier(nn.Module, metaclass=ABCMeta):
"""Base class for classifiers"""
def __init__(self):
super(BaseClassifier, self).__init__()
self.fp16_enabled = False
@property
def with_neck(self):
@ -70,6 +79,7 @@ class BaseClassifier(nn.Module, metaclass=ABCMeta):
else:
raise NotImplementedError('aug_test has not been implemented')
@auto_fp16(apply_to=('img', ))
def forward(self, img, return_loss=True, **kwargs):
"""
Calls either forward_train or forward_test depending on whether

View File

@ -10,10 +10,17 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmcls.apis import multi_gpu_test, single_gpu_test
from mmcls.core import wrap_fp16_model
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.models import build_classifier
# TODO import `wrap_fp16_model` from mmcv and delete them from mmcls
try:
from mmcv.runner import wrap_fp16_model
except ImportError:
warnings.warn('wrap_fp16_model from mmcls will be deprecated.'
'Please install mmcv>=1.1.4.')
from mmcls.core import wrap_fp16_model
def parse_args():
parser = argparse.ArgumentParser(description='mmcls test model')