[Feature] Add Recorder to improve Distiller

This commit is contained in:
pppppM 2022-07-08 08:09:10 +00:00 committed by pppppM
parent 8913d6840d
commit cb238e36e3
48 changed files with 2099 additions and 671 deletions

View File

@ -0,0 +1,39 @@
_base_ = [
'mmcls::_base_/datasets/imagenet_bs32.py',
'mmcls::_base_/schedules/imagenet_bs256.py',
'mmcls::_base_/default_runtime.py'
]
model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
student_recorders=dict(fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(
from_student=True,
recorder='fc',
),
preds_T=dict(
from_student=False,
recorder='fc',
))))
find_unused_parameters = True
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

View File

@ -0,0 +1,36 @@
_base_ = [
'mmcls::_base_/datasets/imagenet_bs32.py',
'mmcls::_base_/schedules/imagenet_bs256.py',
'mmcls::_base_/default_runtime.py'
]
model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc'),
data_samples=dict(type='ModuleInputs', source='')),
teacher_recorders=dict(fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(loss_wsld=dict(type='WSLD', tau=2, loss_weight=2.5)),
loss_forward_mappings=dict(
loss_wsld=dict(
student=dict(recorder='fc', from_student=True),
teacher=dict(recorder='fc', from_student=False),
data_samples=dict(
recorder='data_samples', from_student=True, data_idx=1))))
find_unused_parameters = True
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

View File

@ -0,0 +1,58 @@
_base_ = [
'mmdet::_base_/datasets/coco_detection.py',
'mmdet::_base_/schedules/schedule_1x.py',
'mmdet::_base_/default_runtime.py'
]
# default_scope = 'mmrazor'
teacher_ckpt = 'faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth' # noqa: E501
model = dict(
_scope_='mmrazor',
type='FpnTeacherDistill',
architecture=dict(
cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py',
pretrained=False),
teacher=dict(
cfg_path='mmdet::faster_rcnn/faster_rcnn_r101_fpn_2x_coco.py',
pretrained=False),
teacher_ckpt=teacher_ckpt,
distill_losses=dict(
loss_cwd_fpn0=dict(
type='ChannelWiseDivergence', tau=1, loss_weight=10),
loss_cwd_fpn1=dict(
type='ChannelWiseDivergence', tau=1, loss_weight=10),
loss_cwd_fpn2=dict(
type='ChannelWiseDivergence', tau=1, loss_weight=10),
loss_cwd_fpn3=dict(
type='ChannelWiseDivergence', tau=1, loss_weight=10),
loss_cwd_fpn4=dict(
type='ChannelWiseDivergence', tau=1, loss_weight=10)),
student_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
teacher_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
loss_forward_mappings=dict(
loss_cwd_fpn0=dict(
preds_S=dict(from_student=True, recorder='fpn', data_idx=0),
preds_T=dict(from_student=False, recorder='fpn', data_idx=0),
),
loss_cwd_fpn1=dict(
preds_S=dict(from_student=True, recorder='fpn', data_idx=1),
preds_T=dict(from_student=False, recorder='fpn', data_idx=1),
),
loss_cwd_fpn2=dict(
preds_S=dict(from_student=True, recorder='fpn', data_idx=2),
preds_T=dict(from_student=False, recorder='fpn', data_idx=2),
),
loss_cwd_fpn3=dict(
preds_S=dict(from_student=True, recorder='fpn', data_idx=3),
preds_T=dict(from_student=False, recorder='fpn', data_idx=3),
),
loss_cwd_fpn4=dict(
preds_S=dict(from_student=True, recorder='fpn', data_idx=4),
preds_T=dict(from_student=False, recorder='fpn', data_idx=4),
),
),
)
find_unused_parameters = True
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

View File

@ -0,0 +1,8 @@
_base_ = ['./cwd_fpn_gfl_r101_gfl_r50_1x_coco.py']
model = dict(
architecture=dict(
cfg_path='mmdet::gfl/gfl_r50_fpn_1x_coco.py', pretrained=False),
teacher=dict(
cfg_path='mmdet::gfl/gfl_r101_fpn_mstrain_2x_coco.py',
pretrained=True))

View File

@ -0,0 +1,9 @@
_base_ = ['./cwd_fpn_frcnn_r101_frcnn_r50_1x_coco.py']
model = dict(
architecture=dict(
cfg_path='mmdet::retinanet/retinanet_r50_fpn_1x_coco.py',
pretrained=False),
teacher=dict(
cfg_path='mmdet::retinanet/retinanet_r101_fpn_2x_coco.py',
pretrained=True))

View File

@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .delivers import * # noqa: F401,F403
from .recorders import * # noqa: F401,F403
from .tracer import * # noqa: F401,F403

View File

@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .deliver_manager import DistillDeliverManager
from .function_outputs_deliver import FunctionOutputsDeliver
from .method_outputs_deliver import MethodOutputsDeliver
from .deliver_manager import DistillDeliveryManager
from .function_outputs_deliver import FunctionOutputsDelivery
from .method_outputs_deliver import MethodOutputsDelivery
__all__ = [
'FunctionOutputsDeliver', 'MethodOutputsDeliver', 'DistillDeliverManager'
'FunctionOutputsDelivery', 'MethodOutputsDelivery',
'DistillDeliveryManager'
]

View File

@ -1,22 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import List
from typing import Dict, Optional
from mmrazor.registry import TASK_UTILS
from .distill_deliver import DistillDelivery
SUPPORT_DELIVERIES = ['FunctionOutputs', 'MethodOutputs']
class DistillDeliverManager:
"""Various types delivers' manager. The ``DistillDeliverManager`` is also a
context manager, managing various types of delivers. When entering the
``DistillDeliverManager``, all delivers managed by it will be started.
class DistillDeliveryManager:
"""Various types deliveries' manager. The ``DistillDeliveryManager`` is
also a context manager, managing various types of deliveries.
When entering the ``DistillDeliveryManager``, all deliveries managed by it
will be started.
Notes:
DistillDeliver is a context manager used to override function (method)
outputs from the target model with function (method) outputs from the
source model.
DistillDelivery is a context manager used to override function(method)
outputs during teacher(student) forward.
Args:
deliveries (list(dict)): Configs of all deliveries.
deliveries (dict): Configs of all deliveries.
Examples:
>>> from mmcls.models.utils import Augments
@ -27,8 +31,8 @@ class DistillDeliverManager:
>>> imgs = torch.randn(2, 3, 32, 32)
>>> label = torch.randint(0, 10, (2,))
>>> # Without ``MethodOutputsDeliver``, outputs of the teacher and the
>>> # student are very likely to be different.
>>> # Without ``MethodOutputsDelivery``, outputs of the teacher and
>>> # the student are different.
>>> imgs_tea, label_tea = augments(imgs, label)
>>> imgs_stu, label_stu = augments(imgs, label)
>>> torch.equal(label_tea, label_stu)
@ -36,10 +40,10 @@ class DistillDeliverManager:
>>> torch.equal(imgs_tea, imgs_stu)
False
>>> distill_deliveries = [
... ConfigDict(type='MethodOutputs', max_keep_data=1,
... method_path='mmcls.models.utils.Augments.__call__')]
>>> manager = DistillDeliverManager(distill_deliveries)
>>> distill_deliveries = ConfigDict(
... aug=dict(type='MethodOutputs', max_keep_data=1,
... method_path='mmcls.models.utils.Augments.__call__'))
>>> manager = DistillDeliveryManager(distill_deliveries)
>>> manager.override_data = False
>>> with manager:
@ -55,44 +59,55 @@ class DistillDeliverManager:
True
"""
def __init__(self, deliveries: List) -> None:
def __init__(self, deliveries: Optional[Dict[str, Dict]] = None) -> None:
# As there may be several delivers belong to a same deliver type,
# we use a list to save delivers rather than a dict.
self.deliveries = list()
for cfg in deliveries:
deliver_cfg = copy.deepcopy(cfg)
deliver_type = cfg.type
deliver_type = deliver_type + 'Deliver'
deliver_cfg.type = deliver_type
self.deliveries.append(TASK_UTILS.build(deliver_cfg))
self._deliveries: Dict[str, DistillDelivery] = dict()
if deliveries:
for delivery_name, delivery_cfg in deliveries.items():
delivery_cfg_ = copy.deepcopy(delivery_cfg)
delivery_type_ = delivery_cfg_.get('type', '')
assert isinstance(delivery_type_, str)
assert delivery_type_ in SUPPORT_DELIVERIES
delivery_type_ = delivery_type_ + 'Delivery'
delivery_cfg_.update(dict(type=delivery_type_))
delivery = TASK_UTILS.build(delivery_cfg_)
self.deliveries[delivery_name] = delivery
self._override_data = False
@property
def override_data(self):
def deliveries(self) -> Dict[str, DistillDelivery]:
"""dict: all deliveries."""
return self._deliveries
@property
def override_data(self) -> bool:
"""bool: indicate whether to override the data with the recorded data.
"""
return self._override_data
@override_data.setter
def override_data(self, override):
"""Set the override_data property to all the delivers.
def override_data(self, override: bool) -> None:
"""Set the override_data property to all the deliveries.
If the `override_data` of a deliver is False, the deliver will record
and keep the origin data. If the current_mode of a deliver is True, the
deliver will override the origin data with the recorded data.
If the `override_data` of a delivery is False, the delivery will
record the origin data.
If the `override_data` of a delivery is True, the delivery will
override the origin data with the recorded data.
"""
self._override_data = override
for deliver in self.deliveries:
deliver.override_data = override
for delivery in self.deliveries.values():
delivery.override_data = override
def __enter__(self) -> None:
"""Enter the context manager."""
for deliver in self.deliveries:
deliver.__enter__()
for delivery in self.deliveries.values():
delivery.__enter__()
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""Exit the context manager."""
for deliver in self.deliveries:
deliver.__exit__(exc_type, exc_value, traceback)
for delivery in self.deliveries.values():
delivery.__exit__(exc_type, exc_value, traceback)

View File

@ -4,32 +4,33 @@ from queue import Queue
from typing import Callable
class DistillDeliver(metaclass=ABCMeta):
"""Base class for delivers for distillation.
# TODO: Support overriding part of the outputs of a function or method
class DistillDelivery(metaclass=ABCMeta):
"""Base class for deliveries for distillation.
DistillDeliver is a context manager used to override function (method)
outputs from the target model with function (method) outputs from the
source model.
In MMRazor, there will be different types of delivers to deliver different
types of data. They can be used in combination with the
DistillDelivery is a context manager used to override function(method)
outputs during teacher(student) forward.
A delivery can only handle one function or method. Some algorithms may use
multiple deliveries, which can be managed uniformly using
``DistillDeliverManager``.
Args:
max_keep_data (int): The length limitation of the queue, should be
larger than the execute times of the function or method. Defaults
to 1.
Notes:
If a function (method) is executed more than once during the forward
of the source model, all the outputs of this function (method) will be
used to override function (method) outputs from the target model.
TODO:
Support overriding some of the outputs of a function (method)
Args:
max_keep_data (int): The length limitation of the queue. If a function
(method) is executed more than once during the forward of the
target model, function (method) outputs from the source model are
pushed into the queue in order. Default to 1.
If a function or method is executed more than once during the forward
of the target model, its' outputs from the source model are pushed
into the queue in order.
"""
def __init__(self, max_keep_data: int = 1):
def __init__(self, max_keep_data: int = 1) -> None:
self._override_data = False
self.data_queue: Queue = Queue(maxsize=max_keep_data)
@ -55,3 +56,11 @@ class DistillDeliver(metaclass=ABCMeta):
def deliver_wrapper(self, origin: Callable) -> Callable:
"""Wrap the specific object to make the intermediate results of the
model can be delivered."""
@abstractmethod
def __enter__(self) -> None:
"""Enter the context manager."""
@abstractmethod
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""Exit the context manager."""

View File

@ -6,11 +6,11 @@ from typing import Callable
from mmcv.utils import import_modules_from_strings
from mmrazor.registry import TASK_UTILS
from .distill_deliver import DistillDeliver
from .distill_deliver import DistillDelivery
@TASK_UTILS.register_module()
class FunctionOutputsDeliver(DistillDeliver):
class FunctionOutputsDelivery(DistillDelivery):
"""Delivery for intermediate results which are ``FunctionType``'s outputs.
Args:
@ -29,23 +29,28 @@ class FunctionOutputsDeliver(DistillDeliver):
`mmdet.core.anchor.utils.anchor_inside_flags`.
Examples:
>>> # Suppose there is a toy function named ``toy_func`` in test.py.
>>> # It return random integers from 0 to 999.
>>> import toy_module
>>> # Suppose we want to deliver outputs from the teacher to
>>> # Below code in toy_module.py
>>> import random
>>> def toy_func():
>>> return random.randint(0, 1000)
>>> # Below code in main.py
>>> # Teacher and student both will execute toy_func.
>>> # Now, we want to deliver outputs from the teacher to
>>> # the student
>>> import toy_module
>>> delivery = FunctionOutputsDeliver(
... max_keep_data=1, func_path='toy_module.toy_func')
>>> delivery.override_data = False
>>> with delivery:
... output_tea = toy_module.toy_func()
... output_teacher = toy_module.toy_func()
>>> delivery.override_data = True
>>> with delivery:
... output_stu = toy_module.toy_func()
... output_student = toy_module.toy_func()
>>> output_tea == output_stu
>>> output_teacher == output_student
True
>>> # If a function (method) is executed more than once during the

