[Feature] Add Dsnas Algorithm (#226)

* [tmp] Update Dsnas

* [tmp] refactor arch_loss & flops_loss

* Update Dsnas & MMRAZOR_EVALUATOR:
1. finalized compute_loss & handle_grads in algorithm;
2. add MMRAZOR_EVALUATOR;
3. fix bugs.

* Update lr scheduler & fix a bug:
1. update param_scheduler & lr_scheduler for dsnas;
2. fix a bug of switching to finetune stage.

* remove old evaluators

* remove old evaluators

* update param_scheduler config

* merge dev-1.x into gy/estimator

* add flops_loss in Dsnas using ResourcesEstimator

* get resources before mutator.prepare_from_supernet

* delete unness broadcast api from gml

* broadcast spec_modules_resources when estimating

* update early fix mechanism for Dsnas

* fix merge

* update units in estimator

* minor change

* fix data_preprocessor api

* add flops_loss_coef

* remove DsnasOptimWrapper

* fix bn eps and data_preprocessor

* fix bn weight decay bug

* add betas for mutator optimizer

* set diff_rank_seed=True for dsnas

* fix start_factor of lr when warm up

* remove .module in non-ddp mode

* add GlobalAveragePoolingWithDropout

* add UT for dsnas

* remove unness channel adjustment for shufflenetv2

* update supernet configs

* delete unness dropout

* delete unness part with minor change on dsnas

* minor change on the flag of search stage

* update README and subnet configs

* add UT for OneHotMutableOP
pull/303/head
Yang Gao 2022-09-29 16:48:47 +08:00 committed by GitHub
parent d07dee9887
commit 8d603d917e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1187 additions and 30 deletions

View File

@ -0,0 +1,28 @@
norm_cfg = dict(type='BN', eps=0.01)
_STAGE_MUTABLE = dict(
type='mmrazor.OneHotMutableOP',
fix_threshold=0.3,
candidates=dict(
shuffle_3x3=dict(
type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg),
shuffle_5x5=dict(
type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg),
shuffle_7x7=dict(
type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg),
shuffle_xception=dict(type='ShuffleXception', norm_cfg=norm_cfg)))
arch_setting = [
# Parameters to build layers. 3 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, mutable_cfg.
[64, 4, _STAGE_MUTABLE],
[160, 4, _STAGE_MUTABLE],
[320, 8, _STAGE_MUTABLE],
[640, 4, _STAGE_MUTABLE]
]
nas_backbone = dict(
type='mmrazor.SearchableShuffleNetV2',
widen_factor=1.0,
arch_setting=arch_setting,
norm_cfg=norm_cfg)

View File

@ -0,0 +1,102 @@
# dataset settings
dataset_type = 'mmcls.ImageNet'
data_preprocessor = dict(
type='mmcls.ClsDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='mmcls.LoadImageFromFile'),
dict(type='mmcls.RandomResizedCrop', scale=224),
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'),
dict(type='mmcls.PackClsInputs'),
]
test_pipeline = [
dict(type='mmcls.LoadImageFromFile'),
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
dict(type='mmcls.CenterCrop', crop_size=224),
dict(type='mmcls.PackClsInputs'),
]
train_dataloader = dict(
batch_size=128,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=True),
persistent_workers=True,
)
val_dataloader = dict(
batch_size=128,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='mmcls.DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
# optimizer
paramwise_cfg = dict(bias_decay_mult=0.0, norm_decay_mult=0.0)
optim_wrapper = dict(
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
optimizer=dict(
type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5),
paramwise_cfg=paramwise_cfg),
mutator=dict(
optimizer=dict(
type='mmcls.Adam', lr=0.001, weight_decay=0.0, betas=(0.5,
0.999))))
search_epochs = 85
# leanring policy
param_scheduler = dict(
architecture=[
dict(
type='mmcls.LinearLR',
end=5,
start_factor=0.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='mmcls.CosineAnnealingLR',
T_max=240,
begin=5,
end=search_epochs,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='mmcls.CosineAnnealingLR',
T_max=160,
begin=search_epochs,
end=240,
eta_min=0.0,
by_epoch=True,
convert_to_iter_based=True)
],
mutator=[])
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=240)
val_cfg = dict()
test_cfg = dict()

View File

@ -0,0 +1,20 @@
backbone.layers.0.0: shuffle_3x3
backbone.layers.0.1: shuffle_3x3
backbone.layers.0.2: shuffle_xception
backbone.layers.0.3: shuffle_3x3
backbone.layers.1.0: shuffle_xception
backbone.layers.1.1: shuffle_7x7
backbone.layers.1.2: shuffle_3x3
backbone.layers.1.3: shuffle_3x3
backbone.layers.2.0: shuffle_xception
backbone.layers.2.1: shuffle_xception
backbone.layers.2.2: shuffle_7x7
backbone.layers.2.3: shuffle_xception
backbone.layers.2.4: shuffle_xception
backbone.layers.2.5: shuffle_xception
backbone.layers.2.6: shuffle_7x7
backbone.layers.2.7: shuffle_3x3
backbone.layers.3.0: shuffle_3x3
backbone.layers.3.1: shuffle_xception
backbone.layers.3.2: shuffle_xception
backbone.layers.3.3: shuffle_3x3

