800 lines
28 KiB
Python
800 lines
28 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
# Modified from
|
|
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_jit_model_analysis.py
|
|
# pyre-ignore-all-errors[2,56]
|
|
|
|
import logging
|
|
import typing
|
|
import unittest
|
|
import warnings
|
|
from collections import Counter
|
|
from typing import Any, Dict, List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmengine import MMLogger
|
|
from mmengine.analysis import FlopAnalyzer
|
|
from mmengine.analysis.jit_analysis import JitModelAnalysis
|
|
from mmengine.analysis.jit_handles import (Handle, addmm_flop_jit,
|
|
conv_flop_jit, linear_flop_jit)
|
|
|
|
|
|
class NestedNetInnerModule(nn.Module):
|
|
"""A submodule for the nested net test module below."""
|
|
|
|
def __init__(self, lin_op: str = 'addmm') -> None:
|
|
super().__init__()
|
|
conv_input_size = (2, 5)
|
|
conv_in = 2
|
|
conv_out = 2
|
|
kernel_size = 1
|
|
padding = 0
|
|
fc_in = 10
|
|
fc_out = 10
|
|
|
|
self.conv = nn.Conv1d(
|
|
in_channels=conv_in,
|
|
out_channels=conv_out,
|
|
kernel_size=kernel_size,
|
|
padding=padding,
|
|
)
|
|
self.fc = nn.Linear(in_features=fc_in, out_features=fc_out)
|
|
|
|
fc_flops_ = fc_in * fc_out
|
|
fc_flops = Counter({lin_op: fc_flops_})
|
|
spatial_pos = (conv_input_size[1] + 2 * padding) - 2 * (
|
|
kernel_size // 2)
|
|
conv_flops_ = spatial_pos * kernel_size * conv_in * conv_out
|
|
conv_flops = Counter({'conv': conv_flops_})
|
|
model_flops = conv_flops + fc_flops
|
|
self.flops: 'Dict[str, typing.Counter[str]]' = {
|
|
'': model_flops,
|
|
'fc': fc_flops,
|
|
'conv': conv_flops,
|
|
}
|
|
|
|
self.name_to_module: 'Dict[str, nn.Module]' = {
|
|
'': self,
|
|
'fc': self.fc,
|
|
'conv': self.conv,
|
|
}
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = x.reshape(-1, 2, 5)
|
|
x = self.conv(x)
|
|
x = torch.flatten(x, 1)
|
|
x = 3 * self.fc(x) + 1
|
|
return x
|
|
|
|
|
|
class NestedNet(nn.Module):
|
|
"""A network with nested submodules for testing the ability to correctly
|
|
capture scope information."""
|
|
|
|
def __init__(self, lin_op: str = 'addmm') -> None:
|
|
super().__init__()
|
|
self.input_size = (4, 5)
|
|
|
|
conv_in = 4
|
|
conv_out = 4
|
|
kernel_size = 3
|
|
padding = 1
|
|
fc_in = 20
|
|
fc_out = 10
|
|
|
|
self.submod = NestedNetInnerModule(lin_op)
|
|
self.fc = nn.Linear(in_features=fc_in, out_features=fc_out)
|
|
self.conv = nn.Conv1d(
|
|
in_channels=conv_in,
|
|
out_channels=conv_out,
|
|
kernel_size=kernel_size,
|
|
padding=padding,
|
|
)
|
|
|
|
fc_flops_ = fc_in * fc_out
|
|
fc_flops = Counter({lin_op: fc_flops_})
|
|
spatial_pos = (self.input_size[1] + 2 * padding) - 2 * (
|
|
kernel_size // 2)
|
|
conv_flops_ = spatial_pos * kernel_size * conv_in * conv_out
|
|
conv_flops = Counter({'conv': conv_flops_})
|
|
|
|
model_flops = conv_flops + fc_flops + self.submod.flops['']
|
|
self.flops: 'Dict[str, typing.Counter[str]]' = {
|
|
'': model_flops,
|
|
'fc': fc_flops,
|
|
'conv': conv_flops,
|
|
'submod': self.submod.flops[''],
|
|
'submod.fc': self.submod.flops['fc'],
|
|
'submod.conv': self.submod.flops['conv'],
|
|
}
|
|
|
|
self.name_to_module: 'Dict[str, nn.Module]' = {
|
|
'': self,
|
|
'fc': self.fc,
|
|
'conv': self.conv,
|
|
'submod': self.submod,
|
|
'submod.fc': self.submod.name_to_module['fc'],
|
|
'submod.conv': self.submod.name_to_module['conv'],
|
|
}
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.conv(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.fc(x)
|
|
x = self.submod(x)**2
|
|
return x
|
|
|
|
|
|
class UnusedNet(nn.Module):
|
|
"""Has a submodule that is never called in the forward function."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.input_size = (10, )
|
|
fc1_in, fc1_out = 10, 10
|
|
fc2_in, fc2_out = 10, 1
|
|
unused_in, unused_out = 20, 20
|
|
|
|
self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out)
|
|
self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out)
|
|
self.unused = nn.Linear(in_features=unused_in, out_features=unused_out)
|
|
self.act: 'nn.Module' = nn.ReLU()
|
|
|
|
self.fc1_flops: int = fc1_in * fc1_out
|
|
self.fc2_flops: int = fc2_in * fc2_out
|
|
self.unused_flops: int = unused_in * unused_out
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.fc2(self.act(self.fc1(x)))
|
|
|
|
|
|
class RepeatedNet(nn.Module):
|
|
"""Makes repeated calls to the same submodule."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.input_size = (10, )
|
|
fc1_in, fc1_out = 10, 10
|
|
fc2_in, fc2_out = 10, 10
|
|
self.fc1_num = 3
|
|
self.fc2_num = 2
|
|
|
|
self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out)
|
|
self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out)
|
|
|
|
self.fc1_flops: int = fc1_in * fc1_out
|
|
self.fc2_flops: int = fc2_in * fc2_out
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
for _i in range(self.fc1_num):
|
|
x = self.fc1(x)
|
|
for _i in range(self.fc2_num):
|
|
x = self.fc2(x)
|
|
return x
|
|
|
|
|
|
class NonForwardInnerModule(nn.Module):
|
|
"""Has a function separate from the forward function."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.input_size = (10, )
|
|
fc_in, fc_out = 10, 1
|
|
|
|
self.fc = nn.Linear(in_features=fc_in, out_features=fc_out)
|
|
|
|
self.fc_flops: int = fc_in * fc_out
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x
|
|
|
|
def other_func(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.fc(x)
|
|
|
|
|
|
class NonForwardNet(nn.Module):
|
|
"""The submodule has a non-forward function called by the parent module."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.input_size = (10, )
|
|
fc_in, fc_out = 10, 10
|
|
|
|
self.submod = NonForwardInnerModule()
|
|
self.fc = nn.Linear(in_features=fc_in, out_features=fc_out)
|
|
|
|
self.fc_flops: int = fc_in * fc_out
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.submod.other_func(self.fc(x))
|
|
|
|
|
|
class SharedInnerModule(nn.Module):
|
|
"""Is initialized with a module that it may share with other modules."""
|
|
|
|
def __init__(self, submod: nn.Module) -> None:
|
|
super().__init__()
|
|
self.submod = submod
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.submod(x)
|
|
|
|
|
|
class SharedModuleNet(nn.Module):
|
|
"""A subsubmodule is shared by multiple submodules.
|
|
|
|
Also calls a module using multiple names.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.input_size = (10, )
|
|
fc1_in, fc1_out = 10, 10
|
|
fc2_in, fc2_out = 10, 1
|
|
|
|
inner = nn.Linear(in_features=fc1_in, out_features=fc1_out)
|
|
self.submod1 = SharedInnerModule(inner)
|
|
self.submod2 = SharedInnerModule(inner)
|
|
multiname = nn.Linear(in_features=fc2_in, out_features=fc2_out)
|
|
self.multiname1: 'nn.Module' = multiname
|
|
self.multiname2: 'nn.Module' = multiname
|
|
|
|
self.multiname_flops: int = fc2_in * fc2_out
|
|
self.shared_flops: int = fc1_in * fc1_out
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.submod1(x) + self.submod2(x)
|
|
x = self.multiname1(x) + self.multiname2(x)
|
|
return x
|
|
|
|
|
|
class RecursiveScopeNet(nn.Module):
|
|
"""An op is in the same module's scope multiple times."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.input_size = (10, )
|
|
fc_in, fc_out = 10, 1
|
|
|
|
self.fc = nn.Linear(in_features=fc_in, out_features=fc_out)
|
|
|
|
self.flops: int = fc_in * fc_out
|
|
|
|
def forward(self, x: torch.Tensor, count: int = 3) -> torch.Tensor:
|
|
if count > 0:
|
|
return self(x, count - 1)
|
|
return self.fc(x)
|
|
|
|
|
|
class TraceWarningNet(nn.Module):
|
|
"""Will raise a warning on trace due to python comparison of tensor data,
|
|
and explicitly raises a runtime warning.
|
|
|
|
Also has an aten::add op that will be skipped and raise a warning.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.input_size = (10, )
|
|
fc1_in, fc1_out = 10, 1
|
|
fc2_in, fc2_out = 10, 10
|
|
|
|
self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out)
|
|
self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out)
|
|
|
|
self.fc1_flops: int = fc1_in * fc1_out
|
|
self.fc2_flops: int = fc2_in * fc2_out
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
y = self.fc1(x).item()
|
|
warnings.warn('Dummy RuntimeWarning.', RuntimeWarning)
|
|
if y < 0.0:
|
|
x = self.fc2(x)
|
|
return x + 2
|
|
|
|
|
|
class TestJitModelAnalysis(unittest.TestCase):
|
|
"""Unittest for JitModelAnalysis.
|
|
|
|
Tests for specific jit_handles are covered in test_flop_count.py and
|
|
test_activation_count.py.
|
|
"""
|
|
|
|
def setUp(self) -> None:
|
|
# nn.Linear uses a different operator based on version, so make sure
|
|
# we are testing the right thing.
|
|
lin = nn.Linear(10, 10)
|
|
lin_x: torch.Tensor = torch.randn(10, 10)
|
|
trace = torch.jit.trace(lin, (lin_x, ))
|
|
node_kinds = [node.kind() for node in trace.graph.nodes()]
|
|
assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds
|
|
if 'aten::addmm' in node_kinds:
|
|
self.lin_op = 'addmm'
|
|
else:
|
|
self.lin_op = 'linear'
|
|
|
|
def test_total(self) -> None:
|
|
"""Tests that JitModelAnalysis.total(module) returns the correct counts
|
|
for string and module inputs."""
|
|
|
|
model = NestedNet(lin_op=self.lin_op)
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
|
|
analyzer = FlopAnalyzer(model=model, inputs=inputs)
|
|
analyzer.unsupported_ops_warnings(enabled=False)
|
|
|
|
# Using a string input
|
|
for name in model.flops:
|
|
with self.subTest(name=name):
|
|
gt_flops = sum(model.flops[name].values())
|
|
self.assertEqual(analyzer.total(name), gt_flops)
|
|
|
|
def test_by_module(self) -> None:
|
|
"""Tests that JitModelAnalysis.by_module() returns the correct counts
|
|
in the correctly structured dictionary."""
|
|
|
|
model = NestedNet(lin_op=self.lin_op)
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
|
|
analyzer = FlopAnalyzer(model=model, inputs=inputs)
|
|
analyzer.unsupported_ops_warnings(enabled=False)
|
|
|
|
flops = {
|
|
name: sum(counts.values())
|
|
for name, counts in model.flops.items()
|
|
}
|
|
|
|
self.assertEqual(analyzer.by_module(), flops)
|
|
|
|
def test_by_operator(self) -> None:
|
|
"""Tests that JitModelAnalysis.by_operator(module) returns the correct
|
|
counts for string and module inputs."""
|
|
|
|
model = NestedNet(lin_op=self.lin_op)
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
|
|
analyzer = FlopAnalyzer(model=model, inputs=inputs)
|
|
analyzer.unsupported_ops_warnings(enabled=False)
|
|
|
|
# Using a string input
|
|
for name in model.flops:
|
|
with self.subTest(name=name):
|
|
self.assertEqual(analyzer.by_operator(name), model.flops[name])
|
|
|
|
def test_by_module_and_operator(self) -> None:
|
|
"""Tests that JitModelAnalysis.by_module_and_operator() returns the
|
|
correct counts in the correct structure."""
|
|
|
|
model = NestedNet(lin_op=self.lin_op)
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
|
|
analyzer = FlopAnalyzer(model=model, inputs=inputs)
|
|
analyzer.unsupported_ops_warnings(enabled=False)
|
|
|
|
self.assertEqual(analyzer.by_module_and_operator(), model.flops)
|
|
|
|
def test_unused_module(self) -> None:
|
|
"""Tests that unused modules return 0 count for operator sums and and
|
|
empty Counter() for per-operator results.
|
|
|
|
Also tests that unused modules are reported by .uncalled_modules(), but
|
|
that modules that simply have zero flops (like ReLU) are not.
|
|
"""
|
|
|
|
model = UnusedNet()
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
analyzer = FlopAnalyzer(model=model, inputs=inputs)
|
|
|
|
unused_count = 0
|
|
unused_per_operator = Counter() # type: Counter
|
|
model_count = model.fc1_flops + model.fc2_flops
|
|
|
|
self.assertEqual(analyzer.total('unused'), unused_count)
|
|
self.assertEqual(analyzer.by_operator('unused'), unused_per_operator)
|
|
self.assertEqual(analyzer.total(''), model_count)
|
|
|
|
# The unused mod is recognized as never called
|
|
self.assertEqual(analyzer.uncalled_modules(), {'unused'})
|
|
|
|
def test_repeated_module(self) -> None:
|
|
"""Tests that repeated calls to the same submodule correct aggregates
|
|
results to that submodule."""
|
|
|
|
model = RepeatedNet()
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
|
|
analyzer = FlopAnalyzer(model=model, inputs=inputs)
|
|
fc1_count = model.fc1_num * model.fc1_flops
|
|
fc2_count = model.fc2_num * model.fc2_flops
|
|
total_count = fc1_count + fc2_count
|
|
fc1_per_operator = Counter({self.lin_op: fc1_count})
|
|
|
|
self.assertEqual(analyzer.total('fc1'), fc1_count)
|
|
self.assertEqual(analyzer.total('fc2'), fc2_count)
|
|
self.assertEqual(analyzer.total(''), total_count)
|
|
self.assertEqual(analyzer.by_operator('fc1'), fc1_per_operator)
|
|
|
|
# Tests no uncalled mods
|
|
self.assertEqual(analyzer.uncalled_modules(), set())
|
|
|
|
def test_non_forward_func_call(self) -> None:
|
|
"""Tests calls to a submodule's non-forward function.
|
|
|
|
Also tests that the intermediate module is correctly identified as a
|
|
skipped module.
|
|
"""
|
|
|
|
model = NonForwardNet()
|
|
inputs = (torch.randn((1, 10)), )
|
|
analyzer = FlopAnalyzer(
|
|
model=model, inputs=inputs).ancestor_mode('caller')
|
|
|
|
inner_fc_count = model.submod.fc_flops
|
|
total_count = model.fc_flops + inner_fc_count
|
|
|
|
self.assertEqual(analyzer.total('submod'), 0)
|
|
self.assertEqual(analyzer.total('submod.fc'), inner_fc_count)
|
|
self.assertEqual(analyzer.total(''), total_count)
|
|
|
|
# The mod not directly called is registered as such
|
|
self.assertEqual(analyzer.uncalled_modules(), {'submod'})
|
|
|
|
analyzer = FlopAnalyzer(
|
|
model=model, inputs=inputs).ancestor_mode('owner')
|
|
self.assertEqual(analyzer.total('submod'), inner_fc_count)
|
|
self.assertEqual(analyzer.total('submod.fc'), inner_fc_count)
|
|
self.assertEqual(analyzer.total(''), total_count)
|
|
self.assertEqual(analyzer.uncalled_modules(), set())
|
|
|
|
def test_shared_module(self) -> None:
|
|
"""Tests the behavior of shared submodules that may have multiple
|
|
names."""
|
|
|
|
model = SharedModuleNet()
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
|
|
analyzer = (
|
|
FlopAnalyzer(model=model, inputs=inputs).unsupported_ops_warnings(
|
|
enabled=False).ancestor_mode('caller'))
|
|
|
|
# The names `submod2.submod` and `multiname2` are not included,
|
|
# since only the first name of a module is made the canonical one.
|
|
# The counts associated with these cases are included under
|
|
# `submod1.submod` and `multiname1` respectively.
|
|
multiname_flops = 2 * model.multiname_flops # Called under 2 names
|
|
shared_flops = 2 * model.shared_flops # Shared under 2 submodules
|
|
total_flops = multiname_flops + shared_flops
|
|
flops = {
|
|
'': total_flops,
|
|
'submod1': model.shared_flops,
|
|
'submod1.submod': shared_flops,
|
|
'submod2': model.shared_flops,
|
|
'multiname1': multiname_flops,
|
|
}
|
|
|
|
self.assertEqual(analyzer.by_module(), flops)
|
|
|
|
# Test access by alternative name
|
|
self.assertEqual(
|
|
analyzer.total('submod2.submod'),
|
|
flops['submod1.submod'],
|
|
)
|
|
self.assertEqual(
|
|
analyzer.total('multiname2'),
|
|
flops['multiname1'],
|
|
)
|
|
|
|
# Test getting canonical name
|
|
self.assertEqual(
|
|
analyzer.canonical_module_name('multiname2'), 'multiname1')
|
|
self.assertEqual(
|
|
analyzer.canonical_module_name('multiname1'), 'multiname1')
|
|
self.assertEqual(
|
|
analyzer.canonical_module_name('submod2.submod'), 'submod1.submod')
|
|
self.assertEqual(
|
|
analyzer.canonical_module_name('submod1.submod'), 'submod1.submod')
|
|
|
|
# Tests no uncalled modules
|
|
self.assertEqual(analyzer.uncalled_modules(), set())
|
|
|
|
def test_recursive_scope(self) -> None:
|
|
"""Tests that an op is only counted once per module, even if it is in
|
|
the scope of that module multiple times."""
|
|
model = RecursiveScopeNet()
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
|
|
analyzer = FlopAnalyzer(model, inputs)
|
|
|
|
self.assertEqual(analyzer.total(), model.flops)
|
|
self.assertEqual(analyzer.total('fc'), model.flops)
|
|
|
|
# Tests no uncalled modules
|
|
self.assertEqual(analyzer.uncalled_modules(), set())
|
|
|
|
def test_data_parallel(self) -> None:
|
|
"""Tests that a model wrapped in DataParallel still returns results
|
|
labeled by the correct scopes."""
|
|
model = NestedNet(lin_op=self.lin_op)
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
|
|
# Find flops for wrapper
|
|
flops = {
|
|
'module' + ('.' if name else '') + name: flop
|
|
for name, flop in model.flops.items()
|
|
}
|
|
flops[''] = model.flops['']
|
|
name_to_module = {
|
|
'module' + ('.' if name else '') + name: mod
|
|
for name, mod in model.name_to_module.items()
|
|
}
|
|
name_to_module[''] = model.name_to_module['']
|
|
|
|
model = torch.nn.DataParallel(model).cpu()
|
|
analyzer = FlopAnalyzer(model=model, inputs=inputs)
|
|
analyzer.unsupported_ops_warnings(enabled=False)
|
|
|
|
# Using a string input
|
|
for name in flops:
|
|
with self.subTest(name=name):
|
|
gt_flops = sum(flops[name].values())
|
|
self.assertEqual(analyzer.total(name), gt_flops)
|
|
|
|
# Output as dictionary
|
|
self.assertEqual(analyzer.by_module_and_operator(), flops)
|
|
|
|
# Test no uncalled modules
|
|
self.assertEqual(analyzer.uncalled_modules(), set())
|
|
|
|
def test_data_parallel_root_scope(self) -> None:
|
|
# A test case discussed in D32227000
|
|
model = nn.DataParallel(nn.Linear(10, 10)).cpu()
|
|
for mode in ['caller', 'owner']:
|
|
flop = FlopAnalyzer(model, (torch.randn(10, 10), ))
|
|
flop.ancestor_mode(mode)
|
|
self.assertEqual(flop.total(), 1000)
|
|
|
|
def test_unsupported_ops(self) -> None:
|
|
"""Tests per-module recording of unsupported operations."""
|
|
|
|
model = NestedNet(lin_op=self.lin_op)
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
|
|
analyzer = JitModelAnalysis(
|
|
model=model, inputs=inputs).set_op_handle(
|
|
'aten::addmm',
|
|
addmm_flop_jit,
|
|
'aten::linear',
|
|
linear_flop_jit,
|
|
)
|
|
analyzer.total()
|
|
|
|
skipped_inner_conv = Counter({'aten::_convolution': 1})
|
|
skipped_inner_fc = Counter() # type: Counter
|
|
skipped_inner = Counter({'aten::add': 1, 'aten::mul': 1})
|
|
skipped_inner += skipped_inner_fc
|
|
skipped_inner += skipped_inner_conv
|
|
|
|
skipped_outer_conv = Counter({'aten::_convolution': 1})
|
|
skipped_outer_fc = Counter() # type: Counter
|
|
skipped_outer = Counter({'aten::pow': 1})
|
|
skipped_outer += skipped_outer_conv
|
|
skipped_outer += skipped_outer_fc
|
|
skipped_outer += skipped_inner
|
|
|
|
skipped = {
|
|
'': skipped_outer,
|
|
'conv': skipped_outer_conv,
|
|
'fc': skipped_outer_fc,
|
|
'submod': skipped_inner,
|
|
'submod.conv': skipped_inner_conv,
|
|
'submod.fc': skipped_inner_fc,
|
|
}
|
|
|
|
# Access by string
|
|
for name in skipped:
|
|
with self.subTest(name=name):
|
|
self.assertEqual(analyzer.unsupported_ops(name), skipped[name])
|
|
|
|
def test_changing_handles(self) -> None:
|
|
"""Tests .set_op_handle(), .clear_op_handles()"""
|
|
model = NestedNet(lin_op=self.lin_op)
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
op_handles: 'Dict[str, Handle]' = {
|
|
'aten::addmm': addmm_flop_jit,
|
|
'aten::linear': linear_flop_jit,
|
|
}
|
|
|
|
analyzer = JitModelAnalysis(
|
|
model=model, inputs=inputs).set_op_handle(**op_handles)
|
|
analyzer.unsupported_ops_warnings(enabled=False)
|
|
|
|
# Request a result once to cache flop counts
|
|
_ = analyzer.total('')
|
|
|
|
# Add an op handle
|
|
analyzer.set_op_handle('aten::_convolution', conv_flop_jit)
|
|
|
|
self.assertEqual(analyzer.by_module_and_operator(), model.flops)
|
|
|
|
# Overwrite an op handle
|
|
def make_dummy_op(name: str, output: int) -> Handle:
|
|
|
|
def dummy_ops_handle(inputs: List[Any],
|
|
outputs: List[Any]) -> typing.Counter[str]:
|
|
return Counter({name: output})
|
|
|
|
return dummy_ops_handle
|
|
|
|
dummy_name = 'dummy_op'
|
|
dummy_out = 1000
|
|
analyzer.set_op_handle(f'aten::{self.lin_op}',
|
|
make_dummy_op(dummy_name, dummy_out))
|
|
|
|
dummy_flops = {}
|
|
for name, counts in model.flops.items():
|
|
dummy_flops[name] = Counter(
|
|
{op: flop
|
|
for op, flop in counts.items() if op != self.lin_op})
|
|
dummy_flops[''][dummy_name] = 2 * dummy_out
|
|
dummy_flops['fc'][dummy_name] = dummy_out
|
|
dummy_flops['submod'][dummy_name] = dummy_out
|
|
dummy_flops['submod.fc'][dummy_name] = dummy_out
|
|
|
|
self.assertEqual(analyzer.by_module_and_operator(), dummy_flops)
|
|
|
|
# Clear ops handles
|
|
analyzer.clear_op_handles()
|
|
|
|
empty_flops = {name: Counter() for name in model.flops} # type: Dict
|
|
|
|
self.assertEqual(analyzer.by_module_and_operator(), empty_flops)
|
|
|
|
def test_copy(self) -> None:
|
|
"""Tests .copy(...)"""
|
|
|
|
model = RepeatedNet()
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
|
|
analyzer = (
|
|
JitModelAnalysis(model=model, inputs=inputs).set_op_handle(
|
|
'aten::addmm',
|
|
addmm_flop_jit,
|
|
'aten::linear',
|
|
linear_flop_jit,
|
|
).unsupported_ops_warnings(enabled=False).tracer_warnings(
|
|
mode='none'))
|
|
|
|
repeated_net_flops = model.fc1_num * model.fc1_flops
|
|
repeated_net_flops += model.fc2_num * model.fc2_flops
|
|
|
|
analyzer_copy = analyzer.copy()
|
|
|
|
# Outputs are the same
|
|
self.assertEqual(
|
|
analyzer.by_module_and_operator(),
|
|
analyzer_copy.by_module_and_operator(),
|
|
)
|
|
|
|
# Settings match
|
|
self.assertEqual(
|
|
analyzer._enable_warn_unsupported_ops,
|
|
analyzer_copy._enable_warn_unsupported_ops,
|
|
)
|
|
self.assertEqual(
|
|
analyzer._enable_warn_uncalled_mods,
|
|
analyzer_copy._enable_warn_uncalled_mods,
|
|
)
|
|
self.assertEqual(analyzer._warn_trace, analyzer_copy._warn_trace)
|
|
|
|
# Changing copy does not change original
|
|
analyzer_copy.unsupported_ops_warnings(enabled=True)
|
|
self.assertNotEqual(
|
|
analyzer._enable_warn_unsupported_ops,
|
|
analyzer_copy._enable_warn_unsupported_ops,
|
|
)
|
|
|
|
# Copy with new model and inputs
|
|
new_model = NonForwardNet()
|
|
bs = 5
|
|
new_inputs = (torch.randn((bs, *new_model.input_size)), )
|
|
analyzer_new = analyzer.copy(
|
|
new_model=new_model, new_inputs=new_inputs)
|
|
|
|
non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops
|
|
|
|
# Total is correct for new model and inputs
|
|
self.assertEqual(analyzer_new.total(), non_forward_flops * bs)
|
|
|
|
# Original is unaffected
|
|
self.assertEqual(analyzer.total(), repeated_net_flops)
|
|
|
|
# Settings match
|
|
self.assertEqual(
|
|
analyzer._enable_warn_unsupported_ops,
|
|
analyzer_new._enable_warn_unsupported_ops,
|
|
)
|
|
self.assertEqual(analyzer._warn_trace, analyzer_new._warn_trace)
|
|
|
|
def test_disable_warnings(self) -> None:
|
|
"""Tests .unsupported_ops_warnings(...) and .tracer_warnings(...)"""
|
|
model = TraceWarningNet()
|
|
inputs = (torch.randn((1, *model.input_size)), )
|
|
analyzer = FlopAnalyzer(model=model, inputs=inputs)
|
|
|
|
# Tracer warnings
|
|
analyzer.tracer_warnings(mode='all')
|
|
analyzer._stats = None # Manually clear cache so trace is rerun
|
|
self.assertWarns(torch.jit.TracerWarning, analyzer.total)
|
|
analyzer._stats = None # Manually clear cache so trace is rerun
|
|
self.assertWarns(RuntimeWarning, analyzer.total)
|
|
|
|
analyzer.tracer_warnings(mode='none')
|
|
analyzer._stats = None # Manually clear cache so trace is rerun
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter('always')
|
|
_ = analyzer.total()
|
|
if w:
|
|
warning_types = [s.category for s in w]
|
|
self.assertFalse(torch.jit.TracerWarning in warning_types)
|
|
self.assertFalse(RuntimeWarning in warning_types)
|
|
|
|
analyzer.tracer_warnings(mode='no_tracer_warning')
|
|
analyzer._stats = None # Manually clear cache so trace is rerun
|
|
self.assertWarns(RuntimeWarning, analyzer.total)
|
|
analyzer._stats = None # Manually clear cache so trace is rerun
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter('always')
|
|
_ = analyzer.total()
|
|
if w:
|
|
warning_types = [s.category for s in w]
|
|
self.assertFalse(torch.jit.TracerWarning in warning_types)
|
|
|
|
# Unsupported ops and uncalled modules warnings
|
|
|
|
logger = MMLogger.get_current_instance()
|
|
skipeed_msg = 'Unsupported operator aten::add encountered 1 time(s)'
|
|
uncalled_msg = 'never called'
|
|
uncalled_modules = 'fc1' # fc2 is called by chance
|
|
|
|
analyzer.uncalled_modules_warnings(enabled=False)
|
|
analyzer.unsupported_ops_warnings(enabled=False)
|
|
analyzer._stats = None # Manually clear cache so trace is rerun
|
|
with self.assertLogs(logger, logging.WARN) as cm:
|
|
logger.warning('Dummy warning.')
|
|
_ = analyzer.total()
|
|
self.assertFalse(any(skipeed_msg in s for s in cm.output))
|
|
self.assertFalse(any(uncalled_msg in s for s in cm.output))
|
|
|
|
analyzer.unsupported_ops_warnings(enabled=True)
|
|
analyzer.uncalled_modules_warnings(enabled=True)
|
|
analyzer._stats = None # Manually clear cache so trace is rerun
|
|
|
|
logger = MMLogger.get_current_instance()
|
|
with self.assertLogs(logger, logging.WARN) as cm:
|
|
_ = analyzer.total()
|
|
self.assertTrue(any(skipeed_msg in s for s in cm.output))
|
|
self.assertTrue(any(uncalled_msg in s for s in cm.output))
|
|
self.assertTrue(any(uncalled_modules in s for s in cm.output))
|
|
|
|
def test_skip_uncalled_containers_warnings(self) -> None:
|
|
# uncalled containers should not warn
|
|
|
|
class A(nn.Module):
|
|
|
|
def forward(self, x):
|
|
return self.submod[0](x) + 1
|
|
|
|
mod = A()
|
|
mod.submod = nn.ModuleList([nn.Linear(3, 3)]) # pyre-ignore
|
|
analyzer = FlopAnalyzer(model=mod, inputs=torch.rand(1, 3))
|
|
analyzer.unsupported_ops_warnings(enabled=False)
|
|
|
|
logger = MMLogger.get_current_instance()
|
|
with self.assertLogs(logger, logging.WARN) as cm:
|
|
logger.warning('Dummy warning.')
|
|
_ = analyzer.total()
|
|
uncalled_string = 'Module never called: submod'
|
|
self.assertFalse(any(uncalled_string in s for s in cm.output))
|