View File

@ -6,11 +6,11 @@ from typing import Callable
from mmcv.utils import import_modules_from_strings
from mmrazor.registry import TASK_UTILS
from .distill_deliver import DistillDeliver
from .distill_deliver import DistillDelivery
@TASK_UTILS.register_module()
class MethodOutputsDeliver(DistillDeliver):
class MethodOutputsDelivery(DistillDelivery):
"""Delivery for intermediate results which are ``MethodType``'s outputs.
Note:

View File

@ -0,0 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .function_outputs_recorder import FunctionOutputsRecorder
from .method_outputs_recorder import MethodOutputsRecorder
from .module_inputs_recorder import ModuleInputsRecorder
from .module_outputs_recorder import ModuleOutputsRecorder
from .param_recorder import ParameterRecorder
from .recorder_manager import RecorderManager
__all__ = [
'FunctionOutputsRecorder', 'MethodOutputsRecorder',
'ModuleOutputsRecorder', 'ParameterRecorder', 'RecorderManager',
'ModuleInputsRecorder'
]

View File

@ -0,0 +1,116 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional
from torch import nn
class BaseRecorder(metaclass=ABCMeta):
"""Base class for recorders.
Recorder is a context manager used to record various intermediate results
during the model forward. It can be used in distillation algorithm and
can also be used to obtain some specific data for visual analysis.
In MMRazor, there will be different types of recorders to obtain different
types of intermediate results. They can be used in combination with the
``RecorderManager``.
Note:
The recorder will be lazily initialized in the ``RecorderManager`` by
default. If you want to use the recorder without the
``RecorderManager``, you need to initialize it first.
"""
def __init__(self, source: str) -> None:
self._source = source
# Intermediate results are recorded in dictionary format according
# to the data source.
# One data source may generate multiple records, which need to be
# recorded through list.
self._data_buffer: List = list()
# Before using the recorder for the first time, it needs to be
# initialized.
self._initialized = False
@property
def source(self) -> str:
"""str: source of recorded data."""
return self._source
@property
def data_buffer(self) -> List:
"""list: data buffer."""
return self._data_buffer
@abstractmethod
def prepare_from_model(self, model: Optional[nn.Module] = None) -> None:
"""Make the intermediate results of the model can be record."""
def initialize(self, model: Optional[nn.Module] = None) -> None:
"""Init the recorder.
Args:
model (nn.Module): The model which need to record intermediate
results.
"""
self.prepare_from_model(model)
self._initialized = True
def get_record_data(self,
record_idx: int = 0,
data_idx: Optional[int] = None) -> Any:
"""Get data from ``data_buffer``.
Args:
record_idx (int): The index of the record saved in
``data_buffer``. If a source is executed N times during
forward, there will be N records in ``data_buffer``.
data_index (int, optional): The index of target data in
a record. A record may be a tuple or a list, if data_idx is
None, the whole list or tuple is returned. Defaults to None.
Returns:
Any: The type of the return value is undefined, and different
source data may have different types.
"""
assert record_idx < len(self._data_buffer), \
'record_idx is illegal. The length of data_buffer is ' \
f'{len(self._data_buffer)}, but record_idx is ' \
f'{record_idx}.'
record = self._data_buffer[record_idx]
if data_idx is None:
target_data = record
else:
if isinstance(record, (list, tuple)):
assert data_idx < len(record), \
'data_idx is illegal. The length of record is ' \
f'{len(record)}, but data_idx is {data_idx}.'
target_data = record[data_idx]
else:
raise TypeError('When data_idx is not None, record should be '
'a list or tuple instance, but got '
f'{type(record)}.')
return target_data
def reset_data_buffer(self) -> None:
"""Clear data in data_buffer."""
self._data_buffer = list()
def __enter__(self):
"""Enter the context manager."""
assert self._initialized, \
'The recorder will be initialized in the RecorderManager by '\
'default. If you want to use the recorder without the '\
'RecorderManager, you need to initialize it first.'
self.reset_data_buffer()
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context manager."""

View File