View File

@ -0,0 +1,43 @@
# DSNAS
> [DSNAS: Direct Neural Architecture Search without Parameter Retraining](https://arxiv.org/abs/2002.09128.pdf)
<!-- [ALGORITHM] -->
## Abstract
Most existing NAS methods require two-stage parameter optimization.
However, performance of the same architecture in the two stages correlates poorly.
Based on this observation, DSNAS proposes a task-specific end-to-end differentiable NAS framework that simultaneously optimizes architecture and parameters with a low-biased Monte Carlo estimate. Child networks derived from DSNAS can be deployed directly without parameter retraining.
![pipeline](/docs/en/imgs/model_zoo/dsnas/pipeline.jpg)
## Results and models
### Supernet
| Dataset | Params(M) | FLOPs (G) | Top-1 Acc (%) | Top-5 Acc (%) | Config | Download | Remarks |
| :------: | :-------: | :-------: | :-----------: | :-----------: | :---------------------------------------: | :----------------------: | :--------------: |
| ImageNet | 3.33 | 0.299 | 73.56 | 91.24 | [config](./dsnas_supernet_8xb128_in1k.py) | [model](<>) \| [log](<>) | MMRazor searched |
**Note**:
1. There **might be(not all the case)** some small differences in our experiment in order to be consistent with other repos in OpenMMLab. For example,
normalize images in data preprocessing; resize by cv2 rather than PIL in training; dropout is not used in network. **Please refer to corresponding config for details.**
2. We convert the official searched checkpoint DSNASsearch240.pth into mmrazor-style and evaluate with pytorch1.8_cuda11.0, Top-1 is 74.1 and Top-5 is 91.51.
3. The implementation of ShuffleNetV2 in official DSNAS is different from OpenMMLab's and we follow the structure design in OpenMMLab. Note that with the
origin ShuffleNetV2 design in official DSNAS, the Top-1 is 73.92 and Top-5 is 91.59.
4. The finetune stage in our implementation refers to the 'search-from-search' stage mentioned in official DSNAS.
5. We obtain params and FLOPs using `mmrazor.ResourceEstimator`, which may be different from the origin repo.
## Citation
```latex
@inproceedings{hu2020dsnas,
title={Dsnas: Direct neural architecture search without parameter retraining},
author={Hu, Shoukang and Xie, Sirui and Zheng, Hehui and Liu, Chunxiao and Shi, Jianping and Liu, Xunying and Lin, Dahua},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={12084--12092},
year={2020}
}
```

View File

@ -0,0 +1,29 @@
_base_ = ['./dsnas_supernet_8xb128_in1k.py']
# NOTE: Replace this with the mutable_cfg searched by yourself.
fix_subnet = {
'backbone.layers.0.0': 'shuffle_3x3',
'backbone.layers.0.1': 'shuffle_7x7',
'backbone.layers.0.2': 'shuffle_3x3',
'backbone.layers.0.3': 'shuffle_5x5',
'backbone.layers.1.0': 'shuffle_3x3',
'backbone.layers.1.1': 'shuffle_3x3',
'backbone.layers.1.2': 'shuffle_3x3',
'backbone.layers.1.3': 'shuffle_7x7',
'backbone.layers.2.0': 'shuffle_xception',
'backbone.layers.2.1': 'shuffle_3x3',
'backbone.layers.2.2': 'shuffle_3x3',
'backbone.layers.2.3': 'shuffle_5x5',
'backbone.layers.2.4': 'shuffle_3x3',
'backbone.layers.2.5': 'shuffle_5x5',
'backbone.layers.2.6': 'shuffle_7x7',
'backbone.layers.2.7': 'shuffle_7x7',
'backbone.layers.3.0': 'shuffle_xception',
'backbone.layers.3.1': 'shuffle_3x3',
'backbone.layers.3.2': 'shuffle_7x7',
'backbone.layers.3.3': 'shuffle_3x3',
}
model = dict(fix_subnet=fix_subnet)
find_unused_parameters = False

View File

@ -0,0 +1,36 @@
_base_ = [
'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py',
'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py',
'mmcls::_base_/default_runtime.py',
]
# model
model = dict(
type='mmrazor.Dsnas',
architecture=dict(
type='ImageClassifier',
data_preprocessor=_base_.data_preprocessor,
backbone=_base_.nas_backbone,
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(
type='LabelSmoothLoss',
num_classes=1000,
label_smooth_val=0.1,
mode='original',
loss_weight=1.0),
topk=(1, 5))),
mutator=dict(type='mmrazor.DiffModuleMutator'),
pretrain_epochs=15,
finetune_epochs=_base_.search_epochs,
)
model_wrapper_cfg = dict(
type='mmrazor.DsnasDDP',
broadcast_buffers=False,
find_unused_parameters=True)
randomness = dict(seed=48, diff_rank_seed=True)

View File

@ -10,6 +10,6 @@ __all__ = [
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop',
'EstimateResourcesHook'
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook',
'SelfDistillValLoop'
]

View File

@ -3,12 +3,13 @@ from .base import BaseAlgorithm
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
FpnTeacherDistill, OverhaulFeatureDistillation,
SelfDistill, SingleTeacherDistill)
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP, Dsnas, DsnasDDP
from .pruning import SlimmableNetwork, SlimmableNetworkDDP
__all__ = [
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation'
'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'Dsnas',
'DsnasDDP'
]

