mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
[Feature] Add Recorder to improve Distiller
This commit is contained in:
parent
8913d6840d
commit
cb238e36e3
39
configs/distill/mmcls/kl/kl_r34_r18_8xb32_in1k.py
Normal file
39
configs/distill/mmcls/kl/kl_r34_r18_8xb32_in1k.py
Normal file
@ -0,0 +1,39 @@
|
||||
_base_ = [
|
||||
'mmcls::_base_/datasets/imagenet_bs32.py',
|
||||
'mmcls::_base_/schedules/imagenet_bs256.py',
|
||||
'mmcls::_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(
|
||||
_scope_='mmrazor',
|
||||
type='SingleTeacherDistill',
|
||||
data_preprocessor=dict(
|
||||
type='ImgDataPreprocessor',
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
bgr_to_rgb=True),
|
||||
architecture=dict(
|
||||
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
|
||||
teacher=dict(
|
||||
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
|
||||
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
|
||||
student_recorders=dict(fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||
teacher_recorders=dict(fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||
distill_losses=dict(
|
||||
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5)),
|
||||
loss_forward_mappings=dict(
|
||||
loss_kl=dict(
|
||||
preds_S=dict(
|
||||
from_student=True,
|
||||
recorder='fc',
|
||||
),
|
||||
preds_T=dict(
|
||||
from_student=False,
|
||||
recorder='fc',
|
||||
))))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
|
36
configs/distill/mmcls/wsld/wsld_r34_r18_8xb32_in1k.py
Normal file
36
configs/distill/mmcls/wsld/wsld_r34_r18_8xb32_in1k.py
Normal file
@ -0,0 +1,36 @@
|
||||
_base_ = [
|
||||
'mmcls::_base_/datasets/imagenet_bs32.py',
|
||||
'mmcls::_base_/schedules/imagenet_bs256.py',
|
||||
'mmcls::_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(
|
||||
_scope_='mmrazor',
|
||||
type='SingleTeacherDistill',
|
||||
data_preprocessor=dict(
|
||||
type='ImgDataPreprocessor',
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
bgr_to_rgb=True),
|
||||
architecture=dict(
|
||||
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
|
||||
teacher=dict(
|
||||
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
|
||||
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
|
||||
student_recorders=dict(
|
||||
fc=dict(type='ModuleOutputs', source='head.fc'),
|
||||
data_samples=dict(type='ModuleInputs', source='')),
|
||||
teacher_recorders=dict(fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||
distill_losses=dict(loss_wsld=dict(type='WSLD', tau=2, loss_weight=2.5)),
|
||||
loss_forward_mappings=dict(
|
||||
loss_wsld=dict(
|
||||
student=dict(recorder='fc', from_student=True),
|
||||
teacher=dict(recorder='fc', from_student=False),
|
||||
data_samples=dict(
|
||||
recorder='data_samples', from_student=True, data_idx=1))))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
|
@ -0,0 +1,58 @@
|
||||
_base_ = [
|
||||
'mmdet::_base_/datasets/coco_detection.py',
|
||||
'mmdet::_base_/schedules/schedule_1x.py',
|
||||
'mmdet::_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# default_scope = 'mmrazor'
|
||||
teacher_ckpt = 'faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth' # noqa: E501
|
||||
model = dict(
|
||||
_scope_='mmrazor',
|
||||
type='FpnTeacherDistill',
|
||||
architecture=dict(
|
||||
cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py',
|
||||
pretrained=False),
|
||||
teacher=dict(
|
||||
cfg_path='mmdet::faster_rcnn/faster_rcnn_r101_fpn_2x_coco.py',
|
||||
pretrained=False),
|
||||
teacher_ckpt=teacher_ckpt,
|
||||
distill_losses=dict(
|
||||
loss_cwd_fpn0=dict(
|
||||
type='ChannelWiseDivergence', tau=1, loss_weight=10),
|
||||
loss_cwd_fpn1=dict(
|
||||
type='ChannelWiseDivergence', tau=1, loss_weight=10),
|
||||
loss_cwd_fpn2=dict(
|
||||
type='ChannelWiseDivergence', tau=1, loss_weight=10),
|
||||
loss_cwd_fpn3=dict(
|
||||
type='ChannelWiseDivergence', tau=1, loss_weight=10),
|
||||
loss_cwd_fpn4=dict(
|
||||
type='ChannelWiseDivergence', tau=1, loss_weight=10)),
|
||||
student_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
|
||||
teacher_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_cwd_fpn0=dict(
|
||||
preds_S=dict(from_student=True, recorder='fpn', data_idx=0),
|
||||
preds_T=dict(from_student=False, recorder='fpn', data_idx=0),
|
||||
),
|
||||
loss_cwd_fpn1=dict(
|
||||
preds_S=dict(from_student=True, recorder='fpn', data_idx=1),
|
||||
preds_T=dict(from_student=False, recorder='fpn', data_idx=1),
|
||||
),
|
||||
loss_cwd_fpn2=dict(
|
||||
preds_S=dict(from_student=True, recorder='fpn', data_idx=2),
|
||||
preds_T=dict(from_student=False, recorder='fpn', data_idx=2),
|
||||
),
|
||||
loss_cwd_fpn3=dict(
|
||||
preds_S=dict(from_student=True, recorder='fpn', data_idx=3),
|
||||
preds_T=dict(from_student=False, recorder='fpn', data_idx=3),
|
||||
),
|
||||
loss_cwd_fpn4=dict(
|
||||
preds_S=dict(from_student=True, recorder='fpn', data_idx=4),
|
||||
preds_T=dict(from_student=False, recorder='fpn', data_idx=4),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
|
@ -0,0 +1,8 @@
|
||||
_base_ = ['./cwd_fpn_gfl_r101_gfl_r50_1x_coco.py']
|
||||
|
||||
model = dict(
|
||||
architecture=dict(
|
||||
cfg_path='mmdet::gfl/gfl_r50_fpn_1x_coco.py', pretrained=False),
|
||||
teacher=dict(
|
||||
cfg_path='mmdet::gfl/gfl_r101_fpn_mstrain_2x_coco.py',
|
||||
pretrained=True))
|
@ -0,0 +1,9 @@
|
||||
_base_ = ['./cwd_fpn_frcnn_r101_frcnn_r50_1x_coco.py']
|
||||
|
||||
model = dict(
|
||||
architecture=dict(
|
||||
cfg_path='mmdet::retinanet/retinanet_r50_fpn_1x_coco.py',
|
||||
pretrained=False),
|
||||
teacher=dict(
|
||||
cfg_path='mmdet::retinanet/retinanet_r101_fpn_2x_coco.py',
|
||||
pretrained=True))
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .delivers import * # noqa: F401,F403
|
||||
from .recorders import * # noqa: F401,F403
|
||||
from .tracer import * # noqa: F401,F403
|
||||
|
@ -1,8 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .deliver_manager import DistillDeliverManager
|
||||
from .function_outputs_deliver import FunctionOutputsDeliver
|
||||
from .method_outputs_deliver import MethodOutputsDeliver
|
||||
from .deliver_manager import DistillDeliveryManager
|
||||
from .function_outputs_deliver import FunctionOutputsDelivery
|
||||
from .method_outputs_deliver import MethodOutputsDelivery
|
||||
|
||||
__all__ = [
|
||||
'FunctionOutputsDeliver', 'MethodOutputsDeliver', 'DistillDeliverManager'
|
||||
'FunctionOutputsDelivery', 'MethodOutputsDelivery',
|
||||
'DistillDeliveryManager'
|
||||
]
|
||||
|
@ -1,22 +1,26 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import List
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .distill_deliver import DistillDelivery
|
||||
|
||||
SUPPORT_DELIVERIES = ['FunctionOutputs', 'MethodOutputs']
|
||||
|
||||
|
||||
class DistillDeliverManager:
|
||||
"""Various types delivers' manager. The ``DistillDeliverManager`` is also a
|
||||
context manager, managing various types of delivers. When entering the
|
||||
``DistillDeliverManager``, all delivers managed by it will be started.
|
||||
class DistillDeliveryManager:
|
||||
"""Various types deliveries' manager. The ``DistillDeliveryManager`` is
|
||||
also a context manager, managing various types of deliveries.
|
||||
|
||||
When entering the ``DistillDeliveryManager``, all deliveries managed by it
|
||||
will be started.
|
||||
|
||||
Notes:
|
||||
DistillDeliver is a context manager used to override function (method)
|
||||
outputs from the target model with function (method) outputs from the
|
||||
source model.
|
||||
DistillDelivery is a context manager used to override function(method)
|
||||
outputs during teacher(student) forward.
|
||||
|
||||
Args:
|
||||
deliveries (list(dict)): Configs of all deliveries.
|
||||
deliveries (dict): Configs of all deliveries.
|
||||
|
||||
Examples:
|
||||
>>> from mmcls.models.utils import Augments
|
||||
@ -27,8 +31,8 @@ class DistillDeliverManager:
|
||||
>>> imgs = torch.randn(2, 3, 32, 32)
|
||||
>>> label = torch.randint(0, 10, (2,))
|
||||
|
||||
>>> # Without ``MethodOutputsDeliver``, outputs of the teacher and the
|
||||
>>> # student are very likely to be different.
|
||||
>>> # Without ``MethodOutputsDelivery``, outputs of the teacher and
|
||||
>>> # the student are different.
|
||||
>>> imgs_tea, label_tea = augments(imgs, label)
|
||||
>>> imgs_stu, label_stu = augments(imgs, label)
|
||||
>>> torch.equal(label_tea, label_stu)
|
||||
@ -36,10 +40,10 @@ class DistillDeliverManager:
|
||||
>>> torch.equal(imgs_tea, imgs_stu)
|
||||
False
|
||||
|
||||
>>> distill_deliveries = [
|
||||
... ConfigDict(type='MethodOutputs', max_keep_data=1,
|
||||
... method_path='mmcls.models.utils.Augments.__call__')]
|
||||
>>> manager = DistillDeliverManager(distill_deliveries)
|
||||
>>> distill_deliveries = ConfigDict(
|
||||
... aug=dict(type='MethodOutputs', max_keep_data=1,
|
||||
... method_path='mmcls.models.utils.Augments.__call__'))
|
||||
>>> manager = DistillDeliveryManager(distill_deliveries)
|
||||
|
||||
>>> manager.override_data = False
|
||||
>>> with manager:
|
||||
@ -55,44 +59,55 @@ class DistillDeliverManager:
|
||||
True
|
||||
"""
|
||||
|
||||
def __init__(self, deliveries: List) -> None:
|
||||
def __init__(self, deliveries: Optional[Dict[str, Dict]] = None) -> None:
|
||||
|
||||
# As there may be several delivers belong to a same deliver type,
|
||||
# we use a list to save delivers rather than a dict.
|
||||
self.deliveries = list()
|
||||
for cfg in deliveries:
|
||||
deliver_cfg = copy.deepcopy(cfg)
|
||||
deliver_type = cfg.type
|
||||
deliver_type = deliver_type + 'Deliver'
|
||||
deliver_cfg.type = deliver_type
|
||||
self.deliveries.append(TASK_UTILS.build(deliver_cfg))
|
||||
self._deliveries: Dict[str, DistillDelivery] = dict()
|
||||
if deliveries:
|
||||
for delivery_name, delivery_cfg in deliveries.items():
|
||||
delivery_cfg_ = copy.deepcopy(delivery_cfg)
|
||||
delivery_type_ = delivery_cfg_.get('type', '')
|
||||
assert isinstance(delivery_type_, str)
|
||||
assert delivery_type_ in SUPPORT_DELIVERIES
|
||||
|
||||
delivery_type_ = delivery_type_ + 'Delivery'
|
||||
delivery_cfg_.update(dict(type=delivery_type_))
|
||||
|
||||
delivery = TASK_UTILS.build(delivery_cfg_)
|
||||
self.deliveries[delivery_name] = delivery
|
||||
|
||||
self._override_data = False
|
||||
|
||||
@property
|
||||
def override_data(self):
|
||||
def deliveries(self) -> Dict[str, DistillDelivery]:
|
||||
"""dict: all deliveries."""
|
||||
return self._deliveries
|
||||
|
||||
@property
|
||||
def override_data(self) -> bool:
|
||||
"""bool: indicate whether to override the data with the recorded data.
|
||||
"""
|
||||
return self._override_data
|
||||
|
||||
@override_data.setter
|
||||
def override_data(self, override):
|
||||
"""Set the override_data property to all the delivers.
|
||||
def override_data(self, override: bool) -> None:
|
||||
"""Set the override_data property to all the deliveries.
|
||||
|
||||
If the `override_data` of a deliver is False, the deliver will record
|
||||
and keep the origin data. If the current_mode of a deliver is True, the
|
||||
deliver will override the origin data with the recorded data.
|
||||
If the `override_data` of a delivery is False, the delivery will
|
||||
record the origin data.
|
||||
|
||||
If the `override_data` of a delivery is True, the delivery will
|
||||
override the origin data with the recorded data.
|
||||
"""
|
||||
self._override_data = override
|
||||
for deliver in self.deliveries:
|
||||
deliver.override_data = override
|
||||
for delivery in self.deliveries.values():
|
||||
delivery.override_data = override
|
||||
|
||||
def __enter__(self) -> None:
|
||||
"""Enter the context manager."""
|
||||
for deliver in self.deliveries:
|
||||
deliver.__enter__()
|
||||
for delivery in self.deliveries.values():
|
||||
delivery.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""Exit the context manager."""
|
||||
for deliver in self.deliveries:
|
||||
deliver.__exit__(exc_type, exc_value, traceback)
|
||||
for delivery in self.deliveries.values():
|
||||
delivery.__exit__(exc_type, exc_value, traceback)
|
||||
|
@ -4,32 +4,33 @@ from queue import Queue
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class DistillDeliver(metaclass=ABCMeta):
|
||||
"""Base class for delivers for distillation.
|
||||
# TODO: Support overriding part of the outputs of a function or method
|
||||
class DistillDelivery(metaclass=ABCMeta):
|
||||
"""Base class for deliveries for distillation.
|
||||
|
||||
DistillDeliver is a context manager used to override function (method)
|
||||
outputs from the target model with function (method) outputs from the
|
||||
source model.
|
||||
In MMRazor, there will be different types of delivers to deliver different
|
||||
types of data. They can be used in combination with the
|
||||
DistillDelivery is a context manager used to override function(method)
|
||||
outputs during teacher(student) forward.
|
||||
|
||||
A delivery can only handle one function or method. Some algorithms may use
|
||||
multiple deliveries, which can be managed uniformly using
|
||||
``DistillDeliverManager``.
|
||||
|
||||
Args:
|
||||
max_keep_data (int): The length limitation of the queue, should be
|
||||
larger than the execute times of the function or method. Defaults
|
||||
to 1.
|
||||
|
||||
Notes:
|
||||
If a function (method) is executed more than once during the forward
|
||||
of the source model, all the outputs of this function (method) will be
|
||||
used to override function (method) outputs from the target model.
|
||||
|
||||
TODO:
|
||||
Support overriding some of the outputs of a function (method)
|
||||
|
||||
Args:
|
||||
max_keep_data (int): The length limitation of the queue. If a function
|
||||
(method) is executed more than once during the forward of the
|
||||
target model, function (method) outputs from the source model are
|
||||
pushed into the queue in order. Default to 1.
|
||||
If a function or method is executed more than once during the forward
|
||||
of the target model, its' outputs from the source model are pushed
|
||||
into the queue in order.
|
||||
"""
|
||||
|
||||
def __init__(self, max_keep_data: int = 1):
|
||||
def __init__(self, max_keep_data: int = 1) -> None:
|
||||
|
||||
self._override_data = False
|
||||
self.data_queue: Queue = Queue(maxsize=max_keep_data)
|
||||
@ -55,3 +56,11 @@ class DistillDeliver(metaclass=ABCMeta):
|
||||
def deliver_wrapper(self, origin: Callable) -> Callable:
|
||||
"""Wrap the specific object to make the intermediate results of the
|
||||
model can be delivered."""
|
||||
|
||||
@abstractmethod
|
||||
def __enter__(self) -> None:
|
||||
"""Enter the context manager."""
|
||||
|
||||
@abstractmethod
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""Exit the context manager."""
|
||||
|
@ -6,11 +6,11 @@ from typing import Callable
|
||||
from mmcv.utils import import_modules_from_strings
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .distill_deliver import DistillDeliver
|
||||
from .distill_deliver import DistillDelivery
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class FunctionOutputsDeliver(DistillDeliver):
|
||||
class FunctionOutputsDelivery(DistillDelivery):
|
||||
"""Delivery for intermediate results which are ``FunctionType``'s outputs.
|
||||
|
||||
Args:
|
||||
@ -29,23 +29,28 @@ class FunctionOutputsDeliver(DistillDeliver):
|
||||
`mmdet.core.anchor.utils.anchor_inside_flags`.
|
||||
|
||||
Examples:
|
||||
>>> # Suppose there is a toy function named ``toy_func`` in test.py.
|
||||
>>> # It return random integers from 0 to 999.
|
||||
>>> import toy_module
|
||||
>>> # Suppose we want to deliver outputs from the teacher to
|
||||
>>> # Below code in toy_module.py
|
||||
>>> import random
|
||||
>>> def toy_func():
|
||||
>>> return random.randint(0, 1000)
|
||||
|
||||
>>> # Below code in main.py
|
||||
>>> # Teacher and student both will execute toy_func.
|
||||
>>> # Now, we want to deliver outputs from the teacher to
|
||||
>>> # the student
|
||||
>>> import toy_module
|
||||
>>> delivery = FunctionOutputsDeliver(
|
||||
... max_keep_data=1, func_path='toy_module.toy_func')
|
||||
|
||||
>>> delivery.override_data = False
|
||||
>>> with delivery:
|
||||
... output_tea = toy_module.toy_func()
|
||||
... output_teacher = toy_module.toy_func()
|
||||
|
||||
>>> delivery.override_data = True
|
||||
>>> with delivery:
|
||||
... output_stu = toy_module.toy_func()
|
||||
... output_student = toy_module.toy_func()
|
||||
|
||||
>>> output_tea == output_stu
|
||||
>>> output_teacher == output_student
|
||||
True
|
||||
|
||||
>>> # If a function (method) is executed more than once during the
|
||||
|
@ -6,11 +6,11 @@ from typing import Callable
|
||||
from mmcv.utils import import_modules_from_strings
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .distill_deliver import DistillDeliver
|
||||
from .distill_deliver import DistillDelivery
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class MethodOutputsDeliver(DistillDeliver):
|
||||
class MethodOutputsDelivery(DistillDelivery):
|
||||
"""Delivery for intermediate results which are ``MethodType``'s outputs.
|
||||
|
||||
Note:
|
||||
|
13
mmrazor/core/recorders/__init__.py
Normal file
13
mmrazor/core/recorders/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .function_outputs_recorder import FunctionOutputsRecorder
|
||||
from .method_outputs_recorder import MethodOutputsRecorder
|
||||
from .module_inputs_recorder import ModuleInputsRecorder
|
||||
from .module_outputs_recorder import ModuleOutputsRecorder
|
||||
from .param_recorder import ParameterRecorder
|
||||
from .recorder_manager import RecorderManager
|
||||
|
||||
__all__ = [
|
||||
'FunctionOutputsRecorder', 'MethodOutputsRecorder',
|
||||
'ModuleOutputsRecorder', 'ParameterRecorder', 'RecorderManager',
|
||||
'ModuleInputsRecorder'
|
||||
]
|
116
mmrazor/core/recorders/base_recorder.py
Normal file
116
mmrazor/core/recorders/base_recorder.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BaseRecorder(metaclass=ABCMeta):
|
||||
"""Base class for recorders.
|
||||
|
||||
Recorder is a context manager used to record various intermediate results
|
||||
during the model forward. It can be used in distillation algorithm and
|
||||
can also be used to obtain some specific data for visual analysis.
|
||||
|
||||
In MMRazor, there will be different types of recorders to obtain different
|
||||
types of intermediate results. They can be used in combination with the
|
||||
``RecorderManager``.
|
||||
|
||||
Note:
|
||||
The recorder will be lazily initialized in the ``RecorderManager`` by
|
||||
default. If you want to use the recorder without the
|
||||
``RecorderManager``, you need to initialize it first.
|
||||
"""
|
||||
|
||||
def __init__(self, source: str) -> None:
|
||||
|
||||
self._source = source
|
||||
# Intermediate results are recorded in dictionary format according
|
||||
# to the data source.
|
||||
# One data source may generate multiple records, which need to be
|
||||
# recorded through list.
|
||||
self._data_buffer: List = list()
|
||||
# Before using the recorder for the first time, it needs to be
|
||||
# initialized.
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def source(self) -> str:
|
||||
"""str: source of recorded data."""
|
||||
return self._source
|
||||
|
||||
@property
|
||||
def data_buffer(self) -> List:
|
||||
"""list: data buffer."""
|
||||
return self._data_buffer
|
||||
|
||||
@abstractmethod
|
||||
def prepare_from_model(self, model: Optional[nn.Module] = None) -> None:
|
||||
"""Make the intermediate results of the model can be record."""
|
||||
|
||||
def initialize(self, model: Optional[nn.Module] = None) -> None:
|
||||
"""Init the recorder.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model which need to record intermediate
|
||||
results.
|
||||
"""
|
||||
self.prepare_from_model(model)
|
||||
self._initialized = True
|
||||
|
||||
def get_record_data(self,
|
||||
record_idx: int = 0,
|
||||
data_idx: Optional[int] = None) -> Any:
|
||||
"""Get data from ``data_buffer``.
|
||||
|
||||
Args:
|
||||
record_idx (int): The index of the record saved in
|
||||
``data_buffer``. If a source is executed N times during
|
||||
forward, there will be N records in ``data_buffer``.
|
||||
data_index (int, optional): The index of target data in
|
||||
a record. A record may be a tuple or a list, if data_idx is
|
||||
None, the whole list or tuple is returned. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Any: The type of the return value is undefined, and different
|
||||
source data may have different types.
|
||||
"""
|
||||
assert record_idx < len(self._data_buffer), \
|
||||
'record_idx is illegal. The length of data_buffer is ' \
|
||||
f'{len(self._data_buffer)}, but record_idx is ' \
|
||||
f'{record_idx}.'
|
||||
|
||||
record = self._data_buffer[record_idx]
|
||||
|
||||
if data_idx is None:
|
||||
target_data = record
|
||||
else:
|
||||
if isinstance(record, (list, tuple)):
|
||||
assert data_idx < len(record), \
|
||||
'data_idx is illegal. The length of record is ' \
|
||||
f'{len(record)}, but data_idx is {data_idx}.'
|
||||
target_data = record[data_idx]
|
||||
else:
|
||||
raise TypeError('When data_idx is not None, record should be '
|
||||
'a list or tuple instance, but got '
|
||||
f'{type(record)}.')
|
||||
|
||||
return target_data
|
||||
|
||||
def reset_data_buffer(self) -> None:
|
||||
"""Clear data in data_buffer."""
|
||||
|
||||
self._data_buffer = list()
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager."""
|
||||
|
||||
assert self._initialized, \
|
||||
'The recorder will be initialized in the RecorderManager by '\
|
||||
'default. If you want to use the recorder without the '\
|
||||
'RecorderManager, you need to initialize it first.'
|
||||
|
||||
self.reset_data_buffer()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit the context manager."""
|
161
mmrazor/core/recorders/function_outputs_recorder.py
Normal file
161
mmrazor/core/recorders/function_outputs_recorder.py
Normal file
@ -0,0 +1,161 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
from types import FunctionType, ModuleType
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from mmcv.utils import import_modules_from_strings
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .base_recorder import BaseRecorder
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class FunctionOutputsRecorder(BaseRecorder):
|
||||
"""Recorder for intermediate results which are ``FunctionType``'s outputs.
|
||||
|
||||
Notes:
|
||||
The form of `source` needs special attention. For example,
|
||||
`anchor_inside_flags` is a function in mmdetection to check whether the
|
||||
anchors are inside the border. This function is in
|
||||
`mmdet/core/anchor/utils.py` and used in
|
||||
`mmdet/models/dense_heads/anchor_head.py`. Then the source should be
|
||||
`mmdet.models.dense_heads.anchor_head.anchor_inside_flags` but not
|
||||
`mmdet.core.anchor.utils.anchor_inside_flags`.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> # Below code in toy_module.py
|
||||
>>> import random
|
||||
>>> def toy_func():
|
||||
... return random.randint(0, 1000)
|
||||
>>> def toy_list_func():
|
||||
... return [random.randint(0,1000) for _ in range(3)]
|
||||
|
||||
>>> # Below code in main.py
|
||||
>>> # Now, we want to get teacher's outputs by recorder.
|
||||
|
||||
>>> import toy_module
|
||||
>>> r1 = FunctionOutputsRecorder('toy_module.toy_func')
|
||||
>>> r1.initialize()
|
||||
>>> with r1:
|
||||
... output_teacher1 = toy_module.toy_func()
|
||||
... output_teacher2 = toy_module.toy_func()
|
||||
... output_teacher3 = toy_module.toy_func()
|
||||
|
||||
>>> r1.data_buffer
|
||||
[33, 41, 12]
|
||||
>>> recorder.get_record_data(record_idx=2)
|
||||
12
|
||||
>>> output_teacher1==33 and output_teacher2==41 and output_teacher3==41
|
||||
True
|
||||
|
||||
>>> r2 = FunctionOutputsRecorder('toy_module.toy_list_func')
|
||||
>>> r2.initialize()
|
||||
>>> with r2:
|
||||
... output_teacher1 = toy_module.toy_list_func()
|
||||
... output_teacher2 = toy_module.toy_list_func()
|
||||
... output_teacher3 = toy_module.toy_list_func()
|
||||
|
||||
>>> r2.data_buffer
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||
>>> r2.get_record_data(record_idx=2, data_idx=2)
|
||||
9
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._check_valid_source(self.source)
|
||||
|
||||
# import the function corrosponding module
|
||||
try:
|
||||
mod = import_modules_from_strings(self.module_string)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f'{self.module_string} is not imported correctly.')
|
||||
|
||||
self.imported_module: ModuleType = mod
|
||||
|
||||
assert hasattr(mod, self.func_name), \
|
||||
f'{self.func_name} is not in {self.module_string}.'
|
||||
|
||||
origin_func = getattr(mod, self.func_name)
|
||||
if not isinstance(origin_func, FunctionType):
|
||||
raise TypeError(f'{self.func_name} should be a FunctionType '
|
||||
f'instance, but got {type(origin_func)}')
|
||||
|
||||
self.origin_func: Callable = origin_func
|
||||
|
||||
@staticmethod
|
||||
def _check_valid_source(source):
|
||||
"""Check if the source's format is valid."""
|
||||
if not isinstance(source, str):
|
||||
raise TypeError(f'source should be a str '
|
||||
f'instance, but got {type(source)}')
|
||||
|
||||
assert len(source.split('.')) > 1, \
|
||||
'source must have at least one `.`'
|
||||
|
||||
@property
|
||||
def func_name(self):
|
||||
"""Get the function name according to `func_path`."""
|
||||
return self.source.split('.')[-1]
|
||||
|
||||
@property
|
||||
def module_string(self):
|
||||
"""Get the module name according to `func_path`."""
|
||||
return '.'.join(self.source.split('.')[:-1])
|
||||
|
||||
def prepare_from_model(self, model: Optional[nn.Module] = None) -> None:
|
||||
"""The `model` is useless in `FunctionOutputsRecorder`."""
|
||||
pass
|
||||
|
||||
def func_record_wrapper(self, origin_func: Callable,
|
||||
data_buffer: List) -> Callable:
|
||||
"""Save the function's outputs.
|
||||
|
||||
Args:
|
||||
origin_func (FunctionType): The method whose outputs need to be
|
||||
recorded.
|
||||
buffer_key (str): The key of the function's outputs saved in
|
||||
``data_buffer``.
|
||||
"""
|
||||
|
||||
@functools.wraps(origin_func)
|
||||
def wrap_func(*args, **kwargs):
|
||||
outputs = origin_func(*args, **kwargs)
|
||||
# assume a func execute N times, there will be N outputs need to
|
||||
# save.
|
||||
data_buffer.append(outputs)
|
||||
return outputs
|
||||
|
||||
return wrap_func
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager."""
|
||||
super().__enter__()
|
||||
|
||||
mod = self.imported_module
|
||||
origin_func = self.origin_func
|
||||
# add record wrapper to origin function.
|
||||
record_func = self.func_record_wrapper(origin_func, self.data_buffer)
|
||||
|
||||
assert hasattr(mod, self.func_name), \
|
||||
f'{self.func_name} is not in {self.module_string}.'
|
||||
|
||||
# rewrite the origin function
|
||||
setattr(mod, self.func_name, record_func)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit the context manager."""
|
||||
super().__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
mod = self.imported_module
|
||||
origin_func = self.origin_func
|
||||
|
||||
assert hasattr(mod, self.func_name), \
|
||||
f'{self.func_name} is not in {self.module_string}.'
|
||||
|
||||
# restore the origin function
|
||||
setattr(mod, self.func_name, origin_func)
|
168
mmrazor/core/recorders/method_outputs_recorder.py
Normal file
168
mmrazor/core/recorders/method_outputs_recorder.py
Normal file
@ -0,0 +1,168 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
from types import FunctionType, ModuleType
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from mmcv.utils import import_modules_from_strings
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .base_recorder import BaseRecorder
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class MethodOutputsRecorder(BaseRecorder):
|
||||
"""Recorder for intermediate results which are ``MethodType``'s outputs.
|
||||
|
||||
Note:
|
||||
Different from ``FunctionType``, ``MethodType`` is the type of methods
|
||||
of class instances.
|
||||
|
||||
Examples:
|
||||
>>> # Below code in toy_module.py
|
||||
>>> import random
|
||||
>>> class Toy():
|
||||
... def toy_func(self):
|
||||
... return random.randint(0, 1000)
|
||||
... def toy_list_func(self):
|
||||
... return [random.randint(0, 1000) for _ in range(3)]
|
||||
|
||||
>>> # Below code in main.py
|
||||
>>> # Now, we want to get teacher's outputs by recorder.
|
||||
|
||||
>>> from toy_module import Toy
|
||||
>>> toy = Toy()
|
||||
>>> r1 = MethodOutputsRecorder('toy_module.Toy.toy_func')
|
||||
>>> r1.initialize()
|
||||
>>> with r1:
|
||||
... output_teacher1 = toy.toy_func()
|
||||
... output_teacher2 = toy.toy_func()
|
||||
... output_teacher3 = toy.toy_func()
|
||||
|
||||
>>> r1.data_buffer
|
||||
[33, 41, 12]
|
||||
>>> r1.get_record_data(record_idx=2)
|
||||
12
|
||||
>>> output_teacher1==33 and output_teacher2==41 and output_teacher3==41
|
||||
True
|
||||
|
||||
>>> r2 = MethodOutputsRecorder('toy_module.Toy.toy_list_func'
|
||||
>>> r2.initialize()
|
||||
>>> with r2:
|
||||
... output_teacher1 = toy.toy_list_func()
|
||||
... output_teacher2 = toy.toy_list_func()
|
||||
... output_teacher3 = toy.toy_list_func()
|
||||
|
||||
>>> r2.data_buffer
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||
>>> r2.get_record_data(record_idx=2, data_idx=2)
|
||||
9
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._check_valid_source(self.source)
|
||||
|
||||
# import the function corrosponding module
|
||||
try:
|
||||
mod: ModuleType = import_modules_from_strings(self.module_string)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f'{self.module_string} is not imported correctly.')
|
||||
|
||||
assert hasattr(mod, self.cls_name), \
|
||||
f'{self.cls_name} is not in {self.module_string}.'
|
||||
|
||||
imported_cls: type = getattr(mod, self.cls_name)
|
||||
if not isinstance(imported_cls, type):
|
||||
raise TypeError(f'{self.cls_name} should be a type '
|
||||
f'instance, but got {type(imported_cls)}')
|
||||
self.imported_class = imported_cls
|
||||
|
||||
assert hasattr(imported_cls, self.method_name), \
|
||||
f'{self.method_name} is not in {self.cls_name}.'
|
||||
|
||||
origin_method = getattr(imported_cls, self.method_name)
|
||||
if not isinstance(origin_method, FunctionType):
|
||||
raise TypeError(f'{self.method_name} should be a FunctionType '
|
||||
f'instance, but got {type(origin_method)}')
|
||||
self.origin_method = origin_method
|
||||
|
||||
@staticmethod
|
||||
def _check_valid_source(source: str) -> None:
|
||||
"""Check if the `source` is valid."""
|
||||
if not isinstance(source, str):
|
||||
raise TypeError(f'source should be a str '
|
||||
f'instance, but got {type(source)}')
|
||||
|
||||
assert len(source.split('.')) > 2, \
|
||||
'source must have at least two `.`'
|
||||
|
||||
@property
|
||||
def method_name(self):
|
||||
"""Get the method name according to `method_path`."""
|
||||
return self.source.split('.')[-1]
|
||||
|
||||
@property
|
||||
def cls_name(self):
|
||||
"""Get the class name corresponding to this method according to
|
||||
`method_path`."""
|
||||
return self.source.split('.')[-2]
|
||||
|
||||
@property
|
||||
def module_string(self):
|
||||
"""Get the module name according to `method_path`."""
|
||||
return '.'.join(self.source.split('.')[:-2])
|
||||
|
||||
def prepare_from_model(self, model: Optional[nn.Module] = None) -> None:
|
||||
"""Wrapper the origin source methods.
|
||||
|
||||
The ``model`` is useless in this recorder, just to be consistent with
|
||||
other recorders.
|
||||
"""
|
||||
pass
|
||||
|
||||
def method_record_wrapper(self, orgin_method: Callable,
|
||||
data_buffer: List) -> Callable:
|
||||
"""Save the method's outputs.
|
||||
|
||||
Args:
|
||||
origin_method (MethodType): The method whose outputs need to be
|
||||
recorded.
|
||||
buffer_key (str): The key of the method's outputs saved in
|
||||
``data_buffer``.
|
||||
"""
|
||||
|
||||
@functools.wraps(orgin_method)
|
||||
def wrap_method(*args, **kwargs):
|
||||
outputs = orgin_method(*args, **kwargs)
|
||||
# assume a func execute N times, there will be N outputs need to
|
||||
# save.
|
||||
data_buffer.append(outputs)
|
||||
return outputs
|
||||
|
||||
return wrap_method
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager."""
|
||||
super().__enter__()
|
||||
|
||||
imported_cls = self.imported_class
|
||||
origin_method = self.origin_method
|
||||
# add record wrapper to origin method.
|
||||
record_method = self.method_record_wrapper(origin_method,
|
||||
self.data_buffer)
|
||||
|
||||
# rewrite the origin method.
|
||||
setattr(imported_cls, self.method_name, record_method)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit the context manager."""
|
||||
super().__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
imported_cls = self.imported_class
|
||||
origin_method = self.origin_method
|
||||
|
||||
# restore the origin method
|
||||
setattr(imported_cls, self.method_name, origin_method)
|
27
mmrazor/core/recorders/module_inputs_recorder.py
Normal file
27
mmrazor/core/recorders/module_inputs_recorder.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Tuple
|
||||
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .module_outputs_recorder import ModuleOutputsRecorder
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ModuleInputsRecorder(ModuleOutputsRecorder):
|
||||
"""Recorder for intermediate results which are Pytorch moudle's inputs."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward_hook(self, module: nn.Module, inputs: Tuple,
|
||||
outputs: Any) -> None:
|
||||
"""Save the module's forward input.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.Module`): The module to register hook.
|
||||
inputs (tuple): The input of the module.
|
||||
outputs : The output of the module.
|
||||
"""
|
||||
if self.recording:
|
||||
self.data_buffer.append(inputs)
|
96
mmrazor/core/recorders/module_outputs_recorder.py
Normal file
96
mmrazor/core/recorders/module_outputs_recorder.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .base_recorder import BaseRecorder
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ModuleOutputsRecorder(BaseRecorder):
|
||||
"""Recorder for intermediate results which are Pytorch moudle's outputs.
|
||||
|
||||
Examples:
|
||||
>>> from torch import nn
|
||||
>>> class ToyModel(nn.Module):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.conv1 = nn.Conv2d(8,8,1)
|
||||
... self.conv2 = nn.Conv2d(1,1,1)
|
||||
... def forward(self, x):
|
||||
... x1 = self.conv1(x)
|
||||
... x2 = self.conv1(x+1)
|
||||
... return self.conv2(x1 + x2)
|
||||
|
||||
>>> model = ToyModel()
|
||||
>>> [ name for name,_ in model.named_modules() ]
|
||||
['conv1', 'conv2']
|
||||
|
||||
>>> r1 = ModuleOutputsRecorder('conv1')
|
||||
>>> r1.initialize(model)
|
||||
|
||||
>>> with r1:
|
||||
>>> res = model(torch.randn(1,1,1,1))
|
||||
|
||||
>>> r1.data_buffer
|
||||
[tensor([[[[0.6734]]]]), tensor([[[[1.2514]]]]) ]
|
||||
>>> r1.get_record_data(record_idx=1)
|
||||
tensor([[[[1.2514]]]])
|
||||
|
||||
>>> r2 = ModuleOutputsRecorder('conv2')
|
||||
>>> r2.initialize(model)
|
||||
|
||||
>>> with r2:
|
||||
>>> res = model(torch.randn(1,1,1,1))
|
||||
|
||||
>>> r2.data_buffer
|
||||
[tensor([[[[0.9534]]]])]
|
||||
>>> r2.get_record_data()
|
||||
tensor([[[[0.9534]]]])
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._recording = False
|
||||
|
||||
@property
|
||||
def recording(self) -> bool:
|
||||
"""bool: whether to record data in forward hook."""
|
||||
return self._recording
|
||||
|
||||
def prepare_from_model(self, model: Optional[nn.Module] = None) -> None:
|
||||
"""Register Pytorch forward hook to corresponding module."""
|
||||
|
||||
assert model is not None, 'model can not be None.'
|
||||
|
||||
founded = False
|
||||
for name, module in model.named_modules():
|
||||
if name == self.source:
|
||||
module.register_forward_hook(self.forward_hook)
|
||||
founded = True
|
||||
break
|
||||
|
||||
assert founded, f'"{self.source}" is not in the model.'
|
||||
|
||||
def forward_hook(self, module: nn.Module, inputs: Tuple,
|
||||
outputs: Any) -> None:
|
||||
"""Save the module's forward output.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.Module`): The module to register hook.
|
||||
inputs (tuple): The input of the module.
|
||||
outputs : The output of the module.
|
||||
"""
|
||||
if self._recording:
|
||||
self.data_buffer.append(outputs)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager."""
|
||||
super().__enter__()
|
||||
self._recording = True
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit the context manager."""
|
||||
super().__exit__(exc_type, exc_value, traceback)
|
||||
self._recording = False
|
57
mmrazor/core/recorders/param_recorder.py
Normal file
57
mmrazor/core/recorders/param_recorder.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .base_recorder import BaseRecorder
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ParameterRecorder(BaseRecorder):
|
||||
"""Recorder for Pytorch model's parameters.
|
||||
|
||||
Examples:
|
||||
>>> from torch import nn
|
||||
>>> class ToyModel(nn.Module):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.toy_conv = nn.Conv2d(1,1,1)
|
||||
... def forward(self, x):
|
||||
... return self.toy_conv(x)
|
||||
|
||||
>>> model = ToyModel()
|
||||
>>> [ name for name,_ in model.named_parameters() ]
|
||||
['toy_conv.weight', 'toy_conv.bias']
|
||||
|
||||
>>> recorder = ParameterRecorder('toy_conv.weight')
|
||||
>>> recorder.initialize(model)
|
||||
|
||||
>>> recorder.data_buffer
|
||||
[Parameter containing: tensor([[[[0.3244]]]], requires_grad=True)]
|
||||
>>> recorder.get_record_data()
|
||||
Parameter containing: tensor([[[[0.3244]]]], requires_grad=True)
|
||||
"""
|
||||
|
||||
def prepare_from_model(self, model: Optional[nn.Module] = None) -> None:
|
||||
"""Record the Pytorch model's parameters."""
|
||||
assert model is not None, \
|
||||
'model can not be None when use ParameterRecorder.'
|
||||
|
||||
founded = False
|
||||
for param_name, param in model.named_parameters():
|
||||
if param_name == self.source:
|
||||
self.data_buffer.append(param)
|
||||
founded = True
|
||||
break
|
||||
|
||||
assert founded, f'"{self.source}" is not in the model.'
|
||||
|
||||
def reset_data_buffer(self):
|
||||
"""Clear data in data_buffer.
|
||||
|
||||
Note:
|
||||
The data_buffer stores the address of the parameter in memory and
|
||||
does not need to be reset.
|
||||
"""
|
||||
pass
|
117
mmrazor/core/recorders/recorder_manager.py
Normal file
117
mmrazor/core/recorders/recorder_manager.py
Normal file
@ -0,0 +1,117 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mmcv import ConfigDict
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .base_recorder import BaseRecorder
|
||||
|
||||
|
||||
class RecorderManager:
|
||||
"""Various types recorders' manager. The ``RecorderManager`` also is a
|
||||
context manager, managing various types of Recorder. When entering the
|
||||
``RecorderManager``, all recorders managed by it will be started.
|
||||
|
||||
Note:
|
||||
The recorders will be initialized in the ``RecorderManager`` by
|
||||
default. If you want to just use a recorder without the
|
||||
``RecorderManager``, you need to initialize it first.
|
||||
|
||||
Args:
|
||||
recorders (dict, optional): All recorders' config.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> # Below code in toy_module.py
|
||||
>>> import random
|
||||
>>> class Toy():
|
||||
... def toy_func(self):
|
||||
... return random.randint(0, 1000)
|
||||
|
||||
>>> # Below code in main.py
|
||||
>>> from torch import nn
|
||||
>>> from toy_module import Toy
|
||||
|
||||
>>> class ToyModel(nn.Module):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.conv1 = nn.Conv2d(1,1,1)
|
||||
... self.conv2 = nn.Conv2d(1,1,1)
|
||||
... self.toy = Toy()
|
||||
... def forward(self, x):
|
||||
... return self.conv2(self.conv1(x)) + self.toy.toy_func()
|
||||
|
||||
>>> model = ToyModel()
|
||||
>>> [ name for name,_ in model.named_modules() ]
|
||||
['conv1', 'conv2']
|
||||
|
||||
>>> conv1_rec = ModuleOutputsRecorder('conv1')
|
||||
>>> conv2_rec = ModuleOutputsRecorder('conv2')
|
||||
>>> func_rec = MethodOutputsRecorder('toy_module.Toy.toy_func')
|
||||
>>> manager = RecorderManager(
|
||||
... {'conv1_rec': conv1_rec ,
|
||||
... 'conv2_rec': conv2_rec,
|
||||
... 'func_rec': func_rec})
|
||||
>>> manager.initialize(model)
|
||||
|
||||
>>> with manager:
|
||||
... res = model(torch.ones(1,1,1,1))
|
||||
>>> res
|
||||
tensor([[[[22.9534]]]])
|
||||
|
||||
>>> conv2_data = manager.get_recorder('conv2_rec').get_record_data()
|
||||
>>> conv2_data
|
||||
tensor([[[[0.9534]]]])
|
||||
|
||||
>>> func_data = manager.get_recorder('func_rec').get_record_data()
|
||||
>>> func_data
|
||||
22
|
||||
|
||||
>>> res.sum() == (conv2_data + func_data).sum()
|
||||
True
|
||||
"""
|
||||
|
||||
def __init__(self, recorders: Optional[ConfigDict] = None) -> None:
|
||||
|
||||
self._recorders: Dict[str, BaseRecorder] = dict()
|
||||
if recorders:
|
||||
for name, cfg in recorders.items():
|
||||
recorder_cfg = copy.deepcopy(cfg)
|
||||
recorder_type = cfg.type
|
||||
recorder_type_ = recorder_type + 'Recorder'
|
||||
|
||||
recorder_cfg.type = recorder_type_
|
||||
recorder = TASK_UTILS.build(recorder_cfg)
|
||||
|
||||
self._recorders[name] = recorder
|
||||
|
||||
@property
|
||||
def recorders(self) -> Dict[str, BaseRecorder]:
|
||||
"""dict: all recorders."""
|
||||
return self._recorders
|
||||
|
||||
def get_recorder(self, recorder: str) -> BaseRecorder:
|
||||
"""Get the corresponding recorder according to the name."""
|
||||
return self.recorders[recorder]
|
||||
|
||||
def initialize(self, model: nn.Module):
|
||||
"""Init all recorders.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model which need to record intermediate
|
||||
results.
|
||||
"""
|
||||
for recorder in self.recorders.values():
|
||||
recorder.initialize(model)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager."""
|
||||
for recorder in self.recorders.values():
|
||||
recorder.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit the context manager."""
|
||||
for recorder in self.recorders.values():
|
||||
recorder.__exit__(exc_type, exc_value, traceback)
|
@ -1,5 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base import BaseAlgorithm
|
||||
from .distill import (ConfigurableDistill, FpnTeacherDistill,
|
||||
SingleTeacherDistill)
|
||||
from .nas import SPOS
|
||||
|
||||
__all__ = ['BaseAlgorithm', 'SPOS']
|
||||
__all__ = [
|
||||
'SingleTeacherDistill', 'ConfigurableDistill', 'BaseAlgorithm',
|
||||
'FpnTeacherDistill', 'SPOS'
|
||||
]
|
||||
|
5
mmrazor/models/algorithms/distill/__init__.py
Normal file
5
mmrazor/models/algorithms/distill/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .configurable import (ConfigurableDistill, FpnTeacherDistill,
|
||||
SingleTeacherDistill)
|
||||
|
||||
__all__ = ['SingleTeacherDistill', 'ConfigurableDistill', 'FpnTeacherDistill']
|
@ -0,0 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .configurable_distill import ConfigurableDistill
|
||||
from .fpn_teacher_distill import FpnTeacherDistill
|
||||
from .single_teacher_distill import SingleTeacherDistill
|
||||
|
||||
__all__ = [
|
||||
'SingleTeacherDistill', 'FpnTeacherDistill', 'ConfigurableDistill',
|
||||
'ConfigurableDistill'
|
||||
]
|
@ -0,0 +1,248 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from inspect import signature
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from mmengine.model import BaseModel
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.core import DistillDeliveryManager, RecorderManager
|
||||
from mmrazor.registry import MODELS
|
||||
from ...base import BaseAlgorithm, LossResults
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ConfigurableDistill(BaseAlgorithm):
|
||||
"""``ConfigurableDistill`` is a powerful tool that can reproduce most
|
||||
distillation algorithms without modifying the code of teacher or student
|
||||
models.
|
||||
|
||||
``ConfigurableDistill`` can get various intermediate results of the model
|
||||
in a hacky way by ``Recorder``. More details see user-docs for ``Recorder``
|
||||
|
||||
``ConfigurableDistill`` can use the teacher's intermediate results to
|
||||
override the student's intermediate results in a hacky way by ``Delivery``.
|
||||
More details see user-docs for ``Delivery``.
|
||||
|
||||
Args:
|
||||
architecture (dict | :obj:`BaseModel`): The config of
|
||||
:class:`BaseModel` or built model.
|
||||
student_recorders (dict, optional): Config for multiple recorders. A
|
||||
student model may have more than one recorder. These recorders
|
||||
only record the student model's intermediate results. Defaults to
|
||||
None.
|
||||
teacher_recorders (dict, optional): Config for multiple recorders. A
|
||||
teacher model may have more than one recorder. These recorders
|
||||
only record the teacher model's intermediate results. Defaults to
|
||||
None.
|
||||
distill_deliveries (dict, optional): Config for multiple deliveries. A
|
||||
distill algorithm may have more than one delivery. Defaults to
|
||||
None.
|
||||
distill_losses: (Dict[str, Dict], optional): Config for multiple
|
||||
distill losses. A distill algorithm may have more than one distill
|
||||
loss. Defaults to None.
|
||||
loss_forward_mappings: (Dict[str, Dict], optional): Mapping between
|
||||
distill loss forward arguments and records.
|
||||
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
|
||||
pre-processing data sampled by dataloader to the format accepted by
|
||||
:meth:`forward`.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
|
||||
Note:
|
||||
If a distill loss needs to backward, the name of the loss must contain
|
||||
"loss". If it is only used as a statistical value, the name can not
|
||||
contain "loss". More details see docs for
|
||||
:func:`mmengine.model.BaseModel._parse_loss`.
|
||||
|
||||
Note:
|
||||
The keys of ``loss_forward_mappings`` should be consistent with the
|
||||
keys of ``distill_losses``.
|
||||
|
||||
Each item in ``loss_forward_mappings`` is a mapping between a distill
|
||||
loss and its forward arguments. The keys of the mapping are the
|
||||
signature of the loss's forward, and the values of the mapping are the
|
||||
recorded data location.
|
||||
|
||||
``from_recorder``refers to the recorder where the data is stored, and
|
||||
if ``from_student`` is True, it means the recorder is in `
|
||||
`student_recorders``; otherwise, it means the recorder is in
|
||||
``teacher_recorders``.
|
||||
|
||||
Examples:
|
||||
>>> distill_losses = dict(
|
||||
... loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5))
|
||||
|
||||
>>> student_recorders = dict(
|
||||
... fc = dict(type='ModuleOutputs', sources=['head.fc']))
|
||||
|
||||
>>> teacher_recorders = dict(
|
||||
... fc = dict(type='ModuleOutputs', sources=['head.fc']))
|
||||
|
||||
>>> loss_forward_mappings = dict(
|
||||
... loss_kl=dict(
|
||||
... preds_S=dict(from_recorder='fc', from_student=True),
|
||||
... preds_T=dict(from_recorder='fc', from_student=False)))
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
architecture: Union[BaseModel, Dict],
|
||||
student_recorders: Optional[Dict[str, Dict]] = None,
|
||||
teacher_recorders: Optional[Dict[str, Dict]] = None,
|
||||
distill_deliveries: Optional[Dict[str, Dict]] = None,
|
||||
distill_losses: Optional[Dict[str, Dict]] = None,
|
||||
loss_forward_mappings: Optional[Dict[str, Dict]] = None,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super().__init__(architecture, data_preprocessor, init_cfg)
|
||||
|
||||
# The recorder manager is just constructed, but not really initialized
|
||||
# yet. Recorder manager initialization needs to input the corresponding
|
||||
# model.
|
||||
# Different subclasses may have different teacher models, and it is
|
||||
# inconvenient to initialize the recorder manager in
|
||||
# ``ConfigurableDistll``.
|
||||
# During the initialization of the subclass, need to execute
|
||||
# `self.student_recorder_manager.initialize(student)` and
|
||||
# `self.teacher_recorder_manager.initialize(teacher)` according to the
|
||||
# corresponding student and teacher.
|
||||
self.student_recorders = RecorderManager(student_recorders)
|
||||
self.teacher_recorders = RecorderManager(teacher_recorders)
|
||||
|
||||
self.distill_deliveries = DistillDeliveryManager(distill_deliveries)
|
||||
|
||||
self.distill_losses = self.build_distill_losses(distill_losses)
|
||||
|
||||
if loss_forward_mappings:
|
||||
# Check if loss_forward_mappings is in the correct format
|
||||
self._check_loss_forward_mappings(self.distill_losses,
|
||||
loss_forward_mappings,
|
||||
self.student_recorders,
|
||||
self.teacher_recorders)
|
||||
self.loss_forward_mappings = loss_forward_mappings
|
||||
else:
|
||||
self.loss_forward_mappings = dict()
|
||||
|
||||
@property
|
||||
def student(self) -> BaseModel:
|
||||
"""Alias for ``architecture``."""
|
||||
return self.architecture
|
||||
|
||||
def build_distill_losses(
|
||||
self,
|
||||
losses: Optional[Dict[str, Dict]] = None,
|
||||
) -> nn.ModuleDict:
|
||||
"""build distill losses according config."""
|
||||
|
||||
distill_losses = nn.ModuleDict()
|
||||
if losses:
|
||||
for loss_name, loss_cfg in losses.items():
|
||||
assert loss_name not in distill_losses
|
||||
if 'loss' not in loss_name:
|
||||
warnings.warn(
|
||||
f'Warning: If {loss_name} is a loss that needs to '
|
||||
f'backward, the name of {loss_name} must contain '
|
||||
f'"loss". If it is only used as a statistical value, '
|
||||
'then the name must not contain "loss". More details '
|
||||
'see docs for '
|
||||
':func:`mmengine.model.BaseModel._parse_loss`',
|
||||
UserWarning)
|
||||
item_loss = MODELS.build(loss_cfg)
|
||||
distill_losses[loss_name] = item_loss
|
||||
|
||||
return distill_losses
|
||||
|
||||
def get_record(self,
|
||||
recorder: str,
|
||||
from_student: bool,
|
||||
record_idx: int = 0,
|
||||
data_idx: Optional[int] = None) -> List:
|
||||
"""According to each item in ``record_infos``, get the corresponding
|
||||
record in ``recorder_manager``."""
|
||||
|
||||
if from_student:
|
||||
recorder_ = self.student_recorders.get_recorder(recorder)
|
||||
else:
|
||||
recorder_ = self.teacher_recorders.get_recorder(recorder)
|
||||
|
||||
return recorder_.get_record_data(record_idx, data_idx)
|
||||
|
||||
def compute_distill_losses(
|
||||
self,
|
||||
distill_losses: nn.ModuleDict,
|
||||
loss_forward_mappings: Dict[str, Dict],
|
||||
student_recorders: RecorderManager,
|
||||
teacher_recorders: RecorderManager,
|
||||
) -> LossResults:
|
||||
"""Compute distill losses automatically."""
|
||||
# Record all computed losses' results.
|
||||
losses = dict()
|
||||
for loss_name, forward_mappings in loss_forward_mappings.items():
|
||||
forward_kwargs = dict()
|
||||
for forward_key, record_info in forward_mappings.items():
|
||||
forward_var = self.get_record(**record_info)
|
||||
forward_kwargs[forward_key] = forward_var
|
||||
|
||||
loss_module = distill_losses[loss_name]
|
||||
loss = loss_module(**forward_kwargs) # type: ignore
|
||||
# add computed loss result.
|
||||
losses[loss_name] = loss
|
||||
|
||||
return losses
|
||||
|
||||
def _check_loss_forward_mappings(
|
||||
self, losses: nn.ModuleDict, loss_forward_mappings: Dict[str,
|
||||
Dict],
|
||||
student_recorders: RecorderManager,
|
||||
teacher_recorders: RecorderManager) -> None:
|
||||
"""Check if ``loss_forward_mappings`` is in the correct format."""
|
||||
|
||||
if not isinstance(loss_forward_mappings, dict):
|
||||
raise TypeError(
|
||||
'loss_forward_mappings should be a dict instance, but got'
|
||||
f'{type(loss_forward_mappings)}')
|
||||
|
||||
for loss_name, forward_mappings in loss_forward_mappings.items():
|
||||
assert loss_name in losses, \
|
||||
f'"{loss_name}" is not in distill losses. The keys of ' \
|
||||
'loss_forward_kwargs must match the keys of distill_losses.'
|
||||
|
||||
if not isinstance(forward_mappings, dict):
|
||||
raise TypeError(
|
||||
'Each item of loss_forward_mappings should be a dict '
|
||||
f'instance, but got {type(forward_mappings)}')
|
||||
|
||||
loss_module = losses[loss_name]
|
||||
loss_forward_keys = signature(
|
||||
loss_module.forward).parameters.keys()
|
||||
assert len(loss_forward_keys) == len(forward_mappings.keys())
|
||||
|
||||
for forward_key, record_info in forward_mappings.items():
|
||||
assert forward_key in loss_forward_keys, \
|
||||
f'{forward_key} is not in the signature of \
|
||||
{type(loss_module).__name__} forward, \
|
||||
please check your config.'
|
||||
|
||||
assert 'recorder' in record_info, \
|
||||
'Each item of loss_forward_mappings should have ' \
|
||||
'"recorder", pls check your config.'
|
||||
|
||||
assert 'from_student' in record_info, \
|
||||
'Each item of loss_forward_mappings should have ' \
|
||||
'"from_student", pls check your config.'
|
||||
|
||||
recorder: str = record_info['recorder']
|
||||
from_student: bool = record_info['from_student']
|
||||
|
||||
if not isinstance(from_student, bool):
|
||||
raise TypeError(f'from_student should be a bool instance, '
|
||||
f'but got {type(from_student)}')
|
||||
|
||||
if from_student:
|
||||
assert recorder in self.student_recorders.recorders, \
|
||||
f'For {forward_key}, "{recorder}" must be in \
|
||||
`student_recorders`.'
|
||||
|
||||
else:
|
||||
assert recorder in self.teacher_recorders.recorders, \
|
||||
f'For {forward_key}, "{recorder}" must be in \
|
||||
`teacher_recorders`.'
|
@ -0,0 +1,57 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from mmengine import BaseDataElement
|
||||
|
||||
from mmrazor.models.utils import add_prefix
|
||||
from mmrazor.registry import MODELS
|
||||
from ...base import LossResults
|
||||
from .single_teacher_distill import SingleTeacherDistill
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FpnTeacherDistill(SingleTeacherDistill):
|
||||
"""``FpnTeacherDistill`` means teacher only execute backbone and neck.
|
||||
|
||||
If the intermediate results required for distill algorithm are generated by
|
||||
the backbone and neck parts, using ``FpnTeacherDistill`` can speed up
|
||||
training.
|
||||
"""
|
||||
|
||||
def loss(
|
||||
self,
|
||||
batch_inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
) -> LossResults:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
|
||||
losses = dict()
|
||||
# If the `override_data` of a delivery is False, the delivery will
|
||||
# record the origin data.
|
||||
self.delivery_manager.override_data = False
|
||||
if self.teacher_trainable:
|
||||
# Unlike ``SingleTeacherDistill``, teacher will only execute
|
||||
# back + neck, not head, so there will be no loss.
|
||||
with self.teacher_recorders, self.delivery_manager:
|
||||
_ = self.teacher.extract_feat(batch_inputs)
|
||||
else:
|
||||
with self.teacher_recorders, self.distill_deliveries:
|
||||
with torch.no_grad():
|
||||
_ = self.teacher(batch_inputs, data_samples, mode='loss')
|
||||
|
||||
# If the `override_data` of a delivery is True, the delivery will
|
||||
# override the origin data with the recorded data.
|
||||
self.delivery_manager.override_data = True
|
||||
with self.student_recorders, self.delivery_manager:
|
||||
student_losses = self.student(
|
||||
batch_inputs, data_samples, mode='loss')
|
||||
losses.update(add_prefix(student_losses, 'student'))
|
||||
|
||||
# Automatically compute distill losses based on `loss_forward_mappings`
|
||||
distill_losses = self.compute_distill_losses(
|
||||
self.distill_losses, self.loss_forward_mappings,
|
||||
self.student_recorders, self.teacher_recorders)
|
||||
losses.update(add_prefix(distill_losses, 'distill'))
|
||||
|
||||
return losses
|
@ -0,0 +1,107 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmengine import BaseDataElement
|
||||
from mmengine.model import BaseModel
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmrazor.models.utils import add_prefix
|
||||
from mmrazor.registry import MODELS
|
||||
from ...base import LossResults
|
||||
from .configurable_distill import ConfigurableDistill
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SingleTeacherDistill(ConfigurableDistill):
|
||||
"""``SingleTeacherDistill`` can be used to develop distill algorithms which
|
||||
only use one teacher.
|
||||
|
||||
Args:
|
||||
teacher (dict | BaseModel): The config dict for teacher model or built
|
||||
teacher model.
|
||||
teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None.
|
||||
teacher_trainable (bool): Whether the teacher is trainable. Defaults
|
||||
to False.
|
||||
teacher_norm_eval (bool): Whether to set teacher's norm layers to eval
|
||||
mode, namely, freeze running stats (mean and var). Note: Effect on
|
||||
Batch Norm and its variants only. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
teacher: Union[BaseModel, Dict],
|
||||
teacher_ckpt: Optional[str] = None,
|
||||
teacher_trainable: bool = False,
|
||||
teacher_norm_eval: bool = True,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if isinstance(teacher, Dict):
|
||||
teacher = MODELS.build(teacher)
|
||||
|
||||
if not isinstance(teacher, BaseModel):
|
||||
raise TypeError('teacher should be a `dict` or '
|
||||
f'`BaseModel` instance, but got '
|
||||
f'{type(teacher)}')
|
||||
|
||||
self.teacher = teacher
|
||||
if teacher_ckpt:
|
||||
# avoid loaded parameters be overwritten
|
||||
self.teacher.init_weights()
|
||||
_ = load_checkpoint(self.teacher, teacher_ckpt)
|
||||
self.teacher_trainable = teacher_trainable
|
||||
self.teacher_norm_eval = teacher_norm_eval
|
||||
|
||||
# In ``ConfigurableDistll``, the recorder manager is just constructed,
|
||||
# but not really initialized yet.
|
||||
self.student_recorders.initialize(self.student)
|
||||
self.teacher_recorders.initialize(self.teacher)
|
||||
|
||||
def loss(
|
||||
self,
|
||||
batch_inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
) -> LossResults:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
|
||||
losses = dict()
|
||||
|
||||
# If the `override_data` of a delivery is False, the delivery will
|
||||
# record the origin data.
|
||||
self.distill_deliveries.override_data = False
|
||||
if self.teacher_trainable:
|
||||
with self.teacher_recorders, self.distill_deliveries:
|
||||
teacher_losses = self.teacher(
|
||||
batch_inputs, data_samples, mode='loss')
|
||||
|
||||
losses.update(add_prefix(teacher_losses, 'teacher'))
|
||||
else:
|
||||
with self.teacher_recorders, self.distill_deliveries:
|
||||
with torch.no_grad():
|
||||
|
||||
_ = self.teacher(batch_inputs, data_samples, mode='loss')
|
||||
|
||||
# If the `override_data` of a delivery is True, the delivery will
|
||||
# override the origin data with the recorded data.
|
||||
self.distill_deliveries.override_data = True
|
||||
with self.student_recorders, self.distill_deliveries:
|
||||
student_losses = self.student(
|
||||
batch_inputs, data_samples, mode='loss')
|
||||
losses.update(add_prefix(student_losses, 'student'))
|
||||
|
||||
# Automatically compute distill losses based on `loss_forward_mappings`
|
||||
distill_losses = self.compute_distill_losses(
|
||||
self.distill_losses, self.loss_forward_mappings,
|
||||
self.student_recorders, self.teacher_recorders)
|
||||
losses.update(add_prefix(distill_losses, 'distill'))
|
||||
|
||||
return losses
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Set distiller's forward mode."""
|
||||
super().train(mode)
|
||||
if mode and self.teacher_norm_eval:
|
||||
for m in self.teacher.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
@ -1,5 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .self_distiller import SelfDistiller
|
||||
from .single_teacher import SingleTeacherDistiller
|
||||
|
||||
__all__ = ['SelfDistiller', 'SingleTeacherDistiller']
|
@ -1,155 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from mmcv.runner import BaseModule
|
||||
from mmcv.utils import import_modules_from_strings
|
||||
|
||||
|
||||
def function_wrapper(ctx, method, method_str):
|
||||
"""Pass teacher's outputs to student."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
# record inputs
|
||||
ctx.method_args[method_str] = args
|
||||
ctx.method_kwargs[method_str] = kwargs
|
||||
# TODO cover more usecases, not only pass teacher's outputs to
|
||||
# student.
|
||||
if ctx.is_teacher:
|
||||
# execute the raw function
|
||||
outputs = method(*args, **kwargs)
|
||||
# record outputs
|
||||
ctx.method_return[method_str] = outputs
|
||||
else:
|
||||
# modify student's outputs to be same with teacher
|
||||
outputs = ctx.method_return[method_str]
|
||||
|
||||
return outputs
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class FunctionContext():
|
||||
"""Function context manager for rewrite function.
|
||||
|
||||
Args:
|
||||
ctx (ConversionContext): The distiller's overall context manager.
|
||||
method (str): The name of the function to rewrite.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx, method, import_module=None):
|
||||
self.ctx = ctx
|
||||
|
||||
self.import_module = import_modules_from_strings(import_module)
|
||||
self.method_str = method
|
||||
self.method_exec_str = f'self.import_module.{method}'
|
||||
|
||||
def _set_method(self, method):
|
||||
"""Modify a function."""
|
||||
exec(f'{self.method_exec_str} = method')
|
||||
|
||||
def __enter__(self):
|
||||
"""Rewrite the function."""
|
||||
self.method_impl = eval(self.method_exec_str)
|
||||
|
||||
if self.method_impl:
|
||||
self._set_method(
|
||||
function_wrapper(self.ctx, self.method_impl, self.method_str,
|
||||
self.align_mode))
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Restore the function."""
|
||||
if self.method_impl:
|
||||
self._set_method(self.method_impl)
|
||||
|
||||
|
||||
class ConversionContext():
|
||||
"""Context manager for record functions' inputs or outputs."""
|
||||
|
||||
def __init__(self, hooks):
|
||||
# save functions' inputs
|
||||
self.method_args = dict()
|
||||
self.method_kwargs = dict()
|
||||
# save functions' outputs
|
||||
self.method_return = dict()
|
||||
|
||||
# Each function will have a sub context manager, the function will be
|
||||
# rewritten when enter the sub context manager.
|
||||
self.hooks = []
|
||||
self.is_teacher = True
|
||||
for hook in hooks:
|
||||
self.hooks.append(FunctionContext(self, **hook))
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter every sub context managers."""
|
||||
for hook in self.hooks:
|
||||
hook.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit every sub context managers."""
|
||||
for hook in self.hooks:
|
||||
hook.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
class BaseDistiller(BaseModule, metaclass=ABCMeta):
|
||||
"""Base Distiller.
|
||||
|
||||
In the distillation algorithm, some intermediate results of the teacher
|
||||
need to be obtained and passed to the student.
|
||||
|
||||
For nn.Module's outputs, obtained by pytorch forward hook.
|
||||
For python function's outputs, obtained by a specific context manager.
|
||||
|
||||
Args:
|
||||
align_methods (dict): The details of the functions which outputs need
|
||||
to be obtained.
|
||||
"""
|
||||
|
||||
def __init__(self, align_methods=None, **kwargs):
|
||||
super(BaseDistiller, self).__init__(**kwargs)
|
||||
|
||||
if align_methods is None:
|
||||
self.context_manager = None
|
||||
else:
|
||||
# To obtain the python function's outputs, there will build a
|
||||
# specific context manager. When enter the context manager, the
|
||||
# functions will be rewrite. The context manager could record
|
||||
# inputs or outputs of the functions , and pass from teachr to
|
||||
# student. When exit the context manager, the rewritten functions
|
||||
# will restore.
|
||||
self.context_manager = ConversionContext(align_methods)
|
||||
|
||||
@abstractmethod
|
||||
def prepare_from_student(self, student):
|
||||
"""Register forward hooks to students and teachers."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def teacher_forward_output_hook(self, module, inputs, outputs):
|
||||
"""Save the teacher output."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def student_forward_output_hook(self, module, inputs, outputs):
|
||||
"""Save the student output."""
|
||||
pass
|
||||
|
||||
def reset_ctx_teacher_mode(self, mode=True):
|
||||
if self.context_manager is not None:
|
||||
self.context_manager.is_teacher = mode
|
||||
|
||||
@abstractmethod
|
||||
def exec_teacher_forward(self, data):
|
||||
"""Execute the teacher's forward function."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exec_student_forward(self, student, data):
|
||||
"""Execute the student's forward function."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_distill_loss(self, data):
|
||||
"""Compute distill loss according teacher's outputs and student's
|
||||
outputs."""
|
||||
pass
|
@ -1,141 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from .base import BaseDistiller
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SelfDistiller(BaseDistiller):
|
||||
"""Transfer knowledge inside a single model.
|
||||
|
||||
Args:
|
||||
components (dict): The details of the distillation. It usually includes
|
||||
the module names of the teacher and the student, and the losses
|
||||
used in the distillation.
|
||||
"""
|
||||
|
||||
def __init__(self, components, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.components = components
|
||||
self.losses = nn.ModuleDict()
|
||||
|
||||
self.student_outputs = dict()
|
||||
self.teacher_outputs = dict()
|
||||
|
||||
for component in self.components:
|
||||
student_module_name = component['student_module']
|
||||
teacher_module_name = component['teacher_module']
|
||||
self.student_outputs[student_module_name] = list()
|
||||
self.teacher_outputs[teacher_module_name] = list()
|
||||
|
||||
for loss in component.losses:
|
||||
loss_cfg = loss.copy()
|
||||
loss_name = loss_cfg.pop('name')
|
||||
self.losses[loss_name] = MODELS.build(loss_cfg)
|
||||
|
||||
def prepare_from_student(self, student):
|
||||
"""Registers a global forward hook for each teacher module and student
|
||||
module to be used in the distillation.
|
||||
|
||||
Args:
|
||||
student (:obj:`torch.nn.Module`): The student model to be used
|
||||
in the distillation.
|
||||
"""
|
||||
self.module2name = {}
|
||||
for name, module in student.model.named_modules():
|
||||
self.module2name[module] = name
|
||||
self.name_modules = dict(student.model.named_modules())
|
||||
|
||||
for component in self.components:
|
||||
student_module_name = component['student_module']
|
||||
teacher_module_name = component['teacher_module']
|
||||
|
||||
student_module = self.name_modules[student_module_name]
|
||||
teacher_module = self.name_modules[teacher_module_name]
|
||||
|
||||
student_module.register_forward_hook(
|
||||
self.student_forward_output_hook)
|
||||
teacher_module.register_forward_hook(
|
||||
self.teacher_forward_output_hook)
|
||||
|
||||
def teacher_forward_output_hook(self, module, inputs, outputs):
|
||||
"""Save the output.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.Module`): the module of register hook
|
||||
inputs (tuple): input of module
|
||||
outputs (tuple): out of module
|
||||
"""
|
||||
if self.training and getattr(self, 'is_teacher', None):
|
||||
self.teacher_outputs[self.module2name[module]].append(outputs)
|
||||
|
||||
def student_forward_output_hook(self, module, inputs, outputs):
|
||||
"""Save the output.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.Module`): the module of register hook
|
||||
inputs (tuple): input of module
|
||||
outputs (tuple): out of module
|
||||
"""
|
||||
if self.training and not getattr(self, 'is_teacher', None):
|
||||
self.student_outputs[self.module2name[module]].append(outputs)
|
||||
|
||||
def reset_outputs(self, outputs):
|
||||
"""Reset the teacher's outputs or student's outputs."""
|
||||
for key in outputs.keys():
|
||||
outputs[key] = list()
|
||||
|
||||
def exec_teacher_forward(self, teacher, data):
|
||||
"""Forward computation of the teacher.
|
||||
|
||||
Args:
|
||||
teacher (:obj:`torch.nn.Module`): The teacher model to be used
|
||||
in the distillation.
|
||||
data (dict): The output of dataloader.
|
||||
"""
|
||||
self.reset_outputs(self.teacher_outputs)
|
||||
self.is_teacher = True
|
||||
output = teacher(**data)
|
||||
self.is_teacher = False
|
||||
|
||||
return output
|
||||
|
||||
def exec_student_forward(self, student, data):
|
||||
"""Forward computation of the student.
|
||||
|
||||
Args:
|
||||
student (:obj:`torch.nn.Module`): The student model to be used
|
||||
in the distillation.
|
||||
data (dict): The output of dataloader.
|
||||
"""
|
||||
assert not self.is_teacher
|
||||
self.reset_outputs(self.student_outputs)
|
||||
output = student(**data)
|
||||
|
||||
return output
|
||||
|
||||
def compute_distill_loss(self, data):
|
||||
"""Compute the distillation loss."""
|
||||
|
||||
losses = dict()
|
||||
|
||||
for i, component in enumerate(self.components):
|
||||
student_module_name = component['student_module']
|
||||
student_outputs = self.student_outputs[student_module_name]
|
||||
|
||||
teacher_module_name = component['teacher_module']
|
||||
teacher_outputs = self.teacher_outputs[teacher_module_name]
|
||||
|
||||
for out_idx, (s_out, t_out) in enumerate(
|
||||
zip(student_outputs, teacher_outputs)):
|
||||
|
||||
for loss in component.losses:
|
||||
loss_module = self.losses[loss.name]
|
||||
loss_name = f'{loss.name}.{out_idx}'
|
||||
|
||||
loss_module.current_data = data
|
||||
losses[loss_name] = loss_module(s_out, t_out)
|
||||
loss_module.current_data = None
|
||||
|
||||
return losses
|
@ -1,243 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
from .base import BaseDistiller
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SingleTeacherDistiller(BaseDistiller):
|
||||
"""Distiller with single teacher.
|
||||
|
||||
Args:
|
||||
teacher (dict): The config dict for teacher.
|
||||
teacher_trainable (bool): Whether the teacher is trainable.
|
||||
Default: False.
|
||||
teacher_norm_eval (bool): Whether to set teacher's norm layers to eval
|
||||
mode, namely, freeze running stats (mean and var). Note: Effect on
|
||||
Batch Norm and its variants only. Default: True.
|
||||
components (dict): The details of the distillation. It usually includes
|
||||
the module names of the teacher and the student, and the losses
|
||||
used in the distillation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
teacher,
|
||||
teacher_trainable=False,
|
||||
teacher_norm_eval=True,
|
||||
components=tuple(),
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.teacher_trainable = teacher_trainable
|
||||
self.teacher_norm_eval = teacher_norm_eval
|
||||
self.teacher = self.build_teacher(teacher)
|
||||
|
||||
self.components = components
|
||||
self.losses = nn.ModuleDict()
|
||||
self.align_modules = nn.ModuleDict()
|
||||
|
||||
# Record the featuremaps that need to calculate the distillation loss.
|
||||
self.student_outputs = dict()
|
||||
self.teacher_outputs = dict()
|
||||
|
||||
for i, component in enumerate(self.components):
|
||||
student_module_name = component['student_module']
|
||||
teacher_module_name = component['teacher_module']
|
||||
# The type of every student_output is a list by default, because
|
||||
# some modules will execute multiple forward calculations, such as
|
||||
# the shareable head in Retinanet
|
||||
self.student_outputs[student_module_name] = list()
|
||||
self.teacher_outputs[teacher_module_name] = list()
|
||||
|
||||
# If the number of featuremap channels of student and teacher are
|
||||
# inconsistent, they need to be aligned by a 1x1 convolution
|
||||
align_module_cfg = getattr(component, 'align_module', None)
|
||||
if align_module_cfg is not None:
|
||||
align_module_name = f'component_{i}'
|
||||
align_module = self.build_align_module(align_module_cfg)
|
||||
self.align_modules[align_module_name] = align_module
|
||||
|
||||
# Multiple losses can be calculated at the same location
|
||||
for loss in component.losses:
|
||||
loss_cfg = loss.copy()
|
||||
loss_name = loss_cfg.pop('name')
|
||||
self.losses[loss_name] = MODELS.build(loss_cfg)
|
||||
|
||||
def build_teacher(self, cfg):
|
||||
"""Build a model from the `cfg`."""
|
||||
|
||||
teacher = MODELS.build(cfg)
|
||||
|
||||
return teacher
|
||||
|
||||
def build_align_module(self, cfg):
|
||||
"""Build ``align_module`` from the `cfg`.
|
||||
|
||||
``align_module`` is needed when the number of channels output by the
|
||||
teacher module is not equal to that of the student module, or for some
|
||||
other reasons.
|
||||
|
||||
Args:
|
||||
cfg (dict): The config dict for ``align_module``.
|
||||
"""
|
||||
|
||||
in_channels = cfg.student_channels
|
||||
out_channels = cfg.teacher_channels
|
||||
if cfg.type == 'conv2d':
|
||||
align_module = nn.Conv2d(in_channels, out_channels, 1)
|
||||
elif cfg.type == 'linear':
|
||||
align_module = nn.Linear(in_channels, out_channels)
|
||||
return align_module
|
||||
|
||||
def prepare_from_student(self, student):
|
||||
"""Registers a global forward hook for each teacher module and student
|
||||
module to be used in the distillation.
|
||||
|
||||
Args:
|
||||
student (:obj:`torch.nn.Module`): The student model to be used
|
||||
in the distillation.
|
||||
"""
|
||||
|
||||
# Record the mapping relationship between student's modules and module
|
||||
# names.
|
||||
self.student_module2name = {}
|
||||
for name, module in student.model.named_modules():
|
||||
self.student_module2name[module] = name
|
||||
self.student_name2module = dict(student.model.named_modules())
|
||||
|
||||
# Record the mapping relationship between teacher's modules and module
|
||||
# names.
|
||||
self.teacher_module2name = {}
|
||||
for name, module in self.teacher.named_modules():
|
||||
self.teacher_module2name[module] = name
|
||||
self.teacher_name2module = dict(self.teacher.named_modules())
|
||||
|
||||
# Register forward hooks for modules that need to participate in loss
|
||||
# calculation.
|
||||
for component in self.components:
|
||||
student_module_name = component['student_module']
|
||||
teacher_module_name = component['teacher_module']
|
||||
|
||||
student_module = self.student_name2module[student_module_name]
|
||||
teacher_module = self.teacher_name2module[teacher_module_name]
|
||||
|
||||
student_module.register_forward_hook(
|
||||
self.student_forward_output_hook)
|
||||
teacher_module.register_forward_hook(
|
||||
self.teacher_forward_output_hook)
|
||||
|
||||
def teacher_forward_output_hook(self, module, inputs, outputs):
|
||||
"""Save the module's forward output.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.Module`): The module to register hook.
|
||||
inputs (tuple): The input of the module.
|
||||
outputs (tuple): The output of the module.
|
||||
"""
|
||||
if self.training:
|
||||
self.teacher_outputs[self.teacher_module2name[module]].append(
|
||||
outputs)
|
||||
|
||||
def student_forward_output_hook(self, module, inputs, outputs):
|
||||
"""Save the module's forward output.
|
||||
|
||||
Args:
|
||||
module (:obj:`torch.nn.Module`): The module to register hook.
|
||||
inputs (tuple): The input of the module.
|
||||
outputs (tuple): The output of the module.
|
||||
"""
|
||||
if self.training:
|
||||
self.student_outputs[self.student_module2name[module]].append(
|
||||
outputs)
|
||||
|
||||
def reset_outputs(self, outputs):
|
||||
"""Reset the teacher's outputs or student's outputs."""
|
||||
for key in outputs.keys():
|
||||
outputs[key] = list()
|
||||
|
||||
def exec_teacher_forward(self, data):
|
||||
"""Execute the teacher's forward function.
|
||||
|
||||
After this function, the teacher's featuremaps will be saved in
|
||||
``teacher_outputs``.
|
||||
"""
|
||||
|
||||
# Convert the context manager's mode to teacher.
|
||||
self.reset_ctx_teacher_mode(True)
|
||||
# Clear the saved data of the last forward。
|
||||
self.reset_outputs(self.teacher_outputs)
|
||||
|
||||
if self.teacher_trainable:
|
||||
output = self.teacher(**data)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
output = self.teacher(**data)
|
||||
|
||||
return output
|
||||
|
||||
def exec_student_forward(self, student, data):
|
||||
"""Execute the teacher's forward function.
|
||||
|
||||
After this function, the student's featuremaps will be saved in
|
||||
``student_outputs``.
|
||||
"""
|
||||
# Convert the context manager's mode to teacher.
|
||||
self.reset_ctx_teacher_mode(False)
|
||||
# Clear the saved data of the last forward。
|
||||
self.reset_outputs(self.student_outputs)
|
||||
|
||||
output = student(**data)
|
||||
return output
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Set distiller's forward mode."""
|
||||
super(SingleTeacherDistiller, self).train(mode)
|
||||
if mode and self.teacher_norm_eval:
|
||||
for m in self.teacher.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def get_teacher_outputs(self, teacher_module_name):
|
||||
"""Get the outputs according module name."""
|
||||
return self.teacher_outputs[teacher_module_name]
|
||||
|
||||
def compute_distill_loss(self, data=None):
|
||||
"""Compute the distillation loss."""
|
||||
|
||||
losses = dict()
|
||||
|
||||
for i, component in enumerate(self.components):
|
||||
# Get the student's outputs.
|
||||
student_module_name = component['student_module']
|
||||
student_outputs = self.student_outputs[student_module_name]
|
||||
|
||||
# Align student output's channels with teacher.
|
||||
align_module_name = f'component_{i}'
|
||||
if align_module_name in self.align_modules:
|
||||
align_module = self.align_modules[align_module_name]
|
||||
student_outputs = [
|
||||
align_module(s_out) for s_out in student_outputs
|
||||
]
|
||||
|
||||
# Get the teacher's outputs.
|
||||
teacher_module_name = component['teacher_module']
|
||||
teacher_outputs = self.get_teacher_outputs(teacher_module_name)
|
||||
|
||||
# One module maybe have N outputs, such as the shareable head in
|
||||
# RetinaNet.
|
||||
for out_idx, (s_out, t_out) in enumerate(
|
||||
zip(student_outputs, teacher_outputs)):
|
||||
|
||||
for loss in component.losses:
|
||||
loss_module = self.losses[loss.name]
|
||||
loss_name = f'{loss.name}.{out_idx}'
|
||||
# TODO ugly implementation.
|
||||
# Pass the gt_label to loss function.
|
||||
# Only used by WSLD.
|
||||
loss_module.current_data = data
|
||||
losses[loss_name] = loss_module(s_out, t_out)
|
||||
loss_module.current_data = None
|
||||
|
||||
return losses
|
@ -26,9 +26,9 @@ class KLDivergence(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tau=1.0,
|
||||
reduction='batchmean',
|
||||
loss_weight=1.0,
|
||||
tau: float = 1.0,
|
||||
reduction: str = 'batchmean',
|
||||
loss_weight: float = 1.0,
|
||||
):
|
||||
super(KLDivergence, self).__init__()
|
||||
self.tau = tau
|
||||
|
@ -24,12 +24,20 @@ class WSLD(nn.Module):
|
||||
self.tau = tau
|
||||
self.loss_weight = loss_weight
|
||||
self.num_classes = num_classes
|
||||
self.softmax = nn.Softmax(dim=1).cuda()
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1).cuda()
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, student, teacher):
|
||||
def forward(self, student, teacher, data_samples):
|
||||
|
||||
gt_labels = self.current_data['gt_label']
|
||||
# Unpack data samples and pack targets
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
# Batch augmentation may convert labels to one-hot format scores.
|
||||
gt_labels = torch.stack([i.gt_label.score for i in data_samples])
|
||||
one_hot_labels = gt_labels.float()
|
||||
else:
|
||||
gt_labels = torch.hstack([i.gt_label.label for i in data_samples])
|
||||
one_hot_labels = F.one_hot(
|
||||
gt_labels, num_classes=self.num_classes).float()
|
||||
|
||||
student_logits = student / self.tau
|
||||
teacher_logits = teacher / self.tau
|
||||
@ -49,7 +57,7 @@ class WSLD(nn.Module):
|
||||
ce_loss_t = -torch.sum(one_hot_labels * log_softmax_t, 1, keepdim=True)
|
||||
|
||||
focal_weight = ce_loss_s / (ce_loss_t + 1e-7)
|
||||
ratio_lower = torch.zeros(1).cuda()
|
||||
ratio_lower = torch.zeros_like(focal_weight)
|
||||
focal_weight = torch.max(focal_weight, ratio_lower)
|
||||
focal_weight = 1 - torch.exp(-focal_weight)
|
||||
ce_loss = focal_weight * ce_loss
|
||||
|
@ -38,10 +38,9 @@ def build_razor_model_from_cfg(
|
||||
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
|
||||
|
||||
# TODO relay on mmengine:HAOCHENYE/config_new_feature
|
||||
# if cfg.get('cfg_path', None) and not cfg.get('type', None):
|
||||
# from mmengine.config import get_model
|
||||
# teacher = get_model(**cfg)
|
||||
# return teacher
|
||||
if cfg.get('cfg_path', None) and not cfg.get('type', None):
|
||||
from mmengine.config import get_model
|
||||
return get_model(**cfg)
|
||||
|
||||
if cfg.get('_fix_subnet_', None):
|
||||
fix_subnet = cfg.pop('_fix_subnet_')
|
||||
|
@ -3,7 +3,7 @@ from typing import Dict, List, Sequence, Union
|
||||
|
||||
import torch
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.runner import ValLoop
|
||||
from mmengine.runner import ValLoop, autocast
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.registry import LOOPS
|
||||
@ -19,16 +19,29 @@ class SingleTeacherDistillValLoop(ValLoop):
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
fp16 (bool): Whether to enable fp16 validation. Defaults to
|
||||
False.
|
||||
"""
|
||||
|
||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List]) -> None:
|
||||
super().__init__(runner, dataloader, evaluator)
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
fp16: bool = False) -> None:
|
||||
super().__init__(runner, dataloader, evaluator, fp16)
|
||||
if self.runner.distributed:
|
||||
self.model = runner.model.module
|
||||
assert hasattr(self.runner.model.module, 'teacher')
|
||||
# TODO: remove hard code after mmcls add data_preprocessor
|
||||
data_preprocessor = self.runner.model.module.data_preprocessor
|
||||
self.teacher = self.runner.model.module.teacher
|
||||
self.teacher.data_preprocessor = data_preprocessor
|
||||
|
||||
else:
|
||||
self.model = runner.model
|
||||
assert hasattr(self.model, 'teacher')
|
||||
assert hasattr(self.runner.model, 'teacher')
|
||||
# TODO: remove hard code after mmcls add data_preprocessor
|
||||
data_preprocessor = self.runner.model.data_preprocessor
|
||||
self.teacher = self.runner.model.teacher
|
||||
self.teacher.data_preprocessor = data_preprocessor
|
||||
|
||||
def run(self):
|
||||
"""Launch validation."""
|
||||
@ -38,19 +51,32 @@ class SingleTeacherDistillValLoop(ValLoop):
|
||||
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter(idx, data_batch)
|
||||
# compute metrics
|
||||
metrics_s = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
for key, value in metrics_s.items():
|
||||
self.runner.message_hub.update_scalar(f'val_student/{key}', value)
|
||||
# compute student metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
student_metrics = dict()
|
||||
for key, value in metrics.items():
|
||||
student_key = 'student.' + key
|
||||
teacher_key = 'teacher.' + key
|
||||
|
||||
student_metrics[student_key] = value
|
||||
self.runner.message_hub.log_scalars.pop(f'val/{teacher_key}', None)
|
||||
|
||||
self.runner.call_hook('after_val_epoch', metrics=student_metrics)
|
||||
|
||||
self.runner.call_hook('before_val_epoch')
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter_teacher(idx, data_batch)
|
||||
# compute metrics
|
||||
metrics_t = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
for key, value in metrics_t.items():
|
||||
self.runner.message_hub.update_scalar(f'val_teacher/{key}', value)
|
||||
# compute teacher metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
teacher_metrics = dict()
|
||||
for key, value in metrics.items():
|
||||
student_key = 'student.' + key
|
||||
teacher_key = 'teacher.' + key
|
||||
|
||||
self.runner.call_hook('after_val_epoch', metrics=None)
|
||||
teacher_metrics[teacher_key] = value
|
||||
self.runner.message_hub.log_scalars.pop(f'val/{student_key}', None)
|
||||
|
||||
self.runner.call_hook('after_val_epoch', metrics=teacher_metrics)
|
||||
self.runner.call_hook('after_val')
|
||||
|
||||
@torch.no_grad()
|
||||
@ -63,8 +89,11 @@ class SingleTeacherDistillValLoop(ValLoop):
|
||||
"""
|
||||
self.runner.call_hook(
|
||||
'before_val_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# outputs should be sequence of BaseDataElement
|
||||
outputs = self.model.teacher(data_batch)
|
||||
|
||||
with autocast(enabled=self.fp16):
|
||||
# outputs should be sequence of BaseDataElement
|
||||
outputs = self.teacher.val_step(data_batch)
|
||||
|
||||
self.evaluator.process(data_batch, outputs)
|
||||
self.runner.call_hook(
|
||||
'after_val_iter',
|
||||
|
@ -62,9 +62,7 @@ def register_all_modules(init_default_scope: bool = True) -> None:
|
||||
""" # noqa
|
||||
import mmrazor.core # noqa: F401,F403
|
||||
import mmrazor.models # noqa: F401,F403
|
||||
|
||||
# TODO add mmrazor.runners
|
||||
|
||||
import mmrazor.runners # noqa: F401,F403
|
||||
if init_default_scope:
|
||||
never_created = DefaultScope.get_current_instance() is None \
|
||||
or not DefaultScope.check_instance_created('mmrazor')
|
||||
|
@ -3,24 +3,38 @@ from unittest import TestCase
|
||||
|
||||
from mmcv import ConfigDict
|
||||
|
||||
from mmrazor.core import DistillDeliverManager
|
||||
from mmrazor.core import DistillDeliveryManager
|
||||
|
||||
|
||||
class TestDeliverManager(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
distill_deliveries = ConfigDict(
|
||||
delivery1=dict(
|
||||
type='MethodOutputs',
|
||||
max_keep_data=2,
|
||||
method_path='toy_module.ToyClass.random_int'))
|
||||
|
||||
manager = DistillDeliveryManager(distill_deliveries)
|
||||
self.assertEquals(len(manager.deliveries), 1)
|
||||
|
||||
manager = DistillDeliveryManager()
|
||||
self.assertEquals(len(manager.deliveries), 0)
|
||||
|
||||
def test_context_manager(self):
|
||||
from toy_module import ToyClass
|
||||
|
||||
distill_deliveries = [
|
||||
ConfigDict(
|
||||
distill_deliveries = ConfigDict(
|
||||
delivery1=dict(
|
||||
type='MethodOutputs',
|
||||
max_keep_data=2,
|
||||
method_path='toy_module.ToyClass.random_int')
|
||||
]
|
||||
method_path='toy_module.ToyClass.random_int'))
|
||||
|
||||
manager = DistillDeliverManager(distill_deliveries)
|
||||
manager = DistillDeliveryManager(distill_deliveries)
|
||||
|
||||
manager.override_data = False
|
||||
self.assertFalse(manager.override_data)
|
||||
with manager:
|
||||
toy_class = ToyClass()
|
||||
output1_tea = toy_class.random_int()
|
||||
@ -30,7 +44,9 @@ class TestDeliverManager(TestCase):
|
||||
with manager:
|
||||
_ = toy_class.random_int()
|
||||
|
||||
self.assertFalse(manager.override_data)
|
||||
manager.override_data = True
|
||||
self.assertTrue(manager.override_data)
|
||||
with manager:
|
||||
output1_stu = toy_class.random_int()
|
||||
output2_stu = toy_class.random_int()
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
from mmrazor.core import FunctionOutputsDeliver
|
||||
from mmrazor.core import FunctionOutputsDelivery
|
||||
|
||||
|
||||
class TestFuncOutputsDeliver(TestCase):
|
||||
@ -9,26 +9,26 @@ class TestFuncOutputsDeliver(TestCase):
|
||||
def test_init(self):
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'func_path should be'):
|
||||
_ = FunctionOutputsDeliver(max_keep_data=1, func_path=1)
|
||||
_ = FunctionOutputsDelivery(max_keep_data=1, func_path=1)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'func_path must have at '):
|
||||
_ = FunctionOutputsDeliver(max_keep_data=1, func_path='toy_func')
|
||||
_ = FunctionOutputsDelivery(max_keep_data=1, func_path='toy_func')
|
||||
|
||||
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
|
||||
_ = FunctionOutputsDeliver(max_keep_data=1, func_path='aaa.bb')
|
||||
_ = FunctionOutputsDelivery(max_keep_data=1, func_path='aaa.bb')
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'bb is not in toy_mod'):
|
||||
_ = FunctionOutputsDeliver(
|
||||
_ = FunctionOutputsDelivery(
|
||||
max_keep_data=1, func_path='toy_module.bb')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'):
|
||||
_ = FunctionOutputsDeliver(
|
||||
_ = FunctionOutputsDelivery(
|
||||
max_keep_data=1, func_path='toy_module.TOY_VAR')
|
||||
|
||||
def test_context_manager(self):
|
||||
import toy_module
|
||||
|
||||
delivery = FunctionOutputsDeliver(
|
||||
delivery = FunctionOutputsDelivery(
|
||||
max_keep_data=2, func_path='toy_module.toy_func')
|
||||
|
||||
delivery.override_data = False
|
||||
|
@ -1,45 +1,45 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
from mmrazor.core import MethodOutputsDeliver
|
||||
from mmrazor.core import MethodOutputsDelivery
|
||||
|
||||
|
||||
class TestMethodOutputsDeliver(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaisesRegex(TypeError, 'method_path should be'):
|
||||
_ = MethodOutputsDeliver(max_keep_data=1, method_path=1)
|
||||
_ = MethodOutputsDelivery(max_keep_data=1, method_path=1)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'method_path must have at '):
|
||||
_ = MethodOutputsDeliver(max_keep_data=1, method_path='toy_func')
|
||||
_ = MethodOutputsDelivery(max_keep_data=1, method_path='toy_func')
|
||||
|
||||
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
|
||||
_ = MethodOutputsDeliver(max_keep_data=1, method_path='aaa.bb.b')
|
||||
_ = MethodOutputsDelivery(max_keep_data=1, method_path='aaa.bb.b')
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'bb is not in toy_module'):
|
||||
_ = MethodOutputsDeliver(
|
||||
_ = MethodOutputsDelivery(
|
||||
max_keep_data=1, method_path='toy_module.bb.bbb')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'toy_func should be a type'):
|
||||
_ = MethodOutputsDeliver(
|
||||
_ = MethodOutputsDelivery(
|
||||
max_keep_data=1, method_path='toy_module.toy_func.bbb')
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'bbb is not in'):
|
||||
_ = MethodOutputsDeliver(
|
||||
_ = MethodOutputsDelivery(
|
||||
max_keep_data=1, method_path='toy_module.ToyClass.bbb')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'count should be'):
|
||||
_ = MethodOutputsDeliver(
|
||||
_ = MethodOutputsDelivery(
|
||||
max_keep_data=1, method_path='toy_module.ToyClass.count')
|
||||
|
||||
def test_context_manager(self):
|
||||
from toy_module import ToyClass
|
||||
|
||||
delivery = MethodOutputsDeliver(
|
||||
delivery = MethodOutputsDelivery(
|
||||
max_keep_data=2, method_path='toy_module.ToyClass.random_int')
|
||||
|
||||
# Without ``MethodOutputsDeliver``, outputs of the teacher and the
|
||||
# Without ``MethodOutputsDelivery``, outputs of the teacher and the
|
||||
# student are very likely to be different.
|
||||
# from toy_module import ToyClass
|
||||
# toy_class = ToyClass()
|
||||
|
50
tests/test_core/test_recorders/test_base_recorder.py
Normal file
50
tests/test_core/test_recorders/test_base_recorder.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
from toy_mod import Toy
|
||||
|
||||
from mmrazor.core.recorders import MethodOutputsRecorder
|
||||
|
||||
|
||||
class TestFuncOutputsRecorder(TestCase):
|
||||
|
||||
def test_get_record_data(self):
|
||||
|
||||
toy = Toy()
|
||||
|
||||
recorder = MethodOutputsRecorder('toy_mod.Toy.toy_func')
|
||||
recorder.initialize()
|
||||
|
||||
with recorder:
|
||||
res0 = toy.toy_func()
|
||||
res1 = toy.toy_func()
|
||||
|
||||
self.assertEquals(res0, recorder.get_record_data(record_idx=0))
|
||||
self.assertEquals(res1, recorder.get_record_data(record_idx=1))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
'record_idx is illegal. The length of data_buffer is 2, '
|
||||
'but record_idx is 2'):
|
||||
_ = recorder.get_record_data(record_idx=2)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
'When data_idx is not None, record should be a list or '
|
||||
'tuple instance'):
|
||||
_ = recorder.get_record_data(data_idx=0)
|
||||
|
||||
recorder = MethodOutputsRecorder('toy_mod.Toy.toy_list_func')
|
||||
recorder.initialize()
|
||||
|
||||
with recorder:
|
||||
res = toy.toy_list_func()
|
||||
|
||||
self.assertEqual(len(res), 3)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
'data_idx is illegal. The length of record is 3'):
|
||||
_ = recorder.get_record_data(data_idx=3)
|
||||
|
||||
self.assertEquals(res[2], recorder.get_record_data(data_idx=2))
|
42
tests/test_core/test_recorders/test_func_outputs_recorder.py
Normal file
42
tests/test_core/test_recorders/test_func_outputs_recorder.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
from mmrazor.core.recorders import FunctionOutputsRecorder
|
||||
|
||||
|
||||
class TestFuncOutputsRecorder(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
_ = FunctionOutputsRecorder('toy_mod.toy_func')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'source should be'):
|
||||
_ = FunctionOutputsRecorder([1])
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'source must have at '):
|
||||
_ = FunctionOutputsRecorder('aaaaa')
|
||||
|
||||
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
|
||||
_ = FunctionOutputsRecorder('aaa.bbb')
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'aaa is not in toy_mod'):
|
||||
_ = FunctionOutputsRecorder('toy_mod.aaa')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'):
|
||||
_ = FunctionOutputsRecorder('toy_mod.TOY_VAR')
|
||||
|
||||
def test_context_manager(self):
|
||||
from toy_mod import execute_toy_func
|
||||
|
||||
recorder = FunctionOutputsRecorder('toy_mod.toy_func')
|
||||
recorder.initialize()
|
||||
|
||||
with recorder:
|
||||
execute_toy_func(1)
|
||||
|
||||
data = recorder.get_record_data()
|
||||
self.assertTrue(data == 1)
|
||||
|
||||
execute_toy_func(1)
|
||||
data = recorder.get_record_data()
|
||||
self.assertTrue(data == 1)
|
@ -0,0 +1,55 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
from mmrazor.core.recorders import MethodOutputsRecorder
|
||||
|
||||
|
||||
class TestFuncOutputsRecorder(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
_ = MethodOutputsRecorder('toy_mod.ToyClass.toy')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'source should be'):
|
||||
_ = MethodOutputsRecorder([1])
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'source must have at '):
|
||||
_ = MethodOutputsRecorder('aaaaa')
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'source must have at '):
|
||||
_ = MethodOutputsRecorder('aaa.bbb')
|
||||
|
||||
with self.assertRaisesRegex(ImportError, 'aaa is not imported'):
|
||||
_ = MethodOutputsRecorder('aaa.bbb.ccc')
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'aaa is not in toy_mod'):
|
||||
_ = MethodOutputsRecorder('toy_mod.aaa.bbb')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'toy_func should be'):
|
||||
_ = MethodOutputsRecorder('toy_mod.toy_func.bbb')
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, 'bbb is not in ToyClass'):
|
||||
_ = MethodOutputsRecorder('toy_mod.ToyClass.bbb')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'TOY_CLS should be'):
|
||||
_ = MethodOutputsRecorder('toy_mod.ToyClass.TOY_CLS')
|
||||
|
||||
def test_context_manager(self):
|
||||
from toy_mod import ToyClass
|
||||
|
||||
toy = ToyClass()
|
||||
|
||||
recorder = MethodOutputsRecorder('toy_mod.ToyClass.toy')
|
||||
recorder.initialize()
|
||||
|
||||
with recorder:
|
||||
result = toy.toy()
|
||||
|
||||
data = recorder.get_record_data()
|
||||
self.assertTrue(data == result)
|
||||
|
||||
result_ = toy.toy()
|
||||
|
||||
data = recorder.get_record_data()
|
||||
self.assertTrue(data == result)
|
||||
self.assertFalse(result_ == result)
|
74
tests/test_core/test_recorders/test_module_recorders.py
Normal file
74
tests/test_core/test_recorders/test_module_recorders.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.core.recorders import ModuleInputsRecorder, ModuleOutputsRecorder
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 1, 1)
|
||||
self.conv2 = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.conv1(x))
|
||||
|
||||
|
||||
class TestModuleOutputsRecorder(TestCase):
|
||||
|
||||
def test_prepare_from_model(self):
|
||||
|
||||
recorder = ModuleOutputsRecorder('conv1')
|
||||
with self.assertRaisesRegex(AssertionError, 'model can not be'):
|
||||
recorder.prepare_from_model()
|
||||
|
||||
recorder = ModuleOutputsRecorder('conv3')
|
||||
model = ToyModel()
|
||||
with self.assertRaisesRegex(AssertionError, '"conv3" is not in'):
|
||||
recorder.prepare_from_model(model)
|
||||
|
||||
recorder = ModuleOutputsRecorder('conv2')
|
||||
model = ToyModel()
|
||||
recorder.prepare_from_model(model)
|
||||
|
||||
def test_module_outputs(self):
|
||||
|
||||
recorder = ModuleOutputsRecorder('conv2')
|
||||
model = ToyModel()
|
||||
recorder.initialize(model)
|
||||
|
||||
with recorder:
|
||||
self.assertTrue(recorder.recording)
|
||||
res = model(torch.randn(1, 1, 1, 1))
|
||||
|
||||
self.assertEquals(res, recorder.get_record_data())
|
||||
|
||||
with recorder:
|
||||
self.assertTrue(len(recorder.data_buffer) == 0)
|
||||
|
||||
_ = model(torch.randn(1, 1, 1, 1))
|
||||
self.assertTrue(len(recorder.data_buffer) == 0)
|
||||
|
||||
def test_module_intputs(self):
|
||||
|
||||
recorder = ModuleInputsRecorder('conv1')
|
||||
model = ToyModel()
|
||||
recorder.initialize(model)
|
||||
|
||||
tensor = torch.randn(1, 1, 1, 1)
|
||||
with recorder:
|
||||
self.assertTrue(recorder.recording)
|
||||
_ = model(tensor)
|
||||
|
||||
conv1_input = recorder.get_record_data(data_idx=0)
|
||||
self.assertEquals(conv1_input.sum(), tensor.sum())
|
||||
|
||||
with recorder:
|
||||
self.assertTrue(len(recorder.data_buffer) == 0)
|
||||
|
||||
_ = model(torch.randn(1, 1, 1, 1))
|
||||
self.assertTrue(len(recorder.data_buffer) == 0)
|
42
tests/test_core/test_recorders/test_param_recorder.py
Normal file
42
tests/test_core/test_recorders/test_param_recorder.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.core.recorders import ParameterRecorder
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.toy_conv = nn.Conv2d(1, 1, 1)
|
||||
self.no_record_conv = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.toy_conv(x)
|
||||
|
||||
|
||||
class TestParameterRecorder(TestCase):
|
||||
|
||||
def test_prepare_from_model(self):
|
||||
|
||||
model = ToyModel()
|
||||
recorder = ParameterRecorder('AAA')
|
||||
with self.assertRaisesRegex(AssertionError, '"AAA" is not in the'):
|
||||
recorder.initialize(model)
|
||||
|
||||
recorder = ParameterRecorder('toy_conv.bias')
|
||||
with self.assertRaisesRegex(AssertionError, 'model can not be None'):
|
||||
recorder.prepare_from_model()
|
||||
|
||||
recorder.initialize(model)
|
||||
bias_weight = recorder.get_record_data()
|
||||
|
||||
self.assertEquals(bias_weight, model.toy_conv.bias)
|
||||
|
||||
with recorder:
|
||||
_ = model(torch.randn(1, 1, 1, 1))
|
||||
|
||||
self.assertEquals(bias_weight, model.toy_conv.bias)
|
58
tests/test_core/test_recorders/test_recorder_manager.py
Normal file
58
tests/test_core/test_recorders/test_recorder_manager.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv import ConfigDict
|
||||
from torch import nn
|
||||
from toy_mod import Toy
|
||||
|
||||
from mmrazor.core import RecorderManager
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 1, 1)
|
||||
self.conv2 = nn.Conv2d(1, 1, 1)
|
||||
self.toy = Toy()
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.conv1(x)) + self.toy.toy_func()
|
||||
|
||||
|
||||
class TestRecorderManager(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
manager = RecorderManager()
|
||||
self.assertEquals(len(manager.recorders), 0)
|
||||
|
||||
recorders = ConfigDict(
|
||||
r1=dict(type='ModuleOutputs', source='conv1'),
|
||||
r2=dict(type='MethodOutputs', source='toy_mod.Toy.toy_func'),
|
||||
)
|
||||
manager = RecorderManager(recorders)
|
||||
model = ToyModel()
|
||||
manager.initialize(model)
|
||||
|
||||
def test_context_manager(self):
|
||||
|
||||
recorders = ConfigDict(
|
||||
r1=dict(type='ModuleOutputs', source='conv2'),
|
||||
r2=dict(type='MethodOutputs', source='toy_mod.Toy.toy_func'),
|
||||
)
|
||||
manager = RecorderManager(recorders)
|
||||
model = ToyModel()
|
||||
manager.initialize(model)
|
||||
|
||||
self.assertEquals(manager.get_recorder('r1'), manager.recorders['r1'])
|
||||
self.assertEquals(manager.get_recorder('r2'), manager.recorders['r2'])
|
||||
|
||||
with manager:
|
||||
res = model(torch.ones(1, 1, 1, 1))
|
||||
|
||||
method_outputs = manager.recorders['r2'].get_record_data()
|
||||
conv2_outputs = manager.recorders['r1'].get_record_data()
|
||||
|
||||
self.assertEquals(res.sum(), method_outputs + conv2_outputs.sum())
|
45
tests/test_core/test_recorders/toy_mod.py
Normal file
45
tests/test_core/test_recorders/toy_mod.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
|
||||
TOY_VAR = 'aaa'
|
||||
|
||||
|
||||
def toy_func(a):
|
||||
return a
|
||||
|
||||
|
||||
def toy_list_func(a):
|
||||
return [a, a, a]
|
||||
|
||||
|
||||
def execute_toy_func(a):
|
||||
toy_func(a)
|
||||
|
||||
|
||||
def execute_toy_list_func(a):
|
||||
toy_list_func(a)
|
||||
|
||||
|
||||
class ToyClass:
|
||||
|
||||
TOY_CLS = 'TOY_CLASS'
|
||||
|
||||
def __init__(self):
|
||||
self._count = 0
|
||||
|
||||
def toy(self):
|
||||
self._count += 1
|
||||
return self._count
|
||||
|
||||
def __call__(self):
|
||||
self._count += 1
|
||||
return self._count
|
||||
|
||||
|
||||
class Toy():
|
||||
|
||||
def toy_func(self):
|
||||
return random.randint(0, 1000)
|
||||
|
||||
def toy_list_func(self):
|
||||
return [random.randint(0, 1000) for _ in range(3)]
|
@ -0,0 +1,71 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest import TestCase
|
||||
|
||||
from mmcv import ConfigDict
|
||||
from toy_models import ToyStudent
|
||||
|
||||
from mmrazor.models import ConfigurableDistill
|
||||
|
||||
|
||||
class TestConfigurableDistill(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
|
||||
student = ToyStudent()
|
||||
|
||||
alg_kwargs = ConfigDict(
|
||||
architecture=student,
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=recorders_cfg,
|
||||
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_toy=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='conv'),
|
||||
)),
|
||||
)
|
||||
|
||||
alg = ConfigurableDistill(**alg_kwargs)
|
||||
self.assertEquals(alg.student, alg.architecture)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['distill_losses'] = None
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'"loss_toy" is not in distill'):
|
||||
_ = ConfigurableDistill(**alg_kwargs_)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['distill_losses'] = dict(toy=dict(type='ToyDistillLoss'))
|
||||
alg_kwargs_['loss_forward_mappings'] = dict(
|
||||
toy=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='conv')))
|
||||
with self.assertWarnsRegex(UserWarning, 'Warning: If toy is a'):
|
||||
_ = ConfigurableDistill(**alg_kwargs_)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['loss_forward_mappings'] = None
|
||||
_ = ConfigurableDistill(**alg_kwargs_)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['loss_forward_mappings'] = list('AAA')
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'loss_forward_mappings should be '):
|
||||
_ = ConfigurableDistill(**alg_kwargs_)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['loss_forward_mappings']['loss_toy'] = list()
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'Each item of loss_forward_mappings should be '):
|
||||
_ = ConfigurableDistill(**alg_kwargs_)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_.loss_forward_mappings.loss_toy.arg1.from_student = ''
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'from_student should be a bool'):
|
||||
_ = ConfigurableDistill(**alg_kwargs_)
|
@ -0,0 +1,77 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv import ConfigDict
|
||||
from toy_models import ToyStudent
|
||||
|
||||
from mmrazor.models import SingleTeacherDistill
|
||||
|
||||
|
||||
class TestSingleTeacherDistill(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
|
||||
recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
|
||||
alg_kwargs = ConfigDict(
|
||||
architecture=dict(type='ToyStudent'),
|
||||
teacher=dict(type='ToyTeacher'),
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=recorders_cfg,
|
||||
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_toy=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='conv'),
|
||||
)),
|
||||
)
|
||||
|
||||
alg = SingleTeacherDistill(**alg_kwargs)
|
||||
|
||||
teacher = ToyStudent()
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['teacher'] = teacher
|
||||
alg = SingleTeacherDistill(**alg_kwargs_)
|
||||
self.assertEquals(alg.teacher, teacher)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['teacher'] = 'teacher'
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'teacher should be a `dict` or'):
|
||||
_ = SingleTeacherDistill(**alg_kwargs_)
|
||||
|
||||
def test_loss(self):
|
||||
|
||||
recorders_cfg = ConfigDict(
|
||||
conv=dict(type='ModuleOutputs', source='conv'))
|
||||
|
||||
alg_kwargs = ConfigDict(
|
||||
architecture=dict(type='ToyStudent'),
|
||||
teacher=dict(type='ToyTeacher'),
|
||||
student_recorders=recorders_cfg,
|
||||
teacher_recorders=recorders_cfg,
|
||||
distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_toy=dict(
|
||||
arg1=dict(from_student=True, recorder='conv'),
|
||||
arg2=dict(from_student=False, recorder='conv'),
|
||||
)),
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 1, 1)
|
||||
|
||||
alg = SingleTeacherDistill(**alg_kwargs)
|
||||
losses = alg(img, mode='loss')
|
||||
self.assertIn('distill.loss_toy', losses)
|
||||
self.assertIn('student.loss', losses)
|
||||
|
||||
alg_kwargs_ = copy.deepcopy(alg_kwargs)
|
||||
alg_kwargs_['teacher_trainable'] = True
|
||||
alg = SingleTeacherDistill(**alg_kwargs_)
|
||||
losses = alg(img, mode='loss')
|
||||
self.assertIn('distill.loss_toy', losses)
|
||||
self.assertIn('student.loss', losses)
|
||||
self.assertIn('teacher.loss', losses)
|
42
tests/test_models/test_algorithms/toy_models.py
Normal file
42
tests/test_models/test_algorithms/toy_models.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmengine.model import BaseModel
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyStudent(BaseModel):
|
||||
|
||||
def __init__(self, data_preprocessor=None):
|
||||
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
|
||||
self.conv = nn.Conv2d(3, 1, 1)
|
||||
|
||||
def forward(self, batch_inputs, data_samples=None, mode='tensor'):
|
||||
if mode == 'loss':
|
||||
out = self.conv(batch_inputs)
|
||||
return dict(loss=out)
|
||||
elif mode == 'predict':
|
||||
out = self.conv(batch_inputs) + 1
|
||||
return out
|
||||
elif mode == 'tensor':
|
||||
out = self.conv(batch_inputs) + 2
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyTeacher(ToyStudent):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyDistillLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, arg1, arg2):
|
||||
return arg1 + arg2
|
@ -124,5 +124,4 @@ class TestSingleTeacherDistillValLoop(TestCase):
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.val()
|
||||
|
||||
self.assertIn('val_student/acc', runner.message_hub.log_scalars.keys())
|
||||
self.assertIn('val_teacher/acc', runner.message_hub.log_scalars.keys())
|
||||
self.assertIn('val/teacher.acc', runner.message_hub.log_scalars.keys())
|
||||
|
Loading…
x
Reference in New Issue
Block a user