@ -0,0 +1,161 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from types import FunctionType, ModuleType
from typing import Callable, List, Optional
from mmcv.utils import import_modules_from_strings
from torch import nn
from mmrazor.registry import TASK_UTILS
from .base_recorder import BaseRecorder
@TASK_UTILS.register_module()
class FunctionOutputsRecorder(BaseRecorder):
"""Recorder for intermediate results which are ``FunctionType``'s outputs.
Notes:
The form of `source` needs special attention. For example,
`anchor_inside_flags` is a function in mmdetection to check whether the
anchors are inside the border. This function is in
`mmdet/core/anchor/utils.py` and used in
`mmdet/models/dense_heads/anchor_head.py`. Then the source should be
`mmdet.models.dense_heads.anchor_head.anchor_inside_flags` but not
`mmdet.core.anchor.utils.anchor_inside_flags`.
Examples:
>>> # Below code in toy_module.py
>>> import random
>>> def toy_func():
... return random.randint(0, 1000)
>>> def toy_list_func():
... return [random.randint(0,1000) for _ in range(3)]
>>> # Below code in main.py
>>> # Now, we want to get teacher's outputs by recorder.
>>> import toy_module
>>> r1 = FunctionOutputsRecorder('toy_module.toy_func')
>>> r1.initialize()
>>> with r1:
... output_teacher1 = toy_module.toy_func()
... output_teacher2 = toy_module.toy_func()
... output_teacher3 = toy_module.toy_func()
>>> r1.data_buffer
[33, 41, 12]
>>> recorder.get_record_data(record_idx=2)
12
>>> output_teacher1==33 and output_teacher2==41 and output_teacher3==41
True
>>> r2 = FunctionOutputsRecorder('toy_module.toy_list_func')
>>> r2.initialize()
>>> with r2:
... output_teacher1 = toy_module.toy_list_func()
... output_teacher2 = toy_module.toy_list_func()
... output_teacher3 = toy_module.toy_list_func()
>>> r2.data_buffer
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
>>> r2.get_record_data(record_idx=2, data_idx=2)
9
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._check_valid_source(self.source)
# import the function corrosponding module
try:
mod = import_modules_from_strings(self.module_string)
except ImportError:
raise ImportError(
f'{self.module_string} is not imported correctly.')
self.imported_module: ModuleType = mod
assert hasattr(mod, self.func_name), \
f'{self.func_name} is not in {self.module_string}.'
origin_func = getattr(mod, self.func_name)
if not isinstance(origin_func, FunctionType):
raise TypeError(f'{self.func_name} should be a FunctionType '
f'instance, but got {type(origin_func)}')
self.origin_func: Callable = origin_func
@staticmethod
def _check_valid_source(source):
"""Check if the source's format is valid."""
if not isinstance(source, str):
raise TypeError(f'source should be a str '
f'instance, but got {type(source)}')
assert len(source.split('.')) > 1, \
'source must have at least one `.`'
@property
def func_name(self):
"""Get the function name according to `func_path`."""
return self.source.split('.')[-1]
@property
def module_string(self):
"""Get the module name according to `func_path`."""
return '.'.join(self.source.split('.')[:-1])
def prepare_from_model(self, model: Optional[nn.Module] = None) -> None:
"""The `model` is useless in `FunctionOutputsRecorder`."""
pass
def func_record_wrapper(self, origin_func: Callable,
data_buffer: List) -> Callable:
"""Save the function's outputs.
Args:
origin_func (FunctionType): The method whose outputs need to be
recorded.
buffer_key (str): The key of the function's outputs saved in
``data_buffer``.
"""
@functools.wraps(origin_func)
def wrap_func(*args, **kwargs):
outputs = origin_func(*args, **kwargs)
# assume a func execute N times, there will be N outputs need to
# save.
data_buffer.append(outputs)
return outputs
return wrap_func
def __enter__(self):
"""Enter the context manager."""
super().__enter__()
mod = self.imported_module
origin_func = self.origin_func
# add record wrapper to origin function.
record_func = self.func_record_wrapper(origin_func, self.data_buffer)
assert hasattr(mod, self.func_name), \
f'{self.func_name} is not in {self.module_string}.'
# rewrite the origin function
setattr(mod, self.func_name, record_func)
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context manager."""
super().__exit__(exc_type, exc_value, traceback)
mod = self.imported_module
origin_func = self.origin_func
assert hasattr(mod, self.func_name), \
f'{self.func_name} is not in {self.module_string}.'
# restore the origin function
setattr(mod, self.func_name, origin_func)

View File

@ -0,0 +1,168 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from types import FunctionType, ModuleType
from typing import Callable, List, Optional
from mmcv.utils import import_modules_from_strings
from torch import nn
from mmrazor.registry import TASK_UTILS
from .base_recorder import BaseRecorder
@TASK_UTILS.register_module()
class MethodOutputsRecorder(BaseRecorder):
"""Recorder for intermediate results which are ``MethodType``'s outputs.
Note:
Different from ``FunctionType``, ``MethodType`` is the type of methods
of class instances.
Examples:
>>> # Below code in toy_module.py
>>> import random
>>> class Toy():
... def toy_func(self):
... return random.randint(0, 1000)
... def toy_list_func(self):
... return [random.randint(0, 1000) for _ in range(3)]
>>> # Below code in main.py
>>> # Now, we want to get teacher's outputs by recorder.
>>> from toy_module import Toy
>>> toy = Toy()
>>> r1 = MethodOutputsRecorder('toy_module.Toy.toy_func')
>>> r1.initialize()
>>> with r1:
... output_teacher1 = toy.toy_func()
... output_teacher2 = toy.toy_func()
... output_teacher3 = toy.toy_func()
>>> r1.data_buffer
[33, 41, 12]
>>> r1.get_record_data(record_idx=2)
12
>>> output_teacher1==33 and output_teacher2==41 and output_teacher3==41
True
>>> r2 = MethodOutputsRecorder('toy_module.Toy.toy_list_func'
>>> r2.initialize()
>>> with r2:
... output_teacher1 = toy.toy_list_func()
... output_teacher2 = toy.toy_list_func()
... output_teacher3 = toy.toy_list_func()
>>> r2.data_buffer
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
>>> r2.get_record_data(record_idx=2, data_idx=2)
9
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._check_valid_source(self.source)
# import the function corrosponding module
try:
mod: ModuleType = import_modules_from_strings(self.module_string)
except ImportError:
raise ImportError(
f'{self.module_string} is not imported correctly.')
assert hasattr(mod, self.cls_name), \
f'{self.cls_name} is not in {self.module_string}.'
imported_cls: type = getattr(mod, self.cls_name)
if not isinstance(imported_cls, type):
raise TypeError(f'{self.cls_name} should be a type '
f'instance, but got {type(imported_cls)}')
self.imported_class = imported_cls
assert hasattr(imported_cls, self.method_name), \
f'{self.method_name} is not in {self.cls_name}.'
origin_method = getattr(imported_cls, self.method_name)
if not isinstance(origin_method, FunctionType):
raise TypeError(f'{self.method_name} should be a FunctionType '
f'instance, but got {type(origin_method)}')
self.origin_method = origin_method
@staticmethod
def _check_valid_source(source: str) -> None:
"""Check if the `source` is valid."""
if not isinstance(source, str):
raise TypeError(f'source should be a str '
f'instance, but got {type(source)}')
assert len(source.split('.')) > 2, \
'source must have at least two `.`'
@property
def method_name(self):
"""Get the method name according to `method_path`."""
return self.source.split('.')[-1]
@property
def cls_name(self):
"""Get the class name corresponding to this method according to
`method_path`."""
return self.source.split('.')[-2]
@property
def module_string(self):
"""Get the module name according to `method_path`."""
return '.'.join(self.source.split('.')[:-2])
def prepare_from_model(self, model: Optional[nn.Module] = None) -> None:
"""Wrapper the origin source methods.
The ``model`` is useless in this recorder, just to be consistent with
other recorders.
"""
pass
def method_record_wrapper(self, orgin_method: Callable,
data_buffer: List) -> Callable:
"""Save the method's outputs.
Args:
origin_method (MethodType): The method whose outputs need to be
recorded.
buffer_key (str): The key of the method's outputs saved in
``data_buffer``.
"""
@functools.wraps(orgin_method)
def wrap_method(*args, **kwargs):
outputs = orgin_method(*args, **kwargs)
# assume a func execute N times, there will be N outputs need to
# save.
data_buffer.append(outputs)
return outputs
return wrap_method
def __enter__(self):
"""Enter the context manager."""
super().__enter__()
imported_cls = self.imported_class
origin_method = self.origin_method
# add record wrapper to origin method.
record_method = self.method_record_wrapper(origin_method,
self.data_buffer)
# rewrite the origin method.
setattr(imported_cls, self.method_name, record_method)
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context manager."""
super().__exit__(exc_type, exc_value, traceback)
imported_cls = self.imported_class
origin_method = self.origin_method
# restore the origin method
setattr(imported_cls, self.method_name, origin_method)

View File

@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple
from torch import nn
from mmrazor.registry import TASK_UTILS
from .module_outputs_recorder import ModuleOutputsRecorder
@TASK_UTILS.register_module()
class ModuleInputsRecorder(ModuleOutputsRecorder):
"""Recorder for intermediate results which are Pytorch moudle's inputs."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward_hook(self, module: nn.Module, inputs: Tuple,
outputs: Any) -> None:
"""Save the module's forward input.
Args:
module (:obj:`torch.nn.Module`): The module to register hook.
inputs (tuple): The input of the module.
outputs : The output of the module.
"""
if self.recording:
self.data_buffer.append(inputs)

View File

@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Tuple
from torch import nn
from mmrazor.registry import TASK_UTILS
from .base_recorder import BaseRecorder
@TASK_UTILS.register_module()
class ModuleOutputsRecorder(BaseRecorder):
"""Recorder for intermediate results which are Pytorch moudle's outputs.
Examples:
>>> from torch import nn
>>> class ToyModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.conv1 = nn.Conv2d(8,8,1)
... self.conv2 = nn.Conv2d(1,1,1)
... def forward(self, x):
... x1 = self.conv1(x)
... x2 = self.conv1(x+1)
... return self.conv2(x1 + x2)
>>> model = ToyModel()
>>> [ name for name,_ in model.named_modules() ]
['conv1', 'conv2']
>>> r1 = ModuleOutputsRecorder('conv1')
>>> r1.initialize(model)
>>> with r1:
>>> res = model(torch.randn(1,1,1,1))
>>> r1.data_buffer
[tensor([[[[0.6734]]]]), tensor([[[[1.2514]]]]) ]
>>> r1.get_record_data(record_idx=1)
tensor([[[[1.2514]]]])
>>> r2 = ModuleOutputsRecorder('conv2')
>>> r2.initialize(model)
>>> with r2:
>>> res = model(torch.randn(1,1,1,1))
>>> r2.data_buffer
[tensor([[[[0.9534]]]])]
>>> r2.get_record_data()
tensor([[[[0.9534]]]])
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._recording = False
@property
def recording(self) -> bool:
"""bool: whether to record data in forward hook."""
return self._recording
def prepare_from_model(self, model: Optional[nn.Module] = None) -> None:
"""Register Pytorch forward hook to corresponding module."""
assert model is not None, 'model can not be None.'
founded = False
for name, module in model.named_modules():
if name == self.source:
module.register_forward_hook(self.forward_hook)
founded = True
break
assert founded, f'"{self.source}" is not in the model.'
def forward_hook(self, module: nn.Module, inputs: Tuple,
outputs: Any) -> None:
"""Save the module's forward output.
Args:
module (:obj:`torch.nn.Module`): The module to register hook.
inputs (tuple): The input of the module.
outputs : The output of the module.
"""
if self._recording:
self.data_buffer.append(outputs)
def __enter__(self):
"""Enter the context manager."""
super().__enter__()
self._recording = True
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context manager."""
super().__exit__(exc_type, exc_value, traceback)
self._recording = False

View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from torch import nn
from mmrazor.registry import TASK_UTILS
from .base_recorder import BaseRecorder
@TASK_UTILS.register_module()
class ParameterRecorder(BaseRecorder):
"""Recorder for Pytorch model's parameters.
Examples:
>>> from torch import nn
>>> class ToyModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.toy_conv = nn.Conv2d(1,1,1)
... def forward(self, x):
... return self.toy_conv(x)
>>> model = ToyModel()
>>> [ name for name,_ in model.named_parameters() ]
['toy_conv.weight', 'toy_conv.bias']
>>> recorder = ParameterRecorder('toy_conv.weight')
>>> recorder.initialize(model)
>>> recorder.data_buffer
[Parameter containing: tensor([[[[0.3244]]]], requires_grad=True)]
>>> recorder.get_record_data()
Parameter containing: tensor([[[[0.3244]]]], requires_grad=True)
"""
def prepare_from_model(self, model: Optional[nn.Module] = None) -> None:
"""Record the Pytorch model's parameters."""
assert model is not None, \
'model can not be None when use ParameterRecorder.'
founded = False
for param_name, param in model.named_parameters():
if param_name == self.source:
self.data_buffer.append(param)
founded = True
break
assert founded, f'"{self.source}" is not in the model.'
def reset_data_buffer(self):
"""Clear data in data_buffer.
Note:
The data_buffer stores the address of the parameter in memory and
does not need to be reset.
"""
pass

View File