View File

@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .autoslim import AutoSlim, AutoSlimDDP
from .darts import Darts, DartsDDP
from .dsnas import Dsnas, DsnasDDP
from .spos import SPOS
__all__ = ['SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP']
__all__ = [
'SPOS', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'Dsnas', 'DsnasDDP'
]

View File

@ -0,0 +1,347 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
from typing import Any, Dict, List, Optional, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from mmengine.dist import get_dist_info
from mmengine.logging import MessageHub
from mmengine.model import BaseModel, MMDistributedDataParallel
from mmengine.optim import OptimWrapper, OptimWrapperDict
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.models.mutators import DiffModuleMutator
from mmrazor.models.utils import add_prefix
from mmrazor.registry import MODEL_WRAPPERS, MODELS, TASK_UTILS
from mmrazor.structures import load_fix_subnet
from mmrazor.utils import FixMutable
from ..base import BaseAlgorithm
@MODELS.register_module()
class Dsnas(BaseAlgorithm):
"""Implementation of `DSNAS <https://arxiv.org/abs/2002.09128>`_
Args:
architecture (dict|:obj:`BaseModel`): The config of :class:`BaseModel`
or built model. Corresponding to supernet in NAS algorithm.
mutator (dict|:obj:`DiffModuleMutator`): The config of
:class:`DiffModuleMutator` or built mutator.
fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or
loaded dict or built :obj:`FixSubnet`.
pretrain_epochs (int): Num of epochs for supernet pretraining.
finetune_epochs (int): Num of epochs for subnet finetuning.
flops_constraints (float): Flops constraints for judging whether to
backward flops loss or not. Default to 300.0(M).
estimator_cfg (Dict[str, Any]): Used for building a resource estimator.
Default to None.
norm_training (bool): Whether to set norm layers to training mode,
namely, not freeze running stats (mean and var). Note: Effect on
Batch Norm and its variants only. Defaults to False.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`. Defaults to None.
init_cfg (dict): Init config for ``BaseModule``.
Note:
Dsnas doesn't require retraining. It has 3 stages in searching:
1. `cur_epoch` < `pretrain_epochs` refers to supernet pretraining.
2. `pretrain_epochs` <= `cur_epoch` < `finetune_epochs` refers to
normal supernet training while mutator is updated.
3. `cur_epoch` >= `finetune_epochs` refers to subnet finetuning.
"""
def __init__(self,
architecture: Union[BaseModel, Dict],
mutator: Optional[Union[DiffModuleMutator, Dict]] = None,
fix_subnet: Optional[FixMutable] = None,
pretrain_epochs: int = 0,
finetune_epochs: int = 80,
flops_constraints: float = 300.0,
estimator_cfg: Dict[str, Any] = None,
norm_training: bool = False,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None,
**kwargs):
super().__init__(architecture, data_preprocessor, **kwargs)
if estimator_cfg is None:
estimator_cfg = dict(type='mmrazor.ResourceEstimator')
self.estimator = TASK_UTILS.build(estimator_cfg)
if fix_subnet:
# Avoid circular import
from mmrazor.structures import load_fix_subnet
# According to fix_subnet, delete the unchosen part of supernet
load_fix_subnet(self.architecture, fix_subnet)
self.is_supernet = False
else:
assert mutator is not None, \
'mutator cannot be None when fix_subnet is None.'
if isinstance(mutator, DiffModuleMutator):
self.mutator = mutator
elif isinstance(mutator, dict):
self.mutator = MODELS.build(mutator)
else:
raise TypeError('mutator should be a `dict` or '
f'`DiffModuleMutator` instance, but got '
f'{type(mutator)}')
self.mutable_module_resources = self._get_module_resources()
# Mutator is an essential component of the NAS algorithm. It
# provides some APIs commonly used by NAS.
# Before using it, you must do some preparations according to
# the supernet.
self.mutator.prepare_from_supernet(self.architecture)
self.is_supernet = True
self.search_space_name_list = list(
self.mutator.name2mutable.keys())
self.norm_training = norm_training
self.pretrain_epochs = pretrain_epochs
self.finetune_epochs = finetune_epochs
if pretrain_epochs >= finetune_epochs:
raise ValueError(f'Pretrain stage (optional) must be done before '
f'finetuning stage. Got `{pretrain_epochs}` >= '
f'`{finetune_epochs}`.')
self.flops_loss_coef = 1e-2
self.flops_constraints = flops_constraints
_, self.world_size = get_dist_info()
def search_subnet(self):
"""Search subnet by mutator."""
# Avoid circular import
from mmrazor.structures import export_fix_subnet
subnet = self.mutator.sample_choices()
self.mutator.set_choices(subnet)
return export_fix_subnet(self)
def fix_subnet(self):
"""Fix subnet when finetuning."""
subnet = self.mutator.sample_choices()
self.mutator.set_choices(subnet)
for module in self.architecture.modules():
if isinstance(module, BaseMutable):
if not module.is_fixed:
module.fix_chosen(module.current_choice)
self.is_supernet = False
def train(self, mode=True):
"""Convert the model into eval mode while keep normalization layer
unfreezed."""
super().train(mode)
if self.norm_training and not mode:
for module in self.architecture.modules():
if isinstance(module, _BatchNorm):
module.training = True
def train_step(self, data: List[dict],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""The iteration step during training.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``.
"""
if isinstance(optim_wrapper, OptimWrapperDict):
log_vars = dict()
self.message_hub = MessageHub.get_current_instance()
cur_epoch = self.message_hub.get_info('epoch')
need_update_mutator = self.need_update_mutator(cur_epoch)
# TODO process the input
if cur_epoch == self.finetune_epochs and self.is_supernet:
# synchronize arch params to start the finetune stage.
for k, v in self.mutator.arch_params.items():
dist.broadcast(v, src=0)
self.fix_subnet()
# 1. update architecture
with optim_wrapper['architecture'].optim_context(self):
pseudo_data = self.data_preprocessor(data, True)
supernet_batch_inputs = pseudo_data['inputs']
supernet_data_samples = pseudo_data['data_samples']
supernet_loss = self(
supernet_batch_inputs, supernet_data_samples, mode='loss')
supernet_losses, supernet_log_vars = self.parse_losses(
supernet_loss)
optim_wrapper['architecture'].backward(
supernet_losses, retain_graph=need_update_mutator)
optim_wrapper['architecture'].step()
optim_wrapper['architecture'].zero_grad()
log_vars.update(add_prefix(supernet_log_vars, 'supernet'))
# 2. update mutator
if need_update_mutator:
with optim_wrapper['mutator'].optim_context(self):
mutator_loss = self.compute_mutator_loss()
mutator_losses, mutator_log_vars = \
self.parse_losses(mutator_loss)
optim_wrapper['mutator'].update_params(mutator_losses)
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
# handle the grad of arch params & weights
self.handle_grads()
else:
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
pseudo_data = self.data_preprocessor(data, True)
batch_inputs = pseudo_data['inputs']
data_samples = pseudo_data['data_samples']
losses = self(batch_inputs, data_samples, mode='loss')
parsed_losses, log_vars = self.parse_losses(losses)
optim_wrapper.update_params(parsed_losses)
return log_vars
def _get_module_resources(self):
"""Get resources of spec modules."""
spec_modules = []
for name, module in self.architecture.named_modules():
if isinstance(module, BaseMutable):
for choice in module.choices:
spec_modules.append(name + '._candidates.' + choice)
mutable_module_resources = self.estimator.estimate_separation_modules(
self.architecture, dict(spec_modules=spec_modules))
return mutable_module_resources
def need_update_mutator(self, cur_epoch: int) -> bool:
"""Whether to update mutator."""
if cur_epoch >= self.pretrain_epochs and \
cur_epoch < self.finetune_epochs:
return True
return False
def compute_mutator_loss(self) -> Dict[str, torch.Tensor]:
"""Compute mutator loss.
In this method, arch_loss & flops_loss[optional] are computed
by traversing arch_weights & probs in search groups.
Returns:
Dict: Loss of the mutator.
"""
arch_loss = 0.0
flops_loss = 0.0
for name, module in self.architecture.named_modules():
if isinstance(module, BaseMutable):
k = str(self.search_space_name_list.index(name))
probs = F.softmax(self.mutator.arch_params[k], -1)
arch_loss += torch.log(
(module.arch_weights * probs).sum(-1)).sum()
# get the index of op with max arch weights.
index = (module.arch_weights == 1).nonzero().item()
_module_key = name + '._candidates.' + module.choices[index]
flops_loss += probs[index] * \
self.mutable_module_resources[_module_key]['flops']
mutator_loss = dict(arch_loss=arch_loss / self.world_size)
copied_model = copy.deepcopy(self)
fix_mutable = copied_model.search_subnet()
load_fix_subnet(copied_model, fix_mutable)
subnet_flops = self.estimator.estimate(copied_model)['flops']
if subnet_flops >= self.flops_constraints:
mutator_loss['flops_loss'] = \
(flops_loss * self.flops_loss_coef) / self.world_size
return mutator_loss
def handle_grads(self):
"""Handle grads of arch params & arch weights."""
for name, module in self.architecture.named_modules():
if isinstance(module, BaseMutable):
k = str(self.search_space_name_list.index(name))
self.mutator.arch_params[k].grad.data.mul_(
module.arch_weights.grad.data.sum())
module.arch_weights.grad.zero_()
@MODEL_WRAPPERS.register_module()
class DsnasDDP(MMDistributedDataParallel):
def __init__(self,
*,
device_ids: Optional[Union[List, int, torch.device]] = None,
**kwargs) -> None:
if device_ids is None:
if os.environ.get('LOCAL_RANK') is not None:
device_ids = [int(os.environ['LOCAL_RANK'])]
super().__init__(device_ids=device_ids, **kwargs)
def train_step(self, data: List[dict],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating are also defined in
this method, such as GAN.
"""
if isinstance(optim_wrapper, OptimWrapperDict):
log_vars = dict()
self.message_hub = MessageHub.get_current_instance()
cur_epoch = self.message_hub.get_info('epoch')
need_update_mutator = self.module.need_update_mutator(cur_epoch)
# TODO process the input
if cur_epoch == self.module.finetune_epochs and \
self.module.is_supernet:
# synchronize arch params to start the finetune stage.
for k, v in self.module.mutator.arch_params.items():
dist.broadcast(v, src=0)
self.module.fix_subnet()
# 1. update architecture
with optim_wrapper['architecture'].optim_context(self):
pseudo_data = self.module.data_preprocessor(data, True)
supernet_batch_inputs = pseudo_data['inputs']
supernet_data_samples = pseudo_data['data_samples']
supernet_loss = self(
supernet_batch_inputs, supernet_data_samples, mode='loss')
supernet_losses, supernet_log_vars = self.module.parse_losses(
supernet_loss)
optim_wrapper['architecture'].backward(
supernet_losses, retain_graph=need_update_mutator)
optim_wrapper['architecture'].step()
optim_wrapper['architecture'].zero_grad()
log_vars.update(add_prefix(supernet_log_vars, 'supernet'))
# 2. update mutator
if need_update_mutator:
with optim_wrapper['mutator'].optim_context(self):
mutator_loss = self.module.compute_mutator_loss()
mutator_losses, mutator_log_vars = \
self.module.parse_losses(mutator_loss)
optim_wrapper['mutator'].update_params(mutator_losses)
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
# handle the grad of arch params & weights
self.module.handle_grads()
else:
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
pseudo_data = self.module.data_preprocessor(data, True)
batch_inputs = pseudo_data['inputs']
data_samples = pseudo_data['data_samples']
losses = self(batch_inputs, data_samples, mode='loss')
parsed_losses, log_vars = self.module.parse_losses(losses)
optim_wrapper.update_params(parsed_losses)
return log_vars

View File

@ -3,12 +3,13 @@ from .derived_mutable import DerivedMutable
from .mutable_channel import (MutableChannel, OneShotMutableChannel,
SlimmableMutableChannel)
from .mutable_module import (DiffChoiceRoute, DiffMutableModule, DiffMutableOP,
OneShotMutableModule, OneShotMutableOP)
OneHotMutableOP, OneShotMutableModule,
OneShotMutableOP)
from .mutable_value import MutableValue, OneShotMutableValue
__all__ = [
'OneShotMutableOP', 'OneShotMutableModule', 'DiffMutableOP',
'DiffChoiceRoute', 'DiffMutableModule', 'OneShotMutableChannel',
'SlimmableMutableChannel', 'MutableChannel', 'DerivedMutable',
'MutableValue', 'OneShotMutableValue'
'MutableValue', 'OneShotMutableValue', 'OneHotMutableOP'
]

