From fffc87578f37273961bc4a4e6c454989df88ca57 Mon Sep 17 00:00:00 2001 From: "Y. Xiong" Date: Tue, 27 Apr 2021 13:55:59 +0800 Subject: [PATCH] [Feature]: Support auto_fp16 using torch.cuda.amp when PyTorch >= 1.6.0 (#951) * add torch.cuda.amp to fp16_utils and optimizers * use with context manager for autocast * add doc to explain the behavior differences between real amp and ours * fix docstring --- mmcv/runner/fp16_utils.py | 47 ++++- mmcv/runner/hooks/optimizer.py | 335 ++++++++++++++++++++++----------- 2 files changed, 265 insertions(+), 117 deletions(-) diff --git a/mmcv/runner/fp16_utils.py b/mmcv/runner/fp16_utils.py index b9ab9a13a..2f958fae1 100644 --- a/mmcv/runner/fp16_utils.py +++ b/mmcv/runner/fp16_utils.py @@ -7,8 +7,18 @@ import numpy as np import torch import torch.nn as nn +from mmcv.utils import TORCH_VERSION from .dist_utils import allreduce_grads as _allreduce_grads +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported + # and used; otherwise, auto fp16 will adopt mmcv's implementation. + # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16 + # manually, so the behavior may not be consistant with real amp. + from torch.cuda.amp import autocast +except ImportError: + pass + def cast_tensor_type(inputs, src_type, dst_type): """Recursively convert Tensor in inputs from src_type to dst_type. @@ -45,7 +55,8 @@ def auto_fp16(apply_to=None, out_fp32=False): This decorator is useful when you write custom modules and want to support mixed precision training. If inputs arguments are fp32 tensors, they will be converted to fp16 automatically. Arguments other than fp32 tensors are - ignored. + ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the + backend, otherwise, original mmcv implementation will be adopted. Args: apply_to (Iterable, optional): The argument names to be converted. @@ -82,6 +93,7 @@ def auto_fp16(apply_to=None, out_fp32=False): 'method of nn.Module') if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): return old_func(*args, **kwargs) + # get the arg spec of the decorated method args_info = getfullargspec(old_func) # get the argument names to be casted @@ -107,7 +119,11 @@ def auto_fp16(apply_to=None, out_fp32=False): else: new_kwargs[arg_name] = arg_value # apply converted arguments to the decorated method - output = old_func(*new_args, **new_kwargs) + if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': + with autocast(enabled=True): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) # cast the results back to fp32 if necessary if out_fp32: output = cast_tensor_type(output, torch.half, torch.float) @@ -125,7 +141,9 @@ def force_fp32(apply_to=None, out_fp16=False): mixed precision training. If there are some inputs that must be processed in fp32 mode, then this decorator can handle it. If inputs arguments are fp16 tensors, they will be converted to fp32 automatically. Arguments other - than fp16 tensors are ignored. + than fp16 tensors are ignored. If you are using PyTorch >= 1.6, + torch.cuda.amp is used as the backend, otherwise, original mmcv + implementation will be adopted. Args: apply_to (Iterable, optional): The argument names to be converted. @@ -186,7 +204,11 @@ def force_fp32(apply_to=None, out_fp16=False): else: new_kwargs[arg_name] = arg_value # apply converted arguments to the decorated method - output = old_func(*new_args, **new_kwargs) + if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': + with autocast(enabled=False): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) # cast the results back to fp32 if necessary if out_fp16: output = cast_tensor_type(output, torch.float, torch.half) @@ -207,16 +229,25 @@ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): def wrap_fp16_model(model): """Wrap the FP32 model to FP16. + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the + backend, otherwise, original mmcv implementation will be adopted. + + For PyTorch >= 1.6, this function will + 1. Set fp16 flag inside the model to True. + + Otherwise: 1. Convert FP32 model to FP16. 2. Remain some necessary layers to be FP32, e.g., normalization layers. + 3. Set `fp16_enabled` flag inside the model to True. Args: model (nn.Module): Model in FP32. """ - # convert model to fp16 - model.half() - # patch the normalization layers to make it work in fp32 mode - patch_norm_fp32(model) + if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.6.0': + # convert model to fp16 + model.half() + # patch the normalization layers to make it work in fp32 mode + patch_norm_fp32(model) # set `fp16_enabled` flag for m in model.modules(): if hasattr(m, 'fp16_enabled'): diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py index 4f27844a4..ca21c703b 100644 --- a/mmcv/runner/hooks/optimizer.py +++ b/mmcv/runner/hooks/optimizer.py @@ -5,10 +5,18 @@ from itertools import chain from torch.nn.utils import clip_grad +from mmcv.utils import TORCH_VERSION from ..dist_utils import allreduce_grads from ..fp16_utils import LossScaler, wrap_fp16_model from .hook import HOOKS, Hook +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported + # and used; otherwise, auto fp16 will adopt mmcv's implementation. + from torch.cuda.amp import GradScaler +except ImportError: + pass + @HOOKS.register_module() class OptimizerHook(Hook): @@ -34,128 +42,237 @@ class OptimizerHook(Hook): runner.optimizer.step() -@HOOKS.register_module() -class Fp16OptimizerHook(OptimizerHook): - """FP16 optimizer hook. +if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': - The steps of fp16 optimizer is as follows. - 1. Scale the loss value. - 2. BP in the fp16 model. - 2. Copy gradients from fp16 model to fp32 weights. - 3. Update fp32 weights. - 4. Copy updated parameters from fp32 weights to fp16 model. + @HOOKS.register_module() + class Fp16OptimizerHook(OptimizerHook): + """FP16 optimizer hook (using PyTorch's implementation). - Refer to https://arxiv.org/abs/1710.03740 for more details. + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, + to take care of the optimization procedure. - Args: - loss_scale (float | str | dict): Scale factor multiplied with loss. - If loss_scale is a float, static loss scaling will be used with - the specified scale. If loss_scale is a string, it must be - 'dynamic', then dynamic loss scaling will be used. - It can also be a dict containing arguments of LossScaler. - Defaults to 512. - """ + Args: + loss_scale (float | str | dict): Scale factor configuration. + If loss_scale is a float, static loss scaling will be used with + the specified scale. If loss_scale is a string, it must be + 'dynamic', then dynamic loss scaling will be used. + It can also be a dict containing arguments of GradScalar. + Defaults to 512. For Pytorch >= 1.6, mmcv uses official + implementation of GradScaler. If you use a dict version of + loss_scale to create GradScaler, plese refer to: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler + for the parameters. - def __init__(self, - grad_clip=None, - coalesce=True, - bucket_size_mb=-1, - loss_scale=512., - distributed=True): - self.grad_clip = grad_clip - self.coalesce = coalesce - self.bucket_size_mb = bucket_size_mb - self.distributed = distributed - if loss_scale == 'dynamic': - self.loss_scaler = LossScaler(mode='dynamic') - elif isinstance(loss_scale, float): - self.loss_scaler = LossScaler(init_scale=loss_scale, mode='static') - elif isinstance(loss_scale, dict): - self.loss_scaler = LossScaler(**loss_scale) - else: - raise ValueError('loss_scale must be of type float, dict, or ' - f'"dynamic", got {loss_scale}') - - def before_run(self, runner): - """Preparing steps before Mixed Precision Training. - - 1. Make a master copy of fp32 weights for optimization. - 2. Convert the main model from fp32 to fp16. + Examples: + >>> loss_scale = dict( + ... init_scale=65536.0, + ... growth_factor=2.0, + ... backoff_factor=0.5, + ... growth_interval=2000 + ... ) + >>> optimizer = Fp16OptimizerHook(loss_scale=loss_scale) """ - # keep a copy of fp32 weights - old_groups = runner.optimizer.param_groups - runner.optimizer.param_groups = copy.deepcopy( - runner.optimizer.param_groups) - state = defaultdict(dict) - p_map = { - old_p: p - for old_p, p in zip( - chain(*(g['params'] for g in old_groups)), - chain(*(g['params'] for g in runner.optimizer.param_groups))) - } - for k, v in runner.optimizer.state.items(): - state[p_map[k]] = v - runner.optimizer.state = state - # convert model to fp16 - wrap_fp16_model(runner.model) - def copy_grads_to_fp32(self, fp16_net, fp32_weights): - """Copy gradients from fp16 model to fp32 weight copy.""" - for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()): - if fp16_param.grad is not None: - if fp32_param.grad is None: - fp32_param.grad = fp32_param.data.new(fp32_param.size()) - fp32_param.grad.copy_(fp16_param.grad) + def __init__(self, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + loss_scale=512., + distributed=True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + self._scale_update_param = None + if loss_scale == 'dynamic': + self.loss_scaler = GradScaler() + elif isinstance(loss_scale, float): + self._scale_update_param = loss_scale + self.loss_scaler = GradScaler(init_scale=loss_scale) + elif isinstance(loss_scale, dict): + self.loss_scaler = GradScaler(**loss_scale) + else: + raise ValueError('loss_scale must be of type float, dict, or ' + f'"dynamic", got {loss_scale}') - def copy_params_to_fp16(self, fp16_net, fp32_weights): - """Copy updated params from fp32 weight copy to fp16 model.""" - for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights): - fp16_param.data.copy_(fp32_param.data) + def before_run(self, runner): + """Preparing steps before Mixed Precision Training.""" + # wrap model mode to fp16 + wrap_fp16_model(runner.model) - def after_train_iter(self, runner): - """Backward optimization steps for Mixed Precision Training. For - dynamic loss scaling, please refer `loss_scalar.py` + def copy_grads_to_fp32(self, fp16_net, fp32_weights): + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, + fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new( + fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) - 1. Scale the loss by a scale factor. - 2. Backward the loss to obtain the gradients (fp16). - 3. Copy gradients from the model to the fp32 weight copy. - 4. Scale the gradients back and update the fp32 weight copy. - 5. Copy back the params from fp32 weight copy to the fp16 model. - """ - # clear grads of last iteration - runner.model.zero_grad() - runner.optimizer.zero_grad() - # scale the loss value - scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale - scaled_loss.backward() - # copy fp16 grads in the model to fp32 params in the optimizer + def copy_params_to_fp16(self, fp16_net, fp32_weights): + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), + fp32_weights): + fp16_param.data.copy_(fp32_param.data) - fp32_weights = [] - for param_group in runner.optimizer.param_groups: - fp32_weights += param_group['params'] - self.copy_grads_to_fp32(runner.model, fp32_weights) - # allreduce grads - if self.distributed: - allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb) + def after_train_iter(self, runner): + """Backward optimization steps for Mixed Precision Training. For + dynamic loss scaling, please refer to + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler. - has_overflow = self.loss_scaler.has_overflow(fp32_weights) - # if has overflow, skip this iteration - if not has_overflow: - # scale the gradients back - for param in fp32_weights: - if param.grad is not None: - param.grad.div_(self.loss_scaler.loss_scale) + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients. + 3. Unscale the optimizer’s gradient tensors. + 4. Call optimizer.step() and update scale factor. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + + self.loss_scaler.scale(runner.outputs['loss']).backward() + self.loss_scaler.unscale_(runner.optimizer) + # grad clip if self.grad_clip is not None: - grad_norm = self.clip_grads(fp32_weights) + grad_norm = self.clip_grads(runner.model.parameters()) if grad_norm is not None: # Add grad norm to the logger runner.log_buffer.update({'grad_norm': float(grad_norm)}, runner.outputs['num_samples']) - # update fp32 params - runner.optimizer.step() - # copy fp32 params to the fp16 model - self.copy_params_to_fp16(runner.model, fp32_weights) - self.loss_scaler.update_scale(has_overflow) - if has_overflow: - runner.logger.warning('Check overflow, downscale loss scale ' - f'to {self.loss_scaler.cur_scale}') + # backward and update scaler + self.loss_scaler.step(runner.optimizer) + self.loss_scaler.update(self._scale_update_param) +else: + + @HOOKS.register_module() + class Fp16OptimizerHook(OptimizerHook): + """FP16 optimizer hook (mmcv's implementation). + + The steps of fp16 optimizer is as follows. + 1. Scale the loss value. + 2. BP in the fp16 model. + 2. Copy gradients from fp16 model to fp32 weights. + 3. Update fp32 weights. + 4. Copy updated parameters from fp32 weights to fp16 model. + + Refer to https://arxiv.org/abs/1710.03740 for more details. + + Args: + loss_scale (float | str | dict): Scale factor configuration. + If loss_scale is a float, static loss scaling will be used with + the specified scale. If loss_scale is a string, it must be + 'dynamic', then dynamic loss scaling will be used. + It can also be a dict containing arguments of LossScaler. + Defaults to 512. + """ + + def __init__(self, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + loss_scale=512., + distributed=True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + if loss_scale == 'dynamic': + self.loss_scaler = LossScaler(mode='dynamic') + elif isinstance(loss_scale, float): + self.loss_scaler = LossScaler( + init_scale=loss_scale, mode='static') + elif isinstance(loss_scale, dict): + self.loss_scaler = LossScaler(**loss_scale) + else: + raise ValueError('loss_scale must be of type float, dict, or ' + f'"dynamic", got {loss_scale}') + + def before_run(self, runner): + """Preparing steps before Mixed Precision Training. + + 1. Make a master copy of fp32 weights for optimization. + 2. Convert the main model from fp32 to fp16. + """ + # keep a copy of fp32 weights + old_groups = runner.optimizer.param_groups + runner.optimizer.param_groups = copy.deepcopy( + runner.optimizer.param_groups) + state = defaultdict(dict) + p_map = { + old_p: p + for old_p, p in zip( + chain(*(g['params'] for g in old_groups)), + chain(*(g['params'] + for g in runner.optimizer.param_groups))) + } + for k, v in runner.optimizer.state.items(): + state[p_map[k]] = v + runner.optimizer.state = state + # convert model to fp16 + wrap_fp16_model(runner.model) + + def copy_grads_to_fp32(self, fp16_net, fp32_weights): + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, + fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new( + fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net, fp32_weights): + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), + fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner): + """Backward optimization steps for Mixed Precision Training. For + dynamic loss scaling, please refer `loss_scalar.py` + + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients (fp16). + 3. Copy gradients from the model to the fp32 weight copy. + 4. Scale the gradients back and update the fp32 weight copy. + 5. Copy back the params from fp32 weight copy to the fp16 model. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + # scale the loss value + scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale + scaled_loss.backward() + # copy fp16 grads in the model to fp32 params in the optimizer + + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, + self.bucket_size_mb) + + has_overflow = self.loss_scaler.has_overflow(fp32_weights) + # if has overflow, skip this iteration + if not has_overflow: + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scaler.loss_scale) + if self.grad_clip is not None: + grad_norm = self.clip_grads(fp32_weights) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + self.loss_scaler.update_scale(has_overflow) + if has_overflow: + runner.logger.warning('Check overflow, downscale loss scale ' + f'to {self.loss_scaler.cur_scale}')