[Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320)
* support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytestpull/304/head^2
parent
31052ea322
commit
972fd8e0c7
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from queue import Queue
|
||||
from collections import deque
|
||||
from typing import Callable
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ class DistillDelivery(metaclass=ABCMeta):
|
|||
def __init__(self, max_keep_data: int = 1) -> None:
|
||||
|
||||
self._override_data = False
|
||||
self.data_queue: Queue = Queue(maxsize=max_keep_data)
|
||||
self.data_queue: deque = deque([], maxlen=max_keep_data)
|
||||
self.max_keep_data = max_keep_data
|
||||
|
||||
@property
|
||||
|
|
|
@ -78,23 +78,7 @@ class FunctionOutputsDelivery(DistillDelivery):
|
|||
super().__init__(max_keep_data)
|
||||
|
||||
self._check_valid_path(func_path)
|
||||
module_path = self._get_module_path(func_path)
|
||||
try:
|
||||
module = import_modules_from_strings(module_path)
|
||||
except ImportError:
|
||||
raise ImportError(f'{module_path} is not imported correctly.')
|
||||
self.module = module
|
||||
|
||||
func_name = self._get_func_name(func_path)
|
||||
assert hasattr(module, func_name), \
|
||||
f'{func_name} is not in {module_path}.'
|
||||
self.func_name = func_name
|
||||
|
||||
origin_func = getattr(module, func_name)
|
||||
if not isinstance(origin_func, FunctionType):
|
||||
raise TypeError(f'{func_name} should be a FunctionType '
|
||||
f'instance, but got {type(origin_func)}')
|
||||
self.origin_func = origin_func
|
||||
self.func_path = func_path
|
||||
|
||||
@staticmethod
|
||||
def _check_valid_path(func_path: str) -> None:
|
||||
|
@ -121,6 +105,24 @@ class FunctionOutputsDelivery(DistillDelivery):
|
|||
|
||||
Wrap the origin function.
|
||||
"""
|
||||
module_path = self._get_module_path(self.func_path)
|
||||
try:
|
||||
module = import_modules_from_strings(module_path)
|
||||
except ImportError:
|
||||
raise ImportError(f'{module_path} is not imported correctly.')
|
||||
self.module = module
|
||||
|
||||
func_name = self._get_func_name(self.func_path)
|
||||
assert hasattr(module, func_name), \
|
||||
f'{func_name} is not in {module_path}.'
|
||||
self.func_name = func_name
|
||||
|
||||
origin_func = getattr(module, func_name)
|
||||
if not isinstance(origin_func, FunctionType):
|
||||
raise TypeError(f'{func_name} should be a FunctionType '
|
||||
f'instance, but got {type(origin_func)}')
|
||||
self.origin_func = origin_func
|
||||
|
||||
wrapped_func = self.deliver_wrapper(self.origin_func)
|
||||
setattr(self.module, self.func_name, wrapped_func)
|
||||
|
||||
|
@ -131,6 +133,11 @@ class FunctionOutputsDelivery(DistillDelivery):
|
|||
"""
|
||||
setattr(self.module, self.func_name, self.origin_func)
|
||||
|
||||
# self.module and self.origin_func can not be pickled.
|
||||
# Delete these two attributes to avoid errors when ema model is used.
|
||||
del self.module
|
||||
del self.origin_func
|
||||
|
||||
def deliver_wrapper(self, origin_func: Callable) -> Callable:
|
||||
"""Wrap the specific function to make the intermediate results of the
|
||||
model can be delivered."""
|
||||
|
@ -139,12 +146,13 @@ class FunctionOutputsDelivery(DistillDelivery):
|
|||
def wrap_func(*args, **kwargs):
|
||||
|
||||
if self.override_data:
|
||||
assert not self.data_queue.empty(), 'pop from an empty queue'
|
||||
outputs = self.data_queue.get()
|
||||
assert len(self.data_queue) > 0, 'pop from an empty queue'
|
||||
outputs = self.data_queue.popleft()
|
||||
else:
|
||||
assert not self.data_queue.full(), 'push into an full queue'
|
||||
assert len(self.data_queue) < self.data_queue.maxlen,\
|
||||
'push into an full queue'
|
||||
outputs = origin_func(*args, **kwargs)
|
||||
self.data_queue.put(outputs)
|
||||
self.data_queue.append(outputs)
|
||||
return outputs
|
||||
|
||||
return wrap_func
|
||||
|
|
|
@ -143,12 +143,13 @@ class MethodOutputsDelivery(DistillDelivery):
|
|||
def wrap_method(*args, **kwargs):
|
||||
|
||||
if self.override_data:
|
||||
assert not self.data_queue.empty(), 'pop from an empty queue'
|
||||
outputs = self.data_queue.get()
|
||||
assert len(self.data_queue) > 0, 'pop from an empty queue'
|
||||
outputs = self.data_queue.popleft()
|
||||
else:
|
||||
assert not self.data_queue.full(), 'push into an full queue'
|
||||
assert len(self.data_queue) < self.data_queue.maxlen,\
|
||||
'push into an full queue'
|
||||
outputs = origin_method(*args, **kwargs)
|
||||
self.data_queue.put(outputs)
|
||||
self.data_queue.append(outputs)
|
||||
return outputs
|
||||
|
||||
return wrap_method
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .function_inputs_recorder import FunctionInputsRecorder
|
||||
from .function_outputs_recorder import FunctionOutputsRecorder
|
||||
from .method_inputs_recorder import MethodInputsRecorder
|
||||
from .method_outputs_recorder import MethodOutputsRecorder
|
||||
from .module_inputs_recorder import ModuleInputsRecorder
|
||||
from .module_outputs_recorder import ModuleOutputsRecorder
|
||||
|
@ -9,5 +11,5 @@ from .recorder_manager import RecorderManager
|
|||
__all__ = [
|
||||
'FunctionOutputsRecorder', 'MethodOutputsRecorder',
|
||||
'ModuleOutputsRecorder', 'ParameterRecorder', 'RecorderManager',
|
||||
'ModuleInputsRecorder'
|
||||
'ModuleInputsRecorder', 'MethodInputsRecorder', 'FunctionInputsRecorder'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
from inspect import signature
|
||||
from typing import Callable, List
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .function_outputs_recorder import FunctionOutputsRecorder
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class FunctionInputsRecorder(FunctionOutputsRecorder):
|
||||
"""Recorder for intermediate results which are ``FunctionType``'s inputs.
|
||||
|
||||
Notes:
|
||||
The form of `source` needs special attention. For example,
|
||||
`anchor_inside_flags` is a function in mmdetection to check whether the
|
||||
anchors are inside the border. This function is in
|
||||
`mmdet/core/anchor/utils.py` and used in
|
||||
`mmdet/models/dense_heads/anchor_head.py`. Then the source should be
|
||||
`mmdet.models.dense_heads.anchor_head.anchor_inside_flags` but not
|
||||
`mmdet.core.anchor.utils.anchor_inside_flags`.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> # Below code in toy_module.py
|
||||
>>> import random
|
||||
>>> def toy_func(a, b):
|
||||
... return a, b
|
||||
>>> def execute_toy_func(a, b):
|
||||
... toy_func(a, b)
|
||||
|
||||
>>> # Below code in main.py
|
||||
>>> # Now, we want to get teacher's inputs by recorder.
|
||||
|
||||
>>> from toy_module import execute_toy_func
|
||||
>>> r1 = FunctionInputsRecorder('toy_module.toy_func')
|
||||
>>> r1.initialize()
|
||||
>>> with r1:
|
||||
... execute_toy_func(1, 2)
|
||||
... execute_toy_func(1, b=2)
|
||||
... execute_toy_func(b=2, a=1)
|
||||
|
||||
>>> r1.data_buffer
|
||||
[[1, 2], [1, 2], [1, 2]]
|
||||
"""
|
||||
|
||||
def func_record_wrapper(self, origin_func: Callable,
|
||||
data_buffer: List) -> Callable:
|
||||
"""Save the function's inputs.
|
||||
|
||||
Args:
|
||||
origin_func (FunctionType): The method whose inputs need to be
|
||||
recorded.
|
||||
data_buffer (list): A list of data.
|
||||
"""
|
||||
|
||||
func_input_params = signature(origin_func).parameters.keys()
|
||||
|
||||
@functools.wraps(origin_func)
|
||||
def wrap_func(*args, **kwargs):
|
||||
outputs = origin_func(*args, **kwargs)
|
||||
inputs = list(args)
|
||||
for keyword in func_input_params:
|
||||
if keyword in kwargs:
|
||||
inputs.append(kwargs[keyword])
|
||||
# assume a func execute N times, there will be N inputs need to
|
||||
# save.
|
||||
data_buffer.append(inputs)
|
||||
return outputs
|
||||
|
||||
return wrap_func
|
|
@ -65,28 +65,8 @@ class FunctionOutputsRecorder(BaseRecorder):
|
|||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._check_valid_source(self.source)
|
||||
|
||||
# import the function corrosponding module
|
||||
try:
|
||||
mod = import_modules_from_strings(self.module_string)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f'{self.module_string} is not imported correctly.')
|
||||
|
||||
self.imported_module: ModuleType = mod
|
||||
|
||||
assert hasattr(mod, self.func_name), \
|
||||
f'{self.func_name} is not in {self.module_string}.'
|
||||
|
||||
origin_func = getattr(mod, self.func_name)
|
||||
if not isinstance(origin_func, FunctionType):
|
||||
raise TypeError(f'{self.func_name} should be a FunctionType '
|
||||
f'instance, but got {type(origin_func)}')
|
||||
|
||||
self.origin_func: Callable = origin_func
|
||||
|
||||
@staticmethod
|
||||
def _check_valid_source(source):
|
||||
"""Check if the source's format is valid."""
|
||||
|
@ -118,8 +98,7 @@ class FunctionOutputsRecorder(BaseRecorder):
|
|||
Args:
|
||||
origin_func (FunctionType): The method whose outputs need to be
|
||||
recorded.
|
||||
buffer_key (str): The key of the function's outputs saved in
|
||||
``data_buffer``.
|
||||
data_buffer (list): A list of data.
|
||||
"""
|
||||
|
||||
@functools.wraps(origin_func)
|
||||
|
@ -136,8 +115,25 @@ class FunctionOutputsRecorder(BaseRecorder):
|
|||
"""Enter the context manager."""
|
||||
super().__enter__()
|
||||
|
||||
mod = self.imported_module
|
||||
origin_func = self.origin_func
|
||||
# import the function corrosponding module
|
||||
try:
|
||||
mod = import_modules_from_strings(self.module_string)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f'{self.module_string} is not imported correctly.')
|
||||
|
||||
self.imported_module: ModuleType = mod
|
||||
|
||||
assert hasattr(mod, self.func_name), \
|
||||
f'{self.func_name} is not in {self.module_string}.'
|
||||
|
||||
origin_func = getattr(mod, self.func_name)
|
||||
if not isinstance(origin_func, FunctionType):
|
||||
raise TypeError(f'{self.func_name} should be a FunctionType '
|
||||
f'instance, but got {type(origin_func)}')
|
||||
|
||||
self.origin_func: Callable = origin_func
|
||||
|
||||
# add record wrapper to origin function.
|
||||
record_func = self.func_record_wrapper(origin_func, self.data_buffer)
|
||||
|
||||
|
@ -159,3 +155,8 @@ class FunctionOutputsRecorder(BaseRecorder):
|
|||
|
||||
# restore the origin function
|
||||
setattr(mod, self.func_name, origin_func)
|
||||
|
||||
# self.imported_module and self.origin_func can not be pickled.
|
||||
# Delete these two attributes to avoid errors when ema model is used.
|
||||
del self.imported_module
|
||||
del self.origin_func
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
from inspect import signature
|
||||
from typing import Callable, List
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .method_outputs_recorder import MethodOutputsRecorder
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class MethodInputsRecorder(MethodOutputsRecorder):
|
||||
"""Recorder for intermediate results which are ``MethodType``'s inputs.
|
||||
|
||||
Note:
|
||||
Different from ``FunctionType``, ``MethodType`` is the type of methods
|
||||
of class instances.
|
||||
|
||||
Examples:
|
||||
>>> # Below code in toy_module.py
|
||||
>>> import random
|
||||
>>> class Toy():
|
||||
... def toy_func(self, x, y=0):
|
||||
... return x + y
|
||||
|
||||
>>> # Below code in main.py
|
||||
>>> # Now, we want to get teacher's inputs by recorder.
|
||||
|
||||
>>> from toy_module import Toy
|
||||
>>> toy = Toy()
|
||||
>>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func')
|
||||
>>> r1.initialize()
|
||||
>>> with r1:
|
||||
... _ = toy.toy_func(1, 2)
|
||||
|
||||
>>> r1.data_buffer
|
||||
[[1, 2]]
|
||||
>>> r1.get_record_data(record_idx=0, data_idx=0)
|
||||
1
|
||||
>>> r1.get_record_data(record_idx=0, data_idx=1)
|
||||
2
|
||||
|
||||
>>> from toy_module import Toy
|
||||
>>> toy = Toy()
|
||||
>>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func')
|
||||
>>> r1.initialize()
|
||||
>>> with r1:
|
||||
... _ = toy.toy_func(1, 2)
|
||||
... _ = toy.toy_func(y=2, x=1)
|
||||
|
||||
>>> r1.data_buffer
|
||||
[[1, 2], [1, 2]]
|
||||
>>> r1.get_record_data(record_idx=1, data_idx=0)
|
||||
1
|
||||
>>> r1.get_record_data(record_idx=1, data_idx=1)
|
||||
2
|
||||
"""
|
||||
|
||||
def method_record_wrapper(self, orgin_method: Callable,
|
||||
data_buffer: List) -> Callable:
|
||||
"""Save the method's inputs.
|
||||
|
||||
Args:
|
||||
origin_method (MethodType): The method whose inputs need to be
|
||||
recorded.
|
||||
data_buffer (list): A list of data.
|
||||
"""
|
||||
|
||||
method_input_params = signature(orgin_method).parameters.keys()
|
||||
|
||||
@functools.wraps(orgin_method)
|
||||
def wrap_method(*args, **kwargs):
|
||||
outputs = orgin_method(*args, **kwargs)
|
||||
# the first element of a class method is the class itself
|
||||
inputs = list(args[1:])
|
||||
for keyword in method_input_params:
|
||||
if keyword in kwargs:
|
||||
inputs.append(kwargs[keyword])
|
||||
# Assume a func execute N times, there will be N inputs need to
|
||||
# save.
|
||||
data_buffer.append(inputs)
|
||||
return outputs
|
||||
|
||||
return wrap_method
|
|
@ -130,8 +130,7 @@ class MethodOutputsRecorder(BaseRecorder):
|
|||
Args:
|
||||
origin_method (MethodType): The method whose outputs need to be
|
||||
recorded.
|
||||
buffer_key (str): The key of the method's outputs saved in
|
||||
``data_buffer``.
|
||||
data_buffer (list): A list of data.
|
||||
"""
|
||||
|
||||
@functools.wraps(orgin_method)
|
||||
|
|
|
@ -1,11 +1,75 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.hooks import EMAHook
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModel, ExponentialMovingAverage
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.runner import Runner
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmrazor.models.task_modules import FunctionOutputsDelivery
|
||||
|
||||
|
||||
class ToyModel(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 1)
|
||||
# test FunctionOutputsDelivery when ema_hook is used
|
||||
self.deliver = FunctionOutputsDelivery(
|
||||
max_keep_data=2, func_path='toy_module.toy_func')
|
||||
|
||||
def forward(self, inputs, data_sample, mode='tensor'):
|
||||
labels = torch.stack(data_sample)
|
||||
inputs = torch.stack(inputs)
|
||||
with self.deliver:
|
||||
outputs = self.linear(inputs)
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
class TestFuncOutputsDeliver(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def tearDown(self):
|
||||
# `FileHandler` should be closed in Windows, otherwise we cannot
|
||||
# delete the temporary directory
|
||||
logging.shutdown()
|
||||
MMLogger._instance_dict.clear()
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_init(self):
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'func_path should be'):
|
||||
|
@ -14,20 +78,26 @@ class TestFuncOutputsDeliver(TestCase):
|
|||
with self.assertRaisesRegex(AssertionError, 'func_path must have at '):
|
||||
_ = FunctionOutputsDelivery(max_keep_data=1, func_path='toy_func')
|
||||
|
||||
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
|
||||
_ = FunctionOutputsDelivery(max_keep_data=1, func_path='aaa.bb')
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'bb is not in toy_mod'):
|
||||
_ = FunctionOutputsDelivery(
|
||||
max_keep_data=1, func_path='toy_module.bb')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'):
|
||||
_ = FunctionOutputsDelivery(
|
||||
max_keep_data=1, func_path='toy_module.TOY_VAR')
|
||||
|
||||
def test_context_manager(self):
|
||||
import toy_module
|
||||
|
||||
delivery = FunctionOutputsDelivery(max_keep_data=2, func_path='aaa.bb')
|
||||
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
|
||||
with delivery:
|
||||
_ = toy_module.toy_func()
|
||||
|
||||
delivery = FunctionOutputsDelivery(
|
||||
max_keep_data=1, func_path='toy_module.bb')
|
||||
with self.assertRaisesRegex(AssertionError, 'bb is not in toy_mod'):
|
||||
with delivery:
|
||||
_ = toy_module.toy_func()
|
||||
|
||||
delivery = FunctionOutputsDelivery(
|
||||
max_keep_data=1, func_path='toy_module.TOY_VAR')
|
||||
with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'):
|
||||
with delivery:
|
||||
_ = toy_module.toy_func()
|
||||
|
||||
delivery = FunctionOutputsDelivery(
|
||||
max_keep_data=2, func_path='toy_module.toy_func')
|
||||
|
||||
|
@ -52,3 +122,42 @@ class TestFuncOutputsDeliver(TestCase):
|
|||
with self.assertRaisesRegex(AssertionError, 'pop from an empty queue'):
|
||||
with delivery:
|
||||
_ = toy_module.toy_func()
|
||||
|
||||
def test_ema_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
model = ToyModel().to(device)
|
||||
evaluator = Evaluator([])
|
||||
evaluator.evaluate = Mock(return_value=dict(acc=0.5))
|
||||
runner = Runner(
|
||||
model=model,
|
||||
train_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=evaluator,
|
||||
work_dir=self.temp_dir.name,
|
||||
default_scope='mmrazor',
|
||||
optim_wrapper=OptimWrapper(
|
||||
torch.optim.Adam(ToyModel().parameters())),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||
val_cfg=dict(),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', )],
|
||||
experiment_name='test_func_outputs_deliver')
|
||||
runner.train()
|
||||
for hook in runner.hooks:
|
||||
if isinstance(hook, EMAHook):
|
||||
self.assertTrue(
|
||||
isinstance(hook.ema_model, ExponentialMovingAverage))
|
||||
|
||||
self.assertTrue(
|
||||
osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth')))
|
||||
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
|
||||
self.assertTrue('ema_state_dict' in checkpoint)
|
||||
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.hooks import EMAHook
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModel, ExponentialMovingAverage
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.runner import Runner
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmrazor.models.task_modules import FunctionInputsRecorder, RecorderManager
|
||||
|
||||
|
||||
class ToyModel(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 1)
|
||||
# test FunctionInputsRecorder when ema_hook is used
|
||||
recorders_cfg = dict(
|
||||
out=dict(type='FunctionInputs', source='toy_mod.toy_func'))
|
||||
self.recorders = RecorderManager(recorders_cfg)
|
||||
self.recorders.initialize(self)
|
||||
|
||||
def forward(self, inputs, data_sample, mode='tensor'):
|
||||
labels = torch.stack(data_sample)
|
||||
inputs = torch.stack(inputs)
|
||||
with self.recorders:
|
||||
outputs = self.linear(inputs)
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
class TestFuncInputsRecorder(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def tearDown(self):
|
||||
# `FileHandler` should be closed in Windows, otherwise we cannot
|
||||
# delete the temporary directory
|
||||
logging.shutdown()
|
||||
MMLogger._instance_dict.clear()
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_context_manager(self):
|
||||
from toy_mod import execute_toy_func2 as execute_toy_func
|
||||
|
||||
recorder = FunctionInputsRecorder('toy_mod.toy_func2')
|
||||
recorder.initialize()
|
||||
|
||||
with recorder:
|
||||
execute_toy_func(1, 2)
|
||||
execute_toy_func(1, b=2)
|
||||
execute_toy_func(b=2, a=1)
|
||||
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=0, data_idx=0) == 1)
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=0, data_idx=1) == 2)
|
||||
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=1, data_idx=0) == 1)
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=1, data_idx=1) == 2)
|
||||
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=2, data_idx=0) == 1)
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=2, data_idx=1) == 2)
|
||||
|
||||
def test_ema_hook(self):
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
model = ToyModel().to(device)
|
||||
evaluator = Evaluator([])
|
||||
evaluator.evaluate = Mock(return_value=dict(acc=0.5))
|
||||
runner = Runner(
|
||||
model=model,
|
||||
train_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=evaluator,
|
||||
work_dir=self.temp_dir.name,
|
||||
default_scope='mmrazor',
|
||||
optim_wrapper=OptimWrapper(
|
||||
torch.optim.Adam(ToyModel().parameters())),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||
val_cfg=dict(),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', )],
|
||||
experiment_name='test_func_inputs_recorder')
|
||||
runner.train()
|
||||
for hook in runner.hooks:
|
||||
if isinstance(hook, EMAHook):
|
||||
self.assertTrue(
|
||||
isinstance(hook.ema_model, ExponentialMovingAverage))
|
||||
|
||||
self.assertTrue(
|
||||
osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth')))
|
||||
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
|
||||
self.assertTrue('ema_state_dict' in checkpoint)
|
||||
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
|
|
@ -16,18 +16,27 @@ class TestFuncOutputsRecorder(TestCase):
|
|||
with self.assertRaisesRegex(AssertionError, 'source must have at '):
|
||||
_ = FunctionOutputsRecorder('aaaaa')
|
||||
|
||||
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
|
||||
_ = FunctionOutputsRecorder('aaa.bbb')
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'aaa is not in toy_mod'):
|
||||
_ = FunctionOutputsRecorder('toy_mod.aaa')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'):
|
||||
_ = FunctionOutputsRecorder('toy_mod.TOY_VAR')
|
||||
|
||||
def test_context_manager(self):
|
||||
from toy_mod import execute_toy_func
|
||||
|
||||
recorder = FunctionOutputsRecorder('aaa.bbb')
|
||||
recorder.initialize()
|
||||
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
|
||||
with recorder:
|
||||
execute_toy_func(1)
|
||||
|
||||
recorder = FunctionOutputsRecorder('toy_mod.aaa')
|
||||
recorder.initialize()
|
||||
with self.assertRaisesRegex(AssertionError, 'aaa is not in toy_mod'):
|
||||
with recorder:
|
||||
execute_toy_func(1)
|
||||
|
||||
recorder = FunctionOutputsRecorder('toy_mod.TOY_VAR')
|
||||
recorder.initialize()
|
||||
with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'):
|
||||
with recorder:
|
||||
execute_toy_func(1)
|
||||
|
||||
recorder = FunctionOutputsRecorder('toy_mod.toy_func')
|
||||
recorder.initialize()
|
||||
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
from mmrazor.models.task_modules import MethodInputsRecorder
|
||||
|
||||
|
||||
class TestFuncOutputsRecorder(TestCase):
|
||||
|
||||
def test_context_manager(self):
|
||||
from toy_mod import ToyClass
|
||||
|
||||
toy = ToyClass()
|
||||
|
||||
recorder = MethodInputsRecorder('toy_mod.ToyClass.func')
|
||||
recorder.initialize()
|
||||
|
||||
with recorder:
|
||||
_ = toy.func(x=1, y=2)
|
||||
_ = toy.func(1, y=2)
|
||||
_ = toy.func(y=2, x=1)
|
||||
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=0, data_idx=0) == 1)
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=0, data_idx=1) == 2)
|
||||
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=1, data_idx=0) == 1)
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=1, data_idx=1) == 2)
|
||||
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=2, data_idx=0) == 1)
|
||||
self.assertTrue(
|
||||
recorder.get_record_data(record_idx=2, data_idx=1) == 2)
|
|
@ -8,6 +8,10 @@ def toy_func(a):
|
|||
return a
|
||||
|
||||
|
||||
def toy_func2(a, b):
|
||||
return a, b
|
||||
|
||||
|
||||
def toy_list_func(a):
|
||||
return [a, a, a]
|
||||
|
||||
|
@ -16,6 +20,10 @@ def execute_toy_func(a):
|
|||
toy_func(a)
|
||||
|
||||
|
||||
def execute_toy_func2(a, b):
|
||||
toy_func2(a, b)
|
||||
|
||||
|
||||
def execute_toy_list_func(a):
|
||||
toy_list_func(a)
|
||||
|
||||
|
@ -31,6 +39,9 @@ class ToyClass:
|
|||
self._count += 1
|
||||
return self._count
|
||||
|
||||
def func(self, x, y=0):
|
||||
return x + y
|
||||
|
||||
def __call__(self):
|
||||
self._count += 1
|
||||
return self._count
|
||||
|
|
Loading…
Reference in New Issue