# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Modified from # https://github.com/facebookresearch/fvcore/blob/main/tests/test_jit_model_analysis.py # pyre-ignore-all-errors[2,56] import logging import typing import unittest import warnings from collections import Counter from typing import Any, Dict, List import torch import torch.nn as nn from mmengine import MMLogger from mmengine.analysis import FlopAnalyzer from mmengine.analysis.jit_analysis import JitModelAnalysis from mmengine.analysis.jit_handles import (Handle, addmm_flop_jit, conv_flop_jit, linear_flop_jit) class NestedNetInnerModule(nn.Module): """A submodule for the nested net test module below.""" def __init__(self, lin_op: str = 'addmm') -> None: super().__init__() conv_input_size = (2, 5) conv_in = 2 conv_out = 2 kernel_size = 1 padding = 0 fc_in = 10 fc_out = 10 self.conv = nn.Conv1d( in_channels=conv_in, out_channels=conv_out, kernel_size=kernel_size, padding=padding, ) self.fc = nn.Linear(in_features=fc_in, out_features=fc_out) fc_flops_ = fc_in * fc_out fc_flops = Counter({lin_op: fc_flops_}) spatial_pos = (conv_input_size[1] + 2 * padding) - 2 * ( kernel_size // 2) conv_flops_ = spatial_pos * kernel_size * conv_in * conv_out conv_flops = Counter({'conv': conv_flops_}) model_flops = conv_flops + fc_flops self.flops: 'Dict[str, typing.Counter[str]]' = { '': model_flops, 'fc': fc_flops, 'conv': conv_flops, } self.name_to_module: 'Dict[str, nn.Module]' = { '': self, 'fc': self.fc, 'conv': self.conv, } def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.reshape(-1, 2, 5) x = self.conv(x) x = torch.flatten(x, 1) x = 3 * self.fc(x) + 1 return x class NestedNet(nn.Module): """A network with nested submodules for testing the ability to correctly capture scope information.""" def __init__(self, lin_op: str = 'addmm') -> None: super().__init__() self.input_size = (4, 5) conv_in = 4 conv_out = 4 kernel_size = 3 padding = 1 fc_in = 20 fc_out = 10 self.submod = NestedNetInnerModule(lin_op) self.fc = nn.Linear(in_features=fc_in, out_features=fc_out) self.conv = nn.Conv1d( in_channels=conv_in, out_channels=conv_out, kernel_size=kernel_size, padding=padding, ) fc_flops_ = fc_in * fc_out fc_flops = Counter({lin_op: fc_flops_}) spatial_pos = (self.input_size[1] + 2 * padding) - 2 * ( kernel_size // 2) conv_flops_ = spatial_pos * kernel_size * conv_in * conv_out conv_flops = Counter({'conv': conv_flops_}) model_flops = conv_flops + fc_flops + self.submod.flops[''] self.flops: 'Dict[str, typing.Counter[str]]' = { '': model_flops, 'fc': fc_flops, 'conv': conv_flops, 'submod': self.submod.flops[''], 'submod.fc': self.submod.flops['fc'], 'submod.conv': self.submod.flops['conv'], } self.name_to_module: 'Dict[str, nn.Module]' = { '': self, 'fc': self.fc, 'conv': self.conv, 'submod': self.submod, 'submod.fc': self.submod.name_to_module['fc'], 'submod.conv': self.submod.name_to_module['conv'], } def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = torch.flatten(x, 1) x = self.fc(x) x = self.submod(x)**2 return x class UnusedNet(nn.Module): """Has a submodule that is never called in the forward function.""" def __init__(self) -> None: super().__init__() self.input_size = (10, ) fc1_in, fc1_out = 10, 10 fc2_in, fc2_out = 10, 1 unused_in, unused_out = 20, 20 self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out) self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out) self.unused = nn.Linear(in_features=unused_in, out_features=unused_out) self.act: 'nn.Module' = nn.ReLU() self.fc1_flops: int = fc1_in * fc1_out self.fc2_flops: int = fc2_in * fc2_out self.unused_flops: int = unused_in * unused_out def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) class RepeatedNet(nn.Module): """Makes repeated calls to the same submodule.""" def __init__(self) -> None: super().__init__() self.input_size = (10, ) fc1_in, fc1_out = 10, 10 fc2_in, fc2_out = 10, 10 self.fc1_num = 3 self.fc2_num = 2 self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out) self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out) self.fc1_flops: int = fc1_in * fc1_out self.fc2_flops: int = fc2_in * fc2_out def forward(self, x: torch.Tensor) -> torch.Tensor: for _i in range(self.fc1_num): x = self.fc1(x) for _i in range(self.fc2_num): x = self.fc2(x) return x class NonForwardInnerModule(nn.Module): """Has a function separate from the forward function.""" def __init__(self) -> None: super().__init__() self.input_size = (10, ) fc_in, fc_out = 10, 1 self.fc = nn.Linear(in_features=fc_in, out_features=fc_out) self.fc_flops: int = fc_in * fc_out def forward(self, x: torch.Tensor) -> torch.Tensor: return x def other_func(self, x: torch.Tensor) -> torch.Tensor: return self.fc(x) class NonForwardNet(nn.Module): """The submodule has a non-forward function called by the parent module.""" def __init__(self) -> None: super().__init__() self.input_size = (10, ) fc_in, fc_out = 10, 10 self.submod = NonForwardInnerModule() self.fc = nn.Linear(in_features=fc_in, out_features=fc_out) self.fc_flops: int = fc_in * fc_out def forward(self, x: torch.Tensor) -> torch.Tensor: return self.submod.other_func(self.fc(x)) class SharedInnerModule(nn.Module): """Is initialized with a module that it may share with other modules.""" def __init__(self, submod: nn.Module) -> None: super().__init__() self.submod = submod def forward(self, x: torch.Tensor) -> torch.Tensor: return self.submod(x) class SharedModuleNet(nn.Module): """A subsubmodule is shared by multiple submodules. Also calls a module using multiple names. """ def __init__(self) -> None: super().__init__() self.input_size = (10, ) fc1_in, fc1_out = 10, 10 fc2_in, fc2_out = 10, 1 inner = nn.Linear(in_features=fc1_in, out_features=fc1_out) self.submod1 = SharedInnerModule(inner) self.submod2 = SharedInnerModule(inner) multiname = nn.Linear(in_features=fc2_in, out_features=fc2_out) self.multiname1: 'nn.Module' = multiname self.multiname2: 'nn.Module' = multiname self.multiname_flops: int = fc2_in * fc2_out self.shared_flops: int = fc1_in * fc1_out def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.submod1(x) + self.submod2(x) x = self.multiname1(x) + self.multiname2(x) return x class RecursiveScopeNet(nn.Module): """An op is in the same module's scope multiple times.""" def __init__(self) -> None: super().__init__() self.input_size = (10, ) fc_in, fc_out = 10, 1 self.fc = nn.Linear(in_features=fc_in, out_features=fc_out) self.flops: int = fc_in * fc_out def forward(self, x: torch.Tensor, count: int = 3) -> torch.Tensor: if count > 0: return self(x, count - 1) return self.fc(x) class TraceWarningNet(nn.Module): """Will raise a warning on trace due to python comparison of tensor data, and explicitly raises a runtime warning. Also has an aten::add op that will be skipped and raise a warning. """ def __init__(self) -> None: super().__init__() self.input_size = (10, ) fc1_in, fc1_out = 10, 1 fc2_in, fc2_out = 10, 10 self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out) self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out) self.fc1_flops: int = fc1_in * fc1_out self.fc2_flops: int = fc2_in * fc2_out def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.fc1(x).item() warnings.warn('Dummy RuntimeWarning.', RuntimeWarning) if y < 0.0: x = self.fc2(x) return x + 2 class TestJitModelAnalysis(unittest.TestCase): """Unittest for JitModelAnalysis. Tests for specific jit_handles are covered in test_flop_count.py and test_activation_count.py. """ def setUp(self) -> None: # nn.Linear uses a different operator based on version, so make sure # we are testing the right thing. lin = nn.Linear(10, 10) lin_x: torch.Tensor = torch.randn(10, 10) trace = torch.jit.trace(lin, (lin_x, )) node_kinds = [node.kind() for node in trace.graph.nodes()] assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds if 'aten::addmm' in node_kinds: self.lin_op = 'addmm' else: self.lin_op = 'linear' def test_total(self) -> None: """Tests that JitModelAnalysis.total(module) returns the correct counts for string and module inputs.""" model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopAnalyzer(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) # Using a string input for name in model.flops: with self.subTest(name=name): gt_flops = sum(model.flops[name].values()) self.assertEqual(analyzer.total(name), gt_flops) def test_by_module(self) -> None: """Tests that JitModelAnalysis.by_module() returns the correct counts in the correctly structured dictionary.""" model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopAnalyzer(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) flops = { name: sum(counts.values()) for name, counts in model.flops.items() } self.assertEqual(analyzer.by_module(), flops) def test_by_operator(self) -> None: """Tests that JitModelAnalysis.by_operator(module) returns the correct counts for string and module inputs.""" model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopAnalyzer(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) # Using a string input for name in model.flops: with self.subTest(name=name): self.assertEqual(analyzer.by_operator(name), model.flops[name]) def test_by_module_and_operator(self) -> None: """Tests that JitModelAnalysis.by_module_and_operator() returns the correct counts in the correct structure.""" model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopAnalyzer(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) self.assertEqual(analyzer.by_module_and_operator(), model.flops) def test_unused_module(self) -> None: """Tests that unused modules return 0 count for operator sums and and empty Counter() for per-operator results. Also tests that unused modules are reported by .uncalled_modules(), but that modules that simply have zero flops (like ReLU) are not. """ model = UnusedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopAnalyzer(model=model, inputs=inputs) unused_count = 0 unused_per_operator = Counter() # type: Counter model_count = model.fc1_flops + model.fc2_flops self.assertEqual(analyzer.total('unused'), unused_count) self.assertEqual(analyzer.by_operator('unused'), unused_per_operator) self.assertEqual(analyzer.total(''), model_count) # The unused mod is recognized as never called self.assertEqual(analyzer.uncalled_modules(), {'unused'}) def test_repeated_module(self) -> None: """Tests that repeated calls to the same submodule correct aggregates results to that submodule.""" model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopAnalyzer(model=model, inputs=inputs) fc1_count = model.fc1_num * model.fc1_flops fc2_count = model.fc2_num * model.fc2_flops total_count = fc1_count + fc2_count fc1_per_operator = Counter({self.lin_op: fc1_count}) self.assertEqual(analyzer.total('fc1'), fc1_count) self.assertEqual(analyzer.total('fc2'), fc2_count) self.assertEqual(analyzer.total(''), total_count) self.assertEqual(analyzer.by_operator('fc1'), fc1_per_operator) # Tests no uncalled mods self.assertEqual(analyzer.uncalled_modules(), set()) def test_non_forward_func_call(self) -> None: """Tests calls to a submodule's non-forward function. Also tests that the intermediate module is correctly identified as a skipped module. """ model = NonForwardNet() inputs = (torch.randn((1, 10)), ) analyzer = FlopAnalyzer( model=model, inputs=inputs).ancestor_mode('caller') inner_fc_count = model.submod.fc_flops total_count = model.fc_flops + inner_fc_count self.assertEqual(analyzer.total('submod'), 0) self.assertEqual(analyzer.total('submod.fc'), inner_fc_count) self.assertEqual(analyzer.total(''), total_count) # The mod not directly called is registered as such self.assertEqual(analyzer.uncalled_modules(), {'submod'}) analyzer = FlopAnalyzer( model=model, inputs=inputs).ancestor_mode('owner') self.assertEqual(analyzer.total('submod'), inner_fc_count) self.assertEqual(analyzer.total('submod.fc'), inner_fc_count) self.assertEqual(analyzer.total(''), total_count) self.assertEqual(analyzer.uncalled_modules(), set()) def test_shared_module(self) -> None: """Tests the behavior of shared submodules that may have multiple names.""" model = SharedModuleNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = ( FlopAnalyzer(model=model, inputs=inputs).unsupported_ops_warnings( enabled=False).ancestor_mode('caller')) # The names `submod2.submod` and `multiname2` are not included, # since only the first name of a module is made the canonical one. # The counts associated with these cases are included under # `submod1.submod` and `multiname1` respectively. multiname_flops = 2 * model.multiname_flops # Called under 2 names shared_flops = 2 * model.shared_flops # Shared under 2 submodules total_flops = multiname_flops + shared_flops flops = { '': total_flops, 'submod1': model.shared_flops, 'submod1.submod': shared_flops, 'submod2': model.shared_flops, 'multiname1': multiname_flops, } self.assertEqual(analyzer.by_module(), flops) # Test access by alternative name self.assertEqual( analyzer.total('submod2.submod'), flops['submod1.submod'], ) self.assertEqual( analyzer.total('multiname2'), flops['multiname1'], ) # Test getting canonical name self.assertEqual( analyzer.canonical_module_name('multiname2'), 'multiname1') self.assertEqual( analyzer.canonical_module_name('multiname1'), 'multiname1') self.assertEqual( analyzer.canonical_module_name('submod2.submod'), 'submod1.submod') self.assertEqual( analyzer.canonical_module_name('submod1.submod'), 'submod1.submod') # Tests no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set()) def test_recursive_scope(self) -> None: """Tests that an op is only counted once per module, even if it is in the scope of that module multiple times.""" model = RecursiveScopeNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopAnalyzer(model, inputs) self.assertEqual(analyzer.total(), model.flops) self.assertEqual(analyzer.total('fc'), model.flops) # Tests no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set()) def test_data_parallel(self) -> None: """Tests that a model wrapped in DataParallel still returns results labeled by the correct scopes.""" model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) # Find flops for wrapper flops = { 'module' + ('.' if name else '') + name: flop for name, flop in model.flops.items() } flops[''] = model.flops[''] name_to_module = { 'module' + ('.' if name else '') + name: mod for name, mod in model.name_to_module.items() } name_to_module[''] = model.name_to_module[''] model = torch.nn.DataParallel(model).cpu() analyzer = FlopAnalyzer(model=model, inputs=inputs) analyzer.unsupported_ops_warnings(enabled=False) # Using a string input for name in flops: with self.subTest(name=name): gt_flops = sum(flops[name].values()) self.assertEqual(analyzer.total(name), gt_flops) # Output as dictionary self.assertEqual(analyzer.by_module_and_operator(), flops) # Test no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set()) def test_data_parallel_root_scope(self) -> None: # A test case discussed in D32227000 model = nn.DataParallel(nn.Linear(10, 10)).cpu() for mode in ['caller', 'owner']: flop = FlopAnalyzer(model, (torch.randn(10, 10), )) flop.ancestor_mode(mode) self.assertEqual(flop.total(), 1000) def test_unsupported_ops(self) -> None: """Tests per-module recording of unsupported operations.""" model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) analyzer = JitModelAnalysis( model=model, inputs=inputs).set_op_handle( 'aten::addmm', addmm_flop_jit, 'aten::linear', linear_flop_jit, ) analyzer.total() skipped_inner_conv = Counter({'aten::_convolution': 1}) skipped_inner_fc = Counter() # type: Counter skipped_inner = Counter({'aten::add': 1, 'aten::mul': 1}) skipped_inner += skipped_inner_fc skipped_inner += skipped_inner_conv skipped_outer_conv = Counter({'aten::_convolution': 1}) skipped_outer_fc = Counter() # type: Counter skipped_outer = Counter({'aten::pow': 1}) skipped_outer += skipped_outer_conv skipped_outer += skipped_outer_fc skipped_outer += skipped_inner skipped = { '': skipped_outer, 'conv': skipped_outer_conv, 'fc': skipped_outer_fc, 'submod': skipped_inner, 'submod.conv': skipped_inner_conv, 'submod.fc': skipped_inner_fc, } # Access by string for name in skipped: with self.subTest(name=name): self.assertEqual(analyzer.unsupported_ops(name), skipped[name]) def test_changing_handles(self) -> None: """Tests .set_op_handle(), .clear_op_handles()""" model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) op_handles: 'Dict[str, Handle]' = { 'aten::addmm': addmm_flop_jit, 'aten::linear': linear_flop_jit, } analyzer = JitModelAnalysis( model=model, inputs=inputs).set_op_handle(**op_handles) analyzer.unsupported_ops_warnings(enabled=False) # Request a result once to cache flop counts _ = analyzer.total('') # Add an op handle analyzer.set_op_handle('aten::_convolution', conv_flop_jit) self.assertEqual(analyzer.by_module_and_operator(), model.flops) # Overwrite an op handle def make_dummy_op(name: str, output: int) -> Handle: def dummy_ops_handle(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: return Counter({name: output}) return dummy_ops_handle dummy_name = 'dummy_op' dummy_out = 1000 analyzer.set_op_handle(f'aten::{self.lin_op}', make_dummy_op(dummy_name, dummy_out)) dummy_flops = {} for name, counts in model.flops.items(): dummy_flops[name] = Counter( {op: flop for op, flop in counts.items() if op != self.lin_op}) dummy_flops[''][dummy_name] = 2 * dummy_out dummy_flops['fc'][dummy_name] = dummy_out dummy_flops['submod'][dummy_name] = dummy_out dummy_flops['submod.fc'][dummy_name] = dummy_out self.assertEqual(analyzer.by_module_and_operator(), dummy_flops) # Clear ops handles analyzer.clear_op_handles() empty_flops = {name: Counter() for name in model.flops} # type: Dict self.assertEqual(analyzer.by_module_and_operator(), empty_flops) def test_copy(self) -> None: """Tests .copy(...)""" model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = ( JitModelAnalysis(model=model, inputs=inputs).set_op_handle( 'aten::addmm', addmm_flop_jit, 'aten::linear', linear_flop_jit, ).unsupported_ops_warnings(enabled=False).tracer_warnings( mode='none')) repeated_net_flops = model.fc1_num * model.fc1_flops repeated_net_flops += model.fc2_num * model.fc2_flops analyzer_copy = analyzer.copy() # Outputs are the same self.assertEqual( analyzer.by_module_and_operator(), analyzer_copy.by_module_and_operator(), ) # Settings match self.assertEqual( analyzer._enable_warn_unsupported_ops, analyzer_copy._enable_warn_unsupported_ops, ) self.assertEqual( analyzer._enable_warn_uncalled_mods, analyzer_copy._enable_warn_uncalled_mods, ) self.assertEqual(analyzer._warn_trace, analyzer_copy._warn_trace) # Changing copy does not change original analyzer_copy.unsupported_ops_warnings(enabled=True) self.assertNotEqual( analyzer._enable_warn_unsupported_ops, analyzer_copy._enable_warn_unsupported_ops, ) # Copy with new model and inputs new_model = NonForwardNet() bs = 5 new_inputs = (torch.randn((bs, *new_model.input_size)), ) analyzer_new = analyzer.copy( new_model=new_model, new_inputs=new_inputs) non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops # Total is correct for new model and inputs self.assertEqual(analyzer_new.total(), non_forward_flops * bs) # Original is unaffected self.assertEqual(analyzer.total(), repeated_net_flops) # Settings match self.assertEqual( analyzer._enable_warn_unsupported_ops, analyzer_new._enable_warn_unsupported_ops, ) self.assertEqual(analyzer._warn_trace, analyzer_new._warn_trace) def test_disable_warnings(self) -> None: """Tests .unsupported_ops_warnings(...) and .tracer_warnings(...)""" model = TraceWarningNet() inputs = (torch.randn((1, *model.input_size)), ) analyzer = FlopAnalyzer(model=model, inputs=inputs) # Tracer warnings analyzer.tracer_warnings(mode='all') analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(torch.jit.TracerWarning, analyzer.total) analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(RuntimeWarning, analyzer.total) analyzer.tracer_warnings(mode='none') analyzer._stats = None # Manually clear cache so trace is rerun with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') _ = analyzer.total() if w: warning_types = [s.category for s in w] self.assertFalse(torch.jit.TracerWarning in warning_types) self.assertFalse(RuntimeWarning in warning_types) analyzer.tracer_warnings(mode='no_tracer_warning') analyzer._stats = None # Manually clear cache so trace is rerun self.assertWarns(RuntimeWarning, analyzer.total) analyzer._stats = None # Manually clear cache so trace is rerun with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') _ = analyzer.total() if w: warning_types = [s.category for s in w] self.assertFalse(torch.jit.TracerWarning in warning_types) # Unsupported ops and uncalled modules warnings logger = MMLogger.get_current_instance() skipeed_msg = 'Unsupported operator aten::add encountered 1 time(s)' uncalled_msg = 'never called' uncalled_modules = 'fc1' # fc2 is called by chance analyzer.uncalled_modules_warnings(enabled=False) analyzer.unsupported_ops_warnings(enabled=False) analyzer._stats = None # Manually clear cache so trace is rerun with self.assertLogs(logger, logging.WARN) as cm: logger.warning('Dummy warning.') _ = analyzer.total() self.assertFalse(any(skipeed_msg in s for s in cm.output)) self.assertFalse(any(uncalled_msg in s for s in cm.output)) analyzer.unsupported_ops_warnings(enabled=True) analyzer.uncalled_modules_warnings(enabled=True) analyzer._stats = None # Manually clear cache so trace is rerun logger = MMLogger.get_current_instance() with self.assertLogs(logger, logging.WARN) as cm: _ = analyzer.total() self.assertTrue(any(skipeed_msg in s for s in cm.output)) self.assertTrue(any(uncalled_msg in s for s in cm.output)) self.assertTrue(any(uncalled_modules in s for s in cm.output)) def test_skip_uncalled_containers_warnings(self) -> None: # uncalled containers should not warn class A(nn.Module): def forward(self, x): return self.submod[0](x) + 1 mod = A() mod.submod = nn.ModuleList([nn.Linear(3, 3)]) # pyre-ignore analyzer = FlopAnalyzer(model=mod, inputs=torch.rand(1, 3)) analyzer.unsupported_ops_warnings(enabled=False) logger = MMLogger.get_current_instance() with self.assertLogs(logger, logging.WARN) as cm: logger.warning('Dummy warning.') _ = analyzer.total() uncalled_string = 'Module never called: submod' self.assertFalse(any(uncalled_string in s for s in cm.output))