[Feature] Support unroll with MMDDP in darts algorithm (#210)

* support unroll in darts

* fix bugs in optimizer; add docstring

* update darts algorithm [untested]

* modify autograd.grad to optim_wrapper.backward

* add amp in train.py; support constructor

* rename mmcls.data to mmcls.structures

* modify darts algo to support apex [not done]

* fix code spell in diff_mutable_module

* modify optim_context of dartsddp

* add testcase for dartsddp

* fix bugs of apex in dartsddp

* standardized the unittest of darts

* adapt new data_preprocessor

* fix ut bugs

* remove unness code

Co-authored-by: gaoyang07 <1546308416@qq.com>
pull/326/head
PJDong 2022-10-14 17:41:11 +08:00 committed by GitHub
parent 0409adc31f
commit dd51ab8ca0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 611 additions and 53 deletions

View File

@ -33,8 +33,11 @@ optim_wrapper = dict(
_delete_=True,
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.025, momentum=0.9, weight_decay=3e-4),
clip_grad=dict(max_norm=5, norm_type=2)),
mutator=dict(optimizer=dict(type='Adam', lr=3e-4, weight_decay=1e-3)))
mutator=dict(
type='OptimWrapper',
optimizer=dict(type='Adam', lr=3e-4, weight_decay=1e-3)))
find_unused_parameter = False

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
from typing import Dict, List, Optional, Union
@ -82,12 +83,10 @@ class Darts(BaseAlgorithm):
self.is_supernet = True
self.norm_training = norm_training
# TODO support unroll
self.unroll = unroll
def search_subnet(self):
"""Search subnet by mutator."""
# Avoid circular import
from mmrazor.structures import export_fix_subnet
@ -136,23 +135,38 @@ class Darts(BaseAlgorithm):
assert len(data) == len(optim_wrapper), \
f'The length of data ({len(data)}) should be equal to that '\
f'of optimizers ({len(optim_wrapper)}).'
# TODO check the order of data
supernet_data, mutator_data = data
log_vars = dict()
# TODO support unroll
with optim_wrapper['mutator'].optim_context(self):
mutator_batch_inputs, mutator_data_samples = \
self.data_preprocessor(mutator_data, True)
mutator_loss = self(
mutator_batch_inputs, mutator_data_samples, mode='loss')
mutator_losses, mutator_log_vars = self.parse_losses(mutator_loss)
optim_wrapper['mutator'].update_params(mutator_losses)
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
# Update the parameter of mutator
if self.unroll:
with optim_wrapper['mutator'].optim_context(self):
optim_wrapper['mutator'].zero_grad()
mutator_log_vars = self._unrolled_backward(
mutator_data, supernet_data, optim_wrapper)
optim_wrapper['mutator'].step()
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
else:
with optim_wrapper['mutator'].optim_context(self):
pseudo_data = self.data_preprocessor(mutator_data, True)
mutator_batch_inputs = pseudo_data['inputs']
mutator_data_samples = pseudo_data['data_samples']
mutator_loss = self(
mutator_batch_inputs,
mutator_data_samples,
mode='loss')
mutator_losses, mutator_log_vars = self.parse_losses(
mutator_loss)
optim_wrapper['mutator'].update_params(mutator_losses)
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
# Update the parameter of supernet
with optim_wrapper['architecture'].optim_context(self):
supernet_batch_inputs, supernet_data_samples = \
self.data_preprocessor(supernet_data, True)
pseudo_data = self.data_preprocessor(supernet_data, True)
supernet_batch_inputs = pseudo_data['inputs']
supernet_data_samples = pseudo_data['data_samples']
supernet_loss = self(
supernet_batch_inputs, supernet_data_samples, mode='loss')
supernet_losses, supernet_log_vars = self.parse_losses(
@ -163,16 +177,150 @@ class Darts(BaseAlgorithm):
else:
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
batch_inputs, data_samples = self.data_preprocessor(data, True)
pseudo_data = self.data_preprocessor(data, True)
batch_inputs = pseudo_data['inputs']
data_samples = pseudo_data['data_samples']
losses = self(batch_inputs, data_samples, mode='loss')
parsed_losses, log_vars = self.parse_losses(losses)
optim_wrapper.update_params(parsed_losses)
return log_vars
def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper):
"""Compute unrolled loss and backward its gradients."""
backup_params = copy.deepcopy(tuple(self.architecture.parameters()))
# Do virtual step on training data
lr = optim_wrapper['architecture'].param_groups[0]['lr']
momentum = optim_wrapper['architecture'].param_groups[0]['momentum']
weight_decay = optim_wrapper['architecture'].param_groups[0][
'weight_decay']
self._compute_virtual_model(supernet_data, lr, momentum, weight_decay,
optim_wrapper['architecture'])
# Calculate unrolled loss on validation data
# Keep gradients for model here for compute hessian
pseudo_data = self.data_preprocessor(mutator_data, True)
mutator_batch_inputs = pseudo_data['inputs']
mutator_data_samples = pseudo_data['data_samples']
mutator_loss = self(
mutator_batch_inputs, mutator_data_samples, mode='loss')
mutator_losses, mutator_log_vars = self.parse_losses(mutator_loss)
# Here we use the backward function of optimWrapper to calculate
# the gradients of mutator loss. The gradients of model and arch
# can directly obtained. For more information, please refer to
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/optimizer_wrapper.py
optim_wrapper['mutator'].backward(mutator_losses)
d_model = [param.grad for param in self.architecture.parameters()]
d_arch = [param.grad for param in self.mutator.parameters()]
# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, supernet_data,
optim_wrapper['architecture'])
w_arch = tuple(self.mutator.parameters())
with torch.no_grad():
for param, d, h in zip(w_arch, d_arch, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h
# restore weights
self._restore_weights(backup_params)
return mutator_log_vars
def _compute_virtual_model(self, supernet_data, lr, momentum, weight_decay,
optim_wrapper):
"""Compute unrolled weights w`"""
# don't need zero_grad, using autograd to calculate gradients
pseudo_data = self.data_preprocessor(supernet_data, True)
supernet_batch_inputs = pseudo_data['inputs']
supernet_data_samples = pseudo_data['data_samples']
supernet_loss = self(
supernet_batch_inputs, supernet_data_samples, mode='loss')
supernet_loss, _ = self.parse_losses(supernet_loss)
optim_wrapper.backward(supernet_loss)
gradients = [param.grad for param in self.architecture.parameters()]
with torch.no_grad():
for w, g in zip(self.architecture.parameters(), gradients):
m = optim_wrapper.optimizer.state[w].get('momentum_buffer', 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
"""restore weight from backup params."""
with torch.no_grad():
for param, backup in zip(self.architecture.parameters(),
backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, supernet_data,
optim_wrapper) -> List:
"""compute hession metric
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } \
- dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
print(
'In computing hessian, norm is smaller than 1E-8, \
cause eps to be %.6f.', norm.item())
dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.architecture.parameters(), dw):
p += e * d
pseudo_data = self.data_preprocessor(supernet_data, True)
supernet_batch_inputs = pseudo_data['inputs']
supernet_data_samples = pseudo_data['data_samples']
supernet_loss = self(
supernet_batch_inputs, supernet_data_samples, mode='loss')
supernet_loss, _ = self.parse_losses(supernet_loss)
optim_wrapper.backward(supernet_loss)
dalpha = [param.grad for param in self.mutator.parameters()]
dalphas.append(dalpha)
# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
dalpha_pos, dalpha_neg = dalphas
hessian = [(p - n) / (2. * eps)
for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian
class BatchNormWrapper(nn.Module):
"""Wrapper for BatchNorm.
For more information, Please refer to
https://github.com/NVIDIA/apex/issues/121
"""
def __init__(self, m):
super(BatchNormWrapper, self).__init__()
self.m = m
# Set the batch norm to eval mode
self.m.eval()
def forward(self, x):
"""Convert fp16 to fp32 when forward."""
input_type = x.dtype
x = self.m(x.float())
return x.to(input_type)
@MODEL_WRAPPERS.register_module()
class DartsDDP(MMDistributedDataParallel):
"""DDP for Darts and rewrite train_step of MMDDP."""
def __init__(self,
*,
@ -183,6 +331,18 @@ class DartsDDP(MMDistributedDataParallel):
device_ids = [int(os.environ['LOCAL_RANK'])]
super().__init__(device_ids=device_ids, **kwargs)
fp16 = True
if fp16:
def add_fp16_bn_wrapper(model):
for child_name, child in model.named_children():
if isinstance(child, nn.BatchNorm2d):
setattr(model, child_name, BatchNormWrapper(child))
else:
add_fp16_bn_wrapper(child)
add_fp16_bn_wrapper(self.module)
def train_step(self, data: List[dict],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""The iteration step during training.
@ -214,39 +374,176 @@ class DartsDDP(MMDistributedDataParallel):
assert len(data) == len(optim_wrapper), \
f'The length of data ({len(data)}) should be equal to that '\
f'of optimizers ({len(optim_wrapper)}).'
# TODO check the order of data
supernet_data, mutator_data = data
log_vars = dict()
# TODO process the input
with optim_wrapper['mutator'].optim_context(self):
mutator_batch_inputs, mutator_data_samples = \
self.module.data_preprocessor(mutator_data, True)
mutator_loss = self(
mutator_batch_inputs, mutator_data_samples, mode='loss')
mutator_losses, mutator_log_vars = self.module.parse_losses(
mutator_loss)
optim_wrapper['mutator'].update_params(mutator_losses)
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
# Update the parameter of mutator
if self.module.unroll:
with optim_wrapper['mutator'].optim_context(self):
optim_wrapper['mutator'].zero_grad()
mutator_log_vars = self._unrolled_backward(
mutator_data, supernet_data, optim_wrapper)
optim_wrapper['mutator'].step()
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
else:
with optim_wrapper['mutator'].optim_context(self):
pseudo_data = self.module.data_preprocessor(
mutator_data, True)
mutator_batch_inputs = pseudo_data['inputs']
mutator_data_samples = pseudo_data['data_samples']
mutator_loss = self(
mutator_batch_inputs,
mutator_data_samples,
mode='loss')
mutator_losses, mutator_log_vars = self.module.parse_losses( # noqa: E501
mutator_loss)
optim_wrapper['mutator'].update_params(mutator_losses)
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
# Update the parameter of supernet
with optim_wrapper['architecture'].optim_context(self):
supernet_batch_inputs, supernet_data_samples = \
self.module.data_preprocessor(supernet_data, True)
pseudo_data = self.module.data_preprocessor(
supernet_data, True)
supernet_batch_inputs = pseudo_data['inputs']
supernet_data_samples = pseudo_data['data_samples']
supernet_loss = self(
supernet_batch_inputs, supernet_data_samples, mode='loss')
supernet_losses, supernet_log_vars = self.module.parse_losses(
supernet_loss)
optim_wrapper['architecture'].update_params(supernet_losses)
log_vars.update(add_prefix(supernet_log_vars, 'supernet'))
supernet_losses, supernet_log_vars = self.module.parse_losses(
supernet_loss)
optim_wrapper['architecture'].update_params(supernet_losses)
log_vars.update(add_prefix(supernet_log_vars, 'supernet'))
else:
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
batch_inputs, data_samples = self.module.data_preprocessor(
data, True)
pseudo_data = self.module.data_preprocessor(data, True)
batch_inputs = pseudo_data['inputs']
data_samples = pseudo_data['data_samples']
losses = self(batch_inputs, data_samples, mode='loss')
parsed_losses, log_vars = self.module.parse_losses(losses)
optim_wrapper.update_params(parsed_losses)
return log_vars
def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper):
"""Compute unrolled loss and backward its gradients."""
backup_params = copy.deepcopy(
tuple(self.module.architecture.parameters()))
# do virtual step on training data
lr = optim_wrapper['architecture'].param_groups[0]['lr']
momentum = optim_wrapper['architecture'].param_groups[0]['momentum']
weight_decay = optim_wrapper['architecture'].param_groups[0][
'weight_decay']
self._compute_virtual_model(supernet_data, lr, momentum, weight_decay,
optim_wrapper['architecture'])
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
pseudo_data = self.module.data_preprocessor(mutator_data, True)
mutator_batch_inputs = pseudo_data['inputs']
mutator_data_samples = pseudo_data['data_samples']
mutator_loss = self(
mutator_batch_inputs, mutator_data_samples, mode='loss')
mutator_losses, mutator_log_vars = self.module.parse_losses(
mutator_loss)
# Here we use the backward function of optimWrapper to calculate
# the gradients of mutator loss. The gradients of model and arch
# can directly obtained. For more information, please refer to
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/optimizer_wrapper.py
optim_wrapper['mutator'].backward(mutator_losses)
d_model = [
param.grad for param in self.module.architecture.parameters()
]
d_arch = [param.grad for param in self.module.mutator.parameters()]
# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, supernet_data,
optim_wrapper['architecture'])
w_arch = tuple(self.module.mutator.parameters())
with torch.no_grad():
for param, da, he in zip(w_arch, d_arch, hessian):
# gradient = dalpha - lr * hessian
param.grad = da - lr * he
# restore weights
self._restore_weights(backup_params)
return mutator_log_vars
def _compute_virtual_model(self, supernet_data, lr, momentum, weight_decay,
optim_wrapper):
"""Compute unrolled weights w`"""
# don't need zero_grad, using autograd to calculate gradients
pseudo_data = self.module.data_preprocessor(supernet_data, True)
supernet_batch_inputs = pseudo_data['inputs']
supernet_data_samples = pseudo_data['data_samples']
supernet_loss = self(
supernet_batch_inputs, supernet_data_samples, mode='loss')
supernet_loss, _ = self.module.parse_losses(supernet_loss)
optim_wrapper.backward(supernet_loss)
gradients = [
param.grad for param in self.module.architecture.parameters()
]
with torch.no_grad():
for w, g in zip(self.module.architecture.parameters(), gradients):
m = optim_wrapper.optimizer.state[w].get('momentum_buffer', 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
"""restore weight from backup params."""
with torch.no_grad():
for param, backup in zip(self.module.architecture.parameters(),
backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, supernet_data,
optim_wrapper) -> List:
"""compute hession metric
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } \
- dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
print(
'In computing hessian, norm is smaller than 1E-8, \
cause eps to be %.6f.', norm.item())
dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.module.architecture.parameters(), dw):
p += e * d
pseudo_data = self.module.data_preprocessor(supernet_data, True)
supernet_batch_inputs = pseudo_data['inputs']
supernet_data_samples = pseudo_data['data_samples']
supernet_loss = self(
supernet_batch_inputs, supernet_data_samples, mode='loss')
supernet_loss, _ = self.module.parse_losses(supernet_loss)
optim_wrapper.backward(supernet_loss)
dalpha = [param.grad for param in self.module.mutator.parameters()]
dalphas.append(dalpha)
# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
dalpha_pos, dalpha_neg = dalphas
hessian = [(p - n) / (2. * eps)
for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian

View File

@ -62,7 +62,6 @@ class DartsSubnetClsHead(LinearClsHead):
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
@ -71,7 +70,6 @@ class DartsSubnetClsHead(LinearClsHead):
data_samples (List[ClsDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""

View File

@ -389,7 +389,7 @@ class DiffChoiceRoute(DiffMutableModule[str, List[str]]):
def __init__(
self,
edges: nn.ModuleDict,
num_chsoen: int = 2,
num_chosen: int = 2,
with_arch_param: bool = False,
alias: Optional[str] = None,
init_cfg: Optional[Dict] = None,
@ -402,7 +402,7 @@ class DiffChoiceRoute(DiffMutableModule[str, List[str]]):
self._with_arch_param = with_arch_param
self._is_fixed = False
self._candidates: nn.ModuleDict = edges
self.num_chosen = num_chsoen
self.num_chosen = num_chosen
def forward_fixed(self, inputs: Union[List, Tuple]) -> Tensor:
"""Forward when the mutable is in `fixed` mode.

View File

@ -1 +1,250 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict
from unittest import TestCase
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcls.structures import ClsDataSample
from mmengine.model import BaseModel
from mmengine.optim import build_optim_wrapper
from mmengine.optim.optimizer import OptimWrapper, OptimWrapperDict
from torch import Tensor
from torch.optim import SGD
from mmrazor.models import Darts, DiffModuleMutator, DiffMutableOP
from mmrazor.models.algorithms.nas.darts import DartsDDP
from mmrazor.registry import MODELS
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
@MODELS.register_module()
class ToyDiffModule2(BaseModel):
def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
self.candidates = dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
padding=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
padding=2,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
padding=3,
),
)
module_kwargs = dict(
in_channels=3,
out_channels=8,
stride=1,
)
self.mutable = DiffMutableOP(
candidates=self.candidates,
module_kwargs=module_kwargs,
alias='normal')
self.bn = nn.BatchNorm2d(8)
def forward(self, batch_inputs, data_samples=None, mode='tensor'):
if mode == 'loss':
out = self.bn(self.mutable(batch_inputs))
return dict(loss=out)
elif mode == 'predict':
out = self.bn(self.mutable(batch_inputs)) + 1
return out
elif mode == 'tensor':
out = self.bn(self.mutable(batch_inputs)) + 2
return out
class TestDarts(TestCase):
def setUp(self) -> None:
self.device: str = 'cpu'
OPTIMIZER_CFG = dict(
type='SGD',
lr=0.5,
momentum=0.9,
nesterov=True,
weight_decay=0.0001)
self.OPTIM_WRAPPER_CFG = dict(optimizer=OPTIMIZER_CFG)
def test_init(self) -> None:
# initiate darts when `norm_training` is True.
model = ToyDiffModule2()
mutator = DiffModuleMutator()
algo = Darts(architecture=model, mutator=mutator, norm_training=True)
algo.eval()
self.assertTrue(model.bn.training)
# initiate darts with built mutator
model = ToyDiffModule2()
mutator = DiffModuleMutator()
algo = Darts(model, mutator)
self.assertIs(algo.mutator, mutator)
# initiate darts with unbuilt mutator
mutator = dict(type='DiffModuleMutator')
algo = Darts(model, mutator)
self.assertIsInstance(algo.mutator, DiffModuleMutator)
# initiate darts when `fix_subnet` is not None
fix_subnet = {'normal': ['torch_conv2d_3x3', 'torch_conv2d_7x7']}
algo = Darts(model, mutator, fix_subnet=fix_subnet)
self.assertEqual(algo.architecture.mutable.num_choices, 2)
# initiate darts with error type `mutator`
with self.assertRaisesRegex(TypeError, 'mutator should be'):
Darts(model, model)
def test_forward_loss(self) -> None:
inputs = torch.randn(1, 3, 8, 8)
model = ToyDiffModule2()
# supernet
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
algo = Darts(model, mutator)
loss = algo(inputs, mode='loss')
self.assertIsInstance(loss, dict)
# subnet
fix_subnet = {'normal': ['torch_conv2d_3x3', 'torch_conv2d_7x7']}
algo = Darts(model, fix_subnet=fix_subnet)
loss = algo(inputs, mode='loss')
self.assertIsInstance(loss, dict)
def _prepare_fake_data(self) -> Dict:
imgs = torch.randn(16, 3, 224, 224).to(self.device)
data_samples = [
ClsDataSample().set_gt_label(torch.randint(0, 1000,
(16, ))).to(self.device)
]
return {'inputs': imgs, 'data_samples': data_samples}
def test_search_subnet(self) -> None:
model = ToyDiffModule2()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
algo = Darts(model, mutator)
subnet = algo.search_subnet()
self.assertIsInstance(subnet, dict)
def test_darts_train_step(self) -> None:
model = ToyDiffModule2()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
# data is tensor
algo = Darts(model, mutator)
data = self._prepare_fake_data()
optim_wrapper = build_optim_wrapper(algo, self.OPTIM_WRAPPER_CFG)
loss = algo.train_step(data, optim_wrapper)
self.assertTrue(isinstance(loss['loss'], Tensor))
# data is tuple or list
algo = Darts(model, mutator)
data = [self._prepare_fake_data() for _ in range(2)]
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
loss = algo.train_step(data, optim_wrapper_dict)
self.assertIsNotNone(loss)
def test_darts_with_unroll(self) -> None:
model = ToyDiffModule2()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
# data is tuple or list
algo = Darts(model, mutator, unroll=True)
data = [self._prepare_fake_data() for _ in range(2)]
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
loss = algo.train_step(data, optim_wrapper_dict)
self.assertIsNotNone(loss)
class TestDartsDDP(TestDarts):
@classmethod
def setUpClass(cls) -> None:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
# initialize the process group
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend, rank=0, world_size=1)
def prepare_model(self, unroll=False, device_ids=None) -> Darts:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ToyDiffModule2()
mutator = DiffModuleMutator()
mutator.prepare_from_supernet(model)
algo = Darts(model, mutator, unroll=unroll).to(self.device)
return DartsDDP(
module=algo, find_unused_parameters=True, device_ids=device_ids)
@classmethod
def tearDownClass(cls) -> None:
dist.destroy_process_group()
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='cuda device is not avaliable')
def test_init(self) -> None:
ddp_model = self.prepare_model()
self.assertIsInstance(ddp_model, DartsDDP)
def test_dartsddp_train_step(self) -> None:
# data is tensor
ddp_model = self.prepare_model()
data = self._prepare_fake_data()
optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG)
loss = ddp_model.train_step(data, optim_wrapper)
self.assertIsNotNone(loss)
# data is tuple or list
ddp_model = self.prepare_model()
data = [self._prepare_fake_data() for _ in range(2)]
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01)))
loss = ddp_model.train_step(data, optim_wrapper_dict)
self.assertIsNotNone(loss)
def test_dartsddp_with_unroll(self) -> None:
# data is tuple or list
ddp_model = self.prepare_model(unroll=True)
data = [self._prepare_fake_data() for _ in range(2)]
optim_wrapper_dict = OptimWrapperDict(
architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)),
mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01)))
loss = ddp_model.train_step(data, optim_wrapper_dict)
self.assertIsNotNone(loss)

View File

@ -71,19 +71,30 @@ def main():
osp.splitext(osp.basename(args.config))[0])
# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.type
if optim_wrapper == 'AmpOptimWrapper':
print_log(
'AMP training is already enabled in your config.',
logger='current',
level=logging.WARNING)
else:
assert optim_wrapper == 'OptimWrapper', (
'`--amp` is only supported when the optimizer wrapper type is '
f'`OptimWrapper` but got {optim_wrapper}.')
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
if args.amp:
if getattr(cfg.optim_wrapper, 'type', None):
optim_wrapper = cfg.optim_wrapper.type
if optim_wrapper == 'AmpOptimWrapper':
print_log(
'AMP training is already enabled in your config.',
logger='current',
level=logging.WARNING)
else:
assert optim_wrapper == 'OptimWrapper', (
'`--amp` is only supported when the optimizer wrapper '
f'type is `OptimWrapper` but got {optim_wrapper}.')
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
if getattr(cfg.optim_wrapper, 'constructor', None):
if cfg.optim_wrapper.architecture.type == 'OptimWrapper':
cfg.optim_wrapper.architecture.type = 'AmpOptimWrapper'
cfg.optim_wrapper.architecture.loss_scale = 'dynamic'
# TODO: support amp training for mutator
# if cfg.optim_wrapper.mutator.type == 'OptimWrapper':
# cfg.optim_wrapper.mutator.type = 'AmpOptimWrapper'
# cfg.optim_wrapper.mutator.loss_scale = 'dynamic'
# enable automatically scaling LR
if args.auto_scale_lr: