diff --git a/configs/fp16/README.md b/configs/fp16/README.md new file mode 100644 index 00000000..6b3c6175 --- /dev/null +++ b/configs/fp16/README.md @@ -0,0 +1,14 @@ +# Mixed Precision Training + +## Introduction + +[OTHERS] + +```latex +@article{micikevicius2017mixed, + title={Mixed precision training}, + author={Micikevicius, Paulius and Narang, Sharan and Alben, Jonah and Diamos, Gregory and Elsen, Erich and Garcia, David and Ginsburg, Boris and Houston, Michael and Kuchaiev, Oleksii and Venkatesh, Ganesh and others}, + journal={arXiv preprint arXiv:1710.03740}, + year={2017} +} +``` diff --git a/configs/fp16/resnet50_b32x8_fp16_dynamic_imagenet.py b/configs/fp16/resnet50_b32x8_fp16_dynamic_imagenet.py new file mode 100644 index 00000000..35b4ff54 --- /dev/null +++ b/configs/fp16/resnet50_b32x8_fp16_dynamic_imagenet.py @@ -0,0 +1,4 @@ +_base_ = ['../resnet/resnet50_b32x8_imagenet.py'] + +# fp16 settings +fp16 = dict(loss_scale='dynamic') diff --git a/configs/fp16/resnet50_b32x8_fp16_imagenet.py b/configs/fp16/resnet50_b32x8_fp16_imagenet.py new file mode 100644 index 00000000..fbab0cc1 --- /dev/null +++ b/configs/fp16/resnet50_b32x8_fp16_imagenet.py @@ -0,0 +1,4 @@ +_base_ = ['../resnet/resnet50_b32x8_imagenet.py'] + +# fp16 settings +fp16 = dict(loss_scale=512.) diff --git a/mmcls/apis/train.py b/mmcls/apis/train.py index 43689efb..77df6f8f 100644 --- a/mmcls/apis/train.py +++ b/mmcls/apis/train.py @@ -6,11 +6,18 @@ import torch from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner -from mmcls.core import (DistEvalHook, DistOptimizerHook, EvalHook, - Fp16OptimizerHook) +from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook from mmcls.datasets import build_dataloader, build_dataset from mmcls.utils import get_root_logger +# TODO import optimizer hook from mmcv and delete them from mmcls +try: + from mmcv.runner import Fp16OptimizerHook +except ImportError: + warnings.warn('FP16OptimizerHook from mmcls will be deprecated.' + 'Please install mmcv>=1.1.4.') + from mmcls.core import Fp16OptimizerHook + def set_random_seed(seed, deterministic=False): """Set random seed. diff --git a/mmcls/models/classifiers/base.py b/mmcls/models/classifiers/base.py index 9427d10e..d4d72bdf 100644 --- a/mmcls/models/classifiers/base.py +++ b/mmcls/models/classifiers/base.py @@ -10,12 +10,21 @@ import torch.nn as nn from mmcv import color_val from mmcv.utils import print_log +# TODO import `auto_fp16` from mmcv and delete them from mmcls +try: + from mmcv.runner import auto_fp16 +except ImportError: + warnings.warn('auto_fp16 from mmcls will be deprecated.' + 'Please install mmcv>=1.1.4.') + from mmcls.core import auto_fp16 + class BaseClassifier(nn.Module, metaclass=ABCMeta): """Base class for classifiers""" def __init__(self): super(BaseClassifier, self).__init__() + self.fp16_enabled = False @property def with_neck(self): @@ -70,6 +79,7 @@ class BaseClassifier(nn.Module, metaclass=ABCMeta): else: raise NotImplementedError('aug_test has not been implemented') + @auto_fp16(apply_to=('img', )) def forward(self, img, return_loss=True, **kwargs): """ Calls either forward_train or forward_test depending on whether diff --git a/tools/test.py b/tools/test.py index 1dd7f937..f11473af 100644 --- a/tools/test.py +++ b/tools/test.py @@ -10,10 +10,17 @@ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import get_dist_info, init_dist, load_checkpoint from mmcls.apis import multi_gpu_test, single_gpu_test -from mmcls.core import wrap_fp16_model from mmcls.datasets import build_dataloader, build_dataset from mmcls.models import build_classifier +# TODO import `wrap_fp16_model` from mmcv and delete them from mmcls +try: + from mmcv.runner import wrap_fp16_model +except ImportError: + warnings.warn('wrap_fp16_model from mmcls will be deprecated.' + 'Please install mmcv>=1.1.4.') + from mmcls.core import wrap_fp16_model + def parse_args(): parser = argparse.ArgumentParser(description='mmcls test model')