[Feature] Support unroll with MMDDP in darts algorithm (#210)
* support unroll in darts * fix bugs in optimizer; add docstring * update darts algorithm [untested] * modify autograd.grad to optim_wrapper.backward * add amp in train.py; support constructor * rename mmcls.data to mmcls.structures * modify darts algo to support apex [not done] * fix code spell in diff_mutable_module * modify optim_context of dartsddp * add testcase for dartsddp * fix bugs of apex in dartsddp * standardized the unittest of darts * adapt new data_preprocessor * fix ut bugs * remove unness code Co-authored-by: gaoyang07 <1546308416@qq.com>pull/326/head
parent
0409adc31f
commit
dd51ab8ca0
|
@ -33,8 +33,11 @@ optim_wrapper = dict(
|
|||
_delete_=True,
|
||||
constructor='mmrazor.SeparateOptimWrapperConstructor',
|
||||
architecture=dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
|
||||
clip_grad=dict(max_norm=5, norm_type=2)),
|
||||
mutator=dict(optimizer=dict(type='Adam', lr=3e-4, weight_decay=1e-3)))
|
||||
mutator=dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='Adam', lr=3e-4, weight_decay=1e-3)))
|
||||
|
||||
find_unused_parameter = False
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
|
@ -82,12 +83,10 @@ class Darts(BaseAlgorithm):
|
|||
self.is_supernet = True
|
||||
|
||||
self.norm_training = norm_training
|
||||
# TODO support unroll
|
||||
self.unroll = unroll
|
||||
|
||||
def search_subnet(self):
|
||||
"""Search subnet by mutator."""
|
||||
|
||||
# Avoid circular import
|
||||
from mmrazor.structures import export_fix_subnet
|
||||
|
||||
|
@ -136,23 +135,38 @@ class Darts(BaseAlgorithm):
|
|||
assert len(data) == len(optim_wrapper), \
|
||||
f'The length of data ({len(data)}) should be equal to that '\
|
||||
f'of optimizers ({len(optim_wrapper)}).'
|
||||
# TODO check the order of data
|
||||
|
||||
supernet_data, mutator_data = data
|
||||
|
||||
log_vars = dict()
|
||||
# TODO support unroll
|
||||
with optim_wrapper['mutator'].optim_context(self):
|
||||
mutator_batch_inputs, mutator_data_samples = \
|
||||
self.data_preprocessor(mutator_data, True)
|
||||
mutator_loss = self(
|
||||
mutator_batch_inputs, mutator_data_samples, mode='loss')
|
||||
mutator_losses, mutator_log_vars = self.parse_losses(mutator_loss)
|
||||
optim_wrapper['mutator'].update_params(mutator_losses)
|
||||
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
|
||||
|
||||
# Update the parameter of mutator
|
||||
if self.unroll:
|
||||
with optim_wrapper['mutator'].optim_context(self):
|
||||
optim_wrapper['mutator'].zero_grad()
|
||||
mutator_log_vars = self._unrolled_backward(
|
||||
mutator_data, supernet_data, optim_wrapper)
|
||||
optim_wrapper['mutator'].step()
|
||||
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
|
||||
else:
|
||||
with optim_wrapper['mutator'].optim_context(self):
|
||||
pseudo_data = self.data_preprocessor(mutator_data, True)
|
||||
mutator_batch_inputs = pseudo_data['inputs']
|
||||
mutator_data_samples = pseudo_data['data_samples']
|
||||
mutator_loss = self(
|
||||
mutator_batch_inputs,
|
||||
mutator_data_samples,
|
||||
mode='loss')
|
||||
mutator_losses, mutator_log_vars = self.parse_losses(
|
||||
mutator_loss)
|
||||
optim_wrapper['mutator'].update_params(mutator_losses)
|
||||
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
|
||||
|
||||
# Update the parameter of supernet
|
||||
with optim_wrapper['architecture'].optim_context(self):
|
||||
supernet_batch_inputs, supernet_data_samples = \
|
||||
self.data_preprocessor(supernet_data, True)
|
||||
pseudo_data = self.data_preprocessor(supernet_data, True)
|
||||
supernet_batch_inputs = pseudo_data['inputs']
|
||||
supernet_data_samples = pseudo_data['data_samples']
|
||||
supernet_loss = self(
|
||||
supernet_batch_inputs, supernet_data_samples, mode='loss')
|
||||
supernet_losses, supernet_log_vars = self.parse_losses(
|
||||
|
@ -163,16 +177,150 @@ class Darts(BaseAlgorithm):
|
|||
else:
|
||||
# Enable automatic mixed precision training context.
|
||||
with optim_wrapper.optim_context(self):
|
||||
batch_inputs, data_samples = self.data_preprocessor(data, True)
|
||||
pseudo_data = self.data_preprocessor(data, True)
|
||||
batch_inputs = pseudo_data['inputs']
|
||||
data_samples = pseudo_data['data_samples']
|
||||
losses = self(batch_inputs, data_samples, mode='loss')
|
||||
parsed_losses, log_vars = self.parse_losses(losses)
|
||||
optim_wrapper.update_params(parsed_losses)
|
||||
|
||||
return log_vars
|
||||
|
||||
def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper):
|
||||
"""Compute unrolled loss and backward its gradients."""
|
||||
backup_params = copy.deepcopy(tuple(self.architecture.parameters()))
|
||||
|
||||
# Do virtual step on training data
|
||||
lr = optim_wrapper['architecture'].param_groups[0]['lr']
|
||||
momentum = optim_wrapper['architecture'].param_groups[0]['momentum']
|
||||
weight_decay = optim_wrapper['architecture'].param_groups[0][
|
||||
'weight_decay']
|
||||
self._compute_virtual_model(supernet_data, lr, momentum, weight_decay,
|
||||
optim_wrapper['architecture'])
|
||||
|
||||
# Calculate unrolled loss on validation data
|
||||
# Keep gradients for model here for compute hessian
|
||||
pseudo_data = self.data_preprocessor(mutator_data, True)
|
||||
mutator_batch_inputs = pseudo_data['inputs']
|
||||
mutator_data_samples = pseudo_data['data_samples']
|
||||
mutator_loss = self(
|
||||
mutator_batch_inputs, mutator_data_samples, mode='loss')
|
||||
mutator_losses, mutator_log_vars = self.parse_losses(mutator_loss)
|
||||
|
||||
# Here we use the backward function of optimWrapper to calculate
|
||||
# the gradients of mutator loss. The gradients of model and arch
|
||||
# can directly obtained. For more information, please refer to
|
||||
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/optimizer_wrapper.py
|
||||
optim_wrapper['mutator'].backward(mutator_losses)
|
||||
d_model = [param.grad for param in self.architecture.parameters()]
|
||||
d_arch = [param.grad for param in self.mutator.parameters()]
|
||||
|
||||
# compute hessian and final gradients
|
||||
hessian = self._compute_hessian(backup_params, d_model, supernet_data,
|
||||
optim_wrapper['architecture'])
|
||||
|
||||
w_arch = tuple(self.mutator.parameters())
|
||||
|
||||
with torch.no_grad():
|
||||
for param, d, h in zip(w_arch, d_arch, hessian):
|
||||
# gradient = dalpha - lr * hessian
|
||||
param.grad = d - lr * h
|
||||
|
||||
# restore weights
|
||||
self._restore_weights(backup_params)
|
||||
return mutator_log_vars
|
||||
|
||||
def _compute_virtual_model(self, supernet_data, lr, momentum, weight_decay,
|
||||
optim_wrapper):
|
||||
"""Compute unrolled weights w`"""
|
||||
# don't need zero_grad, using autograd to calculate gradients
|
||||
pseudo_data = self.data_preprocessor(supernet_data, True)
|
||||
supernet_batch_inputs = pseudo_data['inputs']
|
||||
supernet_data_samples = pseudo_data['data_samples']
|
||||
supernet_loss = self(
|
||||
supernet_batch_inputs, supernet_data_samples, mode='loss')
|
||||
supernet_loss, _ = self.parse_losses(supernet_loss)
|
||||
|
||||
optim_wrapper.backward(supernet_loss)
|
||||
gradients = [param.grad for param in self.architecture.parameters()]
|
||||
|
||||
with torch.no_grad():
|
||||
for w, g in zip(self.architecture.parameters(), gradients):
|
||||
m = optim_wrapper.optimizer.state[w].get('momentum_buffer', 0.)
|
||||
w = w - lr * (momentum * m + g + weight_decay * w)
|
||||
|
||||
def _restore_weights(self, backup_params):
|
||||
"""restore weight from backup params."""
|
||||
with torch.no_grad():
|
||||
for param, backup in zip(self.architecture.parameters(),
|
||||
backup_params):
|
||||
param.copy_(backup)
|
||||
|
||||
def _compute_hessian(self, backup_params, dw, supernet_data,
|
||||
optim_wrapper) -> List:
|
||||
"""compute hession metric
|
||||
dw = dw` { L_val(w`, alpha) }
|
||||
w+ = w + eps * dw
|
||||
w- = w - eps * dw
|
||||
hessian = (dalpha { L_trn(w+, alpha) } \
|
||||
- dalpha { L_trn(w-, alpha) }) / (2*eps)
|
||||
eps = 0.01 / ||dw||
|
||||
"""
|
||||
self._restore_weights(backup_params)
|
||||
norm = torch.cat([w.view(-1) for w in dw]).norm()
|
||||
eps = 0.01 / norm
|
||||
if norm < 1E-8:
|
||||
print(
|
||||
'In computing hessian, norm is smaller than 1E-8, \
|
||||
cause eps to be %.6f.', norm.item())
|
||||
|
||||
dalphas = []
|
||||
for e in [eps, -2. * eps]:
|
||||
# w+ = w + eps*dw`, w- = w - eps*dw`
|
||||
with torch.no_grad():
|
||||
for p, d in zip(self.architecture.parameters(), dw):
|
||||
p += e * d
|
||||
|
||||
pseudo_data = self.data_preprocessor(supernet_data, True)
|
||||
supernet_batch_inputs = pseudo_data['inputs']
|
||||
supernet_data_samples = pseudo_data['data_samples']
|
||||
supernet_loss = self(
|
||||
supernet_batch_inputs, supernet_data_samples, mode='loss')
|
||||
supernet_loss, _ = self.parse_losses(supernet_loss)
|
||||
|
||||
optim_wrapper.backward(supernet_loss)
|
||||
dalpha = [param.grad for param in self.mutator.parameters()]
|
||||
dalphas.append(dalpha)
|
||||
|
||||
# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
|
||||
dalpha_pos, dalpha_neg = dalphas
|
||||
hessian = [(p - n) / (2. * eps)
|
||||
for p, n in zip(dalpha_pos, dalpha_neg)]
|
||||
return hessian
|
||||
|
||||
|
||||
class BatchNormWrapper(nn.Module):
|
||||
"""Wrapper for BatchNorm.
|
||||
|
||||
For more information, Please refer to
|
||||
https://github.com/NVIDIA/apex/issues/121
|
||||
"""
|
||||
|
||||
def __init__(self, m):
|
||||
super(BatchNormWrapper, self).__init__()
|
||||
self.m = m
|
||||
# Set the batch norm to eval mode
|
||||
self.m.eval()
|
||||
|
||||
def forward(self, x):
|
||||
"""Convert fp16 to fp32 when forward."""
|
||||
input_type = x.dtype
|
||||
x = self.m(x.float())
|
||||
return x.to(input_type)
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class DartsDDP(MMDistributedDataParallel):
|
||||
"""DDP for Darts and rewrite train_step of MMDDP."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
|
@ -183,6 +331,18 @@ class DartsDDP(MMDistributedDataParallel):
|
|||
device_ids = [int(os.environ['LOCAL_RANK'])]
|
||||
super().__init__(device_ids=device_ids, **kwargs)
|
||||
|
||||
fp16 = True
|
||||
if fp16:
|
||||
|
||||
def add_fp16_bn_wrapper(model):
|
||||
for child_name, child in model.named_children():
|
||||
if isinstance(child, nn.BatchNorm2d):
|
||||
setattr(model, child_name, BatchNormWrapper(child))
|
||||
else:
|
||||
add_fp16_bn_wrapper(child)
|
||||
|
||||
add_fp16_bn_wrapper(self.module)
|
||||
|
||||
def train_step(self, data: List[dict],
|
||||
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
||||
"""The iteration step during training.
|
||||
|
@ -214,39 +374,176 @@ class DartsDDP(MMDistributedDataParallel):
|
|||
assert len(data) == len(optim_wrapper), \
|
||||
f'The length of data ({len(data)}) should be equal to that '\
|
||||
f'of optimizers ({len(optim_wrapper)}).'
|
||||
# TODO check the order of data
|
||||
|
||||
supernet_data, mutator_data = data
|
||||
|
||||
log_vars = dict()
|
||||
# TODO process the input
|
||||
|
||||
with optim_wrapper['mutator'].optim_context(self):
|
||||
mutator_batch_inputs, mutator_data_samples = \
|
||||
self.module.data_preprocessor(mutator_data, True)
|
||||
mutator_loss = self(
|
||||
mutator_batch_inputs, mutator_data_samples, mode='loss')
|
||||
mutator_losses, mutator_log_vars = self.module.parse_losses(
|
||||
mutator_loss)
|
||||
optim_wrapper['mutator'].update_params(mutator_losses)
|
||||
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
|
||||
# Update the parameter of mutator
|
||||
if self.module.unroll:
|
||||
with optim_wrapper['mutator'].optim_context(self):
|
||||
optim_wrapper['mutator'].zero_grad()
|
||||
mutator_log_vars = self._unrolled_backward(
|
||||
mutator_data, supernet_data, optim_wrapper)
|
||||
optim_wrapper['mutator'].step()
|
||||
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
|
||||
else:
|
||||
with optim_wrapper['mutator'].optim_context(self):
|
||||
pseudo_data = self.module.data_preprocessor(
|
||||
mutator_data, True)
|
||||
mutator_batch_inputs = pseudo_data['inputs']
|
||||
mutator_data_samples = pseudo_data['data_samples']
|
||||
mutator_loss = self(
|
||||
mutator_batch_inputs,
|
||||
mutator_data_samples,
|
||||
mode='loss')
|
||||
|
||||
mutator_losses, mutator_log_vars = self.module.parse_losses( # noqa: E501
|
||||
mutator_loss)
|
||||
optim_wrapper['mutator'].update_params(mutator_losses)
|
||||
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
|
||||
|
||||
# Update the parameter of supernet
|
||||
with optim_wrapper['architecture'].optim_context(self):
|
||||
supernet_batch_inputs, supernet_data_samples = \
|
||||
self.module.data_preprocessor(supernet_data, True)
|
||||
pseudo_data = self.module.data_preprocessor(
|
||||
supernet_data, True)
|
||||
supernet_batch_inputs = pseudo_data['inputs']
|
||||
supernet_data_samples = pseudo_data['data_samples']
|
||||
supernet_loss = self(
|
||||
supernet_batch_inputs, supernet_data_samples, mode='loss')
|
||||
supernet_losses, supernet_log_vars = self.module.parse_losses(
|
||||
supernet_loss)
|
||||
optim_wrapper['architecture'].update_params(supernet_losses)
|
||||
log_vars.update(add_prefix(supernet_log_vars, 'supernet'))
|
||||
|
||||
supernet_losses, supernet_log_vars = self.module.parse_losses(
|
||||
supernet_loss)
|
||||
|
||||
optim_wrapper['architecture'].update_params(supernet_losses)
|
||||
log_vars.update(add_prefix(supernet_log_vars, 'supernet'))
|
||||
|
||||
else:
|
||||
# Enable automatic mixed precision training context.
|
||||
with optim_wrapper.optim_context(self):
|
||||
batch_inputs, data_samples = self.module.data_preprocessor(
|
||||
data, True)
|
||||
pseudo_data = self.module.data_preprocessor(data, True)
|
||||
batch_inputs = pseudo_data['inputs']
|
||||
data_samples = pseudo_data['data_samples']
|
||||
losses = self(batch_inputs, data_samples, mode='loss')
|
||||
parsed_losses, log_vars = self.module.parse_losses(losses)
|
||||
optim_wrapper.update_params(parsed_losses)
|
||||
|
||||
return log_vars
|
||||
|
||||
def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper):
|
||||
"""Compute unrolled loss and backward its gradients."""
|
||||
backup_params = copy.deepcopy(
|
||||
tuple(self.module.architecture.parameters()))
|
||||
|
||||
# do virtual step on training data
|
||||
lr = optim_wrapper['architecture'].param_groups[0]['lr']
|
||||
momentum = optim_wrapper['architecture'].param_groups[0]['momentum']
|
||||
weight_decay = optim_wrapper['architecture'].param_groups[0][
|
||||
'weight_decay']
|
||||
self._compute_virtual_model(supernet_data, lr, momentum, weight_decay,
|
||||
optim_wrapper['architecture'])
|
||||
|
||||
# calculate unrolled loss on validation data
|
||||
# keep gradients for model here for compute hessian
|
||||
pseudo_data = self.module.data_preprocessor(mutator_data, True)
|
||||
mutator_batch_inputs = pseudo_data['inputs']
|
||||
mutator_data_samples = pseudo_data['data_samples']
|
||||
mutator_loss = self(
|
||||
mutator_batch_inputs, mutator_data_samples, mode='loss')
|
||||
mutator_losses, mutator_log_vars = self.module.parse_losses(
|
||||
mutator_loss)
|
||||
|
||||
# Here we use the backward function of optimWrapper to calculate
|
||||
# the gradients of mutator loss. The gradients of model and arch
|
||||
# can directly obtained. For more information, please refer to
|
||||
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/optimizer_wrapper.py
|
||||
optim_wrapper['mutator'].backward(mutator_losses)
|
||||
d_model = [
|
||||
param.grad for param in self.module.architecture.parameters()
|
||||
]
|
||||
d_arch = [param.grad for param in self.module.mutator.parameters()]
|
||||
|
||||
# compute hessian and final gradients
|
||||
hessian = self._compute_hessian(backup_params, d_model, supernet_data,
|
||||
optim_wrapper['architecture'])
|
||||
|
||||
w_arch = tuple(self.module.mutator.parameters())
|
||||
|
||||
with torch.no_grad():
|
||||
for param, da, he in zip(w_arch, d_arch, hessian):
|
||||
# gradient = dalpha - lr * hessian
|
||||
param.grad = da - lr * he
|
||||
|
||||
# restore weights
|
||||
self._restore_weights(backup_params)
|
||||
return mutator_log_vars
|
||||
|
||||
def _compute_virtual_model(self, supernet_data, lr, momentum, weight_decay,
|
||||
optim_wrapper):
|
||||
"""Compute unrolled weights w`"""
|
||||
# don't need zero_grad, using autograd to calculate gradients
|
||||
pseudo_data = self.module.data_preprocessor(supernet_data, True)
|
||||
supernet_batch_inputs = pseudo_data['inputs']
|
||||
supernet_data_samples = pseudo_data['data_samples']
|
||||
supernet_loss = self(
|
||||
supernet_batch_inputs, supernet_data_samples, mode='loss')
|
||||
supernet_loss, _ = self.module.parse_losses(supernet_loss)
|
||||
|
||||
optim_wrapper.backward(supernet_loss)
|
||||
gradients = [
|
||||
param.grad for param in self.module.architecture.parameters()
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
for w, g in zip(self.module.architecture.parameters(), gradients):
|
||||
m = optim_wrapper.optimizer.state[w].get('momentum_buffer', 0.)
|
||||
w = w - lr * (momentum * m + g + weight_decay * w)
|
||||
|
||||
def _restore_weights(self, backup_params):
|
||||
"""restore weight from backup params."""
|
||||
with torch.no_grad():
|
||||
for param, backup in zip(self.module.architecture.parameters(),
|
||||
backup_params):
|
||||
param.copy_(backup)
|
||||
|
||||
def _compute_hessian(self, backup_params, dw, supernet_data,
|
||||
optim_wrapper) -> List:
|
||||
"""compute hession metric
|
||||
dw = dw` { L_val(w`, alpha) }
|
||||
w+ = w + eps * dw
|
||||
w- = w - eps * dw
|
||||
hessian = (dalpha { L_trn(w+, alpha) } \
|
||||
- dalpha { L_trn(w-, alpha) }) / (2*eps)
|
||||
eps = 0.01 / ||dw||
|
||||
"""
|
||||
self._restore_weights(backup_params)
|
||||
norm = torch.cat([w.view(-1) for w in dw]).norm()
|
||||
eps = 0.01 / norm
|
||||
if norm < 1E-8:
|
||||
print(
|
||||
'In computing hessian, norm is smaller than 1E-8, \
|
||||
cause eps to be %.6f.', norm.item())
|
||||
|
||||
dalphas = []
|
||||
for e in [eps, -2. * eps]:
|
||||
# w+ = w + eps*dw`, w- = w - eps*dw`
|
||||
with torch.no_grad():
|
||||
for p, d in zip(self.module.architecture.parameters(), dw):
|
||||
p += e * d
|
||||
|
||||
pseudo_data = self.module.data_preprocessor(supernet_data, True)
|
||||
supernet_batch_inputs = pseudo_data['inputs']
|
||||
supernet_data_samples = pseudo_data['data_samples']
|
||||
supernet_loss = self(
|
||||
supernet_batch_inputs, supernet_data_samples, mode='loss')
|
||||
supernet_loss, _ = self.module.parse_losses(supernet_loss)
|
||||
|
||||
optim_wrapper.backward(supernet_loss)
|
||||
dalpha = [param.grad for param in self.module.mutator.parameters()]
|
||||
dalphas.append(dalpha)
|
||||
|
||||
# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
|
||||
dalpha_pos, dalpha_neg = dalphas
|
||||
hessian = [(p - n) / (2. * eps)
|
||||
for p, n in zip(dalpha_pos, dalpha_neg)]
|
||||
return hessian
|
||||
|
|
|
@ -62,7 +62,6 @@ class DartsSubnetClsHead(LinearClsHead):
|
|||
def loss(self, feats: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
"""Calculate losses from the classification score.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): The features extracted from the backbone.
|
||||
Multiple stage inputs are acceptable but only the last stage
|
||||
|
@ -71,7 +70,6 @@ class DartsSubnetClsHead(LinearClsHead):
|
|||
data_samples (List[ClsDataSample]): The annotation data of
|
||||
every samples.
|
||||
**kwargs: Other keyword arguments to forward the loss module.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
|
|
@ -389,7 +389,7 @@ class DiffChoiceRoute(DiffMutableModule[str, List[str]]):
|
|||
def __init__(
|
||||
self,
|
||||
edges: nn.ModuleDict,
|
||||
num_chsoen: int = 2,
|
||||
num_chosen: int = 2,
|
||||
with_arch_param: bool = False,
|
||||
alias: Optional[str] = None,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
|
@ -402,7 +402,7 @@ class DiffChoiceRoute(DiffMutableModule[str, List[str]]):
|
|||
self._with_arch_param = with_arch_param
|
||||
self._is_fixed = False
|
||||
self._candidates: nn.ModuleDict = edges
|
||||
self.num_chosen = num_chsoen
|
||||
self.num_chosen = num_chosen
|
||||
|
||||
def forward_fixed(self, inputs: Union[List, Tuple]) -> Tensor:
|
||||
"""Forward when the mutable is in `fixed` mode.
|
||||
|
|
|
@ -1 +1,250 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
from typing import Dict
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.optim import build_optim_wrapper
|
||||
from mmengine.optim.optimizer import OptimWrapper, OptimWrapperDict
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmrazor.models import Darts, DiffModuleMutator, DiffMutableOP
|
||||
from mmrazor.models.algorithms.nas.darts import DartsDDP
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
|
||||
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
|
||||
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ToyDiffModule2(BaseModel):
|
||||
|
||||
def __init__(self, data_preprocessor=None):
|
||||
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
|
||||
|
||||
self.candidates = dict(
|
||||
torch_conv2d_3x3=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
torch_conv2d_5x5=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=5,
|
||||
padding=2,
|
||||
),
|
||||
torch_conv2d_7x7=dict(
|
||||
type='torchConv2d',
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
),
|
||||
)
|
||||
module_kwargs = dict(
|
||||
in_channels=3,
|
||||
out_channels=8,
|
||||
stride=1,
|
||||
)
|
||||
self.mutable = DiffMutableOP(
|
||||
candidates=self.candidates,
|
||||
module_kwargs=module_kwargs,
|
||||
alias='normal')
|
||||
|
||||
self.bn = nn.BatchNorm2d(8)
|
||||
|
||||
def forward(self, batch_inputs, data_samples=None, mode='tensor'):
|
||||
if mode == 'loss':
|
||||
out = self.bn(self.mutable(batch_inputs))
|
||||
return dict(loss=out)
|
||||
elif mode == 'predict':
|
||||
out = self.bn(self.mutable(batch_inputs)) + 1
|
||||
return out
|
||||
elif mode == 'tensor':
|
||||
out = self.bn(self.mutable(batch_inputs)) + 2
|
||||
return out
|
||||
|
||||
|
||||
class TestDarts(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.device: str = 'cpu'
|
||||
|
||||
OPTIMIZER_CFG = dict(
|
||||
type='SGD',
|
||||
lr=0.5,
|
||||
momentum=0.9,
|
||||
nesterov=True,
|
||||
weight_decay=0.0001)
|
||||
|
||||
self.OPTIM_WRAPPER_CFG = dict(optimizer=OPTIMIZER_CFG)
|
||||
|
||||
def test_init(self) -> None:
|
||||
# initiate darts when `norm_training` is True.
|
||||
model = ToyDiffModule2()
|
||||
mutator = DiffModuleMutator()
|
||||
algo = Darts(architecture=model, mutator=mutator, norm_training=True)
|
||||
algo.eval()
|
||||
self.assertTrue(model.bn.training)
|
||||
|
||||
# initiate darts with built mutator
|
||||
model = ToyDiffModule2()
|
||||
mutator = DiffModuleMutator()
|
||||
algo = Darts(model, mutator)
|
||||
self.assertIs(algo.mutator, mutator)
|
||||
|
||||
# initiate darts with unbuilt mutator
|
||||
mutator = dict(type='DiffModuleMutator')
|
||||
algo = Darts(model, mutator)
|
||||
self.assertIsInstance(algo.mutator, DiffModuleMutator)
|
||||
|
||||
# initiate darts when `fix_subnet` is not None
|
||||
fix_subnet = {'normal': ['torch_conv2d_3x3', 'torch_conv2d_7x7']}
|
||||
algo = Darts(model, mutator, fix_subnet=fix_subnet)
|
||||
self.assertEqual(algo.architecture.mutable.num_choices, 2)
|
||||
|
||||
# initiate darts with error type `mutator`
|
||||
with self.assertRaisesRegex(TypeError, 'mutator should be'):
|
||||
Darts(model, model)
|
||||
|
||||
def test_forward_loss(self) -> None:
|
||||
inputs = torch.randn(1, 3, 8, 8)
|
||||
model = ToyDiffModule2()
|
||||
|
||||
# supernet
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
algo = Darts(model, mutator)
|
||||
loss = algo(inputs, mode='loss')
|
||||
self.assertIsInstance(loss, dict)
|
||||
|
||||
# subnet
|
||||
fix_subnet = {'normal': ['torch_conv2d_3x3', 'torch_conv2d_7x7']}
|
||||
algo = Darts(model, fix_subnet=fix_subnet)
|
||||
loss = algo(inputs, mode='loss')
|
||||
self.assertIsInstance(loss, dict)
|
||||
|
||||
def _prepare_fake_data(self) -> Dict:
|
||||
imgs = torch.randn(16, 3, 224, 224).to(self.device)
|
||||
data_samples = [
|
||||
ClsDataSample().set_gt_label(torch.randint(0, 1000,
|
||||
(16, ))).to(self.device)
|
||||
]
|
||||
|
||||
return {'inputs': imgs, 'data_samples': data_samples}
|
||||
|
||||
def test_search_subnet(self) -> None:
|
||||
model = ToyDiffModule2()
|
||||
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
algo = Darts(model, mutator)
|
||||
subnet = algo.search_subnet()
|
||||
self.assertIsInstance(subnet, dict)
|
||||
|
||||
def test_darts_train_step(self) -> None:
|
||||
model = ToyDiffModule2()
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
# data is tensor
|
||||
algo = Darts(model, mutator)
|
||||
data = self._prepare_fake_data()
|
||||
optim_wrapper = build_optim_wrapper(algo, self.OPTIM_WRAPPER_CFG)
|
||||
loss = algo.train_step(data, optim_wrapper)
|
||||
|
||||
self.assertTrue(isinstance(loss['loss'], Tensor))
|
||||
|
||||
# data is tuple or list
|
||||
algo = Darts(model, mutator)
|
||||
data = [self._prepare_fake_data() for _ in range(2)]
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
|
||||
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
|
||||
loss = algo.train_step(data, optim_wrapper_dict)
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
|
||||
def test_darts_with_unroll(self) -> None:
|
||||
model = ToyDiffModule2()
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
# data is tuple or list
|
||||
algo = Darts(model, mutator, unroll=True)
|
||||
data = [self._prepare_fake_data() for _ in range(2)]
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
|
||||
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
|
||||
loss = algo.train_step(data, optim_wrapper_dict)
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
|
||||
|
||||
class TestDartsDDP(TestDarts):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12345'
|
||||
|
||||
# initialize the process group
|
||||
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
|
||||
dist.init_process_group(backend, rank=0, world_size=1)
|
||||
|
||||
def prepare_model(self, unroll=False, device_ids=None) -> Darts:
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
model = ToyDiffModule2()
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
algo = Darts(model, mutator, unroll=unroll).to(self.device)
|
||||
|
||||
return DartsDDP(
|
||||
module=algo, find_unused_parameters=True, device_ids=device_ids)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
dist.destroy_process_group()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='cuda device is not avaliable')
|
||||
def test_init(self) -> None:
|
||||
ddp_model = self.prepare_model()
|
||||
self.assertIsInstance(ddp_model, DartsDDP)
|
||||
|
||||
def test_dartsddp_train_step(self) -> None:
|
||||
# data is tensor
|
||||
ddp_model = self.prepare_model()
|
||||
data = self._prepare_fake_data()
|
||||
optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG)
|
||||
loss = ddp_model.train_step(data, optim_wrapper)
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
|
||||
# data is tuple or list
|
||||
ddp_model = self.prepare_model()
|
||||
data = [self._prepare_fake_data() for _ in range(2)]
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)),
|
||||
mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01)))
|
||||
loss = ddp_model.train_step(data, optim_wrapper_dict)
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
|
||||
def test_dartsddp_with_unroll(self) -> None:
|
||||
# data is tuple or list
|
||||
ddp_model = self.prepare_model(unroll=True)
|
||||
data = [self._prepare_fake_data() for _ in range(2)]
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)),
|
||||
mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01)))
|
||||
loss = ddp_model.train_step(data, optim_wrapper_dict)
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
|
|
|
@ -71,19 +71,30 @@ def main():
|
|||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
# enable automatic-mixed-precision training
|
||||
if args.amp is True:
|
||||
optim_wrapper = cfg.optim_wrapper.type
|
||||
if optim_wrapper == 'AmpOptimWrapper':
|
||||
print_log(
|
||||
'AMP training is already enabled in your config.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
assert optim_wrapper == 'OptimWrapper', (
|
||||
'`--amp` is only supported when the optimizer wrapper type is '
|
||||
f'`OptimWrapper` but got {optim_wrapper}.')
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.loss_scale = 'dynamic'
|
||||
if args.amp:
|
||||
if getattr(cfg.optim_wrapper, 'type', None):
|
||||
optim_wrapper = cfg.optim_wrapper.type
|
||||
if optim_wrapper == 'AmpOptimWrapper':
|
||||
print_log(
|
||||
'AMP training is already enabled in your config.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
assert optim_wrapper == 'OptimWrapper', (
|
||||
'`--amp` is only supported when the optimizer wrapper '
|
||||
f'type is `OptimWrapper` but got {optim_wrapper}.')
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.loss_scale = 'dynamic'
|
||||
|
||||
if getattr(cfg.optim_wrapper, 'constructor', None):
|
||||
if cfg.optim_wrapper.architecture.type == 'OptimWrapper':
|
||||
cfg.optim_wrapper.architecture.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.architecture.loss_scale = 'dynamic'
|
||||
|
||||
# TODO: support amp training for mutator
|
||||
# if cfg.optim_wrapper.mutator.type == 'OptimWrapper':
|
||||
# cfg.optim_wrapper.mutator.type = 'AmpOptimWrapper'
|
||||
# cfg.optim_wrapper.mutator.loss_scale = 'dynamic'
|
||||
|
||||
# enable automatically scaling LR
|
||||
if args.auto_scale_lr:
|
||||
|
|
Loading…
Reference in New Issue