mirror of https://github.com/open-mmlab/mmcv.git
Refine rfsearch and fix a typo
parent
1f9e5b57c2
commit
c458829763
|
@ -40,6 +40,7 @@ Module
|
|||
NonLocal3d
|
||||
Scale
|
||||
Swish
|
||||
Conv2dRFSearchOp
|
||||
|
||||
Build Function
|
||||
----------------
|
||||
|
|
|
@ -40,6 +40,7 @@ Module
|
|||
NonLocal3d
|
||||
Scale
|
||||
Swish
|
||||
Conv2dRFSearchOp
|
||||
|
||||
Build Function
|
||||
----------------
|
||||
|
|
|
@ -4,14 +4,12 @@ import copy
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from .utils import expand_rates, get_single_padding
|
||||
|
||||
logger = MMLogger.get_current_instance()
|
||||
|
||||
|
||||
class BaseConvRFSearchOp(BaseModule):
|
||||
"""Based class of ConvRFSearchOp.
|
||||
|
@ -84,7 +82,7 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
|
|||
|
||||
self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches))
|
||||
if self.verbose:
|
||||
logger.info(f'Expand as {self.dilation_rates}')
|
||||
print_log(f'Expand as {self.dilation_rates}', 'current')
|
||||
nn.init.constant_(self.branch_weights, global_config['init_alphas'])
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
|
@ -118,13 +116,14 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
|
|||
output += outputs[i]
|
||||
return output
|
||||
|
||||
def estimate_rates(self):
|
||||
def estimate_rates(self) -> None:
|
||||
"""Estimate new dilation rate based on trained branch_weights."""
|
||||
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
|
||||
if self.verbose:
|
||||
logger.info('Estimate dilation {} with weight {}.'.format(
|
||||
self.dilation_rates,
|
||||
norm_w.detach().cpu().numpy().tolist()))
|
||||
print_log(
|
||||
'Estimate dilation {} with weight {}.'.format(
|
||||
self.dilation_rates,
|
||||
norm_w.detach().cpu().numpy().tolist()), 'current')
|
||||
|
||||
sum0, sum1, w_sum = 0, 0, 0
|
||||
for i in range(len(self.dilation_rates)):
|
||||
|
@ -143,9 +142,9 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
|
|||
self.op_layer.padding = self.get_padding(self.op_layer.dilation)
|
||||
self.dilation_rates = [tuple(estimated)]
|
||||
if self.verbose:
|
||||
logger.info(f'Estimate as {tuple(estimated)}')
|
||||
print_log(f'Estimate as {tuple(estimated)}', 'current')
|
||||
|
||||
def expand_rates(self):
|
||||
def expand_rates(self) -> None:
|
||||
"""Expand dilation rate."""
|
||||
dilation = self.op_layer.dilation
|
||||
dilation_rates = expand_rates(dilation, self.global_config)
|
||||
|
@ -158,11 +157,11 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
|
|||
|
||||
self.dilation_rates = copy.deepcopy(dilation_rates)
|
||||
if self.verbose:
|
||||
logger.info(f'Expand as {self.dilation_rates}')
|
||||
print_log(f'Expand as {self.dilation_rates}', 'current')
|
||||
nn.init.constant_(self.branch_weights,
|
||||
self.global_config['init_alphas'])
|
||||
|
||||
def get_padding(self, dilation):
|
||||
def get_padding(self, dilation) -> tuple:
|
||||
padding = (get_single_padding(self.op_layer.kernel_size[0],
|
||||
self.op_layer.stride[0], dilation[0]),
|
||||
get_single_padding(self.op_layer.kernel_size[1],
|
||||
|
|
|
@ -3,15 +3,14 @@ import os
|
|||
from typing import Dict, Optional
|
||||
|
||||
import mmengine
|
||||
import torch # noqa
|
||||
import torch.nn as nn
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import HOOKS
|
||||
|
||||
from mmcv.cnn.rfsearch.utils import get_single_padding, write_to_json
|
||||
from .operator import BaseConvRFSearchOp
|
||||
|
||||
logger = MMLogger.get_current_instance()
|
||||
from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp # noqa
|
||||
from .utils import get_single_padding, write_to_json
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -82,7 +81,7 @@ class RFSearchHook(Hook):
|
|||
search/fixed_single_branch/fixed_multi_branch
|
||||
"""
|
||||
if self.verbose:
|
||||
logger.info('RFSearch init begin.')
|
||||
print_log('RFSearch init begin.', 'current')
|
||||
if self.mode == 'search':
|
||||
if self.config['structure']:
|
||||
self.set_model(model, search_op='Conv2d')
|
||||
|
@ -95,19 +94,19 @@ class RFSearchHook(Hook):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
if self.verbose:
|
||||
logger.info('RFSearch init end.')
|
||||
print_log('RFSearch init end.', 'current')
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
"""Performs a dilation searching step after one training epoch."""
|
||||
if self.by_epoch and self.mode == 'search':
|
||||
self.step(runner.model, runner.work_dir)
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
|
||||
"""Performs a dilation searching step after one training iteration."""
|
||||
if not self.by_epoch and self.mode == 'search':
|
||||
self.step(runner.model, runner.work_dir)
|
||||
|
||||
def step(self, model: nn.Module, work_dir: str):
|
||||
def step(self, model: nn.Module, work_dir: str) -> None:
|
||||
"""Performs a dilation searching step.
|
||||
|
||||
Args:
|
||||
|
@ -132,7 +131,7 @@ class RFSearchHook(Hook):
|
|||
),
|
||||
)
|
||||
|
||||
def estimate_and_expand(self, model: nn.Module):
|
||||
def estimate_and_expand(self, model: nn.Module) -> None:
|
||||
"""estimate and search for RFConvOp.
|
||||
|
||||
Args:
|
||||
|
@ -146,7 +145,7 @@ class RFSearchHook(Hook):
|
|||
def wrap_model(self,
|
||||
model: nn.Module,
|
||||
search_op: str = 'Conv2d',
|
||||
prefix: str = ''):
|
||||
prefix: str = '') -> None:
|
||||
"""wrap model to support searchable conv op.
|
||||
|
||||
Args:
|
||||
|
@ -176,8 +175,9 @@ class RFSearchHook(Hook):
|
|||
module, self.config['search'], self.verbose)
|
||||
moduleWrap = moduleWrap.to(module.weight.device)
|
||||
if self.verbose:
|
||||
logger.info('Wrap model %s to %s.' %
|
||||
(str(module), str(moduleWrap)))
|
||||
print_log(
|
||||
'Wrap model %s to %s.' %
|
||||
(str(module), str(moduleWrap)), 'current')
|
||||
setattr(model, name, moduleWrap)
|
||||
elif not isinstance(module, BaseConvRFSearchOp):
|
||||
self.wrap_model(module, search_op, fullname)
|
||||
|
@ -186,7 +186,7 @@ class RFSearchHook(Hook):
|
|||
model: nn.Module,
|
||||
search_op: str = 'Conv2d',
|
||||
init_rates: Optional[int] = None,
|
||||
prefix: str = ''):
|
||||
prefix: str = '') -> None:
|
||||
"""set model based on config.
|
||||
|
||||
Args:
|
||||
|
@ -231,8 +231,9 @@ class RFSearchHook(Hook):
|
|||
self.config['structure'][fullname][1]))
|
||||
setattr(model, name, module)
|
||||
if self.verbose:
|
||||
logger.info(
|
||||
print_log(
|
||||
'Set module %s dilation as: [%d %d]' %
|
||||
(fullname, module.dilation[0], module.dilation[1]))
|
||||
(fullname, module.dilation[0], module.dilation[1]),
|
||||
'current')
|
||||
elif not isinstance(module, BaseConvRFSearchOp):
|
||||
self.set_model(module, search_op, init_rates, fullname)
|
||||
|
|
|
@ -440,7 +440,7 @@ def imcrop(
|
|||
img (ndarray): Image to be cropped.
|
||||
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
|
||||
scale (float, optional): Scale ratio of bboxes, the default value
|
||||
1.0 means no padding.
|
||||
1.0 means no scaling.
|
||||
pad_fill (Number | list[Number]): Value to be filled for padding.
|
||||
Default: None, which means no padding.
|
||||
|
||||
|
|
|
@ -1,17 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Tests the rfsearch with runners.
|
||||
|
||||
CommandLine:
|
||||
pytest tests/test_runner/test_hooks.py
|
||||
xdoctest tests/test_hooks.py zero
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmcv.cnn.rfsearch import Conv2dRFSearchOp, RFSearchHook
|
||||
from tests.test_runner.test_hooks import _build_demo_runner
|
||||
|
||||
|
||||
def test_rfsearchhook():
|
||||
|
@ -114,20 +105,6 @@ def test_rfsearchhook():
|
|||
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
|
||||
assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
|
||||
|
||||
# 1. test step() with mode of search
|
||||
loader = DataLoader(torch.ones((1, 1, 1, 1)))
|
||||
runner = _build_demo_runner()
|
||||
runner.model = model
|
||||
runner.register_hook(rfsearchhook_search)
|
||||
runner.run([loader], [('train', 1)])
|
||||
|
||||
test_skip_layer()
|
||||
assert not isinstance(model.conv1, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv2, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv3, Conv2dRFSearchOp)
|
||||
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
|
||||
assert model.conv3.dilation_rates == [(1, 1), (1, 3)]
|
||||
|
||||
# 2. test init_model() with mode of fixed_single_branch
|
||||
model = Model()
|
||||
rfsearchhook_fixed_single_branch.init_model(model)
|
||||
|
@ -139,19 +116,6 @@ def test_rfsearchhook():
|
|||
assert model.conv2.dilation == (2, 2)
|
||||
assert model.conv3.dilation == (1, 1)
|
||||
|
||||
# 2. test step() with mode of fixed_single_branch
|
||||
runner = _build_demo_runner()
|
||||
runner.model = model
|
||||
runner.register_hook(rfsearchhook_fixed_single_branch)
|
||||
runner.run([loader], [('train', 1)])
|
||||
|
||||
assert not isinstance(model.conv1, Conv2dRFSearchOp)
|
||||
assert not isinstance(model.conv2, Conv2dRFSearchOp)
|
||||
assert not isinstance(model.conv3, Conv2dRFSearchOp)
|
||||
assert model.conv1.dilation == (1, 1)
|
||||
assert model.conv2.dilation == (2, 2)
|
||||
assert model.conv3.dilation == (1, 1)
|
||||
|
||||
# 3. test init_model() with mode of fixed_multi_branch
|
||||
model = Model()
|
||||
rfsearchhook_fixed_multi_branch.init_model(model)
|
||||
|
@ -162,16 +126,3 @@ def test_rfsearchhook():
|
|||
assert isinstance(model.conv3, Conv2dRFSearchOp)
|
||||
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
|
||||
assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
|
||||
|
||||
# 3. test step() with mode of fixed_single_branch
|
||||
runner = _build_demo_runner()
|
||||
runner.model = model
|
||||
runner.register_hook(rfsearchhook_fixed_multi_branch)
|
||||
runner.run([loader], [('train', 1)])
|
||||
|
||||
test_skip_layer()
|
||||
assert not isinstance(model.conv1, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv2, Conv2dRFSearchOp)
|
||||
assert isinstance(model.conv3, Conv2dRFSearchOp)
|
||||
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
|
||||
assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
|
||||
|
|
Loading…
Reference in New Issue