@ -0,0 +1,117 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, Optional
from mmcv import ConfigDict
from torch import nn
from mmrazor.registry import TASK_UTILS
from .base_recorder import BaseRecorder
class RecorderManager:
"""Various types recorders' manager. The ``RecorderManager`` also is a
context manager, managing various types of Recorder. When entering the
``RecorderManager``, all recorders managed by it will be started.
Note:
The recorders will be initialized in the ``RecorderManager`` by
default. If you want to just use a recorder without the
``RecorderManager``, you need to initialize it first.
Args:
recorders (dict, optional): All recorders' config.
Examples:
>>> # Below code in toy_module.py
>>> import random
>>> class Toy():
... def toy_func(self):
... return random.randint(0, 1000)
>>> # Below code in main.py
>>> from torch import nn
>>> from toy_module import Toy
>>> class ToyModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.conv1 = nn.Conv2d(1,1,1)
... self.conv2 = nn.Conv2d(1,1,1)
... self.toy = Toy()
... def forward(self, x):
... return self.conv2(self.conv1(x)) + self.toy.toy_func()
>>> model = ToyModel()
>>> [ name for name,_ in model.named_modules() ]
['conv1', 'conv2']
>>> conv1_rec = ModuleOutputsRecorder('conv1')
>>> conv2_rec = ModuleOutputsRecorder('conv2')
>>> func_rec = MethodOutputsRecorder('toy_module.Toy.toy_func')
>>> manager = RecorderManager(
... {'conv1_rec': conv1_rec ,
... 'conv2_rec': conv2_rec,
... 'func_rec': func_rec})
>>> manager.initialize(model)
>>> with manager:
... res = model(torch.ones(1,1,1,1))
>>> res
tensor([[[[22.9534]]]])
>>> conv2_data = manager.get_recorder('conv2_rec').get_record_data()
>>> conv2_data
tensor([[[[0.9534]]]])
>>> func_data = manager.get_recorder('func_rec').get_record_data()
>>> func_data
22
>>> res.sum() == (conv2_data + func_data).sum()
True
"""
def __init__(self, recorders: Optional[ConfigDict] = None) -> None:
self._recorders: Dict[str, BaseRecorder] = dict()
if recorders:
for name, cfg in recorders.items():
recorder_cfg = copy.deepcopy(cfg)
recorder_type = cfg.type
recorder_type_ = recorder_type + 'Recorder'
recorder_cfg.type = recorder_type_
recorder = TASK_UTILS.build(recorder_cfg)
self._recorders[name] = recorder
@property
def recorders(self) -> Dict[str, BaseRecorder]:
"""dict: all recorders."""
return self._recorders
def get_recorder(self, recorder: str) -> BaseRecorder:
"""Get the corresponding recorder according to the name."""
return self.recorders[recorder]
def initialize(self, model: nn.Module):
"""Init all recorders.
Args:
model (nn.Module): The model which need to record intermediate
results.
"""
for recorder in self.recorders.values():
recorder.initialize(model)
def __enter__(self):
"""Enter the context manager."""
for recorder in self.recorders.values():
recorder.__enter__()
def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context manager."""
for recorder in self.recorders.values():
recorder.__exit__(exc_type, exc_value, traceback)

View File

@ -1,5 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseAlgorithm
from .distill import (ConfigurableDistill, FpnTeacherDistill,
SingleTeacherDistill)
from .nas import SPOS
__all__ = ['BaseAlgorithm', 'SPOS']
__all__ = [
'SingleTeacherDistill', 'ConfigurableDistill', 'BaseAlgorithm',
'FpnTeacherDistill', 'SPOS'
]

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .configurable import (ConfigurableDistill, FpnTeacherDistill,
SingleTeacherDistill)
__all__ = ['SingleTeacherDistill', 'ConfigurableDistill', 'FpnTeacherDistill']

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .configurable_distill import ConfigurableDistill
from .fpn_teacher_distill import FpnTeacherDistill
from .single_teacher_distill import SingleTeacherDistill
__all__ = [
'SingleTeacherDistill', 'FpnTeacherDistill', 'ConfigurableDistill',
'ConfigurableDistill'
]

View File

@ -0,0 +1,248 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from inspect import signature
from typing import Dict, List, Optional, Union
from mmengine.model import BaseModel
from torch import nn
from mmrazor.core import DistillDeliveryManager, RecorderManager
from mmrazor.registry import MODELS
from ...base import BaseAlgorithm, LossResults
@MODELS.register_module()
class ConfigurableDistill(BaseAlgorithm):
"""``ConfigurableDistill`` is a powerful tool that can reproduce most
distillation algorithms without modifying the code of teacher or student
models.
``ConfigurableDistill`` can get various intermediate results of the model
in a hacky way by ``Recorder``. More details see user-docs for ``Recorder``
``ConfigurableDistill`` can use the teacher's intermediate results to
override the student's intermediate results in a hacky way by ``Delivery``.
More details see user-docs for ``Delivery``.
Args:
architecture (dict | :obj:`BaseModel`): The config of
:class:`BaseModel` or built model.
student_recorders (dict, optional): Config for multiple recorders. A
student model may have more than one recorder. These recorders
only record the student model's intermediate results. Defaults to
None.
teacher_recorders (dict, optional): Config for multiple recorders. A
teacher model may have more than one recorder. These recorders
only record the teacher model's intermediate results. Defaults to
None.
distill_deliveries (dict, optional): Config for multiple deliveries. A
distill algorithm may have more than one delivery. Defaults to
None.
distill_losses: (Dict[str, Dict], optional): Config for multiple
distill losses. A distill algorithm may have more than one distill
loss. Defaults to None.
loss_forward_mappings: (Dict[str, Dict], optional): Mapping between
distill loss forward arguments and records.
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
pre-processing data sampled by dataloader to the format accepted by
:meth:`forward`.
init_cfg (dict, optional): Initialization config dict.
Note:
If a distill loss needs to backward, the name of the loss must contain
"loss". If it is only used as a statistical value, the name can not
contain "loss". More details see docs for
:func:`mmengine.model.BaseModel._parse_loss`.
Note:
The keys of ``loss_forward_mappings`` should be consistent with the
keys of ``distill_losses``.
Each item in ``loss_forward_mappings`` is a mapping between a distill
loss and its forward arguments. The keys of the mapping are the
signature of the loss's forward, and the values of the mapping are the
recorded data location.
``from_recorder``refers to the recorder where the data is stored, and
if ``from_student`` is True, it means the recorder is in `
`student_recorders``; otherwise, it means the recorder is in
``teacher_recorders``.
Examples:
>>> distill_losses = dict(
... loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5))
>>> student_recorders = dict(
... fc = dict(type='ModuleOutputs', sources=['head.fc']))
>>> teacher_recorders = dict(
... fc = dict(type='ModuleOutputs', sources=['head.fc']))
>>> loss_forward_mappings = dict(
... loss_kl=dict(
... preds_S=dict(from_recorder='fc', from_student=True),
... preds_T=dict(from_recorder='fc', from_student=False)))
"""
def __init__(self,
architecture: Union[BaseModel, Dict],
student_recorders: Optional[Dict[str, Dict]] = None,
teacher_recorders: Optional[Dict[str, Dict]] = None,
distill_deliveries: Optional[Dict[str, Dict]] = None,
distill_losses: Optional[Dict[str, Dict]] = None,
loss_forward_mappings: Optional[Dict[str, Dict]] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None):
super().__init__(architecture, data_preprocessor, init_cfg)
# The recorder manager is just constructed, but not really initialized
# yet. Recorder manager initialization needs to input the corresponding
# model.
# Different subclasses may have different teacher models, and it is
# inconvenient to initialize the recorder manager in
# ``ConfigurableDistll``.
# During the initialization of the subclass, need to execute
# `self.student_recorder_manager.initialize(student)` and
# `self.teacher_recorder_manager.initialize(teacher)` according to the
# corresponding student and teacher.
self.student_recorders = RecorderManager(student_recorders)
self.teacher_recorders = RecorderManager(teacher_recorders)
self.distill_deliveries = DistillDeliveryManager(distill_deliveries)
self.distill_losses = self.build_distill_losses(distill_losses)
if loss_forward_mappings:
# Check if loss_forward_mappings is in the correct format
self._check_loss_forward_mappings(self.distill_losses,
loss_forward_mappings,
self.student_recorders,
self.teacher_recorders)
self.loss_forward_mappings = loss_forward_mappings
else:
self.loss_forward_mappings = dict()
@property
def student(self) -> BaseModel:
"""Alias for ``architecture``."""
return self.architecture
def build_distill_losses(
self,
losses: Optional[Dict[str, Dict]] = None,
) -> nn.ModuleDict:
"""build distill losses according config."""
distill_losses = nn.ModuleDict()
if losses:
for loss_name, loss_cfg in losses.items():
assert loss_name not in distill_losses
if 'loss' not in loss_name:
warnings.warn(
f'Warning: If {loss_name} is a loss that needs to '
f'backward, the name of {loss_name} must contain '
f'"loss". If it is only used as a statistical value, '
'then the name must not contain "loss". More details '
'see docs for '
':func:`mmengine.model.BaseModel._parse_loss`',
UserWarning)
item_loss = MODELS.build(loss_cfg)
distill_losses[loss_name] = item_loss
return distill_losses
def get_record(self,
recorder: str,
from_student: bool,
record_idx: int = 0,
data_idx: Optional[int] = None) -> List:
"""According to each item in ``record_infos``, get the corresponding
record in ``recorder_manager``."""
if from_student:
recorder_ = self.student_recorders.get_recorder(recorder)
else:
recorder_ = self.teacher_recorders.get_recorder(recorder)
return recorder_.get_record_data(record_idx, data_idx)
def compute_distill_losses(
self,
distill_losses: nn.ModuleDict,
loss_forward_mappings: Dict[str, Dict],
student_recorders: RecorderManager,
teacher_recorders: RecorderManager,
) -> LossResults:
"""Compute distill losses automatically."""
# Record all computed losses' results.
losses = dict()
for loss_name, forward_mappings in loss_forward_mappings.items():
forward_kwargs = dict()
for forward_key, record_info in forward_mappings.items():
forward_var = self.get_record(**record_info)
forward_kwargs[forward_key] = forward_var
loss_module = distill_losses[loss_name]
loss = loss_module(**forward_kwargs) # type: ignore
# add computed loss result.
losses[loss_name] = loss
return losses
def _check_loss_forward_mappings(
self, losses: nn.ModuleDict, loss_forward_mappings: Dict[str,
Dict],
student_recorders: RecorderManager,
teacher_recorders: RecorderManager) -> None:
"""Check if ``loss_forward_mappings`` is in the correct format."""
if not isinstance(loss_forward_mappings, dict):
raise TypeError(
'loss_forward_mappings should be a dict instance, but got'
f'{type(loss_forward_mappings)}')
for loss_name, forward_mappings in loss_forward_mappings.items():
assert loss_name in losses, \
f'"{loss_name}" is not in distill losses. The keys of ' \
'loss_forward_kwargs must match the keys of distill_losses.'
if not isinstance(forward_mappings, dict):
raise TypeError(
'Each item of loss_forward_mappings should be a dict '
f'instance, but got {type(forward_mappings)}')
loss_module = losses[loss_name]
loss_forward_keys = signature(
loss_module.forward).parameters.keys()
assert len(loss_forward_keys) == len(forward_mappings.keys())
for forward_key, record_info in forward_mappings.items():
assert forward_key in loss_forward_keys, \
f'{forward_key} is not in the signature of \
{type(loss_module).__name__} forward, \
please check your config.'
assert 'recorder' in record_info, \
'Each item of loss_forward_mappings should have ' \
'"recorder", pls check your config.'
assert 'from_student' in record_info, \
'Each item of loss_forward_mappings should have ' \
'"from_student", pls check your config.'
recorder: str = record_info['recorder']
from_student: bool = record_info['from_student']
if not isinstance(from_student, bool):
raise TypeError(f'from_student should be a bool instance, '
f'but got {type(from_student)}')
if from_student:
assert recorder in self.student_recorders.recorders, \
f'For {forward_key}, "{recorder}" must be in \
`student_recorders`.'
else:
assert recorder in self.teacher_recorders.recorders, \
f'For {forward_key}, "{recorder}" must be in \
`teacher_recorders`.'

View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
from mmengine import BaseDataElement
from mmrazor.models.utils import add_prefix
from mmrazor.registry import MODELS
from ...base import LossResults
from .single_teacher_distill import SingleTeacherDistill
@MODELS.register_module()
class FpnTeacherDistill(SingleTeacherDistill):
"""``FpnTeacherDistill`` means teacher only execute backbone and neck.
If the intermediate results required for distill algorithm are generated by
the backbone and neck parts, using ``FpnTeacherDistill`` can speed up
training.
"""
def loss(
self,
batch_inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
) -> LossResults:
"""Calculate losses from a batch of inputs and data samples."""
losses = dict()
# If the `override_data` of a delivery is False, the delivery will
# record the origin data.
self.delivery_manager.override_data = False
if self.teacher_trainable:
# Unlike ``SingleTeacherDistill``, teacher will only execute
# back + neck, not head, so there will be no loss.
with self.teacher_recorders, self.delivery_manager:
_ = self.teacher.extract_feat(batch_inputs)
else:
with self.teacher_recorders, self.distill_deliveries:
with torch.no_grad():
_ = self.teacher(batch_inputs, data_samples, mode='loss')
# If the `override_data` of a delivery is True, the delivery will
# override the origin data with the recorded data.
self.delivery_manager.override_data = True
with self.student_recorders, self.delivery_manager:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
# Automatically compute distill losses based on `loss_forward_mappings`
distill_losses = self.compute_distill_losses(
self.distill_losses, self.loss_forward_mappings,
self.student_recorders, self.teacher_recorders)
losses.update(add_prefix(distill_losses, 'distill'))
return losses

View File

@ -0,0 +1,107 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
from mmcv.runner import load_checkpoint
from mmengine import BaseDataElement
from mmengine.model import BaseModel
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.models.utils import add_prefix
from mmrazor.registry import MODELS
from ...base import LossResults
from .configurable_distill import ConfigurableDistill
@MODELS.register_module()
class SingleTeacherDistill(ConfigurableDistill):
"""``SingleTeacherDistill`` can be used to develop distill algorithms which
only use one teacher.
Args:
teacher (dict | BaseModel): The config dict for teacher model or built
teacher model.
teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None.
teacher_trainable (bool): Whether the teacher is trainable. Defaults
to False.
teacher_norm_eval (bool): Whether to set teacher's norm layers to eval
mode, namely, freeze running stats (mean and var). Note: Effect on
Batch Norm and its variants only. Defaults to True.
"""
def __init__(self,
teacher: Union[BaseModel, Dict],
teacher_ckpt: Optional[str] = None,
teacher_trainable: bool = False,
teacher_norm_eval: bool = True,
**kwargs):
super().__init__(**kwargs)
if isinstance(teacher, Dict):
teacher = MODELS.build(teacher)
if not isinstance(teacher, BaseModel):
raise TypeError('teacher should be a `dict` or '
f'`BaseModel` instance, but got '
f'{type(teacher)}')
self.teacher = teacher
if teacher_ckpt:
# avoid loaded parameters be overwritten
self.teacher.init_weights()
_ = load_checkpoint(self.teacher, teacher_ckpt)
self.teacher_trainable = teacher_trainable
self.teacher_norm_eval = teacher_norm_eval
# In ``ConfigurableDistll``, the recorder manager is just constructed,
# but not really initialized yet.
self.student_recorders.initialize(self.student)
self.teacher_recorders.initialize(self.teacher)
def loss(
self,
batch_inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
) -> LossResults:
"""Calculate losses from a batch of inputs and data samples."""
losses = dict()
# If the `override_data` of a delivery is False, the delivery will
# record the origin data.
self.distill_deliveries.override_data = False
if self.teacher_trainable:
with self.teacher_recorders, self.distill_deliveries:
teacher_losses = self.teacher(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(teacher_losses, 'teacher'))
else:
with self.teacher_recorders, self.distill_deliveries:
with torch.no_grad():
_ = self.teacher(batch_inputs, data_samples, mode='loss')
# If the `override_data` of a delivery is True, the delivery will
# override the origin data with the recorded data.
self.distill_deliveries.override_data = True
with self.student_recorders, self.distill_deliveries:
student_losses = self.student(
batch_inputs, data_samples, mode='loss')
losses.update(add_prefix(student_losses, 'student'))
# Automatically compute distill losses based on `loss_forward_mappings`
distill_losses = self.compute_distill_losses(
self.distill_losses, self.loss_forward_mappings,
self.student_recorders, self.teacher_recorders)
losses.update(add_prefix(distill_losses, 'distill'))
return losses
def train(self, mode=True):
"""Set distiller's forward mode."""
super().train(mode)
if mode and self.teacher_norm_eval:
for m in self.teacher.modules():
if isinstance(m, _BatchNorm):
m.eval()