View File

@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .diff_mutable_module import (DiffChoiceRoute, DiffMutableModule,
DiffMutableOP)
DiffMutableOP, OneHotMutableOP)
from .mutable_module import MutableModule
from .one_shot_mutable_module import OneShotMutableModule, OneShotMutableOP
__all__ = [
'DiffMutableModule', 'DiffMutableOP', 'DiffChoiceRoute',
'OneShotMutableOP', 'OneShotMutableModule', 'MutableModule'
'OneShotMutableOP', 'OneShotMutableModule', 'MutableModule',
'OneHotMutableOP'
]

View File

@ -37,8 +37,9 @@ class DiffMutableModule(MutableModule[CHOICE_TYPE, CHOSEN_TYPE]):
def forward(self,
x: Any,
arch_param: Optional[nn.Parameter] = None) -> Any:
"""Calls either :func:`forward_fixed` or :func:`forward_choice`
depending on whether :func:`is_fixed` is ``True``.
"""Calls either :func:`forward_fixed` or :func:`forward_arch_param`
depending on whether :func:`is_fixed` is ``True`` and whether
:func:`arch_param` is None.
To reduce the coupling between `Mutable` and `Mutator`, the
`arch_param` is generated by the `Mutator` and is passed to the
@ -52,6 +53,9 @@ class DiffMutableModule(MutableModule[CHOICE_TYPE, CHOSEN_TYPE]):
x (Any): input data for forward computation.
arch_param (nn.Parameter, optional): the architecture parameters
for ``DiffMutableModule``.
Returns:
Any: the result of forward
"""
if self.is_fixed:
return self.forward_fixed(x)
@ -97,6 +101,10 @@ class DiffMutableOP(DiffMutableModule[str, str]):
Args:
candidates (dict[str, dict]): the configs for the candidate
operations.
fix_threshold (float): The threshold that determines whether to fix
the choice of current module as the op with the maximum `probs`.
It happens when the maximum prob is `fix_threshold` or more higher
then all the other probs. Default to 1.0.
module_kwargs (dict[str, dict], optional): Module initialization named
arguments. Defaults to None.
alias (str, optional): alias of the `MUTABLE`.
@ -109,6 +117,7 @@ class DiffMutableOP(DiffMutableModule[str, str]):
def __init__(
self,
candidates: Dict[str, Dict],
fix_threshold: float = 1.0,
module_kwargs: Optional[Dict[str, Dict]] = None,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None,
@ -120,6 +129,10 @@ class DiffMutableOP(DiffMutableModule[str, str]):
f'but got: {len(candidates)}'
self._is_fixed = False
if fix_threshold < 0 or fix_threshold > 1.0:
raise ValueError(
f'The fix_threshold should be in [0, 1]. Got {fix_threshold}.')
self.fix_threshold = fix_threshold
self._candidates = self._build_ops(candidates, self.module_kwargs)
@staticmethod
@ -242,6 +255,94 @@ class DiffMutableOP(DiffMutableModule[str, str]):
return list(self._candidates.keys())
@MODELS.register_module()
class OneHotMutableOP(DiffMutableOP):
"""A type of ``MUTABLES`` for one-hot sample based architecture search,
such as DSNAS. Search the best module by learnable parameters `arch_param`.
Args:
candidates (dict[str, dict]): the configs for the candidate
operations.
module_kwargs (dict[str, dict], optional): Module initialization named
arguments. Defaults to None.
alias (str, optional): alias of the `MUTABLE`.
init_cfg (dict, optional): initialization configuration dict for
``BaseModule``. OpenMMLab has implement 5 initializer including
`Constant`, `Xavier`, `Normal`, `Uniform`, `Kaiming`,
and `Pretrained`.
"""
def sample_weights(self,
arch_param: nn.Parameter,
probs: torch.Tensor,
random_sample: bool = False) -> Tensor:
"""Use one-hot distributions to sample the arch weights based on the
arch params.
Args:
arch_param (nn.Parameter): architecture parameters for
`DiffMutableModule`.
probs (Tensor): the probs of choice.
random_sample (bool): Whether to random sample arch weights or not
Defaults to False.
Returns:
Tensor: Sampled one-hot arch weights.
"""
import torch.distributions as D
if random_sample:
uni = torch.ones_like(arch_param)
m = D.one_hot_categorical.OneHotCategorical(uni)
else:
m = D.one_hot_categorical.OneHotCategorical(probs=probs)
return m.sample()
def forward_arch_param(self,
x: Any,
arch_param: Optional[nn.Parameter] = None
) -> Tensor:
"""Forward with architecture parameters.
Args:
x (Any): x could be a Torch.tensor or a tuple of
Torch.tensor, containing input data for forward computation.
arch_param (str, optional): architecture parameters for
`DiffMutableModule`.
Returns:
Tensor: the result of forward with ``arch_param``.
"""
if arch_param is None:
return self.forward_all(x)
else:
# compute the probs of choice
probs = self.compute_arch_probs(arch_param=arch_param)
if not self.is_fixed:
self.arch_weights = self.sample_weights(arch_param, probs)
sorted_param = torch.topk(probs, 2)
index = (
sorted_param[0][0] - sorted_param[0][1] >=
self.fix_threshold)
if index:
self.fix_chosen(self.choices[index])
if self.is_fixed:
index = self.choices.index(self._chosen[0])
self.arch_weights.data.zero_()
self.arch_weights.data[index].fill_(1.0)
self.arch_weights.requires_grad_()
# forward based on self.arch_weights
outputs = list()
for prob, module in zip(self.arch_weights,
self._candidates.values()):
if prob > 0.:
outputs.append(prob * module(x))
return sum(outputs)
@MODELS.register_module()
class DiffChoiceRoute(DiffMutableModule[str, List[str]]):
"""A type of ``MUTABLES`` for Neural Architecture Search, which can select

