[Enhance] Use graph transform to deal with more general cases for efficient_conv_bn_eval (#1259)
parent
c8a1264568
commit
ee742da254
|
@ -1,5 +1,4 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from functools import partial
|
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
@ -58,48 +57,32 @@ def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm,
|
||||||
return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)
|
return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)
|
||||||
|
|
||||||
|
|
||||||
def bn_once_identity_forward(bn: nn.modules.batchnorm._BatchNorm,
|
|
||||||
x: torch.Tensor):
|
|
||||||
"""The forward function is an identity function.
|
|
||||||
|
|
||||||
The magic is that after one call, the `bn.forward` will be restored to what
|
|
||||||
it used to be.
|
|
||||||
"""
|
|
||||||
bn.__dict__.pop('forward')
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
|
def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
|
||||||
conv: nn.modules.conv._ConvNd,
|
conv: nn.modules.conv._ConvNd,
|
||||||
x: torch.Tensor):
|
x: torch.Tensor):
|
||||||
"""This function controls whether to use `efficient_conv_bn_eval_forward`.
|
"""This function controls whether to use `efficient_conv_bn_eval_forward`.
|
||||||
|
|
||||||
If the following `bn` is in `eval` mode, then we turn on the special
|
If the following `bn` is in `eval` mode, then we turn on the special
|
||||||
`efficient_conv_bn_eval_forward` and let the following call of `bn.forward`
|
`efficient_conv_bn_eval_forward`.
|
||||||
to be identity. Note that this `bn.forward` modification only works for one
|
|
||||||
call. After the call, `bn.forward` will be restored to the default
|
|
||||||
function. This is to deal with the case where one `bn` module is used in
|
|
||||||
multiple places.
|
|
||||||
"""
|
"""
|
||||||
if not bn.training:
|
if not bn.training:
|
||||||
# bn in eval mode
|
# bn in eval mode
|
||||||
output = efficient_conv_bn_eval_forward(bn, conv, x)
|
output = efficient_conv_bn_eval_forward(bn, conv, x)
|
||||||
bn.forward = partial(bn_once_identity_forward, bn)
|
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
return conv._conv_forward(x, conv.weight, conv.bias)
|
conv_out = conv._conv_forward(x, conv.weight, conv.bias)
|
||||||
|
return bn(conv_out)
|
||||||
|
|
||||||
|
|
||||||
def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
|
def efficient_conv_bn_eval_graph_transform(fx_model):
|
||||||
# optimize consecutive conv+bn by modifying forward function
|
"""Find consecutive conv+bn calls in the graph, inplace modify the graph
|
||||||
# Symbolically trace the input model to create an FX GraphModule
|
with the fused operation."""
|
||||||
import torch.fx as fx
|
|
||||||
fx_model: fx.GraphModule = fx.symbolic_trace(model)
|
|
||||||
modules = dict(fx_model.named_modules())
|
modules = dict(fx_model.named_modules())
|
||||||
|
|
||||||
patterns = [(torch.nn.modules.conv._ConvNd,
|
patterns = [(torch.nn.modules.conv._ConvNd,
|
||||||
torch.nn.modules.batchnorm._BatchNorm)]
|
torch.nn.modules.batchnorm._BatchNorm)]
|
||||||
|
|
||||||
|
pairs = []
|
||||||
# Iterate through nodes in the graph to find ConvBN blocks
|
# Iterate through nodes in the graph to find ConvBN blocks
|
||||||
for node in fx_model.graph.nodes:
|
for node in fx_model.graph.nodes:
|
||||||
# If our current node isn't calling a Module then we can ignore it.
|
# If our current node isn't calling a Module then we can ignore it.
|
||||||
|
@ -116,26 +99,54 @@ def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
|
||||||
if not found_pair or len(node.args[0].users) > 1:
|
if not found_pair or len(node.args[0].users) > 1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# check if the conv modules are used in multiple nodes
|
# Find a pair of conv and bn computation nodes to optimize
|
||||||
conv_name = node.args[0].target
|
conv_node = node.args[0]
|
||||||
bn_name = node.target
|
bn_node = node
|
||||||
|
pairs.append([conv_node, bn_node])
|
||||||
|
|
||||||
conv_usage_count = 0
|
for conv_node, bn_node in pairs:
|
||||||
for _node in fx_model.graph.nodes:
|
# set insertion point
|
||||||
if _node.op != 'call_module':
|
fx_model.graph.inserting_before(conv_node)
|
||||||
continue
|
# create `get_attr` node to access modules
|
||||||
if _node.target == conv_name:
|
# note that we directly call `create_node` to fill the `name`
|
||||||
conv_usage_count += 1
|
# argument. `fx_model.graph.get_attr` and
|
||||||
|
# `fx_model.graph.call_function` does not allow the `name` argument.
|
||||||
|
conv_get_node = fx_model.graph.create_node(
|
||||||
|
op='get_attr', target=conv_node.target, name='get_conv')
|
||||||
|
bn_get_node = fx_model.graph.create_node(
|
||||||
|
op='get_attr', target=bn_node.target, name='get_bn')
|
||||||
|
# prepare args for the fused function
|
||||||
|
args = (bn_get_node, conv_get_node, conv_node.args[0])
|
||||||
|
# create a new node
|
||||||
|
new_node = fx_model.graph.create_node(
|
||||||
|
op='call_function',
|
||||||
|
target=efficient_conv_bn_eval_control,
|
||||||
|
args=args,
|
||||||
|
name='efficient_conv_bn_eval')
|
||||||
|
# this node replaces the original conv + bn, and therefore
|
||||||
|
# should replace the uses of bn_node
|
||||||
|
bn_node.replace_all_uses_with(new_node)
|
||||||
|
# take care of the deletion order:
|
||||||
|
# delete bn_node first, and then conv_node
|
||||||
|
fx_model.graph.erase_node(bn_node)
|
||||||
|
fx_model.graph.erase_node(conv_node)
|
||||||
|
|
||||||
if conv_usage_count > 1:
|
# regenerate the code
|
||||||
continue
|
fx_model.graph.lint()
|
||||||
|
fx_model.recompile()
|
||||||
|
|
||||||
# Find a pair of conv and bn to optimize
|
|
||||||
conv_module = modules[conv_name]
|
|
||||||
bn_module = modules[bn_name]
|
|
||||||
|
|
||||||
conv_module.forward = partial(efficient_conv_bn_eval_control,
|
def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
|
||||||
bn_module, conv_module)
|
import torch.fx as fx
|
||||||
|
|
||||||
|
# currently we use `fx.symbolic_trace` to trace models.
|
||||||
|
# in the future, we might turn to pytorch 2.0 compile infrastructure to
|
||||||
|
# get the `fx.GraphModule` IR. Nonetheless, the graph transform function
|
||||||
|
# can remain unchanged. We just need to change the way
|
||||||
|
# we get `fx.GraphModule`.
|
||||||
|
fx_model: fx.GraphModule = fx.symbolic_trace(model)
|
||||||
|
efficient_conv_bn_eval_graph_transform(fx_model)
|
||||||
|
model.forward = fx_model.forward
|
||||||
|
|
||||||
|
|
||||||
def turn_on_efficient_conv_bn_eval(model: torch.nn.Module,
|
def turn_on_efficient_conv_bn_eval(model: torch.nn.Module,
|
||||||
|
|
|
@ -37,8 +37,8 @@ class BackboneModel(nn.Module):
|
||||||
x = self.mod1(x)
|
x = self.mod1(x)
|
||||||
# this conv-bn pair can use efficient_conv_bn_eval feature
|
# this conv-bn pair can use efficient_conv_bn_eval feature
|
||||||
x = self.bn1(self.conv1(x))
|
x = self.bn1(self.conv1(x))
|
||||||
# this conv-bn pair cannot use efficient_conv_bn_eval feature
|
# this conv-bn pair can use efficient_conv_bn_eval feature
|
||||||
# because `self.conv2` is used twice
|
# only for the second `self.conv2` call.
|
||||||
x = self.bn2(self.conv2(self.conv2(x)))
|
x = self.bn2(self.conv2(self.conv2(x)))
|
||||||
# this conv-bn pair can use efficient_conv_bn_eval feature
|
# this conv-bn pair can use efficient_conv_bn_eval feature
|
||||||
# just for the first forward of the `self.bn3`
|
# just for the first forward of the `self.bn3`
|
||||||
|
|
Loading…
Reference in New Issue