View File

@ -1,5 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .self_distiller import SelfDistiller
from .single_teacher import SingleTeacherDistiller
__all__ = ['SelfDistiller', 'SingleTeacherDistiller']

View File

@ -1,155 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from mmcv.runner import BaseModule
from mmcv.utils import import_modules_from_strings
def function_wrapper(ctx, method, method_str):
"""Pass teacher's outputs to student."""
def wrapper(*args, **kwargs):
# record inputs
ctx.method_args[method_str] = args
ctx.method_kwargs[method_str] = kwargs
# TODO cover more usecases, not only pass teacher's outputs to
# student.
if ctx.is_teacher:
# execute the raw function
outputs = method(*args, **kwargs)
# record outputs
ctx.method_return[method_str] = outputs
else:
# modify student's outputs to be same with teacher
outputs = ctx.method_return[method_str]
return outputs
return wrapper
class FunctionContext():
"""Function context manager for rewrite function.
Args:
ctx (ConversionContext): The distiller's overall context manager.
method (str): The name of the function to rewrite.
"""
def __init__(self, ctx, method, import_module=None):
self.ctx = ctx
self.import_module = import_modules_from_strings(import_module)
self.method_str = method
self.method_exec_str = f'self.import_module.{method}'
def _set_method(self, method):
"""Modify a function."""
exec(f'{self.method_exec_str} = method')
def __enter__(self):
"""Rewrite the function."""
self.method_impl = eval(self.method_exec_str)
if self.method_impl:
self._set_method(
function_wrapper(self.ctx, self.method_impl, self.method_str,
self.align_mode))
def __exit__(self, exc_type, exc_value, traceback):
"""Restore the function."""
if self.method_impl:
self._set_method(self.method_impl)
class ConversionContext():
"""Context manager for record functions' inputs or outputs."""
def __init__(self, hooks):
# save functions' inputs
self.method_args = dict()
self.method_kwargs = dict()
# save functions' outputs
self.method_return = dict()
# Each function will have a sub context manager, the function will be
# rewritten when enter the sub context manager.
self.hooks = []
self.is_teacher = True
for hook in hooks:
self.hooks.append(FunctionContext(self, **hook))
def __enter__(self):
"""Enter every sub context managers."""
for hook in self.hooks:
hook.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Exit every sub context managers."""
for hook in self.hooks:
hook.__exit__(exc_type, exc_value, traceback)
class BaseDistiller(BaseModule, metaclass=ABCMeta):
"""Base Distiller.
In the distillation algorithm, some intermediate results of the teacher
need to be obtained and passed to the student.
For nn.Module's outputs, obtained by pytorch forward hook.
For python function's outputs, obtained by a specific context manager.
Args:
align_methods (dict): The details of the functions which outputs need
to be obtained.
"""
def __init__(self, align_methods=None, **kwargs):
super(BaseDistiller, self).__init__(**kwargs)
if align_methods is None:
self.context_manager = None
else:
# To obtain the python function's outputs, there will build a
# specific context manager. When enter the context manager, the
# functions will be rewrite. The context manager could record
# inputs or outputs of the functions , and pass from teachr to
# student. When exit the context manager, the rewritten functions
# will restore.
self.context_manager = ConversionContext(align_methods)
@abstractmethod
def prepare_from_student(self, student):
"""Register forward hooks to students and teachers."""
pass
@abstractmethod
def teacher_forward_output_hook(self, module, inputs, outputs):
"""Save the teacher output."""
pass
@abstractmethod
def student_forward_output_hook(self, module, inputs, outputs):
"""Save the student output."""
pass
def reset_ctx_teacher_mode(self, mode=True):
if self.context_manager is not None:
self.context_manager.is_teacher = mode
@abstractmethod
def exec_teacher_forward(self, data):
"""Execute the teacher's forward function."""
pass
@abstractmethod
def exec_student_forward(self, student, data):
"""Execute the student's forward function."""
pass
@abstractmethod
def compute_distill_loss(self, data):
"""Compute distill loss according teacher's outputs and student's
outputs."""
pass

View File

@ -1,141 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmrazor.registry import MODELS
from .base import BaseDistiller
@MODELS.register_module()
class SelfDistiller(BaseDistiller):
"""Transfer knowledge inside a single model.
Args:
components (dict): The details of the distillation. It usually includes
the module names of the teacher and the student, and the losses
used in the distillation.
"""
def __init__(self, components, **kwargs):
super().__init__(**kwargs)
self.components = components
self.losses = nn.ModuleDict()
self.student_outputs = dict()
self.teacher_outputs = dict()
for component in self.components:
student_module_name = component['student_module']
teacher_module_name = component['teacher_module']
self.student_outputs[student_module_name] = list()
self.teacher_outputs[teacher_module_name] = list()
for loss in component.losses:
loss_cfg = loss.copy()
loss_name = loss_cfg.pop('name')
self.losses[loss_name] = MODELS.build(loss_cfg)
def prepare_from_student(self, student):
"""Registers a global forward hook for each teacher module and student
module to be used in the distillation.
Args:
student (:obj:`torch.nn.Module`): The student model to be used
in the distillation.
"""
self.module2name = {}
for name, module in student.model.named_modules():
self.module2name[module] = name
self.name_modules = dict(student.model.named_modules())
for component in self.components:
student_module_name = component['student_module']
teacher_module_name = component['teacher_module']
student_module = self.name_modules[student_module_name]
teacher_module = self.name_modules[teacher_module_name]
student_module.register_forward_hook(
self.student_forward_output_hook)
teacher_module.register_forward_hook(
self.teacher_forward_output_hook)
def teacher_forward_output_hook(self, module, inputs, outputs):
"""Save the output.
Args:
module (:obj:`torch.nn.Module`): the module of register hook
inputs (tuple): input of module
outputs (tuple): out of module
"""
if self.training and getattr(self, 'is_teacher', None):
self.teacher_outputs[self.module2name[module]].append(outputs)
def student_forward_output_hook(self, module, inputs, outputs):
"""Save the output.
Args:
module (:obj:`torch.nn.Module`): the module of register hook
inputs (tuple): input of module
outputs (tuple): out of module
"""
if self.training and not getattr(self, 'is_teacher', None):
self.student_outputs[self.module2name[module]].append(outputs)
def reset_outputs(self, outputs):
"""Reset the teacher's outputs or student's outputs."""
for key in outputs.keys():
outputs[key] = list()
def exec_teacher_forward(self, teacher, data):
"""Forward computation of the teacher.
Args:
teacher (:obj:`torch.nn.Module`): The teacher model to be used
in the distillation.
data (dict): The output of dataloader.
"""
self.reset_outputs(self.teacher_outputs)
self.is_teacher = True
output = teacher(**data)
self.is_teacher = False
return output
def exec_student_forward(self, student, data):
"""Forward computation of the student.
Args:
student (:obj:`torch.nn.Module`): The student model to be used
in the distillation.
data (dict): The output of dataloader.
"""
assert not self.is_teacher
self.reset_outputs(self.student_outputs)
output = student(**data)
return output
def compute_distill_loss(self, data):
"""Compute the distillation loss."""
losses = dict()
for i, component in enumerate(self.components):
student_module_name = component['student_module']
student_outputs = self.student_outputs[student_module_name]
teacher_module_name = component['teacher_module']
teacher_outputs = self.teacher_outputs[teacher_module_name]
for out_idx, (s_out, t_out) in enumerate(
zip(student_outputs, teacher_outputs)):
for loss in component.losses:
loss_module = self.losses[loss.name]
loss_name = f'{loss.name}.{out_idx}'
loss_module.current_data = data
losses[loss_name] = loss_module(s_out, t_out)
loss_module.current_data = None
return losses