View File

@ -88,8 +88,8 @@ class DiffModuleMutator(ModuleMutator):
choices = dict()
for group_id, mutables in self.search_groups.items():
arch_parm = self.arch_params[str(group_id)]
choice = mutables[0].sample_choice(arch_parm)
arch_param = self.arch_params[str(group_id)]
choice = mutables[0].sample_choice(arch_param)
choices[group_id] = choice
return choices

View File

@ -54,6 +54,24 @@ class ModuleMutator(BaseMutator[MUTABLE_TYPE]):
"""
self._build_search_groups(supernet)
@property
def name2mutable(self) -> Dict[str, MUTABLE_TYPE]:
"""Search space of supernet.
Note:
To get the mapping: module name to mutable.
Raises:
RuntimeError: Called before search space has been parsed.
Returns:
Dict[str, MUTABLE_TYPE]: The name2mutable dict.
"""
if self._name2mutable is None:
raise RuntimeError(
'Call `prepare_from_supernet` before access name2mutable!')
return self._name2mutable
@property
def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]:
"""Search group of supernet.
@ -80,6 +98,8 @@ class ModuleMutator(BaseMutator[MUTABLE_TYPE]):
for name, module in supernet.named_modules():
if isinstance(module, self.mutable_class_type):
name2mutable[name] = module
self._name2mutable = name2mutable
return name2mutable
def _build_alias_names_mapping(self,
@ -121,7 +141,7 @@ class ModuleMutator(BaseMutator[MUTABLE_TYPE]):
>>> import torch
>>> from mmrazor.models.mutables.diff_mutable import DiffMutableOP
>>> # Assume that a toy model consists of three mutabels
>>> # Assume that a toy model consists of three mutables
>>> # whose name are op1,op2,op3. The corresponding
>>> # alias names of the three mutables are a1, a1, a2.
>>> model = ToyModel()

View File

@ -50,21 +50,22 @@ def load_fix_subnet(model: nn.Module,
# In the corresponding mutable, it will check whether the `chosen`
# format is correct.
if isinstance(module, BaseMutable):
if getattr(module, 'alias', None):
alias = module.alias
assert alias in fix_mutable, \
f'The alias {alias} is not in fix_modules, ' \
'please check your `fix_mutable`.'
chosen = fix_mutable.get(alias, None)
else:
mutable_name = name.lstrip(prefix)
if mutable_name not in fix_mutable and \
not isinstance(module, DerivedMutable):
raise RuntimeError(
f'The module name {mutable_name} is not in '
'fix_mutable, please check your `fix_mutable`.')
chosen = fix_mutable.get(mutable_name, None)
module.fix_chosen(chosen)
if not module.is_fixed:
if getattr(module, 'alias', None):
alias = module.alias
assert alias in fix_mutable, \
f'The alias {alias} is not in fix_modules, ' \
'please check your `fix_mutable`.'
chosen = fix_mutable.get(alias, None)
else:
mutable_name = name.lstrip(prefix)
if mutable_name not in fix_mutable and \
not isinstance(module, DerivedMutable):
raise RuntimeError(
f'The module name {mutable_name} is not in '
'fix_mutable, please check your `fix_mutable`.')
chosen = fix_mutable.get(mutable_name, None)
module.fix_chosen(chosen)
# convert dynamic op to static op
_dynamic_to_static(model)
@ -89,7 +90,6 @@ def export_fix_subnet(model: nn.Module,
if isinstance(module, DerivedMutable) and not dump_derived_mutable:
continue
assert not module.is_fixed
if module.alias:
fix_subnet[module.alias] = module.dump_chosen()
else:

View File

@ -0,0 +1,222 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from unittest import TestCase
from unittest.mock import patch
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcls.structures import ClsDataSample
from mmengine.model import BaseModel
from mmengine.optim import build_optim_wrapper
from mmengine.optim.optimizer import OptimWrapper, OptimWrapperDict
from torch import Tensor
from torch.optim import SGD
from mmrazor.models import DiffModuleMutator, Dsnas, OneHotMutableOP
from mmrazor.models.algorithms.nas.dsnas import DsnasDDP
from mmrazor.registry import MODELS
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
@MODELS.register_module()
class ToyDiffModule(BaseModel):
def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
self.candidates = dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
padding=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
padding=2,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
padding=3,
),
)
module_kwargs = dict(in_channels=3, out_channels=8, stride=1)
self.mutable = OneHotMutableOP(
candidates=self.candidates, module_kwargs=module_kwargs)
self.bn = nn.BatchNorm2d(8)
def forward(self, batch_inputs, data_samples=None, mode='tensor'):
if mode == 'loss':
out = self.bn(self.mutable(batch_inputs))
return dict(loss=out)
elif mode == 'predict':
out = self.bn(self.mutable(batch_inputs)) + 1
return out
elif mode == 'tensor':
out = self.bn(self.mutable(batch_inputs)) + 2
return out
class TestDsnas(TestCase):
def setUp(self) -> None:
self.device: str = 'cpu'
OPTIMIZER_CFG = dict(
type='SGD',
lr=0.5,
momentum=0.9,
nesterov=True,
weight_decay=0.0001)
self.OPTIM_WRAPPER_CFG = dict(optimizer=OPTIMIZER_CFG)
def test_init(self) -> None:
# initiate dsnas when `norm_training` is True.
model = ToyDiffModule()
mutator = DiffModuleMutator()
algo = Dsnas(architecture=model, mutator=mutator, norm_training=True)
algo.eval()
self.assertTrue(model.bn.training)
# initiate Dsnas with built mutator
model = ToyDiffModule()
mutator = DiffModuleMutator()
algo = Dsnas(model, mutator)
self.assertIs(algo.mutator, mutator)
# initiate Dsnas with unbuilt mutator
mutator = dict(type='DiffModuleMutator')
algo = Dsnas(model, mutator)
self.assertIsInstance(algo.mutator, DiffModuleMutator)
# initiate Dsnas when `fix_subnet` is not None
fix_subnet = {'mutable': 'torch_conv2d_5x5'}
algo = Dsnas(model, mutator, fix_subnet=fix_subnet)
self.assertEqual(algo.architecture.mutable.num_choices, 1)
# initiate Dsnas with error type `mutator`
with self.assertRaisesRegex(TypeError, 'mutator should be'):
Dsnas(model, model)
def test_forward_loss(self) -> None:
inputs = torch.randn(1, 3, 8, 8)
model = ToyDiffModule()
# supernet
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
algo = Dsnas(model, mutator)
loss = algo(inputs, mode='loss')
self.assertIsInstance(loss, dict)
# subnet
fix_subnet = {'mutable': 'torch_conv2d_5x5'}
algo = Dsnas(model, fix_subnet=fix_subnet)
loss = algo(inputs, mode='loss')
self.assertIsInstance(loss, dict)
def _prepare_fake_data(self):
imgs = torch.randn(16, 3, 224, 224).to(self.device)
data_samples = [
ClsDataSample().set_gt_label(torch.randint(0, 1000,
(16, ))).to(self.device)
]
return {'inputs': imgs, 'data_samples': data_samples}
def test_search_subnet(self) -> None:
model = ToyDiffModule()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
algo = Dsnas(model, mutator)
subnet = algo.search_subnet()
self.assertIsInstance(subnet, dict)
@patch('mmengine.logging.message_hub.MessageHub.get_info')
def test_dsnas_train_step(self, mock_get_info) -> None:
model = ToyDiffModule()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
mock_get_info.return_value = 2
algo = Dsnas(model, mutator)
data = self._prepare_fake_data()
optim_wrapper = build_optim_wrapper(algo, self.OPTIM_WRAPPER_CFG)
loss = algo.train_step(data, optim_wrapper)
self.assertTrue(isinstance(loss['loss'], Tensor))
algo = Dsnas(model, mutator)
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
loss = algo.train_step(data, optim_wrapper_dict)
self.assertIsNotNone(loss)
class TestDsnasDDP(TestDsnas):
@classmethod
def setUpClass(cls) -> None:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
# initialize the process group
if torch.cuda.is_available():
backend = 'nccl'
cls.device = 'cuda'
else:
backend = 'gloo'
dist.init_process_group(backend, rank=0, world_size=1)
def prepare_model(self, device_ids=None) -> Dsnas:
model = ToyDiffModule().to(self.device)
mutator = DiffModuleMutator().to(self.device)
mutator.prepare_from_supernet(model)
algo = Dsnas(model, mutator)
return DsnasDDP(
module=algo, find_unused_parameters=True, device_ids=device_ids)
@classmethod
def tearDownClass(cls) -> None:
dist.destroy_process_group()
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='cuda device is not avaliable')
def test_init(self) -> None:
ddp_model = self.prepare_model()
self.assertIsInstance(ddp_model, DsnasDDP)
@patch('mmengine.logging.message_hub.MessageHub.get_info')
def test_dsnasddp_train_step(self, mock_get_info) -> None:
model = ToyDiffModule()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
mock_get_info.return_value = 2
algo = Dsnas(model, mutator)
ddp_model = DsnasDDP(module=algo, find_unused_parameters=True)
data = self._prepare_fake_data()
optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG)
loss = ddp_model.train_step(data, optim_wrapper)
self.assertIsNotNone(loss)
algo = Dsnas(model, mutator)
ddp_model = DsnasDDP(module=algo, find_unused_parameters=True)
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
loss = ddp_model.train_step(data, optim_wrapper_dict)
self.assertIsNotNone(loss)

