From 4a044c646663231905c5659d2f83d0abe60e0ed8 Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU Date: Fri, 17 Jul 2020 11:14:07 +0800 Subject: [PATCH] Support FP16 (#428) * Support FP16 * update ci --- .github/workflows/build.yml | 2 +- mmcv/runner/fp16_utils.py | 303 +++++++++++++++++++++++++++++++++ mmcv/runner/hooks/optimizer.py | 91 ++++++++++ tests/test_fp16.py | 300 ++++++++++++++++++++++++++++++++ 4 files changed, 695 insertions(+), 1 deletion(-) create mode 100644 mmcv/runner/fp16_utils.py create mode 100644 tests/test_fp16.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index da860ea12..ceb9674f1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -184,7 +184,7 @@ jobs: run: pip install Pillow - name: Run unittests and generate coverage report run: | - pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_registry.py + pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_registry.py --ignore=tests/test_fp16.py build_macos: runs-on: macos-latest diff --git a/mmcv/runner/fp16_utils.py b/mmcv/runner/fp16_utils.py new file mode 100644 index 000000000..0a6daf1dc --- /dev/null +++ b/mmcv/runner/fp16_utils.py @@ -0,0 +1,303 @@ +import functools +from collections import OrderedDict, abc +from inspect import getfullargspec + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + + +def cast_tensor_type(inputs, src_type, dst_type): + """Recursively convert Tensor in inputs from src_type to dst_type. + + Args: + inputs: Inputs that to be casted. + src_type (torch.dtype): Source type.. + dst_type (torch.dtype): Destination type. + + Returns: + The same type with inputs, but all contained Tensors have been cast. + """ + if isinstance(inputs, torch.Tensor): + return inputs.to(dst_type) + elif isinstance(inputs, str): + return inputs + elif isinstance(inputs, np.ndarray): + return inputs + elif isinstance(inputs, abc.Mapping): + return type(inputs)({ + k: cast_tensor_type(v, src_type, dst_type) + for k, v in inputs.items() + }) + elif isinstance(inputs, abc.Iterable): + return type(inputs)( + cast_tensor_type(item, src_type, dst_type) for item in inputs) + else: + return inputs + + +def auto_fp16(apply_to=None, out_fp32=False): + """Decorator to enable fp16 training automatically. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If inputs arguments are fp32 tensors, they will + be converted to fp16 automatically. Arguments other than fp32 tensors are + ignored. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp32 (bool): Whether to convert the output back to fp32. + + Example: + + >>> import torch.nn as nn + >>> class MyModule1(nn.Module): + >>> + >>> # Convert x and y to fp16 + >>> @auto_fp16() + >>> def forward(self, x, y): + >>> pass + + >>> import torch.nn as nn + >>> class MyModule2(nn.Module): + >>> + >>> # convert pred to fp16 + >>> @auto_fp16(apply_to=('pred', )) + >>> def do_something(self, pred, others): + >>> pass + """ + + def auto_fp16_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError('@auto_fp16 can only be used to decorate the ' + 'method of nn.Module') + if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): + return old_func(*args, **kwargs) + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + # NOTE: default args are not taken into consideration + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.float, torch.half)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = {} + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.float, torch.half) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp32: + output = cast_tensor_type(output, torch.half, torch.float) + return output + + return new_func + + return auto_fp16_wrapper + + +def force_fp32(apply_to=None, out_fp16=False): + """Decorator to convert input arguments to fp32 in force. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If there are some inputs that must be processed + in fp32 mode, then this decorator can handle it. If inputs arguments are + fp16 tensors, they will be converted to fp32 automatically. Arguments other + than fp16 tensors are ignored. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp16 (bool): Whether to convert the output back to fp16. + + Example: + + >>> import torch.nn as nn + >>> class MyModule1(nn.Module): + >>> + >>> # Convert x and y to fp32 + >>> @force_fp32() + >>> def loss(self, x, y): + >>> pass + + >>> import torch.nn as nn + >>> class MyModule2(nn.Module): + >>> + >>> # convert pred to fp32 + >>> @force_fp32(apply_to=('pred', )) + >>> def post_process(self, pred, others): + >>> pass + """ + + def force_fp32_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError('@force_fp32 can only be used to decorate the ' + 'method of nn.Module') + if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): + return old_func(*args, **kwargs) + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.half, torch.float)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = dict() + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.half, torch.float) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp16: + output = cast_tensor_type(output, torch.float, torch.half) + return output + + return new_func + + return force_fp32_wrapper + + +def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + dist.all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + """Allreduce gradients. + + Args: + params (list[torch.Parameters]): List of parameters of a model + coalesce (bool, optional): Whether allreduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + grads = [ + param.grad.data for param in params + if param.requires_grad and param.grad is not None + ] + world_size = dist.get_world_size() + if coalesce: + _allreduce_coalesced(grads, world_size, bucket_size_mb) + else: + for tensor in grads: + dist.all_reduce(tensor.div_(world_size)) + + +def wrap_fp16_model(model): + """Wrap the FP32 model to FP16. + + 1. Convert FP32 model to FP16. + 2. Remain some necessary layers to be FP32, e.g., normalization layers. + + Args: + model (nn.Module): Model in FP32. + """ + # convert model to fp16 + model.half() + # patch the normalization layers to make it work in fp32 mode + patch_norm_fp32(model) + # set `fp16_enabled` flag + for m in model.modules(): + if hasattr(m, 'fp16_enabled'): + m.fp16_enabled = True + + +def patch_norm_fp32(module): + """Recursively convert normalization layers from FP16 to FP32. + + Args: + module (nn.Module): The modules to be converted in FP16. + + Returns: + nn.Module: The converted module, the normalization layers have been + converted to FP32. + """ + if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): + module.float() + if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3': + module.forward = patch_forward_method(module.forward, torch.half, + torch.float) + for child in module.children(): + patch_norm_fp32(child) + return module + + +def patch_forward_method(func, src_type, dst_type, convert_output=True): + """Patch the forward method of a module. + + Args: + func (callable): The original forward method. + src_type (torch.dtype): Type of input arguments to be converted from. + dst_type (torch.dtype): Type of input arguments to be converted to. + convert_output (bool): Whether to convert the output back to src_type. + + Returns: + callable: The patched forward method. + """ + + def new_forward(*args, **kwargs): + output = func(*cast_tensor_type(args, src_type, dst_type), + **cast_tensor_type(kwargs, src_type, dst_type)) + if convert_output: + output = cast_tensor_type(output, dst_type, src_type) + return output + + return new_forward diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py index d6390f515..77d8cc867 100644 --- a/mmcv/runner/hooks/optimizer.py +++ b/mmcv/runner/hooks/optimizer.py @@ -1,6 +1,9 @@ # Copyright (c) Open-MMLab. All rights reserved. +import copy + from torch.nn.utils import clip_grad +from ..fp16_utils import allreduce_grads, wrap_fp16_model from .hook import HOOKS, Hook @@ -26,3 +29,91 @@ class OptimizerHook(Hook): runner.log_buffer.update({'grad_norm': float(grad_norm)}, runner.outputs['num_samples']) runner.optimizer.step() + + +class Fp16OptimizerHook(OptimizerHook): + """FP16 optimizer hook. + + The steps of fp16 optimizer is as follows. + 1. Scale the loss value. + 2. BP in the fp16 model. + 2. Copy gradients from fp16 model to fp32 weights. + 3. Update fp32 weights. + 4. Copy updated parameters from fp32 weights to fp16 model. + + Refer to https://arxiv.org/abs/1710.03740 for more details. + + Args: + loss_scale (float): Scale factor multiplied with loss. + """ + + def __init__(self, + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, + loss_scale=512., + distributed=True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.loss_scale = loss_scale + self.distributed = distributed + + def before_run(self, runner): + """Preparing steps before Mixed Precision Training. + + 1. Make a master copy of fp32 weights for optimization. + 2. Convert the main model from fp32 to fp16. + """ + # keep a copy of fp32 weights + runner.optimizer.param_groups = copy.deepcopy( + runner.optimizer.param_groups) + # convert model to fp16 + wrap_fp16_model(runner.model) + + def copy_grads_to_fp32(self, fp16_net, fp32_weights): + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new(fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net, fp32_weights): + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner): + """Backward optimization steps for Mixed Precision Training. + + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients (fp16). + 3. Copy gradients from the model to the fp32 weight copy. + 4. Scale the gradients back and update the fp32 weight copy. + 5. Copy back the params from fp32 weight copy to the fp16 model. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + # scale the loss value + scaled_loss = runner.outputs['loss'] * self.loss_scale + scaled_loss.backward() + # copy fp16 grads in the model to fp32 params in the optimizer + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb) + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scale) + if self.grad_clip is not None: + self.clip_grads(fp32_weights) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) diff --git a/tests/test_fp16.py b/tests/test_fp16.py new file mode 100644 index 000000000..fb1c78814 --- /dev/null +++ b/tests/test_fp16.py @@ -0,0 +1,300 @@ +import numpy as np +import pytest +import torch +import torch.nn as nn + +from mmcv.runner.fp16_utils import auto_fp16, cast_tensor_type, force_fp32 + + +def test_cast_tensor_type(): + inputs = torch.FloatTensor([5.]) + src_type = torch.float32 + dst_type = torch.int32 + outputs = cast_tensor_type(inputs, src_type, dst_type) + assert isinstance(outputs, torch.Tensor) + assert outputs.dtype == dst_type + + inputs = 'tensor' + src_type = str + dst_type = str + outputs = cast_tensor_type(inputs, src_type, dst_type) + assert isinstance(outputs, str) + + inputs = np.array([5.]) + src_type = np.ndarray + dst_type = np.ndarray + outputs = cast_tensor_type(inputs, src_type, dst_type) + assert isinstance(outputs, np.ndarray) + + inputs = dict( + tensor_a=torch.FloatTensor([1.]), tensor_b=torch.FloatTensor([2.])) + src_type = torch.float32 + dst_type = torch.int32 + outputs = cast_tensor_type(inputs, src_type, dst_type) + assert isinstance(outputs, dict) + assert outputs['tensor_a'].dtype == dst_type + assert outputs['tensor_b'].dtype == dst_type + + inputs = [torch.FloatTensor([1.]), torch.FloatTensor([2.])] + src_type = torch.float32 + dst_type = torch.int32 + outputs = cast_tensor_type(inputs, src_type, dst_type) + assert isinstance(outputs, list) + assert outputs[0].dtype == dst_type + assert outputs[1].dtype == dst_type + + inputs = 5 + outputs = cast_tensor_type(inputs, None, None) + assert isinstance(outputs, int) + + +def test_auto_fp16(): + + with pytest.raises(TypeError): + # ExampleObject is not a subclass of nn.Module + + class ExampleObject(object): + + @auto_fp16() + def __call__(self, x): + return x + + model = ExampleObject() + input_x = torch.ones(1, dtype=torch.float32) + model(input_x) + + # apply to all input args + class ExampleModule(nn.Module): + + @auto_fp16() + def forward(self, x, y): + return x, y + + model = ExampleModule() + input_x = torch.ones(1, dtype=torch.float32) + input_y = torch.ones(1, dtype=torch.float32) + output_x, output_y = model(input_x, input_y) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.float32 + + model.fp16_enabled = True + output_x, output_y = model(input_x, input_y) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.half + + if torch.cuda.is_available(): + model.cuda() + output_x, output_y = model(input_x.cuda(), input_y.cuda()) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.half + + # apply to specified input args + class ExampleModule(nn.Module): + + @auto_fp16(apply_to=('x', )) + def forward(self, x, y): + return x, y + + model = ExampleModule() + input_x = torch.ones(1, dtype=torch.float32) + input_y = torch.ones(1, dtype=torch.float32) + output_x, output_y = model(input_x, input_y) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.float32 + + model.fp16_enabled = True + output_x, output_y = model(input_x, input_y) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.float32 + + if torch.cuda.is_available(): + model.cuda() + output_x, output_y = model(input_x.cuda(), input_y.cuda()) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.float32 + + # apply to optional input args + class ExampleModule(nn.Module): + + @auto_fp16(apply_to=('x', 'y')) + def forward(self, x, y=None, z=None): + return x, y, z + + model = ExampleModule() + input_x = torch.ones(1, dtype=torch.float32) + input_y = torch.ones(1, dtype=torch.float32) + input_z = torch.ones(1, dtype=torch.float32) + output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.float32 + assert output_z.dtype == torch.float32 + + model.fp16_enabled = True + output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.half + assert output_z.dtype == torch.float32 + + if torch.cuda.is_available(): + model.cuda() + output_x, output_y, output_z = model( + input_x.cuda(), y=input_y.cuda(), z=input_z.cuda()) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.half + assert output_z.dtype == torch.float32 + + # out_fp32=True + class ExampleModule(nn.Module): + + @auto_fp16(apply_to=('x', 'y'), out_fp32=True) + def forward(self, x, y=None, z=None): + return x, y, z + + model = ExampleModule() + input_x = torch.ones(1, dtype=torch.half) + input_y = torch.ones(1, dtype=torch.float32) + input_z = torch.ones(1, dtype=torch.float32) + output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.float32 + assert output_z.dtype == torch.float32 + + model.fp16_enabled = True + output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.float32 + assert output_z.dtype == torch.float32 + + if torch.cuda.is_available(): + model.cuda() + output_x, output_y, output_z = model( + input_x.cuda(), y=input_y.cuda(), z=input_z.cuda()) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.float32 + assert output_z.dtype == torch.float32 + + +def test_force_fp32(): + + with pytest.raises(TypeError): + # ExampleObject is not a subclass of nn.Module + + class ExampleObject(object): + + @force_fp32() + def __call__(self, x): + return x + + model = ExampleObject() + input_x = torch.ones(1, dtype=torch.float32) + model(input_x) + + # apply to all input args + class ExampleModule(nn.Module): + + @force_fp32() + def forward(self, x, y): + return x, y + + model = ExampleModule() + input_x = torch.ones(1, dtype=torch.half) + input_y = torch.ones(1, dtype=torch.half) + output_x, output_y = model(input_x, input_y) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.half + + model.fp16_enabled = True + output_x, output_y = model(input_x, input_y) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.float32 + + if torch.cuda.is_available(): + model.cuda() + output_x, output_y = model(input_x.cuda(), input_y.cuda()) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.float32 + + # apply to specified input args + class ExampleModule(nn.Module): + + @force_fp32(apply_to=('x', )) + def forward(self, x, y): + return x, y + + model = ExampleModule() + input_x = torch.ones(1, dtype=torch.half) + input_y = torch.ones(1, dtype=torch.half) + output_x, output_y = model(input_x, input_y) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.half + + model.fp16_enabled = True + output_x, output_y = model(input_x, input_y) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.half + + if torch.cuda.is_available(): + model.cuda() + output_x, output_y = model(input_x.cuda(), input_y.cuda()) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.half + + # apply to optional input args + class ExampleModule(nn.Module): + + @force_fp32(apply_to=('x', 'y')) + def forward(self, x, y=None, z=None): + return x, y, z + + model = ExampleModule() + input_x = torch.ones(1, dtype=torch.half) + input_y = torch.ones(1, dtype=torch.half) + input_z = torch.ones(1, dtype=torch.half) + output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.half + assert output_z.dtype == torch.half + + model.fp16_enabled = True + output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.float32 + assert output_z.dtype == torch.half + + if torch.cuda.is_available(): + model.cuda() + output_x, output_y, output_z = model( + input_x.cuda(), y=input_y.cuda(), z=input_z.cuda()) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.float32 + assert output_z.dtype == torch.half + + # out_fp16=True + class ExampleModule(nn.Module): + + @force_fp32(apply_to=('x', 'y'), out_fp16=True) + def forward(self, x, y=None, z=None): + return x, y, z + + model = ExampleModule() + input_x = torch.ones(1, dtype=torch.float32) + input_y = torch.ones(1, dtype=torch.half) + input_z = torch.ones(1, dtype=torch.half) + output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) + assert output_x.dtype == torch.float32 + assert output_y.dtype == torch.half + assert output_z.dtype == torch.half + + model.fp16_enabled = True + output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.half + assert output_z.dtype == torch.half + + if torch.cuda.is_available(): + model.cuda() + output_x, output_y, output_z = model( + input_x.cuda(), y=input_y.cuda(), z=input_z.cuda()) + assert output_x.dtype == torch.half + assert output_y.dtype == torch.half + assert output_z.dtype == torch.half