View File

@ -1,243 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.registry import MODELS
from .base import BaseDistiller
@MODELS.register_module()
class SingleTeacherDistiller(BaseDistiller):
"""Distiller with single teacher.
Args:
teacher (dict): The config dict for teacher.
teacher_trainable (bool): Whether the teacher is trainable.
Default: False.
teacher_norm_eval (bool): Whether to set teacher's norm layers to eval
mode, namely, freeze running stats (mean and var). Note: Effect on
Batch Norm and its variants only. Default: True.
components (dict): The details of the distillation. It usually includes
the module names of the teacher and the student, and the losses
used in the distillation.
"""
def __init__(self,
teacher,
teacher_trainable=False,
teacher_norm_eval=True,
components=tuple(),
**kwargs):
super().__init__(**kwargs)
self.teacher_trainable = teacher_trainable
self.teacher_norm_eval = teacher_norm_eval
self.teacher = self.build_teacher(teacher)
self.components = components
self.losses = nn.ModuleDict()
self.align_modules = nn.ModuleDict()
# Record the featuremaps that need to calculate the distillation loss.
self.student_outputs = dict()
self.teacher_outputs = dict()
for i, component in enumerate(self.components):
student_module_name = component['student_module']
teacher_module_name = component['teacher_module']
# The type of every student_output is a list by default, because
# some modules will execute multiple forward calculations, such as
# the shareable head in Retinanet
self.student_outputs[student_module_name] = list()
self.teacher_outputs[teacher_module_name] = list()
# If the number of featuremap channels of student and teacher are
# inconsistent, they need to be aligned by a 1x1 convolution
align_module_cfg = getattr(component, 'align_module', None)
if align_module_cfg is not None:
align_module_name = f'component_{i}'
align_module = self.build_align_module(align_module_cfg)
self.align_modules[align_module_name] = align_module
# Multiple losses can be calculated at the same location
for loss in component.losses:
loss_cfg = loss.copy()
loss_name = loss_cfg.pop('name')
self.losses[loss_name] = MODELS.build(loss_cfg)
def build_teacher(self, cfg):
"""Build a model from the `cfg`."""
teacher = MODELS.build(cfg)
return teacher
def build_align_module(self, cfg):
"""Build ``align_module`` from the `cfg`.
``align_module`` is needed when the number of channels output by the
teacher module is not equal to that of the student module, or for some
other reasons.
Args:
cfg (dict): The config dict for ``align_module``.
"""
in_channels = cfg.student_channels
out_channels = cfg.teacher_channels
if cfg.type == 'conv2d':
align_module = nn.Conv2d(in_channels, out_channels, 1)
elif cfg.type == 'linear':
align_module = nn.Linear(in_channels, out_channels)
return align_module
def prepare_from_student(self, student):
"""Registers a global forward hook for each teacher module and student
module to be used in the distillation.
Args:
student (:obj:`torch.nn.Module`): The student model to be used
in the distillation.
"""
# Record the mapping relationship between student's modules and module
# names.
self.student_module2name = {}
for name, module in student.model.named_modules():
self.student_module2name[module] = name
self.student_name2module = dict(student.model.named_modules())
# Record the mapping relationship between teacher's modules and module
# names.
self.teacher_module2name = {}
for name, module in self.teacher.named_modules():
self.teacher_module2name[module] = name
self.teacher_name2module = dict(self.teacher.named_modules())
# Register forward hooks for modules that need to participate in loss
# calculation.
for component in self.components:
student_module_name = component['student_module']
teacher_module_name = component['teacher_module']
student_module = self.student_name2module[student_module_name]
teacher_module = self.teacher_name2module[teacher_module_name]
student_module.register_forward_hook(
self.student_forward_output_hook)
teacher_module.register_forward_hook(
self.teacher_forward_output_hook)
def teacher_forward_output_hook(self, module, inputs, outputs):
"""Save the module's forward output.
Args:
module (:obj:`torch.nn.Module`): The module to register hook.
inputs (tuple): The input of the module.
outputs (tuple): The output of the module.
"""
if self.training:
self.teacher_outputs[self.teacher_module2name[module]].append(
outputs)
def student_forward_output_hook(self, module, inputs, outputs):
"""Save the module's forward output.
Args:
module (:obj:`torch.nn.Module`): The module to register hook.
inputs (tuple): The input of the module.
outputs (tuple): The output of the module.
"""
if self.training:
self.student_outputs[self.student_module2name[module]].append(
outputs)
def reset_outputs(self, outputs):
"""Reset the teacher's outputs or student's outputs."""
for key in outputs.keys():
outputs[key] = list()
def exec_teacher_forward(self, data):
"""Execute the teacher's forward function.
After this function, the teacher's featuremaps will be saved in
``teacher_outputs``.
"""
# Convert the context manager's mode to teacher.
self.reset_ctx_teacher_mode(True)
# Clear the saved data of the last forward。
self.reset_outputs(self.teacher_outputs)
if self.teacher_trainable:
output = self.teacher(**data)
else:
with torch.no_grad():
output = self.teacher(**data)
return output
def exec_student_forward(self, student, data):
"""Execute the teacher's forward function.
After this function, the student's featuremaps will be saved in
``student_outputs``.
"""
# Convert the context manager's mode to teacher.
self.reset_ctx_teacher_mode(False)
# Clear the saved data of the last forward。
self.reset_outputs(self.student_outputs)
output = student(**data)
return output
def train(self, mode=True):
"""Set distiller's forward mode."""
super(SingleTeacherDistiller, self).train(mode)
if mode and self.teacher_norm_eval:
for m in self.teacher.modules():
if isinstance(m, _BatchNorm):
m.eval()
def get_teacher_outputs(self, teacher_module_name):
"""Get the outputs according module name."""
return self.teacher_outputs[teacher_module_name]
def compute_distill_loss(self, data=None):
"""Compute the distillation loss."""
losses = dict()
for i, component in enumerate(self.components):
# Get the student's outputs.
student_module_name = component['student_module']
student_outputs = self.student_outputs[student_module_name]
# Align student output's channels with teacher.
align_module_name = f'component_{i}'
if align_module_name in self.align_modules:
align_module = self.align_modules[align_module_name]
student_outputs = [
align_module(s_out) for s_out in student_outputs
]
# Get the teacher's outputs.
teacher_module_name = component['teacher_module']
teacher_outputs = self.get_teacher_outputs(teacher_module_name)
# One module maybe have N outputs, such as the shareable head in
# RetinaNet.
for out_idx, (s_out, t_out) in enumerate(
zip(student_outputs, teacher_outputs)):
for loss in component.losses:
loss_module = self.losses[loss.name]
loss_name = f'{loss.name}.{out_idx}'
# TODO ugly implementation.
# Pass the gt_label to loss function.
# Only used by WSLD.
loss_module.current_data = data
losses[loss_name] = loss_module(s_out, t_out)
loss_module.current_data = None
return losses

View File

@ -26,9 +26,9 @@ class KLDivergence(nn.Module):
def __init__(
self,
tau=1.0,
reduction='batchmean',
loss_weight=1.0,
tau: float = 1.0,
reduction: str = 'batchmean',
loss_weight: float = 1.0,
):
super(KLDivergence, self).__init__()
self.tau = tau

View File

@ -24,12 +24,20 @@ class WSLD(nn.Module):
self.tau = tau
self.loss_weight = loss_weight
self.num_classes = num_classes
self.softmax = nn.Softmax(dim=1).cuda()
self.logsoftmax = nn.LogSoftmax(dim=1).cuda()
self.softmax = nn.Softmax(dim=1)
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, student, teacher):
def forward(self, student, teacher, data_samples):
gt_labels = self.current_data['gt_label']
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
# Batch augmentation may convert labels to one-hot format scores.
gt_labels = torch.stack([i.gt_label.score for i in data_samples])
one_hot_labels = gt_labels.float()
else:
gt_labels = torch.hstack([i.gt_label.label for i in data_samples])
one_hot_labels = F.one_hot(
gt_labels, num_classes=self.num_classes).float()
student_logits = student / self.tau
teacher_logits = teacher / self.tau
@ -49,7 +57,7 @@ class WSLD(nn.Module):
ce_loss_t = -torch.sum(one_hot_labels * log_softmax_t, 1, keepdim=True)
focal_weight = ce_loss_s / (ce_loss_t + 1e-7)
ratio_lower = torch.zeros(1).cuda()
ratio_lower = torch.zeros_like(focal_weight)
focal_weight = torch.max(focal_weight, ratio_lower)
focal_weight = 1 - torch.exp(-focal_weight)
ce_loss = focal_weight * ce_loss

View File