View File

@ -0,0 +1,203 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
import torch.nn as nn
from mmrazor.models import * # noqa:F403,F401
from mmrazor.registry import MODELS
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
class TestOneHotOP(TestCase):
def test_forward_arch_param(self):
op_cfg = dict(
type='mmrazor.OneHotMutableOP',
candidates=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
padding=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
padding=2,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
padding=3,
),
),
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
op = MODELS.build(op_cfg)
input = torch.randn(4, 32, 64, 64)
arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates'])))
output = op.forward_arch_param(input, arch_param=arch_param)
assert output is not None
output = op.forward_arch_param(input, arch_param=None)
assert output is not None
# test when some element of arch_param is 0
arch_param = nn.Parameter(torch.ones(op.num_choices))
output = op.forward_arch_param(input, arch_param=arch_param)
assert output is not None
def test_forward_fixed(self):
op_cfg = dict(
type='mmrazor.OneHotMutableOP',
candidates=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
),
),
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
op = MODELS.build(op_cfg)
input = torch.randn(4, 32, 64, 64)
op.fix_chosen('torch_conv2d_7x7')
output = op.forward_fixed(input)
assert output is not None
assert op.is_fixed is True
def test_forward(self):
op_cfg = dict(
type='mmrazor.OneHotMutableOP',
candidates=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
padding=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
padding=2,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
padding=3,
),
),
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
op = MODELS.build(op_cfg)
input = torch.randn(4, 32, 64, 64)
# test set_forward_args
arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates'])))
op.set_forward_args(arch_param=arch_param)
output = op.forward(input)
assert output is not None
# test dump_chosen
with pytest.raises(AssertionError):
op.dump_chosen()
# test forward when is_fixed is True
op.fix_chosen('torch_conv2d_7x7')
output = op.forward(input)
def test_property(self):
op_cfg = dict(
type='mmrazor.OneHotMutableOP',
candidates=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
padding=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
padding=2,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
padding=3,
),
),
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
op = MODELS.build(op_cfg)
assert len(op.choices) == 3
# test is_fixed propty
assert op.is_fixed is False
# test is_fixed setting
op.fix_chosen('torch_conv2d_5x5')
with pytest.raises(AttributeError):
op.is_fixed = True
# test fix choice when is_fixed is True
with pytest.raises(AttributeError):
op.fix_chosen('torch_conv2d_3x3')
def test_module_kwargs(self):
op_cfg = dict(
type='mmrazor.OneHotMutableOP',
candidates=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
in_channels=32,
out_channels=32,
stride=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
in_channels=32,
out_channels=32,
stride=1,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
in_channels=32,
out_channels=32,
stride=1,
),
torch_maxpool_3x3=dict(
type='torchMaxPool2d',
kernel_size=3,
stride=1,
),
torch_avgpool_3x3=dict(
type='torchAvgPool2d',
kernel_size=3,
stride=1,
),
),
)
op = MODELS.build(op_cfg)
input = torch.randn(4, 32, 64, 64)
op.fix_chosen('torch_avgpool_3x3')
output = op.forward(input)
assert output is not None