mirror of https://github.com/open-mmlab/mmcv.git
parent
7cfc839ea5
commit
4a044c6466
|
@ -184,7 +184,7 @@ jobs:
|
|||
run: pip install Pillow
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_registry.py
|
||||
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_registry.py --ignore=tests/test_fp16.py
|
||||
|
||||
build_macos:
|
||||
runs-on: macos-latest
|
||||
|
|
|
@ -0,0 +1,303 @@
|
|||
import functools
|
||||
from collections import OrderedDict, abc
|
||||
from inspect import getfullargspec
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
||||
_unflatten_dense_tensors)
|
||||
|
||||
|
||||
def cast_tensor_type(inputs, src_type, dst_type):
|
||||
"""Recursively convert Tensor in inputs from src_type to dst_type.
|
||||
|
||||
Args:
|
||||
inputs: Inputs that to be casted.
|
||||
src_type (torch.dtype): Source type..
|
||||
dst_type (torch.dtype): Destination type.
|
||||
|
||||
Returns:
|
||||
The same type with inputs, but all contained Tensors have been cast.
|
||||
"""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
return inputs.to(dst_type)
|
||||
elif isinstance(inputs, str):
|
||||
return inputs
|
||||
elif isinstance(inputs, np.ndarray):
|
||||
return inputs
|
||||
elif isinstance(inputs, abc.Mapping):
|
||||
return type(inputs)({
|
||||
k: cast_tensor_type(v, src_type, dst_type)
|
||||
for k, v in inputs.items()
|
||||
})
|
||||
elif isinstance(inputs, abc.Iterable):
|
||||
return type(inputs)(
|
||||
cast_tensor_type(item, src_type, dst_type) for item in inputs)
|
||||
else:
|
||||
return inputs
|
||||
|
||||
|
||||
def auto_fp16(apply_to=None, out_fp32=False):
|
||||
"""Decorator to enable fp16 training automatically.
|
||||
|
||||
This decorator is useful when you write custom modules and want to support
|
||||
mixed precision training. If inputs arguments are fp32 tensors, they will
|
||||
be converted to fp16 automatically. Arguments other than fp32 tensors are
|
||||
ignored.
|
||||
|
||||
Args:
|
||||
apply_to (Iterable, optional): The argument names to be converted.
|
||||
`None` indicates all arguments.
|
||||
out_fp32 (bool): Whether to convert the output back to fp32.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import torch.nn as nn
|
||||
>>> class MyModule1(nn.Module):
|
||||
>>>
|
||||
>>> # Convert x and y to fp16
|
||||
>>> @auto_fp16()
|
||||
>>> def forward(self, x, y):
|
||||
>>> pass
|
||||
|
||||
>>> import torch.nn as nn
|
||||
>>> class MyModule2(nn.Module):
|
||||
>>>
|
||||
>>> # convert pred to fp16
|
||||
>>> @auto_fp16(apply_to=('pred', ))
|
||||
>>> def do_something(self, pred, others):
|
||||
>>> pass
|
||||
"""
|
||||
|
||||
def auto_fp16_wrapper(old_func):
|
||||
|
||||
@functools.wraps(old_func)
|
||||
def new_func(*args, **kwargs):
|
||||
# check if the module has set the attribute `fp16_enabled`, if not,
|
||||
# just fallback to the original method.
|
||||
if not isinstance(args[0], torch.nn.Module):
|
||||
raise TypeError('@auto_fp16 can only be used to decorate the '
|
||||
'method of nn.Module')
|
||||
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
|
||||
return old_func(*args, **kwargs)
|
||||
# get the arg spec of the decorated method
|
||||
args_info = getfullargspec(old_func)
|
||||
# get the argument names to be casted
|
||||
args_to_cast = args_info.args if apply_to is None else apply_to
|
||||
# convert the args that need to be processed
|
||||
new_args = []
|
||||
# NOTE: default args are not taken into consideration
|
||||
if args:
|
||||
arg_names = args_info.args[:len(args)]
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if arg_name in args_to_cast:
|
||||
new_args.append(
|
||||
cast_tensor_type(args[i], torch.float, torch.half))
|
||||
else:
|
||||
new_args.append(args[i])
|
||||
# convert the kwargs that need to be processed
|
||||
new_kwargs = {}
|
||||
if kwargs:
|
||||
for arg_name, arg_value in kwargs.items():
|
||||
if arg_name in args_to_cast:
|
||||
new_kwargs[arg_name] = cast_tensor_type(
|
||||
arg_value, torch.float, torch.half)
|
||||
else:
|
||||
new_kwargs[arg_name] = arg_value
|
||||
# apply converted arguments to the decorated method
|
||||
output = old_func(*new_args, **new_kwargs)
|
||||
# cast the results back to fp32 if necessary
|
||||
if out_fp32:
|
||||
output = cast_tensor_type(output, torch.half, torch.float)
|
||||
return output
|
||||
|
||||
return new_func
|
||||
|
||||
return auto_fp16_wrapper
|
||||
|
||||
|
||||
def force_fp32(apply_to=None, out_fp16=False):
|
||||
"""Decorator to convert input arguments to fp32 in force.
|
||||
|
||||
This decorator is useful when you write custom modules and want to support
|
||||
mixed precision training. If there are some inputs that must be processed
|
||||
in fp32 mode, then this decorator can handle it. If inputs arguments are
|
||||
fp16 tensors, they will be converted to fp32 automatically. Arguments other
|
||||
than fp16 tensors are ignored.
|
||||
|
||||
Args:
|
||||
apply_to (Iterable, optional): The argument names to be converted.
|
||||
`None` indicates all arguments.
|
||||
out_fp16 (bool): Whether to convert the output back to fp16.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import torch.nn as nn
|
||||
>>> class MyModule1(nn.Module):
|
||||
>>>
|
||||
>>> # Convert x and y to fp32
|
||||
>>> @force_fp32()
|
||||
>>> def loss(self, x, y):
|
||||
>>> pass
|
||||
|
||||
>>> import torch.nn as nn
|
||||
>>> class MyModule2(nn.Module):
|
||||
>>>
|
||||
>>> # convert pred to fp32
|
||||
>>> @force_fp32(apply_to=('pred', ))
|
||||
>>> def post_process(self, pred, others):
|
||||
>>> pass
|
||||
"""
|
||||
|
||||
def force_fp32_wrapper(old_func):
|
||||
|
||||
@functools.wraps(old_func)
|
||||
def new_func(*args, **kwargs):
|
||||
# check if the module has set the attribute `fp16_enabled`, if not,
|
||||
# just fallback to the original method.
|
||||
if not isinstance(args[0], torch.nn.Module):
|
||||
raise TypeError('@force_fp32 can only be used to decorate the '
|
||||
'method of nn.Module')
|
||||
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
|
||||
return old_func(*args, **kwargs)
|
||||
# get the arg spec of the decorated method
|
||||
args_info = getfullargspec(old_func)
|
||||
# get the argument names to be casted
|
||||
args_to_cast = args_info.args if apply_to is None else apply_to
|
||||
# convert the args that need to be processed
|
||||
new_args = []
|
||||
if args:
|
||||
arg_names = args_info.args[:len(args)]
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if arg_name in args_to_cast:
|
||||
new_args.append(
|
||||
cast_tensor_type(args[i], torch.half, torch.float))
|
||||
else:
|
||||
new_args.append(args[i])
|
||||
# convert the kwargs that need to be processed
|
||||
new_kwargs = dict()
|
||||
if kwargs:
|
||||
for arg_name, arg_value in kwargs.items():
|
||||
if arg_name in args_to_cast:
|
||||
new_kwargs[arg_name] = cast_tensor_type(
|
||||
arg_value, torch.half, torch.float)
|
||||
else:
|
||||
new_kwargs[arg_name] = arg_value
|
||||
# apply converted arguments to the decorated method
|
||||
output = old_func(*new_args, **new_kwargs)
|
||||
# cast the results back to fp32 if necessary
|
||||
if out_fp16:
|
||||
output = cast_tensor_type(output, torch.float, torch.half)
|
||||
return output
|
||||
|
||||
return new_func
|
||||
|
||||
return force_fp32_wrapper
|
||||
|
||||
|
||||
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
|
||||
if bucket_size_mb > 0:
|
||||
bucket_size_bytes = bucket_size_mb * 1024 * 1024
|
||||
buckets = _take_tensors(tensors, bucket_size_bytes)
|
||||
else:
|
||||
buckets = OrderedDict()
|
||||
for tensor in tensors:
|
||||
tp = tensor.type()
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(tensor)
|
||||
buckets = buckets.values()
|
||||
|
||||
for bucket in buckets:
|
||||
flat_tensors = _flatten_dense_tensors(bucket)
|
||||
dist.all_reduce(flat_tensors)
|
||||
flat_tensors.div_(world_size)
|
||||
for tensor, synced in zip(
|
||||
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
|
||||
tensor.copy_(synced)
|
||||
|
||||
|
||||
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
|
||||
"""Allreduce gradients.
|
||||
|
||||
Args:
|
||||
params (list[torch.Parameters]): List of parameters of a model
|
||||
coalesce (bool, optional): Whether allreduce parameters as a whole.
|
||||
Defaults to True.
|
||||
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
|
||||
Defaults to -1.
|
||||
"""
|
||||
grads = [
|
||||
param.grad.data for param in params
|
||||
if param.requires_grad and param.grad is not None
|
||||
]
|
||||
world_size = dist.get_world_size()
|
||||
if coalesce:
|
||||
_allreduce_coalesced(grads, world_size, bucket_size_mb)
|
||||
else:
|
||||
for tensor in grads:
|
||||
dist.all_reduce(tensor.div_(world_size))
|
||||
|
||||
|
||||
def wrap_fp16_model(model):
|
||||
"""Wrap the FP32 model to FP16.
|
||||
|
||||
1. Convert FP32 model to FP16.
|
||||
2. Remain some necessary layers to be FP32, e.g., normalization layers.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model in FP32.
|
||||
"""
|
||||
# convert model to fp16
|
||||
model.half()
|
||||
# patch the normalization layers to make it work in fp32 mode
|
||||
patch_norm_fp32(model)
|
||||
# set `fp16_enabled` flag
|
||||
for m in model.modules():
|
||||
if hasattr(m, 'fp16_enabled'):
|
||||
m.fp16_enabled = True
|
||||
|
||||
|
||||
def patch_norm_fp32(module):
|
||||
"""Recursively convert normalization layers from FP16 to FP32.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The modules to be converted in FP16.
|
||||
|
||||
Returns:
|
||||
nn.Module: The converted module, the normalization layers have been
|
||||
converted to FP32.
|
||||
"""
|
||||
if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
|
||||
module.float()
|
||||
if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
|
||||
module.forward = patch_forward_method(module.forward, torch.half,
|
||||
torch.float)
|
||||
for child in module.children():
|
||||
patch_norm_fp32(child)
|
||||
return module
|
||||
|
||||
|
||||
def patch_forward_method(func, src_type, dst_type, convert_output=True):
|
||||
"""Patch the forward method of a module.
|
||||
|
||||
Args:
|
||||
func (callable): The original forward method.
|
||||
src_type (torch.dtype): Type of input arguments to be converted from.
|
||||
dst_type (torch.dtype): Type of input arguments to be converted to.
|
||||
convert_output (bool): Whether to convert the output back to src_type.
|
||||
|
||||
Returns:
|
||||
callable: The patched forward method.
|
||||
"""
|
||||
|
||||
def new_forward(*args, **kwargs):
|
||||
output = func(*cast_tensor_type(args, src_type, dst_type),
|
||||
**cast_tensor_type(kwargs, src_type, dst_type))
|
||||
if convert_output:
|
||||
output = cast_tensor_type(output, dst_type, src_type)
|
||||
return output
|
||||
|
||||
return new_forward
|
|
@ -1,6 +1,9 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
from torch.nn.utils import clip_grad
|
||||
|
||||
from ..fp16_utils import allreduce_grads, wrap_fp16_model
|
||||
from .hook import HOOKS, Hook
|
||||
|
||||
|
||||
|
@ -26,3 +29,91 @@ class OptimizerHook(Hook):
|
|||
runner.log_buffer.update({'grad_norm': float(grad_norm)},
|
||||
runner.outputs['num_samples'])
|
||||
runner.optimizer.step()
|
||||
|
||||
|
||||
class Fp16OptimizerHook(OptimizerHook):
|
||||
"""FP16 optimizer hook.
|
||||
|
||||
The steps of fp16 optimizer is as follows.
|
||||
1. Scale the loss value.
|
||||
2. BP in the fp16 model.
|
||||
2. Copy gradients from fp16 model to fp32 weights.
|
||||
3. Update fp32 weights.
|
||||
4. Copy updated parameters from fp32 weights to fp16 model.
|
||||
|
||||
Refer to https://arxiv.org/abs/1710.03740 for more details.
|
||||
|
||||
Args:
|
||||
loss_scale (float): Scale factor multiplied with loss.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
grad_clip=None,
|
||||
coalesce=True,
|
||||
bucket_size_mb=-1,
|
||||
loss_scale=512.,
|
||||
distributed=True):
|
||||
self.grad_clip = grad_clip
|
||||
self.coalesce = coalesce
|
||||
self.bucket_size_mb = bucket_size_mb
|
||||
self.loss_scale = loss_scale
|
||||
self.distributed = distributed
|
||||
|
||||
def before_run(self, runner):
|
||||
"""Preparing steps before Mixed Precision Training.
|
||||
|
||||
1. Make a master copy of fp32 weights for optimization.
|
||||
2. Convert the main model from fp32 to fp16.
|
||||
"""
|
||||
# keep a copy of fp32 weights
|
||||
runner.optimizer.param_groups = copy.deepcopy(
|
||||
runner.optimizer.param_groups)
|
||||
# convert model to fp16
|
||||
wrap_fp16_model(runner.model)
|
||||
|
||||
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
|
||||
"""Copy gradients from fp16 model to fp32 weight copy."""
|
||||
for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()):
|
||||
if fp16_param.grad is not None:
|
||||
if fp32_param.grad is None:
|
||||
fp32_param.grad = fp32_param.data.new(fp32_param.size())
|
||||
fp32_param.grad.copy_(fp16_param.grad)
|
||||
|
||||
def copy_params_to_fp16(self, fp16_net, fp32_weights):
|
||||
"""Copy updated params from fp32 weight copy to fp16 model."""
|
||||
for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights):
|
||||
fp16_param.data.copy_(fp32_param.data)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
"""Backward optimization steps for Mixed Precision Training.
|
||||
|
||||
1. Scale the loss by a scale factor.
|
||||
2. Backward the loss to obtain the gradients (fp16).
|
||||
3. Copy gradients from the model to the fp32 weight copy.
|
||||
4. Scale the gradients back and update the fp32 weight copy.
|
||||
5. Copy back the params from fp32 weight copy to the fp16 model.
|
||||
"""
|
||||
# clear grads of last iteration
|
||||
runner.model.zero_grad()
|
||||
runner.optimizer.zero_grad()
|
||||
# scale the loss value
|
||||
scaled_loss = runner.outputs['loss'] * self.loss_scale
|
||||
scaled_loss.backward()
|
||||
# copy fp16 grads in the model to fp32 params in the optimizer
|
||||
fp32_weights = []
|
||||
for param_group in runner.optimizer.param_groups:
|
||||
fp32_weights += param_group['params']
|
||||
self.copy_grads_to_fp32(runner.model, fp32_weights)
|
||||
# allreduce grads
|
||||
if self.distributed:
|
||||
allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb)
|
||||
# scale the gradients back
|
||||
for param in fp32_weights:
|
||||
if param.grad is not None:
|
||||
param.grad.div_(self.loss_scale)
|
||||
if self.grad_clip is not None:
|
||||
self.clip_grads(fp32_weights)
|
||||
# update fp32 params
|
||||
runner.optimizer.step()
|
||||
# copy fp32 params to the fp16 model
|
||||
self.copy_params_to_fp16(runner.model, fp32_weights)
|
||||
|
|
|
@ -0,0 +1,300 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.runner.fp16_utils import auto_fp16, cast_tensor_type, force_fp32
|
||||
|
||||
|
||||
def test_cast_tensor_type():
|
||||
inputs = torch.FloatTensor([5.])
|
||||
src_type = torch.float32
|
||||
dst_type = torch.int32
|
||||
outputs = cast_tensor_type(inputs, src_type, dst_type)
|
||||
assert isinstance(outputs, torch.Tensor)
|
||||
assert outputs.dtype == dst_type
|
||||
|
||||
inputs = 'tensor'
|
||||
src_type = str
|
||||
dst_type = str
|
||||
outputs = cast_tensor_type(inputs, src_type, dst_type)
|
||||
assert isinstance(outputs, str)
|
||||
|
||||
inputs = np.array([5.])
|
||||
src_type = np.ndarray
|
||||
dst_type = np.ndarray
|
||||
outputs = cast_tensor_type(inputs, src_type, dst_type)
|
||||
assert isinstance(outputs, np.ndarray)
|
||||
|
||||
inputs = dict(
|
||||
tensor_a=torch.FloatTensor([1.]), tensor_b=torch.FloatTensor([2.]))
|
||||
src_type = torch.float32
|
||||
dst_type = torch.int32
|
||||
outputs = cast_tensor_type(inputs, src_type, dst_type)
|
||||
assert isinstance(outputs, dict)
|
||||
assert outputs['tensor_a'].dtype == dst_type
|
||||
assert outputs['tensor_b'].dtype == dst_type
|
||||
|
||||
inputs = [torch.FloatTensor([1.]), torch.FloatTensor([2.])]
|
||||
src_type = torch.float32
|
||||
dst_type = torch.int32
|
||||
outputs = cast_tensor_type(inputs, src_type, dst_type)
|
||||
assert isinstance(outputs, list)
|
||||
assert outputs[0].dtype == dst_type
|
||||
assert outputs[1].dtype == dst_type
|
||||
|
||||
inputs = 5
|
||||
outputs = cast_tensor_type(inputs, None, None)
|
||||
assert isinstance(outputs, int)
|
||||
|
||||
|
||||
def test_auto_fp16():
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# ExampleObject is not a subclass of nn.Module
|
||||
|
||||
class ExampleObject(object):
|
||||
|
||||
@auto_fp16()
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
model = ExampleObject()
|
||||
input_x = torch.ones(1, dtype=torch.float32)
|
||||
model(input_x)
|
||||
|
||||
# apply to all input args
|
||||
class ExampleModule(nn.Module):
|
||||
|
||||
@auto_fp16()
|
||||
def forward(self, x, y):
|
||||
return x, y
|
||||
|
||||
model = ExampleModule()
|
||||
input_x = torch.ones(1, dtype=torch.float32)
|
||||
input_y = torch.ones(1, dtype=torch.float32)
|
||||
output_x, output_y = model(input_x, input_y)
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.float32
|
||||
|
||||
model.fp16_enabled = True
|
||||
output_x, output_y = model(input_x, input_y)
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.half
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
output_x, output_y = model(input_x.cuda(), input_y.cuda())
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.half
|
||||
|
||||
# apply to specified input args
|
||||
class ExampleModule(nn.Module):
|
||||
|
||||
@auto_fp16(apply_to=('x', ))
|
||||
def forward(self, x, y):
|
||||
return x, y
|
||||
|
||||
model = ExampleModule()
|
||||
input_x = torch.ones(1, dtype=torch.float32)
|
||||
input_y = torch.ones(1, dtype=torch.float32)
|
||||
output_x, output_y = model(input_x, input_y)
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.float32
|
||||
|
||||
model.fp16_enabled = True
|
||||
output_x, output_y = model(input_x, input_y)
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.float32
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
output_x, output_y = model(input_x.cuda(), input_y.cuda())
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.float32
|
||||
|
||||
# apply to optional input args
|
||||
class ExampleModule(nn.Module):
|
||||
|
||||
@auto_fp16(apply_to=('x', 'y'))
|
||||
def forward(self, x, y=None, z=None):
|
||||
return x, y, z
|
||||
|
||||
model = ExampleModule()
|
||||
input_x = torch.ones(1, dtype=torch.float32)
|
||||
input_y = torch.ones(1, dtype=torch.float32)
|
||||
input_z = torch.ones(1, dtype=torch.float32)
|
||||
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.float32
|
||||
assert output_z.dtype == torch.float32
|
||||
|
||||
model.fp16_enabled = True
|
||||
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.half
|
||||
assert output_z.dtype == torch.float32
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
output_x, output_y, output_z = model(
|
||||
input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.half
|
||||
assert output_z.dtype == torch.float32
|
||||
|
||||
# out_fp32=True
|
||||
class ExampleModule(nn.Module):
|
||||
|
||||
@auto_fp16(apply_to=('x', 'y'), out_fp32=True)
|
||||
def forward(self, x, y=None, z=None):
|
||||
return x, y, z
|
||||
|
||||
model = ExampleModule()
|
||||
input_x = torch.ones(1, dtype=torch.half)
|
||||
input_y = torch.ones(1, dtype=torch.float32)
|
||||
input_z = torch.ones(1, dtype=torch.float32)
|
||||
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.float32
|
||||
assert output_z.dtype == torch.float32
|
||||
|
||||
model.fp16_enabled = True
|
||||
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.float32
|
||||
assert output_z.dtype == torch.float32
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
output_x, output_y, output_z = model(
|
||||
input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.float32
|
||||
assert output_z.dtype == torch.float32
|
||||
|
||||
|
||||
def test_force_fp32():
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# ExampleObject is not a subclass of nn.Module
|
||||
|
||||
class ExampleObject(object):
|
||||
|
||||
@force_fp32()
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
model = ExampleObject()
|
||||
input_x = torch.ones(1, dtype=torch.float32)
|
||||
model(input_x)
|
||||
|
||||
# apply to all input args
|
||||
class ExampleModule(nn.Module):
|
||||
|
||||
@force_fp32()
|
||||
def forward(self, x, y):
|
||||
return x, y
|
||||
|
||||
model = ExampleModule()
|
||||
input_x = torch.ones(1, dtype=torch.half)
|
||||
input_y = torch.ones(1, dtype=torch.half)
|
||||
output_x, output_y = model(input_x, input_y)
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.half
|
||||
|
||||
model.fp16_enabled = True
|
||||
output_x, output_y = model(input_x, input_y)
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.float32
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
output_x, output_y = model(input_x.cuda(), input_y.cuda())
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.float32
|
||||
|
||||
# apply to specified input args
|
||||
class ExampleModule(nn.Module):
|
||||
|
||||
@force_fp32(apply_to=('x', ))
|
||||
def forward(self, x, y):
|
||||
return x, y
|
||||
|
||||
model = ExampleModule()
|
||||
input_x = torch.ones(1, dtype=torch.half)
|
||||
input_y = torch.ones(1, dtype=torch.half)
|
||||
output_x, output_y = model(input_x, input_y)
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.half
|
||||
|
||||
model.fp16_enabled = True
|
||||
output_x, output_y = model(input_x, input_y)
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.half
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
output_x, output_y = model(input_x.cuda(), input_y.cuda())
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.half
|
||||
|
||||
# apply to optional input args
|
||||
class ExampleModule(nn.Module):
|
||||
|
||||
@force_fp32(apply_to=('x', 'y'))
|
||||
def forward(self, x, y=None, z=None):
|
||||
return x, y, z
|
||||
|
||||
model = ExampleModule()
|
||||
input_x = torch.ones(1, dtype=torch.half)
|
||||
input_y = torch.ones(1, dtype=torch.half)
|
||||
input_z = torch.ones(1, dtype=torch.half)
|
||||
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.half
|
||||
assert output_z.dtype == torch.half
|
||||
|
||||
model.fp16_enabled = True
|
||||
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.float32
|
||||
assert output_z.dtype == torch.half
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
output_x, output_y, output_z = model(
|
||||
input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.float32
|
||||
assert output_z.dtype == torch.half
|
||||
|
||||
# out_fp16=True
|
||||
class ExampleModule(nn.Module):
|
||||
|
||||
@force_fp32(apply_to=('x', 'y'), out_fp16=True)
|
||||
def forward(self, x, y=None, z=None):
|
||||
return x, y, z
|
||||
|
||||
model = ExampleModule()
|
||||
input_x = torch.ones(1, dtype=torch.float32)
|
||||
input_y = torch.ones(1, dtype=torch.half)
|
||||
input_z = torch.ones(1, dtype=torch.half)
|
||||
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
|
||||
assert output_x.dtype == torch.float32
|
||||
assert output_y.dtype == torch.half
|
||||
assert output_z.dtype == torch.half
|
||||
|
||||
model.fp16_enabled = True
|
||||
output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.half
|
||||
assert output_z.dtype == torch.half
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
output_x, output_y, output_z = model(
|
||||
input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
|
||||
assert output_x.dtype == torch.half
|
||||
assert output_y.dtype == torch.half
|
||||
assert output_z.dtype == torch.half
|
Loading…
Reference in New Issue