@ -38,10 +38,9 @@ def build_razor_model_from_cfg(
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
# TODO relay on mmengine:HAOCHENYE/config_new_feature
# if cfg.get('cfg_path', None) and not cfg.get('type', None):
# from mmengine.config import get_model
# teacher = get_model(**cfg)
# return teacher
if cfg.get('cfg_path', None) and not cfg.get('type', None):
from mmengine.config import get_model
return get_model(**cfg)
if cfg.get('_fix_subnet_', None):
fix_subnet = cfg.pop('_fix_subnet_')

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Sequence, Union
import torch
from mmengine.evaluator import Evaluator
from mmengine.runner import ValLoop
from mmengine.runner import ValLoop, autocast
from torch.utils.data import DataLoader
from mmrazor.registry import LOOPS
@ -19,16 +19,29 @@ class SingleTeacherDistillValLoop(ValLoop):
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to
False.
"""
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List]) -> None:
super().__init__(runner, dataloader, evaluator)
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False) -> None:
super().__init__(runner, dataloader, evaluator, fp16)
if self.runner.distributed:
self.model = runner.model.module
assert hasattr(self.runner.model.module, 'teacher')
# TODO: remove hard code after mmcls add data_preprocessor
data_preprocessor = self.runner.model.module.data_preprocessor
self.teacher = self.runner.model.module.teacher
self.teacher.data_preprocessor = data_preprocessor
else:
self.model = runner.model
assert hasattr(self.model, 'teacher')
assert hasattr(self.runner.model, 'teacher')
# TODO: remove hard code after mmcls add data_preprocessor
data_preprocessor = self.runner.model.data_preprocessor
self.teacher = self.runner.model.teacher
self.teacher.data_preprocessor = data_preprocessor
def run(self):
"""Launch validation."""
@ -38,19 +51,32 @@ class SingleTeacherDistillValLoop(ValLoop):
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute metrics
metrics_s = self.evaluator.evaluate(len(self.dataloader.dataset))
for key, value in metrics_s.items():
self.runner.message_hub.update_scalar(f'val_student/{key}', value)
# compute student metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
student_metrics = dict()
for key, value in metrics.items():
student_key = 'student.' + key
teacher_key = 'teacher.' + key
student_metrics[student_key] = value
self.runner.message_hub.log_scalars.pop(f'val/{teacher_key}', None)
self.runner.call_hook('after_val_epoch', metrics=student_metrics)
self.runner.call_hook('before_val_epoch')
for idx, data_batch in enumerate(self.dataloader):
self.run_iter_teacher(idx, data_batch)
# compute metrics
metrics_t = self.evaluator.evaluate(len(self.dataloader.dataset))
for key, value in metrics_t.items():
self.runner.message_hub.update_scalar(f'val_teacher/{key}', value)
# compute teacher metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
teacher_metrics = dict()
for key, value in metrics.items():
student_key = 'student.' + key
teacher_key = 'teacher.' + key
self.runner.call_hook('after_val_epoch', metrics=None)
teacher_metrics[teacher_key] = value
self.runner.message_hub.log_scalars.pop(f'val/{student_key}', None)
self.runner.call_hook('after_val_epoch', metrics=teacher_metrics)
self.runner.call_hook('after_val')
@torch.no_grad()
@ -63,8 +89,11 @@ class SingleTeacherDistillValLoop(ValLoop):
"""
self.runner.call_hook(
'before_val_iter', batch_idx=idx, data_batch=data_batch)
# outputs should be sequence of BaseDataElement
outputs = self.model.teacher(data_batch)
with autocast(enabled=self.fp16):
# outputs should be sequence of BaseDataElement
outputs = self.teacher.val_step(data_batch)
self.evaluator.process(data_batch, outputs)
self.runner.call_hook(
'after_val_iter',

View File

@ -62,9 +62,7 @@ def register_all_modules(init_default_scope: bool = True) -> None:
""" # noqa
import mmrazor.core # noqa: F401,F403
import mmrazor.models # noqa: F401,F403
# TODO add mmrazor.runners
import mmrazor.runners # noqa: F401,F403
if init_default_scope:
never_created = DefaultScope.get_current_instance() is None \
or not DefaultScope.check_instance_created('mmrazor')

View File

@ -3,24 +3,38 @@ from unittest import TestCase
from mmcv import ConfigDict
from mmrazor.core import DistillDeliverManager
from mmrazor.core import DistillDeliveryManager
class TestDeliverManager(TestCase):
def test_init(self):
distill_deliveries = ConfigDict(
delivery1=dict(
type='MethodOutputs',
max_keep_data=2,
method_path='toy_module.ToyClass.random_int'))
manager = DistillDeliveryManager(distill_deliveries)
self.assertEquals(len(manager.deliveries), 1)
manager = DistillDeliveryManager()
self.assertEquals(len(manager.deliveries), 0)
def test_context_manager(self):
from toy_module import ToyClass
distill_deliveries = [
ConfigDict(
distill_deliveries = ConfigDict(
delivery1=dict(
type='MethodOutputs',
max_keep_data=2,
method_path='toy_module.ToyClass.random_int')
]
method_path='toy_module.ToyClass.random_int'))
manager = DistillDeliverManager(distill_deliveries)
manager = DistillDeliveryManager(distill_deliveries)
manager.override_data = False
self.assertFalse(manager.override_data)
with manager:
toy_class = ToyClass()
output1_tea = toy_class.random_int()
@ -30,7 +44,9 @@ class TestDeliverManager(TestCase):
with manager:
_ = toy_class.random_int()
self.assertFalse(manager.override_data)
manager.override_data = True
self.assertTrue(manager.override_data)
with manager:
output1_stu = toy_class.random_int()
output2_stu = toy_class.random_int()

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from mmrazor.core import FunctionOutputsDeliver
from mmrazor.core import FunctionOutputsDelivery
class TestFuncOutputsDeliver(TestCase):
@ -9,26 +9,26 @@ class TestFuncOutputsDeliver(TestCase):
def test_init(self):
with self.assertRaisesRegex(TypeError, 'func_path should be'):
_ = FunctionOutputsDeliver(max_keep_data=1, func_path=1)
_ = FunctionOutputsDelivery(max_keep_data=1, func_path=1)
with self.assertRaisesRegex(AssertionError, 'func_path must have at '):
_ = FunctionOutputsDeliver(max_keep_data=1, func_path='toy_func')
_ = FunctionOutputsDelivery(max_keep_data=1, func_path='toy_func')
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
_ = FunctionOutputsDeliver(max_keep_data=1, func_path='aaa.bb')
_ = FunctionOutputsDelivery(max_keep_data=1, func_path='aaa.bb')
with self.assertRaisesRegex(AssertionError, 'bb is not in toy_mod'):
_ = FunctionOutputsDeliver(
_ = FunctionOutputsDelivery(
max_keep_data=1, func_path='toy_module.bb')
with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'):
_ = FunctionOutputsDeliver(
_ = FunctionOutputsDelivery(
max_keep_data=1, func_path='toy_module.TOY_VAR')
def test_context_manager(self):
import toy_module
delivery = FunctionOutputsDeliver(
delivery = FunctionOutputsDelivery(
max_keep_data=2, func_path='toy_module.toy_func')
delivery.override_data = False

View File

@ -1,45 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from mmrazor.core import MethodOutputsDeliver
from mmrazor.core import MethodOutputsDelivery
class TestMethodOutputsDeliver(TestCase):
def test_init(self):
with self.assertRaisesRegex(TypeError, 'method_path should be'):
_ = MethodOutputsDeliver(max_keep_data=1, method_path=1)
_ = MethodOutputsDelivery(max_keep_data=1, method_path=1)
with self.assertRaisesRegex(AssertionError,
'method_path must have at '):
_ = MethodOutputsDeliver(max_keep_data=1, method_path='toy_func')
_ = MethodOutputsDelivery(max_keep_data=1, method_path='toy_func')
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
_ = MethodOutputsDeliver(max_keep_data=1, method_path='aaa.bb.b')
_ = MethodOutputsDelivery(max_keep_data=1, method_path='aaa.bb.b')
with self.assertRaisesRegex(AssertionError, 'bb is not in toy_module'):
_ = MethodOutputsDeliver(
_ = MethodOutputsDelivery(
max_keep_data=1, method_path='toy_module.bb.bbb')
with self.assertRaisesRegex(TypeError, 'toy_func should be a type'):
_ = MethodOutputsDeliver(
_ = MethodOutputsDelivery(
max_keep_data=1, method_path='toy_module.toy_func.bbb')
with self.assertRaisesRegex(AssertionError, 'bbb is not in'):
_ = MethodOutputsDeliver(
_ = MethodOutputsDelivery(
max_keep_data=1, method_path='toy_module.ToyClass.bbb')
with self.assertRaisesRegex(TypeError, 'count should be'):
_ = MethodOutputsDeliver(
_ = MethodOutputsDelivery(
max_keep_data=1, method_path='toy_module.ToyClass.count')
def test_context_manager(self):
from toy_module import ToyClass
delivery = MethodOutputsDeliver(
delivery = MethodOutputsDelivery(
max_keep_data=2, method_path='toy_module.ToyClass.random_int')
# Without ``MethodOutputsDeliver``, outputs of the teacher and the
# Without ``MethodOutputsDelivery``, outputs of the teacher and the
# student are very likely to be different.
# from toy_module import ToyClass
# toy_class = ToyClass()

View File

@ -0,0 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from toy_mod import Toy
from mmrazor.core.recorders import MethodOutputsRecorder
class TestFuncOutputsRecorder(TestCase):
def test_get_record_data(self):
toy = Toy()
recorder = MethodOutputsRecorder('toy_mod.Toy.toy_func')
recorder.initialize()
with recorder:
res0 = toy.toy_func()
res1 = toy.toy_func()
self.assertEquals(res0, recorder.get_record_data(record_idx=0))
self.assertEquals(res1, recorder.get_record_data(record_idx=1))
with self.assertRaisesRegex(
AssertionError,
'record_idx is illegal. The length of data_buffer is 2, '
'but record_idx is 2'):
_ = recorder.get_record_data(record_idx=2)
with self.assertRaisesRegex(
TypeError,
'When data_idx is not None, record should be a list or '
'tuple instance'):
_ = recorder.get_record_data(data_idx=0)
recorder = MethodOutputsRecorder('toy_mod.Toy.toy_list_func')
recorder.initialize()
with recorder:
res = toy.toy_list_func()
self.assertEqual(len(res), 3)
with self.assertRaisesRegex(
AssertionError,
'data_idx is illegal. The length of record is 3'):
_ = recorder.get_record_data(data_idx=3)
self.assertEquals(res[2], recorder.get_record_data(data_idx=2))

View File

@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from mmrazor.core.recorders import FunctionOutputsRecorder
class TestFuncOutputsRecorder(TestCase):
def test_init(self):
_ = FunctionOutputsRecorder('toy_mod.toy_func')
with self.assertRaisesRegex(TypeError, 'source should be'):
_ = FunctionOutputsRecorder([1])
with self.assertRaisesRegex(AssertionError, 'source must have at '):
_ = FunctionOutputsRecorder('aaaaa')
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
_ = FunctionOutputsRecorder('aaa.bbb')
with self.assertRaisesRegex(AssertionError, 'aaa is not in toy_mod'):
_ = FunctionOutputsRecorder('toy_mod.aaa')
with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'):
_ = FunctionOutputsRecorder('toy_mod.TOY_VAR')
def test_context_manager(self):
from toy_mod import execute_toy_func
recorder = FunctionOutputsRecorder('toy_mod.toy_func')
recorder.initialize()
with recorder:
execute_toy_func(1)
data = recorder.get_record_data()
self.assertTrue(data == 1)
execute_toy_func(1)
data = recorder.get_record_data()
self.assertTrue(data == 1)

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from mmrazor.core.recorders import MethodOutputsRecorder
class TestFuncOutputsRecorder(TestCase):
def test_init(self):
_ = MethodOutputsRecorder('toy_mod.ToyClass.toy')
with self.assertRaisesRegex(TypeError, 'source should be'):
_ = MethodOutputsRecorder([1])
with self.assertRaisesRegex(AssertionError, 'source must have at '):
_ = MethodOutputsRecorder('aaaaa')
with self.assertRaisesRegex(AssertionError, 'source must have at '):
_ = MethodOutputsRecorder('aaa.bbb')
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
_ = MethodOutputsRecorder('aaa.bbb.ccc')
with self.assertRaisesRegex(AssertionError, 'aaa is not in toy_mod'):
_ = MethodOutputsRecorder('toy_mod.aaa.bbb')
with self.assertRaisesRegex(TypeError, 'toy_func should be'):
_ = MethodOutputsRecorder('toy_mod.toy_func.bbb')
with self.assertRaisesRegex(AssertionError, 'bbb is not in ToyClass'):
_ = MethodOutputsRecorder('toy_mod.ToyClass.bbb')
with self.assertRaisesRegex(TypeError, 'TOY_CLS should be'):
_ = MethodOutputsRecorder('toy_mod.ToyClass.TOY_CLS')
def test_context_manager(self):
from toy_mod import ToyClass
toy = ToyClass()
recorder = MethodOutputsRecorder('toy_mod.ToyClass.toy')
recorder.initialize()
with recorder:
result = toy.toy()
data = recorder.get_record_data()
self.assertTrue(data == result)
result_ = toy.toy()
data = recorder.get_record_data()
self.assertTrue(data == result)
self.assertFalse(result_ == result)

View File

@ -0,0 +1,74 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from torch import nn
from mmrazor.core.recorders import ModuleInputsRecorder, ModuleOutputsRecorder
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x):
return self.conv2(self.conv1(x))
class TestModuleOutputsRecorder(TestCase):
def test_prepare_from_model(self):
recorder = ModuleOutputsRecorder('conv1')
with self.assertRaisesRegex(AssertionError, 'model can not be'):
recorder.prepare_from_model()
recorder = ModuleOutputsRecorder('conv3')
model = ToyModel()
with self.assertRaisesRegex(AssertionError, '"conv3" is not in'):
recorder.prepare_from_model(model)
recorder = ModuleOutputsRecorder('conv2')
model = ToyModel()
recorder.prepare_from_model(model)
def test_module_outputs(self):
recorder = ModuleOutputsRecorder('conv2')
model = ToyModel()
recorder.initialize(model)
with recorder:
self.assertTrue(recorder.recording)
res = model(torch.randn(1, 1, 1, 1))
self.assertEquals(res, recorder.get_record_data())
with recorder:
self.assertTrue(len(recorder.data_buffer) == 0)
_ = model(torch.randn(1, 1, 1, 1))
self.assertTrue(len(recorder.data_buffer) == 0)
def test_module_intputs(self):
recorder = ModuleInputsRecorder('conv1')
model = ToyModel()
recorder.initialize(model)
tensor = torch.randn(1, 1, 1, 1)
with recorder:
self.assertTrue(recorder.recording)
_ = model(tensor)
conv1_input = recorder.get_record_data(data_idx=0)
self.assertEquals(conv1_input.sum(), tensor.sum())
with recorder:
self.assertTrue(len(recorder.data_buffer) == 0)
_ = model(torch.randn(1, 1, 1, 1))
self.assertTrue(len(recorder.data_buffer) == 0)

View File

@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from torch import nn
from mmrazor.core.recorders import ParameterRecorder
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.toy_conv = nn.Conv2d(1, 1, 1)
self.no_record_conv = nn.Conv2d(1, 1, 1)
def forward(self, x):
return self.toy_conv(x)
class TestParameterRecorder(TestCase):
def test_prepare_from_model(self):
model = ToyModel()
recorder = ParameterRecorder('AAA')
with self.assertRaisesRegex(AssertionError, '"AAA" is not in the'):
recorder.initialize(model)
recorder = ParameterRecorder('toy_conv.bias')
with self.assertRaisesRegex(AssertionError, 'model can not be None'):
recorder.prepare_from_model()
recorder.initialize(model)
bias_weight = recorder.get_record_data()
self.assertEquals(bias_weight, model.toy_conv.bias)
with recorder:
_ = model(torch.randn(1, 1, 1, 1))
self.assertEquals(bias_weight, model.toy_conv.bias)

View File

@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmcv import ConfigDict
from torch import nn
from toy_mod import Toy
from mmrazor.core import RecorderManager
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
self.toy = Toy()
def forward(self, x):
return self.conv2(self.conv1(x)) + self.toy.toy_func()
class TestRecorderManager(TestCase):
def test_init(self):
manager = RecorderManager()
self.assertEquals(len(manager.recorders), 0)
recorders = ConfigDict(
r1=dict(type='ModuleOutputs', source='conv1'),
r2=dict(type='MethodOutputs', source='toy_mod.Toy.toy_func'),
)
manager = RecorderManager(recorders)
model = ToyModel()
manager.initialize(model)
def test_context_manager(self):
recorders = ConfigDict(
r1=dict(type='ModuleOutputs', source='conv2'),
r2=dict(type='MethodOutputs', source='toy_mod.Toy.toy_func'),
)
manager = RecorderManager(recorders)
model = ToyModel()
manager.initialize(model)
self.assertEquals(manager.get_recorder('r1'), manager.recorders['r1'])
self.assertEquals(manager.get_recorder('r2'), manager.recorders['r2'])
with manager:
res = model(torch.ones(1, 1, 1, 1))
method_outputs = manager.recorders['r2'].get_record_data()
conv2_outputs = manager.recorders['r1'].get_record_data()
self.assertEquals(res.sum(), method_outputs + conv2_outputs.sum())

View File

@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
TOY_VAR = 'aaa'
def toy_func(a):
return a
def toy_list_func(a):
return [a, a, a]
def execute_toy_func(a):
toy_func(a)
def execute_toy_list_func(a):
toy_list_func(a)
class ToyClass:
TOY_CLS = 'TOY_CLASS'
def __init__(self):
self._count = 0
def toy(self):
self._count += 1
return self._count
def __call__(self):
self._count += 1
return self._count
class Toy():
def toy_func(self):
return random.randint(0, 1000)
def toy_list_func(self):
return [random.randint(0, 1000) for _ in range(3)]

View File

@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
from mmcv import ConfigDict
from toy_models import ToyStudent
from mmrazor.models import ConfigurableDistill
class TestConfigurableDistill(TestCase):
def test_init(self):
recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
student = ToyStudent()
alg_kwargs = ConfigDict(
architecture=student,
student_recorders=recorders_cfg,
teacher_recorders=recorders_cfg,
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_toy=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='conv'),
)),
)
alg = ConfigurableDistill(**alg_kwargs)
self.assertEquals(alg.student, alg.architecture)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['distill_losses'] = None
with self.assertRaisesRegex(AssertionError,
'"loss_toy" is not in distill'):
_ = ConfigurableDistill(**alg_kwargs_)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['distill_losses'] = dict(toy=dict(type='ToyDistillLoss'))
alg_kwargs_['loss_forward_mappings'] = dict(
toy=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='conv')))
with self.assertWarnsRegex(UserWarning, 'Warning: If toy is a'):
_ = ConfigurableDistill(**alg_kwargs_)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['loss_forward_mappings'] = None
_ = ConfigurableDistill(**alg_kwargs_)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['loss_forward_mappings'] = list('AAA')
with self.assertRaisesRegex(TypeError,
'loss_forward_mappings should be '):
_ = ConfigurableDistill(**alg_kwargs_)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['loss_forward_mappings']['loss_toy'] = list()
with self.assertRaisesRegex(
TypeError, 'Each item of loss_forward_mappings should be '):
_ = ConfigurableDistill(**alg_kwargs_)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_.loss_forward_mappings.loss_toy.arg1.from_student = ''
with self.assertRaisesRegex(TypeError,
'from_student should be a bool'):
_ = ConfigurableDistill(**alg_kwargs_)

View File

@ -0,0 +1,77 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
import torch
from mmcv import ConfigDict
from toy_models import ToyStudent
from mmrazor.models import SingleTeacherDistill
class TestSingleTeacherDistill(TestCase):
def test_init(self):
recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
alg_kwargs = ConfigDict(
architecture=dict(type='ToyStudent'),
teacher=dict(type='ToyTeacher'),
student_recorders=recorders_cfg,
teacher_recorders=recorders_cfg,
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_toy=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='conv'),
)),
)
alg = SingleTeacherDistill(**alg_kwargs)
teacher = ToyStudent()
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['teacher'] = teacher
alg = SingleTeacherDistill(**alg_kwargs_)
self.assertEquals(alg.teacher, teacher)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['teacher'] = 'teacher'
with self.assertRaisesRegex(TypeError,
'teacher should be a `dict` or'):
_ = SingleTeacherDistill(**alg_kwargs_)
def test_loss(self):
recorders_cfg = ConfigDict(
conv=dict(type='ModuleOutputs', source='conv'))
alg_kwargs = ConfigDict(
architecture=dict(type='ToyStudent'),
teacher=dict(type='ToyTeacher'),
student_recorders=recorders_cfg,
teacher_recorders=recorders_cfg,
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
loss_forward_mappings=dict(
loss_toy=dict(
arg1=dict(from_student=True, recorder='conv'),
arg2=dict(from_student=False, recorder='conv'),
)),
)
img = torch.randn(1, 3, 1, 1)
alg = SingleTeacherDistill(**alg_kwargs)
losses = alg(img, mode='loss')
self.assertIn('distill.loss_toy', losses)
self.assertIn('student.loss', losses)
alg_kwargs_ = copy.deepcopy(alg_kwargs)
alg_kwargs_['teacher_trainable'] = True
alg = SingleTeacherDistill(**alg_kwargs_)
losses = alg(img, mode='loss')
self.assertIn('distill.loss_toy', losses)
self.assertIn('student.loss', losses)
self.assertIn('teacher.loss', losses)

View File

@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.model import BaseModel
from torch import nn
from mmrazor.registry import MODELS
@MODELS.register_module()
class ToyStudent(BaseModel):
def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
self.conv = nn.Conv2d(3, 1, 1)
def forward(self, batch_inputs, data_samples=None, mode='tensor'):
if mode == 'loss':
out = self.conv(batch_inputs)
return dict(loss=out)
elif mode == 'predict':
out = self.conv(batch_inputs) + 1
return out
elif mode == 'tensor':
out = self.conv(batch_inputs) + 2
return out
@MODELS.register_module()
class ToyTeacher(ToyStudent):
def __init__(self):
super().__init__()
@MODELS.register_module()
class ToyDistillLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg1, arg2):
return arg1 + arg2

View File

@ -124,5 +124,4 @@ class TestSingleTeacherDistillValLoop(TestCase):
runner = Runner.from_cfg(cfg)
runner.val()
self.assertIn('val_student/acc', runner.message_hub.log_scalars.keys())
self.assertIn('val_teacher/acc', runner.message_hub.log_scalars.keys())
self.assertIn('val/teacher.acc', runner.message_hub.log_scalars.keys())