diff --git a/configs/distill/mmcls/kl/kl_r34_r18_8xb32_in1k.py b/configs/distill/mmcls/kl/kl_r34_r18_8xb32_in1k.py new file mode 100644 index 00000000..18df2649 --- /dev/null +++ b/configs/distill/mmcls/kl/kl_r34_r18_8xb32_in1k.py @@ -0,0 +1,39 @@ +_base_ = [ + 'mmcls::_base_/datasets/imagenet_bs32.py', + 'mmcls::_base_/schedules/imagenet_bs256.py', + 'mmcls::_base_/default_runtime.py' +] + +model = dict( + _scope_='mmrazor', + type='SingleTeacherDistill', + data_preprocessor=dict( + type='ImgDataPreprocessor', + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + bgr_to_rgb=True), + architecture=dict( + cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), + teacher=dict( + cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True), + teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth', + student_recorders=dict(fc=dict(type='ModuleOutputs', source='head.fc')), + teacher_recorders=dict(fc=dict(type='ModuleOutputs', source='head.fc')), + distill_losses=dict( + loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5)), + loss_forward_mappings=dict( + loss_kl=dict( + preds_S=dict( + from_student=True, + recorder='fc', + ), + preds_T=dict( + from_student=False, + recorder='fc', + )))) + +find_unused_parameters = True + +val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') diff --git a/configs/distill/mmcls/wsld/wsld_r34_r18_8xb32_in1k.py b/configs/distill/mmcls/wsld/wsld_r34_r18_8xb32_in1k.py new file mode 100644 index 00000000..41e1943a --- /dev/null +++ b/configs/distill/mmcls/wsld/wsld_r34_r18_8xb32_in1k.py @@ -0,0 +1,36 @@ +_base_ = [ + 'mmcls::_base_/datasets/imagenet_bs32.py', + 'mmcls::_base_/schedules/imagenet_bs256.py', + 'mmcls::_base_/default_runtime.py' +] + +model = dict( + _scope_='mmrazor', + type='SingleTeacherDistill', + data_preprocessor=dict( + type='ImgDataPreprocessor', + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + bgr_to_rgb=True), + architecture=dict( + cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), + teacher=dict( + cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True), + teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth', + student_recorders=dict( + fc=dict(type='ModuleOutputs', source='head.fc'), + data_samples=dict(type='ModuleInputs', source='')), + teacher_recorders=dict(fc=dict(type='ModuleOutputs', source='head.fc')), + distill_losses=dict(loss_wsld=dict(type='WSLD', tau=2, loss_weight=2.5)), + loss_forward_mappings=dict( + loss_wsld=dict( + student=dict(recorder='fc', from_student=True), + teacher=dict(recorder='fc', from_student=False), + data_samples=dict( + recorder='data_samples', from_student=True, data_idx=1)))) + +find_unused_parameters = True + +val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') diff --git a/configs/distill/mmdet/cwd/cwd_fpn_frcnn_r101_frcnn_r50_1x_coco.py b/configs/distill/mmdet/cwd/cwd_fpn_frcnn_r101_frcnn_r50_1x_coco.py new file mode 100644 index 00000000..13470a70 --- /dev/null +++ b/configs/distill/mmdet/cwd/cwd_fpn_frcnn_r101_frcnn_r50_1x_coco.py @@ -0,0 +1,58 @@ +_base_ = [ + 'mmdet::_base_/datasets/coco_detection.py', + 'mmdet::_base_/schedules/schedule_1x.py', + 'mmdet::_base_/default_runtime.py' +] + +# default_scope = 'mmrazor' +teacher_ckpt = 'faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth' # noqa: E501 +model = dict( + _scope_='mmrazor', + type='FpnTeacherDistill', + architecture=dict( + cfg_path='mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py', + pretrained=False), + teacher=dict( + cfg_path='mmdet::faster_rcnn/faster_rcnn_r101_fpn_2x_coco.py', + pretrained=False), + teacher_ckpt=teacher_ckpt, + distill_losses=dict( + loss_cwd_fpn0=dict( + type='ChannelWiseDivergence', tau=1, loss_weight=10), + loss_cwd_fpn1=dict( + type='ChannelWiseDivergence', tau=1, loss_weight=10), + loss_cwd_fpn2=dict( + type='ChannelWiseDivergence', tau=1, loss_weight=10), + loss_cwd_fpn3=dict( + type='ChannelWiseDivergence', tau=1, loss_weight=10), + loss_cwd_fpn4=dict( + type='ChannelWiseDivergence', tau=1, loss_weight=10)), + student_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')), + teacher_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')), + loss_forward_mappings=dict( + loss_cwd_fpn0=dict( + preds_S=dict(from_student=True, recorder='fpn', data_idx=0), + preds_T=dict(from_student=False, recorder='fpn', data_idx=0), + ), + loss_cwd_fpn1=dict( + preds_S=dict(from_student=True, recorder='fpn', data_idx=1), + preds_T=dict(from_student=False, recorder='fpn', data_idx=1), + ), + loss_cwd_fpn2=dict( + preds_S=dict(from_student=True, recorder='fpn', data_idx=2), + preds_T=dict(from_student=False, recorder='fpn', data_idx=2), + ), + loss_cwd_fpn3=dict( + preds_S=dict(from_student=True, recorder='fpn', data_idx=3), + preds_T=dict(from_student=False, recorder='fpn', data_idx=3), + ), + loss_cwd_fpn4=dict( + preds_S=dict(from_student=True, recorder='fpn', data_idx=4), + preds_T=dict(from_student=False, recorder='fpn', data_idx=4), + ), + ), +) + +find_unused_parameters = True + +val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') diff --git a/configs/distill/mmdet/cwd/cwd_fpn_gfl_r101_gfl_r50_1x_coco.py b/configs/distill/mmdet/cwd/cwd_fpn_gfl_r101_gfl_r50_1x_coco.py new file mode 100644 index 00000000..da1e4436 --- /dev/null +++ b/configs/distill/mmdet/cwd/cwd_fpn_gfl_r101_gfl_r50_1x_coco.py @@ -0,0 +1,8 @@ +_base_ = ['./cwd_fpn_gfl_r101_gfl_r50_1x_coco.py'] + +model = dict( + architecture=dict( + cfg_path='mmdet::gfl/gfl_r50_fpn_1x_coco.py', pretrained=False), + teacher=dict( + cfg_path='mmdet::gfl/gfl_r101_fpn_mstrain_2x_coco.py', + pretrained=True)) diff --git a/configs/distill/mmdet/cwd/cwd_fpn_retina_r101_retina_r50_1x_coco.py b/configs/distill/mmdet/cwd/cwd_fpn_retina_r101_retina_r50_1x_coco.py new file mode 100644 index 00000000..6e5c1231 --- /dev/null +++ b/configs/distill/mmdet/cwd/cwd_fpn_retina_r101_retina_r50_1x_coco.py @@ -0,0 +1,9 @@ +_base_ = ['./cwd_fpn_frcnn_r101_frcnn_r50_1x_coco.py'] + +model = dict( + architecture=dict( + cfg_path='mmdet::retinanet/retinanet_r50_fpn_1x_coco.py', + pretrained=False), + teacher=dict( + cfg_path='mmdet::retinanet/retinanet_r101_fpn_2x_coco.py', + pretrained=True)) diff --git a/mmrazor/core/__init__.py b/mmrazor/core/__init__.py index ad88a0c2..89b8d67b 100644 --- a/mmrazor/core/__init__.py +++ b/mmrazor/core/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .delivers import * # noqa: F401,F403 +from .recorders import * # noqa: F401,F403 from .tracer import * # noqa: F401,F403 diff --git a/mmrazor/core/delivers/__init__.py b/mmrazor/core/delivers/__init__.py index 6519db3c..569ac3c5 100644 --- a/mmrazor/core/delivers/__init__.py +++ b/mmrazor/core/delivers/__init__.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .deliver_manager import DistillDeliverManager -from .function_outputs_deliver import FunctionOutputsDeliver -from .method_outputs_deliver import MethodOutputsDeliver +from .deliver_manager import DistillDeliveryManager +from .function_outputs_deliver import FunctionOutputsDelivery +from .method_outputs_deliver import MethodOutputsDelivery __all__ = [ - 'FunctionOutputsDeliver', 'MethodOutputsDeliver', 'DistillDeliverManager' + 'FunctionOutputsDelivery', 'MethodOutputsDelivery', + 'DistillDeliveryManager' ] diff --git a/mmrazor/core/delivers/deliver_manager.py b/mmrazor/core/delivers/deliver_manager.py index 60baeff5..af493380 100644 --- a/mmrazor/core/delivers/deliver_manager.py +++ b/mmrazor/core/delivers/deliver_manager.py @@ -1,22 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import List +from typing import Dict, Optional from mmrazor.registry import TASK_UTILS +from .distill_deliver import DistillDelivery + +SUPPORT_DELIVERIES = ['FunctionOutputs', 'MethodOutputs'] -class DistillDeliverManager: - """Various types delivers' manager. The ``DistillDeliverManager`` is also a - context manager, managing various types of delivers. When entering the - ``DistillDeliverManager``, all delivers managed by it will be started. +class DistillDeliveryManager: + """Various types deliveries' manager. The ``DistillDeliveryManager`` is + also a context manager, managing various types of deliveries. + + When entering the ``DistillDeliveryManager``, all deliveries managed by it + will be started. Notes: - DistillDeliver is a context manager used to override function (method) - outputs from the target model with function (method) outputs from the - source model. + DistillDelivery is a context manager used to override function(method) + outputs during teacher(student) forward. Args: - deliveries (list(dict)): Configs of all deliveries. + deliveries (dict): Configs of all deliveries. Examples: >>> from mmcls.models.utils import Augments @@ -27,8 +31,8 @@ class DistillDeliverManager: >>> imgs = torch.randn(2, 3, 32, 32) >>> label = torch.randint(0, 10, (2,)) - >>> # Without ``MethodOutputsDeliver``, outputs of the teacher and the - >>> # student are very likely to be different. + >>> # Without ``MethodOutputsDelivery``, outputs of the teacher and + >>> # the student are different. >>> imgs_tea, label_tea = augments(imgs, label) >>> imgs_stu, label_stu = augments(imgs, label) >>> torch.equal(label_tea, label_stu) @@ -36,10 +40,10 @@ class DistillDeliverManager: >>> torch.equal(imgs_tea, imgs_stu) False - >>> distill_deliveries = [ - ... ConfigDict(type='MethodOutputs', max_keep_data=1, - ... method_path='mmcls.models.utils.Augments.__call__')] - >>> manager = DistillDeliverManager(distill_deliveries) + >>> distill_deliveries = ConfigDict( + ... aug=dict(type='MethodOutputs', max_keep_data=1, + ... method_path='mmcls.models.utils.Augments.__call__')) + >>> manager = DistillDeliveryManager(distill_deliveries) >>> manager.override_data = False >>> with manager: @@ -55,44 +59,55 @@ class DistillDeliverManager: True """ - def __init__(self, deliveries: List) -> None: + def __init__(self, deliveries: Optional[Dict[str, Dict]] = None) -> None: - # As there may be several delivers belong to a same deliver type, - # we use a list to save delivers rather than a dict. - self.deliveries = list() - for cfg in deliveries: - deliver_cfg = copy.deepcopy(cfg) - deliver_type = cfg.type - deliver_type = deliver_type + 'Deliver' - deliver_cfg.type = deliver_type - self.deliveries.append(TASK_UTILS.build(deliver_cfg)) + self._deliveries: Dict[str, DistillDelivery] = dict() + if deliveries: + for delivery_name, delivery_cfg in deliveries.items(): + delivery_cfg_ = copy.deepcopy(delivery_cfg) + delivery_type_ = delivery_cfg_.get('type', '') + assert isinstance(delivery_type_, str) + assert delivery_type_ in SUPPORT_DELIVERIES + + delivery_type_ = delivery_type_ + 'Delivery' + delivery_cfg_.update(dict(type=delivery_type_)) + + delivery = TASK_UTILS.build(delivery_cfg_) + self.deliveries[delivery_name] = delivery self._override_data = False @property - def override_data(self): + def deliveries(self) -> Dict[str, DistillDelivery]: + """dict: all deliveries.""" + return self._deliveries + + @property + def override_data(self) -> bool: """bool: indicate whether to override the data with the recorded data. """ return self._override_data @override_data.setter - def override_data(self, override): - """Set the override_data property to all the delivers. + def override_data(self, override: bool) -> None: + """Set the override_data property to all the deliveries. - If the `override_data` of a deliver is False, the deliver will record - and keep the origin data. If the current_mode of a deliver is True, the - deliver will override the origin data with the recorded data. + If the `override_data` of a delivery is False, the delivery will + record the origin data. + + If the `override_data` of a delivery is True, the delivery will + override the origin data with the recorded data. """ self._override_data = override - for deliver in self.deliveries: - deliver.override_data = override + for delivery in self.deliveries.values(): + delivery.override_data = override def __enter__(self) -> None: """Enter the context manager.""" - for deliver in self.deliveries: - deliver.__enter__() + for delivery in self.deliveries.values(): + delivery.__enter__() def __exit__(self, exc_type, exc_value, traceback) -> None: """Exit the context manager.""" - for deliver in self.deliveries: - deliver.__exit__(exc_type, exc_value, traceback) + for delivery in self.deliveries.values(): + delivery.__exit__(exc_type, exc_value, traceback) diff --git a/mmrazor/core/delivers/distill_deliver.py b/mmrazor/core/delivers/distill_deliver.py index d6b6885e..dcd56f38 100644 --- a/mmrazor/core/delivers/distill_deliver.py +++ b/mmrazor/core/delivers/distill_deliver.py @@ -4,32 +4,33 @@ from queue import Queue from typing import Callable -class DistillDeliver(metaclass=ABCMeta): - """Base class for delivers for distillation. +# TODO: Support overriding part of the outputs of a function or method +class DistillDelivery(metaclass=ABCMeta): + """Base class for deliveries for distillation. - DistillDeliver is a context manager used to override function (method) - outputs from the target model with function (method) outputs from the - source model. - In MMRazor, there will be different types of delivers to deliver different - types of data. They can be used in combination with the + DistillDelivery is a context manager used to override function(method) + outputs during teacher(student) forward. + + A delivery can only handle one function or method. Some algorithms may use + multiple deliveries, which can be managed uniformly using ``DistillDeliverManager``. + Args: + max_keep_data (int): The length limitation of the queue, should be + larger than the execute times of the function or method. Defaults + to 1. + Notes: If a function (method) is executed more than once during the forward of the source model, all the outputs of this function (method) will be used to override function (method) outputs from the target model. - TODO: - Support overriding some of the outputs of a function (method) - - Args: - max_keep_data (int): The length limitation of the queue. If a function - (method) is executed more than once during the forward of the - target model, function (method) outputs from the source model are - pushed into the queue in order. Default to 1. + If a function or method is executed more than once during the forward + of the target model, its' outputs from the source model are pushed + into the queue in order. """ - def __init__(self, max_keep_data: int = 1): + def __init__(self, max_keep_data: int = 1) -> None: self._override_data = False self.data_queue: Queue = Queue(maxsize=max_keep_data) @@ -55,3 +56,11 @@ class DistillDeliver(metaclass=ABCMeta): def deliver_wrapper(self, origin: Callable) -> Callable: """Wrap the specific object to make the intermediate results of the model can be delivered.""" + + @abstractmethod + def __enter__(self) -> None: + """Enter the context manager.""" + + @abstractmethod + def __exit__(self, exc_type, exc_value, traceback) -> None: + """Exit the context manager.""" diff --git a/mmrazor/core/delivers/function_outputs_deliver.py b/mmrazor/core/delivers/function_outputs_deliver.py index 5089f0c5..bf5f81bb 100644 --- a/mmrazor/core/delivers/function_outputs_deliver.py +++ b/mmrazor/core/delivers/function_outputs_deliver.py @@ -6,11 +6,11 @@ from typing import Callable from mmcv.utils import import_modules_from_strings from mmrazor.registry import TASK_UTILS -from .distill_deliver import DistillDeliver +from .distill_deliver import DistillDelivery @TASK_UTILS.register_module() -class FunctionOutputsDeliver(DistillDeliver): +class FunctionOutputsDelivery(DistillDelivery): """Delivery for intermediate results which are ``FunctionType``'s outputs. Args: @@ -29,23 +29,28 @@ class FunctionOutputsDeliver(DistillDeliver): `mmdet.core.anchor.utils.anchor_inside_flags`. Examples: - >>> # Suppose there is a toy function named ``toy_func`` in test.py. - >>> # It return random integers from 0 to 999. - >>> import toy_module - >>> # Suppose we want to deliver outputs from the teacher to + >>> # Below code in toy_module.py + >>> import random + >>> def toy_func(): + >>> return random.randint(0, 1000) + + >>> # Below code in main.py + >>> # Teacher and student both will execute toy_func. + >>> # Now, we want to deliver outputs from the teacher to >>> # the student + >>> import toy_module >>> delivery = FunctionOutputsDeliver( ... max_keep_data=1, func_path='toy_module.toy_func') >>> delivery.override_data = False >>> with delivery: - ... output_tea = toy_module.toy_func() + ... output_teacher = toy_module.toy_func() >>> delivery.override_data = True >>> with delivery: - ... output_stu = toy_module.toy_func() + ... output_student = toy_module.toy_func() - >>> output_tea == output_stu + >>> output_teacher == output_student True >>> # If a function (method) is executed more than once during the diff --git a/mmrazor/core/delivers/method_outputs_deliver.py b/mmrazor/core/delivers/method_outputs_deliver.py index b71f6e8a..50c9c82c 100644 --- a/mmrazor/core/delivers/method_outputs_deliver.py +++ b/mmrazor/core/delivers/method_outputs_deliver.py @@ -6,11 +6,11 @@ from typing import Callable from mmcv.utils import import_modules_from_strings from mmrazor.registry import TASK_UTILS -from .distill_deliver import DistillDeliver +from .distill_deliver import DistillDelivery @TASK_UTILS.register_module() -class MethodOutputsDeliver(DistillDeliver): +class MethodOutputsDelivery(DistillDelivery): """Delivery for intermediate results which are ``MethodType``'s outputs. Note: diff --git a/mmrazor/core/recorders/__init__.py b/mmrazor/core/recorders/__init__.py new file mode 100644 index 00000000..6d1858f0 --- /dev/null +++ b/mmrazor/core/recorders/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .function_outputs_recorder import FunctionOutputsRecorder +from .method_outputs_recorder import MethodOutputsRecorder +from .module_inputs_recorder import ModuleInputsRecorder +from .module_outputs_recorder import ModuleOutputsRecorder +from .param_recorder import ParameterRecorder +from .recorder_manager import RecorderManager + +__all__ = [ + 'FunctionOutputsRecorder', 'MethodOutputsRecorder', + 'ModuleOutputsRecorder', 'ParameterRecorder', 'RecorderManager', + 'ModuleInputsRecorder' +] diff --git a/mmrazor/core/recorders/base_recorder.py b/mmrazor/core/recorders/base_recorder.py new file mode 100644 index 00000000..b34c6918 --- /dev/null +++ b/mmrazor/core/recorders/base_recorder.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Any, List, Optional + +from torch import nn + + +class BaseRecorder(metaclass=ABCMeta): + """Base class for recorders. + + Recorder is a context manager used to record various intermediate results + during the model forward. It can be used in distillation algorithm and + can also be used to obtain some specific data for visual analysis. + + In MMRazor, there will be different types of recorders to obtain different + types of intermediate results. They can be used in combination with the + ``RecorderManager``. + + Note: + The recorder will be lazily initialized in the ``RecorderManager`` by + default. If you want to use the recorder without the + ``RecorderManager``, you need to initialize it first. + """ + + def __init__(self, source: str) -> None: + + self._source = source + # Intermediate results are recorded in dictionary format according + # to the data source. + # One data source may generate multiple records, which need to be + # recorded through list. + self._data_buffer: List = list() + # Before using the recorder for the first time, it needs to be + # initialized. + self._initialized = False + + @property + def source(self) -> str: + """str: source of recorded data.""" + return self._source + + @property + def data_buffer(self) -> List: + """list: data buffer.""" + return self._data_buffer + + @abstractmethod + def prepare_from_model(self, model: Optional[nn.Module] = None) -> None: + """Make the intermediate results of the model can be record.""" + + def initialize(self, model: Optional[nn.Module] = None) -> None: + """Init the recorder. + + Args: + model (nn.Module): The model which need to record intermediate + results. + """ + self.prepare_from_model(model) + self._initialized = True + + def get_record_data(self, + record_idx: int = 0, + data_idx: Optional[int] = None) -> Any: + """Get data from ``data_buffer``. + + Args: + record_idx (int): The index of the record saved in + ``data_buffer``. If a source is executed N times during + forward, there will be N records in ``data_buffer``. + data_index (int, optional): The index of target data in + a record. A record may be a tuple or a list, if data_idx is + None, the whole list or tuple is returned. Defaults to None. + + Returns: + Any: The type of the return value is undefined, and different + source data may have different types. + """ + assert record_idx < len(self._data_buffer), \ + 'record_idx is illegal. The length of data_buffer is ' \ + f'{len(self._data_buffer)}, but record_idx is ' \ + f'{record_idx}.' + + record = self._data_buffer[record_idx] + + if data_idx is None: + target_data = record + else: + if isinstance(record, (list, tuple)): + assert data_idx < len(record), \ + 'data_idx is illegal. The length of record is ' \ + f'{len(record)}, but data_idx is {data_idx}.' + target_data = record[data_idx] + else: + raise TypeError('When data_idx is not None, record should be ' + 'a list or tuple instance, but got ' + f'{type(record)}.') + + return target_data + + def reset_data_buffer(self) -> None: + """Clear data in data_buffer.""" + + self._data_buffer = list() + + def __enter__(self): + """Enter the context manager.""" + + assert self._initialized, \ + 'The recorder will be initialized in the RecorderManager by '\ + 'default. If you want to use the recorder without the '\ + 'RecorderManager, you need to initialize it first.' + + self.reset_data_buffer() + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the context manager.""" diff --git a/mmrazor/core/recorders/function_outputs_recorder.py b/mmrazor/core/recorders/function_outputs_recorder.py new file mode 100644 index 00000000..1d23a512 --- /dev/null +++ b/mmrazor/core/recorders/function_outputs_recorder.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +from types import FunctionType, ModuleType +from typing import Callable, List, Optional + +from mmcv.utils import import_modules_from_strings +from torch import nn + +from mmrazor.registry import TASK_UTILS +from .base_recorder import BaseRecorder + + +@TASK_UTILS.register_module() +class FunctionOutputsRecorder(BaseRecorder): + """Recorder for intermediate results which are ``FunctionType``'s outputs. + + Notes: + The form of `source` needs special attention. For example, + `anchor_inside_flags` is a function in mmdetection to check whether the + anchors are inside the border. This function is in + `mmdet/core/anchor/utils.py` and used in + `mmdet/models/dense_heads/anchor_head.py`. Then the source should be + `mmdet.models.dense_heads.anchor_head.anchor_inside_flags` but not + `mmdet.core.anchor.utils.anchor_inside_flags`. + + + Examples: + >>> # Below code in toy_module.py + >>> import random + >>> def toy_func(): + ... return random.randint(0, 1000) + >>> def toy_list_func(): + ... return [random.randint(0,1000) for _ in range(3)] + + >>> # Below code in main.py + >>> # Now, we want to get teacher's outputs by recorder. + + >>> import toy_module + >>> r1 = FunctionOutputsRecorder('toy_module.toy_func') + >>> r1.initialize() + >>> with r1: + ... output_teacher1 = toy_module.toy_func() + ... output_teacher2 = toy_module.toy_func() + ... output_teacher3 = toy_module.toy_func() + + >>> r1.data_buffer + [33, 41, 12] + >>> recorder.get_record_data(record_idx=2) + 12 + >>> output_teacher1==33 and output_teacher2==41 and output_teacher3==41 + True + + >>> r2 = FunctionOutputsRecorder('toy_module.toy_list_func') + >>> r2.initialize() + >>> with r2: + ... output_teacher1 = toy_module.toy_list_func() + ... output_teacher2 = toy_module.toy_list_func() + ... output_teacher3 = toy_module.toy_list_func() + + >>> r2.data_buffer + [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + >>> r2.get_record_data(record_idx=2, data_idx=2) + 9 + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self._check_valid_source(self.source) + + # import the function corrosponding module + try: + mod = import_modules_from_strings(self.module_string) + except ImportError: + raise ImportError( + f'{self.module_string} is not imported correctly.') + + self.imported_module: ModuleType = mod + + assert hasattr(mod, self.func_name), \ + f'{self.func_name} is not in {self.module_string}.' + + origin_func = getattr(mod, self.func_name) + if not isinstance(origin_func, FunctionType): + raise TypeError(f'{self.func_name} should be a FunctionType ' + f'instance, but got {type(origin_func)}') + + self.origin_func: Callable = origin_func + + @staticmethod + def _check_valid_source(source): + """Check if the source's format is valid.""" + if not isinstance(source, str): + raise TypeError(f'source should be a str ' + f'instance, but got {type(source)}') + + assert len(source.split('.')) > 1, \ + 'source must have at least one `.`' + + @property + def func_name(self): + """Get the function name according to `func_path`.""" + return self.source.split('.')[-1] + + @property + def module_string(self): + """Get the module name according to `func_path`.""" + return '.'.join(self.source.split('.')[:-1]) + + def prepare_from_model(self, model: Optional[nn.Module] = None) -> None: + """The `model` is useless in `FunctionOutputsRecorder`.""" + pass + + def func_record_wrapper(self, origin_func: Callable, + data_buffer: List) -> Callable: + """Save the function's outputs. + + Args: + origin_func (FunctionType): The method whose outputs need to be + recorded. + buffer_key (str): The key of the function's outputs saved in + ``data_buffer``. + """ + + @functools.wraps(origin_func) + def wrap_func(*args, **kwargs): + outputs = origin_func(*args, **kwargs) + # assume a func execute N times, there will be N outputs need to + # save. + data_buffer.append(outputs) + return outputs + + return wrap_func + + def __enter__(self): + """Enter the context manager.""" + super().__enter__() + + mod = self.imported_module + origin_func = self.origin_func + # add record wrapper to origin function. + record_func = self.func_record_wrapper(origin_func, self.data_buffer) + + assert hasattr(mod, self.func_name), \ + f'{self.func_name} is not in {self.module_string}.' + + # rewrite the origin function + setattr(mod, self.func_name, record_func) + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the context manager.""" + super().__exit__(exc_type, exc_value, traceback) + + mod = self.imported_module + origin_func = self.origin_func + + assert hasattr(mod, self.func_name), \ + f'{self.func_name} is not in {self.module_string}.' + + # restore the origin function + setattr(mod, self.func_name, origin_func) diff --git a/mmrazor/core/recorders/method_outputs_recorder.py b/mmrazor/core/recorders/method_outputs_recorder.py new file mode 100644 index 00000000..ca8ba182 --- /dev/null +++ b/mmrazor/core/recorders/method_outputs_recorder.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +from types import FunctionType, ModuleType +from typing import Callable, List, Optional + +from mmcv.utils import import_modules_from_strings +from torch import nn + +from mmrazor.registry import TASK_UTILS +from .base_recorder import BaseRecorder + + +@TASK_UTILS.register_module() +class MethodOutputsRecorder(BaseRecorder): + """Recorder for intermediate results which are ``MethodType``'s outputs. + + Note: + Different from ``FunctionType``, ``MethodType`` is the type of methods + of class instances. + + Examples: + >>> # Below code in toy_module.py + >>> import random + >>> class Toy(): + ... def toy_func(self): + ... return random.randint(0, 1000) + ... def toy_list_func(self): + ... return [random.randint(0, 1000) for _ in range(3)] + + >>> # Below code in main.py + >>> # Now, we want to get teacher's outputs by recorder. + + >>> from toy_module import Toy + >>> toy = Toy() + >>> r1 = MethodOutputsRecorder('toy_module.Toy.toy_func') + >>> r1.initialize() + >>> with r1: + ... output_teacher1 = toy.toy_func() + ... output_teacher2 = toy.toy_func() + ... output_teacher3 = toy.toy_func() + + >>> r1.data_buffer + [33, 41, 12] + >>> r1.get_record_data(record_idx=2) + 12 + >>> output_teacher1==33 and output_teacher2==41 and output_teacher3==41 + True + + >>> r2 = MethodOutputsRecorder('toy_module.Toy.toy_list_func' + >>> r2.initialize() + >>> with r2: + ... output_teacher1 = toy.toy_list_func() + ... output_teacher2 = toy.toy_list_func() + ... output_teacher3 = toy.toy_list_func() + + >>> r2.data_buffer + [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + >>> r2.get_record_data(record_idx=2, data_idx=2) + 9 + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._check_valid_source(self.source) + + # import the function corrosponding module + try: + mod: ModuleType = import_modules_from_strings(self.module_string) + except ImportError: + raise ImportError( + f'{self.module_string} is not imported correctly.') + + assert hasattr(mod, self.cls_name), \ + f'{self.cls_name} is not in {self.module_string}.' + + imported_cls: type = getattr(mod, self.cls_name) + if not isinstance(imported_cls, type): + raise TypeError(f'{self.cls_name} should be a type ' + f'instance, but got {type(imported_cls)}') + self.imported_class = imported_cls + + assert hasattr(imported_cls, self.method_name), \ + f'{self.method_name} is not in {self.cls_name}.' + + origin_method = getattr(imported_cls, self.method_name) + if not isinstance(origin_method, FunctionType): + raise TypeError(f'{self.method_name} should be a FunctionType ' + f'instance, but got {type(origin_method)}') + self.origin_method = origin_method + + @staticmethod + def _check_valid_source(source: str) -> None: + """Check if the `source` is valid.""" + if not isinstance(source, str): + raise TypeError(f'source should be a str ' + f'instance, but got {type(source)}') + + assert len(source.split('.')) > 2, \ + 'source must have at least two `.`' + + @property + def method_name(self): + """Get the method name according to `method_path`.""" + return self.source.split('.')[-1] + + @property + def cls_name(self): + """Get the class name corresponding to this method according to + `method_path`.""" + return self.source.split('.')[-2] + + @property + def module_string(self): + """Get the module name according to `method_path`.""" + return '.'.join(self.source.split('.')[:-2]) + + def prepare_from_model(self, model: Optional[nn.Module] = None) -> None: + """Wrapper the origin source methods. + + The ``model`` is useless in this recorder, just to be consistent with + other recorders. + """ + pass + + def method_record_wrapper(self, orgin_method: Callable, + data_buffer: List) -> Callable: + """Save the method's outputs. + + Args: + origin_method (MethodType): The method whose outputs need to be + recorded. + buffer_key (str): The key of the method's outputs saved in + ``data_buffer``. + """ + + @functools.wraps(orgin_method) + def wrap_method(*args, **kwargs): + outputs = orgin_method(*args, **kwargs) + # assume a func execute N times, there will be N outputs need to + # save. + data_buffer.append(outputs) + return outputs + + return wrap_method + + def __enter__(self): + """Enter the context manager.""" + super().__enter__() + + imported_cls = self.imported_class + origin_method = self.origin_method + # add record wrapper to origin method. + record_method = self.method_record_wrapper(origin_method, + self.data_buffer) + + # rewrite the origin method. + setattr(imported_cls, self.method_name, record_method) + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the context manager.""" + super().__exit__(exc_type, exc_value, traceback) + + imported_cls = self.imported_class + origin_method = self.origin_method + + # restore the origin method + setattr(imported_cls, self.method_name, origin_method) diff --git a/mmrazor/core/recorders/module_inputs_recorder.py b/mmrazor/core/recorders/module_inputs_recorder.py new file mode 100644 index 00000000..53ea9094 --- /dev/null +++ b/mmrazor/core/recorders/module_inputs_recorder.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Tuple + +from torch import nn + +from mmrazor.registry import TASK_UTILS +from .module_outputs_recorder import ModuleOutputsRecorder + + +@TASK_UTILS.register_module() +class ModuleInputsRecorder(ModuleOutputsRecorder): + """Recorder for intermediate results which are Pytorch moudle's inputs.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward_hook(self, module: nn.Module, inputs: Tuple, + outputs: Any) -> None: + """Save the module's forward input. + + Args: + module (:obj:`torch.nn.Module`): The module to register hook. + inputs (tuple): The input of the module. + outputs : The output of the module. + """ + if self.recording: + self.data_buffer.append(inputs) diff --git a/mmrazor/core/recorders/module_outputs_recorder.py b/mmrazor/core/recorders/module_outputs_recorder.py new file mode 100644 index 00000000..277d7393 --- /dev/null +++ b/mmrazor/core/recorders/module_outputs_recorder.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Optional, Tuple + +from torch import nn + +from mmrazor.registry import TASK_UTILS +from .base_recorder import BaseRecorder + + +@TASK_UTILS.register_module() +class ModuleOutputsRecorder(BaseRecorder): + """Recorder for intermediate results which are Pytorch moudle's outputs. + + Examples: + >>> from torch import nn + >>> class ToyModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = nn.Conv2d(8,8,1) + ... self.conv2 = nn.Conv2d(1,1,1) + ... def forward(self, x): + ... x1 = self.conv1(x) + ... x2 = self.conv1(x+1) + ... return self.conv2(x1 + x2) + + >>> model = ToyModel() + >>> [ name for name,_ in model.named_modules() ] + ['conv1', 'conv2'] + + >>> r1 = ModuleOutputsRecorder('conv1') + >>> r1.initialize(model) + + >>> with r1: + >>> res = model(torch.randn(1,1,1,1)) + + >>> r1.data_buffer + [tensor([[[[0.6734]]]]), tensor([[[[1.2514]]]]) ] + >>> r1.get_record_data(record_idx=1) + tensor([[[[1.2514]]]]) + + >>> r2 = ModuleOutputsRecorder('conv2') + >>> r2.initialize(model) + + >>> with r2: + >>> res = model(torch.randn(1,1,1,1)) + + >>> r2.data_buffer + [tensor([[[[0.9534]]]])] + >>> r2.get_record_data() + tensor([[[[0.9534]]]]) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._recording = False + + @property + def recording(self) -> bool: + """bool: whether to record data in forward hook.""" + return self._recording + + def prepare_from_model(self, model: Optional[nn.Module] = None) -> None: + """Register Pytorch forward hook to corresponding module.""" + + assert model is not None, 'model can not be None.' + + founded = False + for name, module in model.named_modules(): + if name == self.source: + module.register_forward_hook(self.forward_hook) + founded = True + break + + assert founded, f'"{self.source}" is not in the model.' + + def forward_hook(self, module: nn.Module, inputs: Tuple, + outputs: Any) -> None: + """Save the module's forward output. + + Args: + module (:obj:`torch.nn.Module`): The module to register hook. + inputs (tuple): The input of the module. + outputs : The output of the module. + """ + if self._recording: + self.data_buffer.append(outputs) + + def __enter__(self): + """Enter the context manager.""" + super().__enter__() + self._recording = True + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the context manager.""" + super().__exit__(exc_type, exc_value, traceback) + self._recording = False diff --git a/mmrazor/core/recorders/param_recorder.py b/mmrazor/core/recorders/param_recorder.py new file mode 100644 index 00000000..afd0d1c0 --- /dev/null +++ b/mmrazor/core/recorders/param_recorder.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +from torch import nn + +from mmrazor.registry import TASK_UTILS +from .base_recorder import BaseRecorder + + +@TASK_UTILS.register_module() +class ParameterRecorder(BaseRecorder): + """Recorder for Pytorch model's parameters. + + Examples: + >>> from torch import nn + >>> class ToyModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.toy_conv = nn.Conv2d(1,1,1) + ... def forward(self, x): + ... return self.toy_conv(x) + + >>> model = ToyModel() + >>> [ name for name,_ in model.named_parameters() ] + ['toy_conv.weight', 'toy_conv.bias'] + + >>> recorder = ParameterRecorder('toy_conv.weight') + >>> recorder.initialize(model) + + >>> recorder.data_buffer + [Parameter containing: tensor([[[[0.3244]]]], requires_grad=True)] + >>> recorder.get_record_data() + Parameter containing: tensor([[[[0.3244]]]], requires_grad=True) + """ + + def prepare_from_model(self, model: Optional[nn.Module] = None) -> None: + """Record the Pytorch model's parameters.""" + assert model is not None, \ + 'model can not be None when use ParameterRecorder.' + + founded = False + for param_name, param in model.named_parameters(): + if param_name == self.source: + self.data_buffer.append(param) + founded = True + break + + assert founded, f'"{self.source}" is not in the model.' + + def reset_data_buffer(self): + """Clear data in data_buffer. + + Note: + The data_buffer stores the address of the parameter in memory and + does not need to be reset. + """ + pass diff --git a/mmrazor/core/recorders/recorder_manager.py b/mmrazor/core/recorders/recorder_manager.py new file mode 100644 index 00000000..c1da4904 --- /dev/null +++ b/mmrazor/core/recorders/recorder_manager.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, Optional + +from mmcv import ConfigDict +from torch import nn + +from mmrazor.registry import TASK_UTILS +from .base_recorder import BaseRecorder + + +class RecorderManager: + """Various types recorders' manager. The ``RecorderManager`` also is a + context manager, managing various types of Recorder. When entering the + ``RecorderManager``, all recorders managed by it will be started. + + Note: + The recorders will be initialized in the ``RecorderManager`` by + default. If you want to just use a recorder without the + ``RecorderManager``, you need to initialize it first. + + Args: + recorders (dict, optional): All recorders' config. + + + Examples: + >>> # Below code in toy_module.py + >>> import random + >>> class Toy(): + ... def toy_func(self): + ... return random.randint(0, 1000) + + >>> # Below code in main.py + >>> from torch import nn + >>> from toy_module import Toy + + >>> class ToyModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = nn.Conv2d(1,1,1) + ... self.conv2 = nn.Conv2d(1,1,1) + ... self.toy = Toy() + ... def forward(self, x): + ... return self.conv2(self.conv1(x)) + self.toy.toy_func() + + >>> model = ToyModel() + >>> [ name for name,_ in model.named_modules() ] + ['conv1', 'conv2'] + + >>> conv1_rec = ModuleOutputsRecorder('conv1') + >>> conv2_rec = ModuleOutputsRecorder('conv2') + >>> func_rec = MethodOutputsRecorder('toy_module.Toy.toy_func') + >>> manager = RecorderManager( + ... {'conv1_rec': conv1_rec , + ... 'conv2_rec': conv2_rec, + ... 'func_rec': func_rec}) + >>> manager.initialize(model) + + >>> with manager: + ... res = model(torch.ones(1,1,1,1)) + >>> res + tensor([[[[22.9534]]]]) + + >>> conv2_data = manager.get_recorder('conv2_rec').get_record_data() + >>> conv2_data + tensor([[[[0.9534]]]]) + + >>> func_data = manager.get_recorder('func_rec').get_record_data() + >>> func_data + 22 + + >>> res.sum() == (conv2_data + func_data).sum() + True + """ + + def __init__(self, recorders: Optional[ConfigDict] = None) -> None: + + self._recorders: Dict[str, BaseRecorder] = dict() + if recorders: + for name, cfg in recorders.items(): + recorder_cfg = copy.deepcopy(cfg) + recorder_type = cfg.type + recorder_type_ = recorder_type + 'Recorder' + + recorder_cfg.type = recorder_type_ + recorder = TASK_UTILS.build(recorder_cfg) + + self._recorders[name] = recorder + + @property + def recorders(self) -> Dict[str, BaseRecorder]: + """dict: all recorders.""" + return self._recorders + + def get_recorder(self, recorder: str) -> BaseRecorder: + """Get the corresponding recorder according to the name.""" + return self.recorders[recorder] + + def initialize(self, model: nn.Module): + """Init all recorders. + + Args: + model (nn.Module): The model which need to record intermediate + results. + """ + for recorder in self.recorders.values(): + recorder.initialize(model) + + def __enter__(self): + """Enter the context manager.""" + for recorder in self.recorders.values(): + recorder.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the context manager.""" + for recorder in self.recorders.values(): + recorder.__exit__(exc_type, exc_value, traceback) diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 0cef0274..36efb0e2 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -1,5 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import BaseAlgorithm +from .distill import (ConfigurableDistill, FpnTeacherDistill, + SingleTeacherDistill) from .nas import SPOS -__all__ = ['BaseAlgorithm', 'SPOS'] +__all__ = [ + 'SingleTeacherDistill', 'ConfigurableDistill', 'BaseAlgorithm', + 'FpnTeacherDistill', 'SPOS' +] diff --git a/mmrazor/models/algorithms/distill/__init__.py b/mmrazor/models/algorithms/distill/__init__.py new file mode 100644 index 00000000..5cee59f2 --- /dev/null +++ b/mmrazor/models/algorithms/distill/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .configurable import (ConfigurableDistill, FpnTeacherDistill, + SingleTeacherDistill) + +__all__ = ['SingleTeacherDistill', 'ConfigurableDistill', 'FpnTeacherDistill'] diff --git a/mmrazor/models/algorithms/distill/configurable/__init__.py b/mmrazor/models/algorithms/distill/configurable/__init__.py new file mode 100644 index 00000000..9ab3b9ba --- /dev/null +++ b/mmrazor/models/algorithms/distill/configurable/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .configurable_distill import ConfigurableDistill +from .fpn_teacher_distill import FpnTeacherDistill +from .single_teacher_distill import SingleTeacherDistill + +__all__ = [ + 'SingleTeacherDistill', 'FpnTeacherDistill', 'ConfigurableDistill', + 'ConfigurableDistill' +] diff --git a/mmrazor/models/algorithms/distill/configurable/configurable_distill.py b/mmrazor/models/algorithms/distill/configurable/configurable_distill.py new file mode 100644 index 00000000..b7afb3f8 --- /dev/null +++ b/mmrazor/models/algorithms/distill/configurable/configurable_distill.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from inspect import signature +from typing import Dict, List, Optional, Union + +from mmengine.model import BaseModel +from torch import nn + +from mmrazor.core import DistillDeliveryManager, RecorderManager +from mmrazor.registry import MODELS +from ...base import BaseAlgorithm, LossResults + + +@MODELS.register_module() +class ConfigurableDistill(BaseAlgorithm): + """``ConfigurableDistill`` is a powerful tool that can reproduce most + distillation algorithms without modifying the code of teacher or student + models. + + ``ConfigurableDistill`` can get various intermediate results of the model + in a hacky way by ``Recorder``. More details see user-docs for ``Recorder`` + + ``ConfigurableDistill`` can use the teacher's intermediate results to + override the student's intermediate results in a hacky way by ``Delivery``. + More details see user-docs for ``Delivery``. + + Args: + architecture (dict | :obj:`BaseModel`): The config of + :class:`BaseModel` or built model. + student_recorders (dict, optional): Config for multiple recorders. A + student model may have more than one recorder. These recorders + only record the student model's intermediate results. Defaults to + None. + teacher_recorders (dict, optional): Config for multiple recorders. A + teacher model may have more than one recorder. These recorders + only record the teacher model's intermediate results. Defaults to + None. + distill_deliveries (dict, optional): Config for multiple deliveries. A + distill algorithm may have more than one delivery. Defaults to + None. + distill_losses: (Dict[str, Dict], optional): Config for multiple + distill losses. A distill algorithm may have more than one distill + loss. Defaults to None. + loss_forward_mappings: (Dict[str, Dict], optional): Mapping between + distill loss forward arguments and records. + data_preprocessor (:obj:`BaseDataPreprocessor`): Used for + pre-processing data sampled by dataloader to the format accepted by + :meth:`forward`. + init_cfg (dict, optional): Initialization config dict. + + Note: + If a distill loss needs to backward, the name of the loss must contain + "loss". If it is only used as a statistical value, the name can not + contain "loss". More details see docs for + :func:`mmengine.model.BaseModel._parse_loss`. + + Note: + The keys of ``loss_forward_mappings`` should be consistent with the + keys of ``distill_losses``. + + Each item in ``loss_forward_mappings`` is a mapping between a distill + loss and its forward arguments. The keys of the mapping are the + signature of the loss's forward, and the values of the mapping are the + recorded data location. + + ``from_recorder``refers to the recorder where the data is stored, and + if ``from_student`` is True, it means the recorder is in ` + `student_recorders``; otherwise, it means the recorder is in + ``teacher_recorders``. + + Examples: + >>> distill_losses = dict( + ... loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5)) + + >>> student_recorders = dict( + ... fc = dict(type='ModuleOutputs', sources=['head.fc'])) + + >>> teacher_recorders = dict( + ... fc = dict(type='ModuleOutputs', sources=['head.fc'])) + + >>> loss_forward_mappings = dict( + ... loss_kl=dict( + ... preds_S=dict(from_recorder='fc', from_student=True), + ... preds_T=dict(from_recorder='fc', from_student=False))) + """ + + def __init__(self, + architecture: Union[BaseModel, Dict], + student_recorders: Optional[Dict[str, Dict]] = None, + teacher_recorders: Optional[Dict[str, Dict]] = None, + distill_deliveries: Optional[Dict[str, Dict]] = None, + distill_losses: Optional[Dict[str, Dict]] = None, + loss_forward_mappings: Optional[Dict[str, Dict]] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): + super().__init__(architecture, data_preprocessor, init_cfg) + + # The recorder manager is just constructed, but not really initialized + # yet. Recorder manager initialization needs to input the corresponding + # model. + # Different subclasses may have different teacher models, and it is + # inconvenient to initialize the recorder manager in + # ``ConfigurableDistll``. + # During the initialization of the subclass, need to execute + # `self.student_recorder_manager.initialize(student)` and + # `self.teacher_recorder_manager.initialize(teacher)` according to the + # corresponding student and teacher. + self.student_recorders = RecorderManager(student_recorders) + self.teacher_recorders = RecorderManager(teacher_recorders) + + self.distill_deliveries = DistillDeliveryManager(distill_deliveries) + + self.distill_losses = self.build_distill_losses(distill_losses) + + if loss_forward_mappings: + # Check if loss_forward_mappings is in the correct format + self._check_loss_forward_mappings(self.distill_losses, + loss_forward_mappings, + self.student_recorders, + self.teacher_recorders) + self.loss_forward_mappings = loss_forward_mappings + else: + self.loss_forward_mappings = dict() + + @property + def student(self) -> BaseModel: + """Alias for ``architecture``.""" + return self.architecture + + def build_distill_losses( + self, + losses: Optional[Dict[str, Dict]] = None, + ) -> nn.ModuleDict: + """build distill losses according config.""" + + distill_losses = nn.ModuleDict() + if losses: + for loss_name, loss_cfg in losses.items(): + assert loss_name not in distill_losses + if 'loss' not in loss_name: + warnings.warn( + f'Warning: If {loss_name} is a loss that needs to ' + f'backward, the name of {loss_name} must contain ' + f'"loss". If it is only used as a statistical value, ' + 'then the name must not contain "loss". More details ' + 'see docs for ' + ':func:`mmengine.model.BaseModel._parse_loss`', + UserWarning) + item_loss = MODELS.build(loss_cfg) + distill_losses[loss_name] = item_loss + + return distill_losses + + def get_record(self, + recorder: str, + from_student: bool, + record_idx: int = 0, + data_idx: Optional[int] = None) -> List: + """According to each item in ``record_infos``, get the corresponding + record in ``recorder_manager``.""" + + if from_student: + recorder_ = self.student_recorders.get_recorder(recorder) + else: + recorder_ = self.teacher_recorders.get_recorder(recorder) + + return recorder_.get_record_data(record_idx, data_idx) + + def compute_distill_losses( + self, + distill_losses: nn.ModuleDict, + loss_forward_mappings: Dict[str, Dict], + student_recorders: RecorderManager, + teacher_recorders: RecorderManager, + ) -> LossResults: + """Compute distill losses automatically.""" + # Record all computed losses' results. + losses = dict() + for loss_name, forward_mappings in loss_forward_mappings.items(): + forward_kwargs = dict() + for forward_key, record_info in forward_mappings.items(): + forward_var = self.get_record(**record_info) + forward_kwargs[forward_key] = forward_var + + loss_module = distill_losses[loss_name] + loss = loss_module(**forward_kwargs) # type: ignore + # add computed loss result. + losses[loss_name] = loss + + return losses + + def _check_loss_forward_mappings( + self, losses: nn.ModuleDict, loss_forward_mappings: Dict[str, + Dict], + student_recorders: RecorderManager, + teacher_recorders: RecorderManager) -> None: + """Check if ``loss_forward_mappings`` is in the correct format.""" + + if not isinstance(loss_forward_mappings, dict): + raise TypeError( + 'loss_forward_mappings should be a dict instance, but got' + f'{type(loss_forward_mappings)}') + + for loss_name, forward_mappings in loss_forward_mappings.items(): + assert loss_name in losses, \ + f'"{loss_name}" is not in distill losses. The keys of ' \ + 'loss_forward_kwargs must match the keys of distill_losses.' + + if not isinstance(forward_mappings, dict): + raise TypeError( + 'Each item of loss_forward_mappings should be a dict ' + f'instance, but got {type(forward_mappings)}') + + loss_module = losses[loss_name] + loss_forward_keys = signature( + loss_module.forward).parameters.keys() + assert len(loss_forward_keys) == len(forward_mappings.keys()) + + for forward_key, record_info in forward_mappings.items(): + assert forward_key in loss_forward_keys, \ + f'{forward_key} is not in the signature of \ + {type(loss_module).__name__} forward, \ + please check your config.' + + assert 'recorder' in record_info, \ + 'Each item of loss_forward_mappings should have ' \ + '"recorder", pls check your config.' + + assert 'from_student' in record_info, \ + 'Each item of loss_forward_mappings should have ' \ + '"from_student", pls check your config.' + + recorder: str = record_info['recorder'] + from_student: bool = record_info['from_student'] + + if not isinstance(from_student, bool): + raise TypeError(f'from_student should be a bool instance, ' + f'but got {type(from_student)}') + + if from_student: + assert recorder in self.student_recorders.recorders, \ + f'For {forward_key}, "{recorder}" must be in \ + `student_recorders`.' + + else: + assert recorder in self.teacher_recorders.recorders, \ + f'For {forward_key}, "{recorder}" must be in \ + `teacher_recorders`.' diff --git a/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py b/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py new file mode 100644 index 00000000..8fca6e23 --- /dev/null +++ b/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +from mmengine import BaseDataElement + +from mmrazor.models.utils import add_prefix +from mmrazor.registry import MODELS +from ...base import LossResults +from .single_teacher_distill import SingleTeacherDistill + + +@MODELS.register_module() +class FpnTeacherDistill(SingleTeacherDistill): + """``FpnTeacherDistill`` means teacher only execute backbone and neck. + + If the intermediate results required for distill algorithm are generated by + the backbone and neck parts, using ``FpnTeacherDistill`` can speed up + training. + """ + + def loss( + self, + batch_inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + ) -> LossResults: + """Calculate losses from a batch of inputs and data samples.""" + + losses = dict() + # If the `override_data` of a delivery is False, the delivery will + # record the origin data. + self.delivery_manager.override_data = False + if self.teacher_trainable: + # Unlike ``SingleTeacherDistill``, teacher will only execute + # back + neck, not head, so there will be no loss. + with self.teacher_recorders, self.delivery_manager: + _ = self.teacher.extract_feat(batch_inputs) + else: + with self.teacher_recorders, self.distill_deliveries: + with torch.no_grad(): + _ = self.teacher(batch_inputs, data_samples, mode='loss') + + # If the `override_data` of a delivery is True, the delivery will + # override the origin data with the recorded data. + self.delivery_manager.override_data = True + with self.student_recorders, self.delivery_manager: + student_losses = self.student( + batch_inputs, data_samples, mode='loss') + losses.update(add_prefix(student_losses, 'student')) + + # Automatically compute distill losses based on `loss_forward_mappings` + distill_losses = self.compute_distill_losses( + self.distill_losses, self.loss_forward_mappings, + self.student_recorders, self.teacher_recorders) + losses.update(add_prefix(distill_losses, 'distill')) + + return losses diff --git a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py new file mode 100644 index 00000000..15100a06 --- /dev/null +++ b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +from mmcv.runner import load_checkpoint +from mmengine import BaseDataElement +from mmengine.model import BaseModel +from torch.nn.modules.batchnorm import _BatchNorm + +from mmrazor.models.utils import add_prefix +from mmrazor.registry import MODELS +from ...base import LossResults +from .configurable_distill import ConfigurableDistill + + +@MODELS.register_module() +class SingleTeacherDistill(ConfigurableDistill): + """``SingleTeacherDistill`` can be used to develop distill algorithms which + only use one teacher. + + Args: + teacher (dict | BaseModel): The config dict for teacher model or built + teacher model. + teacher_ckpt (str): The path of teacher's checkpoint. Defaults to None. + teacher_trainable (bool): Whether the teacher is trainable. Defaults + to False. + teacher_norm_eval (bool): Whether to set teacher's norm layers to eval + mode, namely, freeze running stats (mean and var). Note: Effect on + Batch Norm and its variants only. Defaults to True. + """ + + def __init__(self, + teacher: Union[BaseModel, Dict], + teacher_ckpt: Optional[str] = None, + teacher_trainable: bool = False, + teacher_norm_eval: bool = True, + **kwargs): + super().__init__(**kwargs) + + if isinstance(teacher, Dict): + teacher = MODELS.build(teacher) + + if not isinstance(teacher, BaseModel): + raise TypeError('teacher should be a `dict` or ' + f'`BaseModel` instance, but got ' + f'{type(teacher)}') + + self.teacher = teacher + if teacher_ckpt: + # avoid loaded parameters be overwritten + self.teacher.init_weights() + _ = load_checkpoint(self.teacher, teacher_ckpt) + self.teacher_trainable = teacher_trainable + self.teacher_norm_eval = teacher_norm_eval + + # In ``ConfigurableDistll``, the recorder manager is just constructed, + # but not really initialized yet. + self.student_recorders.initialize(self.student) + self.teacher_recorders.initialize(self.teacher) + + def loss( + self, + batch_inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + ) -> LossResults: + """Calculate losses from a batch of inputs and data samples.""" + + losses = dict() + + # If the `override_data` of a delivery is False, the delivery will + # record the origin data. + self.distill_deliveries.override_data = False + if self.teacher_trainable: + with self.teacher_recorders, self.distill_deliveries: + teacher_losses = self.teacher( + batch_inputs, data_samples, mode='loss') + + losses.update(add_prefix(teacher_losses, 'teacher')) + else: + with self.teacher_recorders, self.distill_deliveries: + with torch.no_grad(): + + _ = self.teacher(batch_inputs, data_samples, mode='loss') + + # If the `override_data` of a delivery is True, the delivery will + # override the origin data with the recorded data. + self.distill_deliveries.override_data = True + with self.student_recorders, self.distill_deliveries: + student_losses = self.student( + batch_inputs, data_samples, mode='loss') + losses.update(add_prefix(student_losses, 'student')) + + # Automatically compute distill losses based on `loss_forward_mappings` + distill_losses = self.compute_distill_losses( + self.distill_losses, self.loss_forward_mappings, + self.student_recorders, self.teacher_recorders) + losses.update(add_prefix(distill_losses, 'distill')) + + return losses + + def train(self, mode=True): + """Set distiller's forward mode.""" + super().train(mode) + if mode and self.teacher_norm_eval: + for m in self.teacher.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmrazor/models/distillers/__init__.py b/mmrazor/models/distillers/__init__.py deleted file mode 100644 index d8f2646d..00000000 --- a/mmrazor/models/distillers/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .self_distiller import SelfDistiller -from .single_teacher import SingleTeacherDistiller - -__all__ = ['SelfDistiller', 'SingleTeacherDistiller'] diff --git a/mmrazor/models/distillers/base.py b/mmrazor/models/distillers/base.py deleted file mode 100644 index c6229488..00000000 --- a/mmrazor/models/distillers/base.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta, abstractmethod - -from mmcv.runner import BaseModule -from mmcv.utils import import_modules_from_strings - - -def function_wrapper(ctx, method, method_str): - """Pass teacher's outputs to student.""" - - def wrapper(*args, **kwargs): - # record inputs - ctx.method_args[method_str] = args - ctx.method_kwargs[method_str] = kwargs - # TODO cover more usecases, not only pass teacher's outputs to - # student. - if ctx.is_teacher: - # execute the raw function - outputs = method(*args, **kwargs) - # record outputs - ctx.method_return[method_str] = outputs - else: - # modify student's outputs to be same with teacher - outputs = ctx.method_return[method_str] - - return outputs - - return wrapper - - -class FunctionContext(): - """Function context manager for rewrite function. - - Args: - ctx (ConversionContext): The distiller's overall context manager. - method (str): The name of the function to rewrite. - """ - - def __init__(self, ctx, method, import_module=None): - self.ctx = ctx - - self.import_module = import_modules_from_strings(import_module) - self.method_str = method - self.method_exec_str = f'self.import_module.{method}' - - def _set_method(self, method): - """Modify a function.""" - exec(f'{self.method_exec_str} = method') - - def __enter__(self): - """Rewrite the function.""" - self.method_impl = eval(self.method_exec_str) - - if self.method_impl: - self._set_method( - function_wrapper(self.ctx, self.method_impl, self.method_str, - self.align_mode)) - - def __exit__(self, exc_type, exc_value, traceback): - """Restore the function.""" - if self.method_impl: - self._set_method(self.method_impl) - - -class ConversionContext(): - """Context manager for record functions' inputs or outputs.""" - - def __init__(self, hooks): - # save functions' inputs - self.method_args = dict() - self.method_kwargs = dict() - # save functions' outputs - self.method_return = dict() - - # Each function will have a sub context manager, the function will be - # rewritten when enter the sub context manager. - self.hooks = [] - self.is_teacher = True - for hook in hooks: - self.hooks.append(FunctionContext(self, **hook)) - - def __enter__(self): - """Enter every sub context managers.""" - for hook in self.hooks: - hook.__enter__() - return self - - def __exit__(self, exc_type, exc_value, traceback): - """Exit every sub context managers.""" - for hook in self.hooks: - hook.__exit__(exc_type, exc_value, traceback) - - -class BaseDistiller(BaseModule, metaclass=ABCMeta): - """Base Distiller. - - In the distillation algorithm, some intermediate results of the teacher - need to be obtained and passed to the student. - - For nn.Module's outputs, obtained by pytorch forward hook. - For python function's outputs, obtained by a specific context manager. - - Args: - align_methods (dict): The details of the functions which outputs need - to be obtained. - """ - - def __init__(self, align_methods=None, **kwargs): - super(BaseDistiller, self).__init__(**kwargs) - - if align_methods is None: - self.context_manager = None - else: - # To obtain the python function's outputs, there will build a - # specific context manager. When enter the context manager, the - # functions will be rewrite. The context manager could record - # inputs or outputs of the functions , and pass from teachr to - # student. When exit the context manager, the rewritten functions - # will restore. - self.context_manager = ConversionContext(align_methods) - - @abstractmethod - def prepare_from_student(self, student): - """Register forward hooks to students and teachers.""" - pass - - @abstractmethod - def teacher_forward_output_hook(self, module, inputs, outputs): - """Save the teacher output.""" - pass - - @abstractmethod - def student_forward_output_hook(self, module, inputs, outputs): - """Save the student output.""" - pass - - def reset_ctx_teacher_mode(self, mode=True): - if self.context_manager is not None: - self.context_manager.is_teacher = mode - - @abstractmethod - def exec_teacher_forward(self, data): - """Execute the teacher's forward function.""" - pass - - @abstractmethod - def exec_student_forward(self, student, data): - """Execute the student's forward function.""" - pass - - @abstractmethod - def compute_distill_loss(self, data): - """Compute distill loss according teacher's outputs and student's - outputs.""" - pass diff --git a/mmrazor/models/distillers/self_distiller.py b/mmrazor/models/distillers/self_distiller.py deleted file mode 100644 index 8f1926f4..00000000 --- a/mmrazor/models/distillers/self_distiller.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch.nn as nn - -from mmrazor.registry import MODELS -from .base import BaseDistiller - - -@MODELS.register_module() -class SelfDistiller(BaseDistiller): - """Transfer knowledge inside a single model. - - Args: - components (dict): The details of the distillation. It usually includes - the module names of the teacher and the student, and the losses - used in the distillation. - """ - - def __init__(self, components, **kwargs): - super().__init__(**kwargs) - self.components = components - self.losses = nn.ModuleDict() - - self.student_outputs = dict() - self.teacher_outputs = dict() - - for component in self.components: - student_module_name = component['student_module'] - teacher_module_name = component['teacher_module'] - self.student_outputs[student_module_name] = list() - self.teacher_outputs[teacher_module_name] = list() - - for loss in component.losses: - loss_cfg = loss.copy() - loss_name = loss_cfg.pop('name') - self.losses[loss_name] = MODELS.build(loss_cfg) - - def prepare_from_student(self, student): - """Registers a global forward hook for each teacher module and student - module to be used in the distillation. - - Args: - student (:obj:`torch.nn.Module`): The student model to be used - in the distillation. - """ - self.module2name = {} - for name, module in student.model.named_modules(): - self.module2name[module] = name - self.name_modules = dict(student.model.named_modules()) - - for component in self.components: - student_module_name = component['student_module'] - teacher_module_name = component['teacher_module'] - - student_module = self.name_modules[student_module_name] - teacher_module = self.name_modules[teacher_module_name] - - student_module.register_forward_hook( - self.student_forward_output_hook) - teacher_module.register_forward_hook( - self.teacher_forward_output_hook) - - def teacher_forward_output_hook(self, module, inputs, outputs): - """Save the output. - - Args: - module (:obj:`torch.nn.Module`): the module of register hook - inputs (tuple): input of module - outputs (tuple): out of module - """ - if self.training and getattr(self, 'is_teacher', None): - self.teacher_outputs[self.module2name[module]].append(outputs) - - def student_forward_output_hook(self, module, inputs, outputs): - """Save the output. - - Args: - module (:obj:`torch.nn.Module`): the module of register hook - inputs (tuple): input of module - outputs (tuple): out of module - """ - if self.training and not getattr(self, 'is_teacher', None): - self.student_outputs[self.module2name[module]].append(outputs) - - def reset_outputs(self, outputs): - """Reset the teacher's outputs or student's outputs.""" - for key in outputs.keys(): - outputs[key] = list() - - def exec_teacher_forward(self, teacher, data): - """Forward computation of the teacher. - - Args: - teacher (:obj:`torch.nn.Module`): The teacher model to be used - in the distillation. - data (dict): The output of dataloader. - """ - self.reset_outputs(self.teacher_outputs) - self.is_teacher = True - output = teacher(**data) - self.is_teacher = False - - return output - - def exec_student_forward(self, student, data): - """Forward computation of the student. - - Args: - student (:obj:`torch.nn.Module`): The student model to be used - in the distillation. - data (dict): The output of dataloader. - """ - assert not self.is_teacher - self.reset_outputs(self.student_outputs) - output = student(**data) - - return output - - def compute_distill_loss(self, data): - """Compute the distillation loss.""" - - losses = dict() - - for i, component in enumerate(self.components): - student_module_name = component['student_module'] - student_outputs = self.student_outputs[student_module_name] - - teacher_module_name = component['teacher_module'] - teacher_outputs = self.teacher_outputs[teacher_module_name] - - for out_idx, (s_out, t_out) in enumerate( - zip(student_outputs, teacher_outputs)): - - for loss in component.losses: - loss_module = self.losses[loss.name] - loss_name = f'{loss.name}.{out_idx}' - - loss_module.current_data = data - losses[loss_name] = loss_module(s_out, t_out) - loss_module.current_data = None - - return losses diff --git a/mmrazor/models/distillers/single_teacher.py b/mmrazor/models/distillers/single_teacher.py deleted file mode 100644 index e1d524f1..00000000 --- a/mmrazor/models/distillers/single_teacher.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn -from torch.nn.modules.batchnorm import _BatchNorm - -from mmrazor.registry import MODELS -from .base import BaseDistiller - - -@MODELS.register_module() -class SingleTeacherDistiller(BaseDistiller): - """Distiller with single teacher. - - Args: - teacher (dict): The config dict for teacher. - teacher_trainable (bool): Whether the teacher is trainable. - Default: False. - teacher_norm_eval (bool): Whether to set teacher's norm layers to eval - mode, namely, freeze running stats (mean and var). Note: Effect on - Batch Norm and its variants only. Default: True. - components (dict): The details of the distillation. It usually includes - the module names of the teacher and the student, and the losses - used in the distillation. - """ - - def __init__(self, - teacher, - teacher_trainable=False, - teacher_norm_eval=True, - components=tuple(), - **kwargs): - super().__init__(**kwargs) - self.teacher_trainable = teacher_trainable - self.teacher_norm_eval = teacher_norm_eval - self.teacher = self.build_teacher(teacher) - - self.components = components - self.losses = nn.ModuleDict() - self.align_modules = nn.ModuleDict() - - # Record the featuremaps that need to calculate the distillation loss. - self.student_outputs = dict() - self.teacher_outputs = dict() - - for i, component in enumerate(self.components): - student_module_name = component['student_module'] - teacher_module_name = component['teacher_module'] - # The type of every student_output is a list by default, because - # some modules will execute multiple forward calculations, such as - # the shareable head in Retinanet - self.student_outputs[student_module_name] = list() - self.teacher_outputs[teacher_module_name] = list() - - # If the number of featuremap channels of student and teacher are - # inconsistent, they need to be aligned by a 1x1 convolution - align_module_cfg = getattr(component, 'align_module', None) - if align_module_cfg is not None: - align_module_name = f'component_{i}' - align_module = self.build_align_module(align_module_cfg) - self.align_modules[align_module_name] = align_module - - # Multiple losses can be calculated at the same location - for loss in component.losses: - loss_cfg = loss.copy() - loss_name = loss_cfg.pop('name') - self.losses[loss_name] = MODELS.build(loss_cfg) - - def build_teacher(self, cfg): - """Build a model from the `cfg`.""" - - teacher = MODELS.build(cfg) - - return teacher - - def build_align_module(self, cfg): - """Build ``align_module`` from the `cfg`. - - ``align_module`` is needed when the number of channels output by the - teacher module is not equal to that of the student module, or for some - other reasons. - - Args: - cfg (dict): The config dict for ``align_module``. - """ - - in_channels = cfg.student_channels - out_channels = cfg.teacher_channels - if cfg.type == 'conv2d': - align_module = nn.Conv2d(in_channels, out_channels, 1) - elif cfg.type == 'linear': - align_module = nn.Linear(in_channels, out_channels) - return align_module - - def prepare_from_student(self, student): - """Registers a global forward hook for each teacher module and student - module to be used in the distillation. - - Args: - student (:obj:`torch.nn.Module`): The student model to be used - in the distillation. - """ - - # Record the mapping relationship between student's modules and module - # names. - self.student_module2name = {} - for name, module in student.model.named_modules(): - self.student_module2name[module] = name - self.student_name2module = dict(student.model.named_modules()) - - # Record the mapping relationship between teacher's modules and module - # names. - self.teacher_module2name = {} - for name, module in self.teacher.named_modules(): - self.teacher_module2name[module] = name - self.teacher_name2module = dict(self.teacher.named_modules()) - - # Register forward hooks for modules that need to participate in loss - # calculation. - for component in self.components: - student_module_name = component['student_module'] - teacher_module_name = component['teacher_module'] - - student_module = self.student_name2module[student_module_name] - teacher_module = self.teacher_name2module[teacher_module_name] - - student_module.register_forward_hook( - self.student_forward_output_hook) - teacher_module.register_forward_hook( - self.teacher_forward_output_hook) - - def teacher_forward_output_hook(self, module, inputs, outputs): - """Save the module's forward output. - - Args: - module (:obj:`torch.nn.Module`): The module to register hook. - inputs (tuple): The input of the module. - outputs (tuple): The output of the module. - """ - if self.training: - self.teacher_outputs[self.teacher_module2name[module]].append( - outputs) - - def student_forward_output_hook(self, module, inputs, outputs): - """Save the module's forward output. - - Args: - module (:obj:`torch.nn.Module`): The module to register hook. - inputs (tuple): The input of the module. - outputs (tuple): The output of the module. - """ - if self.training: - self.student_outputs[self.student_module2name[module]].append( - outputs) - - def reset_outputs(self, outputs): - """Reset the teacher's outputs or student's outputs.""" - for key in outputs.keys(): - outputs[key] = list() - - def exec_teacher_forward(self, data): - """Execute the teacher's forward function. - - After this function, the teacher's featuremaps will be saved in - ``teacher_outputs``. - """ - - # Convert the context manager's mode to teacher. - self.reset_ctx_teacher_mode(True) - # Clear the saved data of the last forward。 - self.reset_outputs(self.teacher_outputs) - - if self.teacher_trainable: - output = self.teacher(**data) - else: - with torch.no_grad(): - output = self.teacher(**data) - - return output - - def exec_student_forward(self, student, data): - """Execute the teacher's forward function. - - After this function, the student's featuremaps will be saved in - ``student_outputs``. - """ - # Convert the context manager's mode to teacher. - self.reset_ctx_teacher_mode(False) - # Clear the saved data of the last forward。 - self.reset_outputs(self.student_outputs) - - output = student(**data) - return output - - def train(self, mode=True): - """Set distiller's forward mode.""" - super(SingleTeacherDistiller, self).train(mode) - if mode and self.teacher_norm_eval: - for m in self.teacher.modules(): - if isinstance(m, _BatchNorm): - m.eval() - - def get_teacher_outputs(self, teacher_module_name): - """Get the outputs according module name.""" - return self.teacher_outputs[teacher_module_name] - - def compute_distill_loss(self, data=None): - """Compute the distillation loss.""" - - losses = dict() - - for i, component in enumerate(self.components): - # Get the student's outputs. - student_module_name = component['student_module'] - student_outputs = self.student_outputs[student_module_name] - - # Align student output's channels with teacher. - align_module_name = f'component_{i}' - if align_module_name in self.align_modules: - align_module = self.align_modules[align_module_name] - student_outputs = [ - align_module(s_out) for s_out in student_outputs - ] - - # Get the teacher's outputs. - teacher_module_name = component['teacher_module'] - teacher_outputs = self.get_teacher_outputs(teacher_module_name) - - # One module maybe have N outputs, such as the shareable head in - # RetinaNet. - for out_idx, (s_out, t_out) in enumerate( - zip(student_outputs, teacher_outputs)): - - for loss in component.losses: - loss_module = self.losses[loss.name] - loss_name = f'{loss.name}.{out_idx}' - # TODO ugly implementation. - # Pass the gt_label to loss function. - # Only used by WSLD. - loss_module.current_data = data - losses[loss_name] = loss_module(s_out, t_out) - loss_module.current_data = None - - return losses diff --git a/mmrazor/models/losses/kl_divergence.py b/mmrazor/models/losses/kl_divergence.py index d4c82d17..ff0ab1df 100644 --- a/mmrazor/models/losses/kl_divergence.py +++ b/mmrazor/models/losses/kl_divergence.py @@ -26,9 +26,9 @@ class KLDivergence(nn.Module): def __init__( self, - tau=1.0, - reduction='batchmean', - loss_weight=1.0, + tau: float = 1.0, + reduction: str = 'batchmean', + loss_weight: float = 1.0, ): super(KLDivergence, self).__init__() self.tau = tau diff --git a/mmrazor/models/losses/weighted_soft_label_distillation.py b/mmrazor/models/losses/weighted_soft_label_distillation.py index f49bda84..6813a6de 100644 --- a/mmrazor/models/losses/weighted_soft_label_distillation.py +++ b/mmrazor/models/losses/weighted_soft_label_distillation.py @@ -24,12 +24,20 @@ class WSLD(nn.Module): self.tau = tau self.loss_weight = loss_weight self.num_classes = num_classes - self.softmax = nn.Softmax(dim=1).cuda() - self.logsoftmax = nn.LogSoftmax(dim=1).cuda() + self.softmax = nn.Softmax(dim=1) + self.logsoftmax = nn.LogSoftmax(dim=1) - def forward(self, student, teacher): + def forward(self, student, teacher, data_samples): - gt_labels = self.current_data['gt_label'] + # Unpack data samples and pack targets + if 'score' in data_samples[0].gt_label: + # Batch augmentation may convert labels to one-hot format scores. + gt_labels = torch.stack([i.gt_label.score for i in data_samples]) + one_hot_labels = gt_labels.float() + else: + gt_labels = torch.hstack([i.gt_label.label for i in data_samples]) + one_hot_labels = F.one_hot( + gt_labels, num_classes=self.num_classes).float() student_logits = student / self.tau teacher_logits = teacher / self.tau @@ -49,7 +57,7 @@ class WSLD(nn.Module): ce_loss_t = -torch.sum(one_hot_labels * log_softmax_t, 1, keepdim=True) focal_weight = ce_loss_s / (ce_loss_t + 1e-7) - ratio_lower = torch.zeros(1).cuda() + ratio_lower = torch.zeros_like(focal_weight) focal_weight = torch.max(focal_weight, ratio_lower) focal_weight = 1 - torch.exp(-focal_weight) ce_loss = focal_weight * ce_loss diff --git a/mmrazor/registry/registry.py b/mmrazor/registry/registry.py index 564a53e6..8ab60fe6 100644 --- a/mmrazor/registry/registry.py +++ b/mmrazor/registry/registry.py @@ -38,10 +38,9 @@ def build_razor_model_from_cfg( default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: # TODO relay on mmengine:HAOCHENYE/config_new_feature - # if cfg.get('cfg_path', None) and not cfg.get('type', None): - # from mmengine.config import get_model - # teacher = get_model(**cfg) - # return teacher + if cfg.get('cfg_path', None) and not cfg.get('type', None): + from mmengine.config import get_model + return get_model(**cfg) if cfg.get('_fix_subnet_', None): fix_subnet = cfg.pop('_fix_subnet_') diff --git a/mmrazor/runners/distill_val_loop.py b/mmrazor/runners/distill_val_loop.py index cdf845ff..8bb7dbc0 100644 --- a/mmrazor/runners/distill_val_loop.py +++ b/mmrazor/runners/distill_val_loop.py @@ -3,7 +3,7 @@ from typing import Dict, List, Sequence, Union import torch from mmengine.evaluator import Evaluator -from mmengine.runner import ValLoop +from mmengine.runner import ValLoop, autocast from torch.utils.data import DataLoader from mmrazor.registry import LOOPS @@ -19,16 +19,29 @@ class SingleTeacherDistillValLoop(ValLoop): dataloader (Dataloader or dict): A dataloader object or a dict to build a dataloader. evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 validation. Defaults to + False. """ - def __init__(self, runner, dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List]) -> None: - super().__init__(runner, dataloader, evaluator) + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False) -> None: + super().__init__(runner, dataloader, evaluator, fp16) if self.runner.distributed: - self.model = runner.model.module + assert hasattr(self.runner.model.module, 'teacher') + # TODO: remove hard code after mmcls add data_preprocessor + data_preprocessor = self.runner.model.module.data_preprocessor + self.teacher = self.runner.model.module.teacher + self.teacher.data_preprocessor = data_preprocessor + else: - self.model = runner.model - assert hasattr(self.model, 'teacher') + assert hasattr(self.runner.model, 'teacher') + # TODO: remove hard code after mmcls add data_preprocessor + data_preprocessor = self.runner.model.data_preprocessor + self.teacher = self.runner.model.teacher + self.teacher.data_preprocessor = data_preprocessor def run(self): """Launch validation.""" @@ -38,19 +51,32 @@ class SingleTeacherDistillValLoop(ValLoop): for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) - # compute metrics - metrics_s = self.evaluator.evaluate(len(self.dataloader.dataset)) - for key, value in metrics_s.items(): - self.runner.message_hub.update_scalar(f'val_student/{key}', value) + # compute student metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + student_metrics = dict() + for key, value in metrics.items(): + student_key = 'student.' + key + teacher_key = 'teacher.' + key + student_metrics[student_key] = value + self.runner.message_hub.log_scalars.pop(f'val/{teacher_key}', None) + + self.runner.call_hook('after_val_epoch', metrics=student_metrics) + + self.runner.call_hook('before_val_epoch') for idx, data_batch in enumerate(self.dataloader): self.run_iter_teacher(idx, data_batch) - # compute metrics - metrics_t = self.evaluator.evaluate(len(self.dataloader.dataset)) - for key, value in metrics_t.items(): - self.runner.message_hub.update_scalar(f'val_teacher/{key}', value) + # compute teacher metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + teacher_metrics = dict() + for key, value in metrics.items(): + student_key = 'student.' + key + teacher_key = 'teacher.' + key - self.runner.call_hook('after_val_epoch', metrics=None) + teacher_metrics[teacher_key] = value + self.runner.message_hub.log_scalars.pop(f'val/{student_key}', None) + + self.runner.call_hook('after_val_epoch', metrics=teacher_metrics) self.runner.call_hook('after_val') @torch.no_grad() @@ -63,8 +89,11 @@ class SingleTeacherDistillValLoop(ValLoop): """ self.runner.call_hook( 'before_val_iter', batch_idx=idx, data_batch=data_batch) - # outputs should be sequence of BaseDataElement - outputs = self.model.teacher(data_batch) + + with autocast(enabled=self.fp16): + # outputs should be sequence of BaseDataElement + outputs = self.teacher.val_step(data_batch) + self.evaluator.process(data_batch, outputs) self.runner.call_hook( 'after_val_iter', diff --git a/mmrazor/utils/setup_env.py b/mmrazor/utils/setup_env.py index a8dba042..1f42c4dd 100644 --- a/mmrazor/utils/setup_env.py +++ b/mmrazor/utils/setup_env.py @@ -62,9 +62,7 @@ def register_all_modules(init_default_scope: bool = True) -> None: """ # noqa import mmrazor.core # noqa: F401,F403 import mmrazor.models # noqa: F401,F403 - - # TODO add mmrazor.runners - + import mmrazor.runners # noqa: F401,F403 if init_default_scope: never_created = DefaultScope.get_current_instance() is None \ or not DefaultScope.check_instance_created('mmrazor') diff --git a/tests/test_core/test_delivers/test_deliver_manager.py b/tests/test_core/test_delivers/test_deliver_manager.py index ef932e90..b702d261 100644 --- a/tests/test_core/test_delivers/test_deliver_manager.py +++ b/tests/test_core/test_delivers/test_deliver_manager.py @@ -3,24 +3,38 @@ from unittest import TestCase from mmcv import ConfigDict -from mmrazor.core import DistillDeliverManager +from mmrazor.core import DistillDeliveryManager class TestDeliverManager(TestCase): + def test_init(self): + + distill_deliveries = ConfigDict( + delivery1=dict( + type='MethodOutputs', + max_keep_data=2, + method_path='toy_module.ToyClass.random_int')) + + manager = DistillDeliveryManager(distill_deliveries) + self.assertEquals(len(manager.deliveries), 1) + + manager = DistillDeliveryManager() + self.assertEquals(len(manager.deliveries), 0) + def test_context_manager(self): from toy_module import ToyClass - distill_deliveries = [ - ConfigDict( + distill_deliveries = ConfigDict( + delivery1=dict( type='MethodOutputs', max_keep_data=2, - method_path='toy_module.ToyClass.random_int') - ] + method_path='toy_module.ToyClass.random_int')) - manager = DistillDeliverManager(distill_deliveries) + manager = DistillDeliveryManager(distill_deliveries) manager.override_data = False + self.assertFalse(manager.override_data) with manager: toy_class = ToyClass() output1_tea = toy_class.random_int() @@ -30,7 +44,9 @@ class TestDeliverManager(TestCase): with manager: _ = toy_class.random_int() + self.assertFalse(manager.override_data) manager.override_data = True + self.assertTrue(manager.override_data) with manager: output1_stu = toy_class.random_int() output2_stu = toy_class.random_int() diff --git a/tests/test_core/test_delivers/test_function_outputs_deliver.py b/tests/test_core/test_delivers/test_function_outputs_deliver.py index 64895fc7..f3ffacc2 100644 --- a/tests/test_core/test_delivers/test_function_outputs_deliver.py +++ b/tests/test_core/test_delivers/test_function_outputs_deliver.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase -from mmrazor.core import FunctionOutputsDeliver +from mmrazor.core import FunctionOutputsDelivery class TestFuncOutputsDeliver(TestCase): @@ -9,26 +9,26 @@ class TestFuncOutputsDeliver(TestCase): def test_init(self): with self.assertRaisesRegex(TypeError, 'func_path should be'): - _ = FunctionOutputsDeliver(max_keep_data=1, func_path=1) + _ = FunctionOutputsDelivery(max_keep_data=1, func_path=1) with self.assertRaisesRegex(AssertionError, 'func_path must have at '): - _ = FunctionOutputsDeliver(max_keep_data=1, func_path='toy_func') + _ = FunctionOutputsDelivery(max_keep_data=1, func_path='toy_func') with self.assertRaisesRegex(ImportError, 'aaa is not imported'): - _ = FunctionOutputsDeliver(max_keep_data=1, func_path='aaa.bb') + _ = FunctionOutputsDelivery(max_keep_data=1, func_path='aaa.bb') with self.assertRaisesRegex(AssertionError, 'bb is not in toy_mod'): - _ = FunctionOutputsDeliver( + _ = FunctionOutputsDelivery( max_keep_data=1, func_path='toy_module.bb') with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'): - _ = FunctionOutputsDeliver( + _ = FunctionOutputsDelivery( max_keep_data=1, func_path='toy_module.TOY_VAR') def test_context_manager(self): import toy_module - delivery = FunctionOutputsDeliver( + delivery = FunctionOutputsDelivery( max_keep_data=2, func_path='toy_module.toy_func') delivery.override_data = False diff --git a/tests/test_core/test_delivers/test_method_outputs_deliver.py b/tests/test_core/test_delivers/test_method_outputs_deliver.py index e4a2a502..ba5e2629 100644 --- a/tests/test_core/test_delivers/test_method_outputs_deliver.py +++ b/tests/test_core/test_delivers/test_method_outputs_deliver.py @@ -1,45 +1,45 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase -from mmrazor.core import MethodOutputsDeliver +from mmrazor.core import MethodOutputsDelivery class TestMethodOutputsDeliver(TestCase): def test_init(self): with self.assertRaisesRegex(TypeError, 'method_path should be'): - _ = MethodOutputsDeliver(max_keep_data=1, method_path=1) + _ = MethodOutputsDelivery(max_keep_data=1, method_path=1) with self.assertRaisesRegex(AssertionError, 'method_path must have at '): - _ = MethodOutputsDeliver(max_keep_data=1, method_path='toy_func') + _ = MethodOutputsDelivery(max_keep_data=1, method_path='toy_func') with self.assertRaisesRegex(ImportError, 'aaa is not imported'): - _ = MethodOutputsDeliver(max_keep_data=1, method_path='aaa.bb.b') + _ = MethodOutputsDelivery(max_keep_data=1, method_path='aaa.bb.b') with self.assertRaisesRegex(AssertionError, 'bb is not in toy_module'): - _ = MethodOutputsDeliver( + _ = MethodOutputsDelivery( max_keep_data=1, method_path='toy_module.bb.bbb') with self.assertRaisesRegex(TypeError, 'toy_func should be a type'): - _ = MethodOutputsDeliver( + _ = MethodOutputsDelivery( max_keep_data=1, method_path='toy_module.toy_func.bbb') with self.assertRaisesRegex(AssertionError, 'bbb is not in'): - _ = MethodOutputsDeliver( + _ = MethodOutputsDelivery( max_keep_data=1, method_path='toy_module.ToyClass.bbb') with self.assertRaisesRegex(TypeError, 'count should be'): - _ = MethodOutputsDeliver( + _ = MethodOutputsDelivery( max_keep_data=1, method_path='toy_module.ToyClass.count') def test_context_manager(self): from toy_module import ToyClass - delivery = MethodOutputsDeliver( + delivery = MethodOutputsDelivery( max_keep_data=2, method_path='toy_module.ToyClass.random_int') - # Without ``MethodOutputsDeliver``, outputs of the teacher and the + # Without ``MethodOutputsDelivery``, outputs of the teacher and the # student are very likely to be different. # from toy_module import ToyClass # toy_class = ToyClass() diff --git a/tests/test_core/test_recorders/test_base_recorder.py b/tests/test_core/test_recorders/test_base_recorder.py new file mode 100644 index 00000000..c23e9b34 --- /dev/null +++ b/tests/test_core/test_recorders/test_base_recorder.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from toy_mod import Toy + +from mmrazor.core.recorders import MethodOutputsRecorder + + +class TestFuncOutputsRecorder(TestCase): + + def test_get_record_data(self): + + toy = Toy() + + recorder = MethodOutputsRecorder('toy_mod.Toy.toy_func') + recorder.initialize() + + with recorder: + res0 = toy.toy_func() + res1 = toy.toy_func() + + self.assertEquals(res0, recorder.get_record_data(record_idx=0)) + self.assertEquals(res1, recorder.get_record_data(record_idx=1)) + + with self.assertRaisesRegex( + AssertionError, + 'record_idx is illegal. The length of data_buffer is 2, ' + 'but record_idx is 2'): + _ = recorder.get_record_data(record_idx=2) + + with self.assertRaisesRegex( + TypeError, + 'When data_idx is not None, record should be a list or ' + 'tuple instance'): + _ = recorder.get_record_data(data_idx=0) + + recorder = MethodOutputsRecorder('toy_mod.Toy.toy_list_func') + recorder.initialize() + + with recorder: + res = toy.toy_list_func() + + self.assertEqual(len(res), 3) + + with self.assertRaisesRegex( + AssertionError, + 'data_idx is illegal. The length of record is 3'): + _ = recorder.get_record_data(data_idx=3) + + self.assertEquals(res[2], recorder.get_record_data(data_idx=2)) diff --git a/tests/test_core/test_recorders/test_func_outputs_recorder.py b/tests/test_core/test_recorders/test_func_outputs_recorder.py new file mode 100644 index 00000000..7dec51c2 --- /dev/null +++ b/tests/test_core/test_recorders/test_func_outputs_recorder.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.core.recorders import FunctionOutputsRecorder + + +class TestFuncOutputsRecorder(TestCase): + + def test_init(self): + + _ = FunctionOutputsRecorder('toy_mod.toy_func') + + with self.assertRaisesRegex(TypeError, 'source should be'): + _ = FunctionOutputsRecorder([1]) + + with self.assertRaisesRegex(AssertionError, 'source must have at '): + _ = FunctionOutputsRecorder('aaaaa') + + with self.assertRaisesRegex(ImportError, 'aaa is not imported'): + _ = FunctionOutputsRecorder('aaa.bbb') + + with self.assertRaisesRegex(AssertionError, 'aaa is not in toy_mod'): + _ = FunctionOutputsRecorder('toy_mod.aaa') + + with self.assertRaisesRegex(TypeError, 'TOY_VAR should be'): + _ = FunctionOutputsRecorder('toy_mod.TOY_VAR') + + def test_context_manager(self): + from toy_mod import execute_toy_func + + recorder = FunctionOutputsRecorder('toy_mod.toy_func') + recorder.initialize() + + with recorder: + execute_toy_func(1) + + data = recorder.get_record_data() + self.assertTrue(data == 1) + + execute_toy_func(1) + data = recorder.get_record_data() + self.assertTrue(data == 1) diff --git a/tests/test_core/test_recorders/test_method_outputs_recorder.py b/tests/test_core/test_recorders/test_method_outputs_recorder.py new file mode 100644 index 00000000..c97cca2c --- /dev/null +++ b/tests/test_core/test_recorders/test_method_outputs_recorder.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.core.recorders import MethodOutputsRecorder + + +class TestFuncOutputsRecorder(TestCase): + + def test_init(self): + + _ = MethodOutputsRecorder('toy_mod.ToyClass.toy') + + with self.assertRaisesRegex(TypeError, 'source should be'): + _ = MethodOutputsRecorder([1]) + + with self.assertRaisesRegex(AssertionError, 'source must have at '): + _ = MethodOutputsRecorder('aaaaa') + + with self.assertRaisesRegex(AssertionError, 'source must have at '): + _ = MethodOutputsRecorder('aaa.bbb') + + with self.assertRaisesRegex(ImportError, 'aaa is not imported'): + _ = MethodOutputsRecorder('aaa.bbb.ccc') + + with self.assertRaisesRegex(AssertionError, 'aaa is not in toy_mod'): + _ = MethodOutputsRecorder('toy_mod.aaa.bbb') + + with self.assertRaisesRegex(TypeError, 'toy_func should be'): + _ = MethodOutputsRecorder('toy_mod.toy_func.bbb') + + with self.assertRaisesRegex(AssertionError, 'bbb is not in ToyClass'): + _ = MethodOutputsRecorder('toy_mod.ToyClass.bbb') + + with self.assertRaisesRegex(TypeError, 'TOY_CLS should be'): + _ = MethodOutputsRecorder('toy_mod.ToyClass.TOY_CLS') + + def test_context_manager(self): + from toy_mod import ToyClass + + toy = ToyClass() + + recorder = MethodOutputsRecorder('toy_mod.ToyClass.toy') + recorder.initialize() + + with recorder: + result = toy.toy() + + data = recorder.get_record_data() + self.assertTrue(data == result) + + result_ = toy.toy() + + data = recorder.get_record_data() + self.assertTrue(data == result) + self.assertFalse(result_ == result) diff --git a/tests/test_core/test_recorders/test_module_recorders.py b/tests/test_core/test_recorders/test_module_recorders.py new file mode 100644 index 00000000..d6505f71 --- /dev/null +++ b/tests/test_core/test_recorders/test_module_recorders.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from torch import nn + +from mmrazor.core.recorders import ModuleInputsRecorder, ModuleOutputsRecorder + + +class ToyModel(nn.Module): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 1, 1) + self.conv2 = nn.Conv2d(1, 1, 1) + + def forward(self, x): + return self.conv2(self.conv1(x)) + + +class TestModuleOutputsRecorder(TestCase): + + def test_prepare_from_model(self): + + recorder = ModuleOutputsRecorder('conv1') + with self.assertRaisesRegex(AssertionError, 'model can not be'): + recorder.prepare_from_model() + + recorder = ModuleOutputsRecorder('conv3') + model = ToyModel() + with self.assertRaisesRegex(AssertionError, '"conv3" is not in'): + recorder.prepare_from_model(model) + + recorder = ModuleOutputsRecorder('conv2') + model = ToyModel() + recorder.prepare_from_model(model) + + def test_module_outputs(self): + + recorder = ModuleOutputsRecorder('conv2') + model = ToyModel() + recorder.initialize(model) + + with recorder: + self.assertTrue(recorder.recording) + res = model(torch.randn(1, 1, 1, 1)) + + self.assertEquals(res, recorder.get_record_data()) + + with recorder: + self.assertTrue(len(recorder.data_buffer) == 0) + + _ = model(torch.randn(1, 1, 1, 1)) + self.assertTrue(len(recorder.data_buffer) == 0) + + def test_module_intputs(self): + + recorder = ModuleInputsRecorder('conv1') + model = ToyModel() + recorder.initialize(model) + + tensor = torch.randn(1, 1, 1, 1) + with recorder: + self.assertTrue(recorder.recording) + _ = model(tensor) + + conv1_input = recorder.get_record_data(data_idx=0) + self.assertEquals(conv1_input.sum(), tensor.sum()) + + with recorder: + self.assertTrue(len(recorder.data_buffer) == 0) + + _ = model(torch.randn(1, 1, 1, 1)) + self.assertTrue(len(recorder.data_buffer) == 0) diff --git a/tests/test_core/test_recorders/test_param_recorder.py b/tests/test_core/test_recorders/test_param_recorder.py new file mode 100644 index 00000000..37b184ac --- /dev/null +++ b/tests/test_core/test_recorders/test_param_recorder.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from torch import nn + +from mmrazor.core.recorders import ParameterRecorder + + +class ToyModel(nn.Module): + + def __init__(self): + super().__init__() + self.toy_conv = nn.Conv2d(1, 1, 1) + self.no_record_conv = nn.Conv2d(1, 1, 1) + + def forward(self, x): + return self.toy_conv(x) + + +class TestParameterRecorder(TestCase): + + def test_prepare_from_model(self): + + model = ToyModel() + recorder = ParameterRecorder('AAA') + with self.assertRaisesRegex(AssertionError, '"AAA" is not in the'): + recorder.initialize(model) + + recorder = ParameterRecorder('toy_conv.bias') + with self.assertRaisesRegex(AssertionError, 'model can not be None'): + recorder.prepare_from_model() + + recorder.initialize(model) + bias_weight = recorder.get_record_data() + + self.assertEquals(bias_weight, model.toy_conv.bias) + + with recorder: + _ = model(torch.randn(1, 1, 1, 1)) + + self.assertEquals(bias_weight, model.toy_conv.bias) diff --git a/tests/test_core/test_recorders/test_recorder_manager.py b/tests/test_core/test_recorders/test_recorder_manager.py new file mode 100644 index 00000000..8ea443c6 --- /dev/null +++ b/tests/test_core/test_recorders/test_recorder_manager.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmcv import ConfigDict +from torch import nn +from toy_mod import Toy + +from mmrazor.core import RecorderManager + + +class ToyModel(nn.Module): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 1, 1) + self.conv2 = nn.Conv2d(1, 1, 1) + self.toy = Toy() + + def forward(self, x): + return self.conv2(self.conv1(x)) + self.toy.toy_func() + + +class TestRecorderManager(TestCase): + + def test_init(self): + + manager = RecorderManager() + self.assertEquals(len(manager.recorders), 0) + + recorders = ConfigDict( + r1=dict(type='ModuleOutputs', source='conv1'), + r2=dict(type='MethodOutputs', source='toy_mod.Toy.toy_func'), + ) + manager = RecorderManager(recorders) + model = ToyModel() + manager.initialize(model) + + def test_context_manager(self): + + recorders = ConfigDict( + r1=dict(type='ModuleOutputs', source='conv2'), + r2=dict(type='MethodOutputs', source='toy_mod.Toy.toy_func'), + ) + manager = RecorderManager(recorders) + model = ToyModel() + manager.initialize(model) + + self.assertEquals(manager.get_recorder('r1'), manager.recorders['r1']) + self.assertEquals(manager.get_recorder('r2'), manager.recorders['r2']) + + with manager: + res = model(torch.ones(1, 1, 1, 1)) + + method_outputs = manager.recorders['r2'].get_record_data() + conv2_outputs = manager.recorders['r1'].get_record_data() + + self.assertEquals(res.sum(), method_outputs + conv2_outputs.sum()) diff --git a/tests/test_core/test_recorders/toy_mod.py b/tests/test_core/test_recorders/toy_mod.py new file mode 100644 index 00000000..3cc33147 --- /dev/null +++ b/tests/test_core/test_recorders/toy_mod.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +TOY_VAR = 'aaa' + + +def toy_func(a): + return a + + +def toy_list_func(a): + return [a, a, a] + + +def execute_toy_func(a): + toy_func(a) + + +def execute_toy_list_func(a): + toy_list_func(a) + + +class ToyClass: + + TOY_CLS = 'TOY_CLASS' + + def __init__(self): + self._count = 0 + + def toy(self): + self._count += 1 + return self._count + + def __call__(self): + self._count += 1 + return self._count + + +class Toy(): + + def toy_func(self): + return random.randint(0, 1000) + + def toy_list_func(self): + return [random.randint(0, 1000) for _ in range(3)] diff --git a/tests/test_models/test_algorithms/test_configurable_distill.py b/tests/test_models/test_algorithms/test_configurable_distill.py new file mode 100644 index 00000000..80960a99 --- /dev/null +++ b/tests/test_models/test_algorithms/test_configurable_distill.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest import TestCase + +from mmcv import ConfigDict +from toy_models import ToyStudent + +from mmrazor.models import ConfigurableDistill + + +class TestConfigurableDistill(TestCase): + + def test_init(self): + + recorders_cfg = ConfigDict( + conv=dict(type='ModuleOutputs', source='conv')) + + student = ToyStudent() + + alg_kwargs = ConfigDict( + architecture=student, + student_recorders=recorders_cfg, + teacher_recorders=recorders_cfg, + distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')), + loss_forward_mappings=dict( + loss_toy=dict( + arg1=dict(from_student=True, recorder='conv'), + arg2=dict(from_student=False, recorder='conv'), + )), + ) + + alg = ConfigurableDistill(**alg_kwargs) + self.assertEquals(alg.student, alg.architecture) + + alg_kwargs_ = copy.deepcopy(alg_kwargs) + alg_kwargs_['distill_losses'] = None + with self.assertRaisesRegex(AssertionError, + '"loss_toy" is not in distill'): + _ = ConfigurableDistill(**alg_kwargs_) + + alg_kwargs_ = copy.deepcopy(alg_kwargs) + alg_kwargs_['distill_losses'] = dict(toy=dict(type='ToyDistillLoss')) + alg_kwargs_['loss_forward_mappings'] = dict( + toy=dict( + arg1=dict(from_student=True, recorder='conv'), + arg2=dict(from_student=False, recorder='conv'))) + with self.assertWarnsRegex(UserWarning, 'Warning: If toy is a'): + _ = ConfigurableDistill(**alg_kwargs_) + + alg_kwargs_ = copy.deepcopy(alg_kwargs) + alg_kwargs_['loss_forward_mappings'] = None + _ = ConfigurableDistill(**alg_kwargs_) + + alg_kwargs_ = copy.deepcopy(alg_kwargs) + alg_kwargs_['loss_forward_mappings'] = list('AAA') + + with self.assertRaisesRegex(TypeError, + 'loss_forward_mappings should be '): + _ = ConfigurableDistill(**alg_kwargs_) + + alg_kwargs_ = copy.deepcopy(alg_kwargs) + alg_kwargs_['loss_forward_mappings']['loss_toy'] = list() + with self.assertRaisesRegex( + TypeError, 'Each item of loss_forward_mappings should be '): + _ = ConfigurableDistill(**alg_kwargs_) + + alg_kwargs_ = copy.deepcopy(alg_kwargs) + alg_kwargs_.loss_forward_mappings.loss_toy.arg1.from_student = '' + with self.assertRaisesRegex(TypeError, + 'from_student should be a bool'): + _ = ConfigurableDistill(**alg_kwargs_) diff --git a/tests/test_models/test_algorithms/test_single_teacher_distill.py b/tests/test_models/test_algorithms/test_single_teacher_distill.py new file mode 100644 index 00000000..2f00c020 --- /dev/null +++ b/tests/test_models/test_algorithms/test_single_teacher_distill.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest import TestCase + +import torch +from mmcv import ConfigDict +from toy_models import ToyStudent + +from mmrazor.models import SingleTeacherDistill + + +class TestSingleTeacherDistill(TestCase): + + def test_init(self): + + recorders_cfg = ConfigDict( + conv=dict(type='ModuleOutputs', source='conv')) + + alg_kwargs = ConfigDict( + architecture=dict(type='ToyStudent'), + teacher=dict(type='ToyTeacher'), + student_recorders=recorders_cfg, + teacher_recorders=recorders_cfg, + distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')), + loss_forward_mappings=dict( + loss_toy=dict( + arg1=dict(from_student=True, recorder='conv'), + arg2=dict(from_student=False, recorder='conv'), + )), + ) + + alg = SingleTeacherDistill(**alg_kwargs) + + teacher = ToyStudent() + alg_kwargs_ = copy.deepcopy(alg_kwargs) + alg_kwargs_['teacher'] = teacher + alg = SingleTeacherDistill(**alg_kwargs_) + self.assertEquals(alg.teacher, teacher) + + alg_kwargs_ = copy.deepcopy(alg_kwargs) + alg_kwargs_['teacher'] = 'teacher' + with self.assertRaisesRegex(TypeError, + 'teacher should be a `dict` or'): + _ = SingleTeacherDistill(**alg_kwargs_) + + def test_loss(self): + + recorders_cfg = ConfigDict( + conv=dict(type='ModuleOutputs', source='conv')) + + alg_kwargs = ConfigDict( + architecture=dict(type='ToyStudent'), + teacher=dict(type='ToyTeacher'), + student_recorders=recorders_cfg, + teacher_recorders=recorders_cfg, + distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')), + loss_forward_mappings=dict( + loss_toy=dict( + arg1=dict(from_student=True, recorder='conv'), + arg2=dict(from_student=False, recorder='conv'), + )), + ) + + img = torch.randn(1, 3, 1, 1) + + alg = SingleTeacherDistill(**alg_kwargs) + losses = alg(img, mode='loss') + self.assertIn('distill.loss_toy', losses) + self.assertIn('student.loss', losses) + + alg_kwargs_ = copy.deepcopy(alg_kwargs) + alg_kwargs_['teacher_trainable'] = True + alg = SingleTeacherDistill(**alg_kwargs_) + losses = alg(img, mode='loss') + self.assertIn('distill.loss_toy', losses) + self.assertIn('student.loss', losses) + self.assertIn('teacher.loss', losses) diff --git a/tests/test_models/test_algorithms/toy_models.py b/tests/test_models/test_algorithms/toy_models.py new file mode 100644 index 00000000..b3ed359d --- /dev/null +++ b/tests/test_models/test_algorithms/toy_models.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmengine.model import BaseModel +from torch import nn + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class ToyStudent(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor, init_cfg=None) + self.conv = nn.Conv2d(3, 1, 1) + + def forward(self, batch_inputs, data_samples=None, mode='tensor'): + if mode == 'loss': + out = self.conv(batch_inputs) + return dict(loss=out) + elif mode == 'predict': + out = self.conv(batch_inputs) + 1 + return out + elif mode == 'tensor': + out = self.conv(batch_inputs) + 2 + return out + + +@MODELS.register_module() +class ToyTeacher(ToyStudent): + + def __init__(self): + super().__init__() + + +@MODELS.register_module() +class ToyDistillLoss(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, arg1, arg2): + return arg1 + arg2 diff --git a/tests/test_runners/test_distill_val_loop.py b/tests/test_runners/test_distill_val_loop.py index 2512cfd0..548a4277 100644 --- a/tests/test_runners/test_distill_val_loop.py +++ b/tests/test_runners/test_distill_val_loop.py @@ -124,5 +124,4 @@ class TestSingleTeacherDistillValLoop(TestCase): runner = Runner.from_cfg(cfg) runner.val() - self.assertIn('val_student/acc', runner.message_hub.log_scalars.keys()) - self.assertIn('val_teacher/acc', runner.message_hub.log_scalars.keys()) + self.assertIn('val/teacher.acc', runner.message_hub.log_scalars.keys())