Refine rfsearch and fix a typo

pull/2695/head
Mashiro 2023-01-17 23:33:13 +08:00 committed by Zaida Zhou
parent 1f9e5b57c2
commit c458829763
6 changed files with 31 additions and 78 deletions

View File

@ -40,6 +40,7 @@ Module
NonLocal3d
Scale
Swish
Conv2dRFSearchOp
Build Function
----------------

View File

@ -40,6 +40,7 @@ Module
NonLocal3d
Scale
Swish
Conv2dRFSearchOp
Build Function
----------------

View File

@ -4,14 +4,12 @@ import copy
import numpy as np
import torch
import torch.nn as nn
from mmengine.logging import MMLogger
from mmengine.logging import print_log
from mmengine.model import BaseModule
from torch import Tensor
from .utils import expand_rates, get_single_padding
logger = MMLogger.get_current_instance()
class BaseConvRFSearchOp(BaseModule):
"""Based class of ConvRFSearchOp.
@ -84,7 +82,7 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches))
if self.verbose:
logger.info(f'Expand as {self.dilation_rates}')
print_log(f'Expand as {self.dilation_rates}', 'current')
nn.init.constant_(self.branch_weights, global_config['init_alphas'])
def forward(self, input: Tensor) -> Tensor:
@ -118,13 +116,14 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
output += outputs[i]
return output
def estimate_rates(self):
def estimate_rates(self) -> None:
"""Estimate new dilation rate based on trained branch_weights."""
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
if self.verbose:
logger.info('Estimate dilation {} with weight {}.'.format(
self.dilation_rates,
norm_w.detach().cpu().numpy().tolist()))
print_log(
'Estimate dilation {} with weight {}.'.format(
self.dilation_rates,
norm_w.detach().cpu().numpy().tolist()), 'current')
sum0, sum1, w_sum = 0, 0, 0
for i in range(len(self.dilation_rates)):
@ -143,9 +142,9 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self.op_layer.padding = self.get_padding(self.op_layer.dilation)
self.dilation_rates = [tuple(estimated)]
if self.verbose:
logger.info(f'Estimate as {tuple(estimated)}')
print_log(f'Estimate as {tuple(estimated)}', 'current')
def expand_rates(self):
def expand_rates(self) -> None:
"""Expand dilation rate."""
dilation = self.op_layer.dilation
dilation_rates = expand_rates(dilation, self.global_config)
@ -158,11 +157,11 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self.dilation_rates = copy.deepcopy(dilation_rates)
if self.verbose:
logger.info(f'Expand as {self.dilation_rates}')
print_log(f'Expand as {self.dilation_rates}', 'current')
nn.init.constant_(self.branch_weights,
self.global_config['init_alphas'])
def get_padding(self, dilation):
def get_padding(self, dilation) -> tuple:
padding = (get_single_padding(self.op_layer.kernel_size[0],
self.op_layer.stride[0], dilation[0]),
get_single_padding(self.op_layer.kernel_size[1],

View File

@ -3,15 +3,14 @@ import os
from typing import Dict, Optional
import mmengine
import torch # noqa
import torch.nn as nn
from mmengine.hooks import Hook
from mmengine.logging import MMLogger
from mmengine.logging import print_log
from mmengine.registry import HOOKS
from mmcv.cnn.rfsearch.utils import get_single_padding, write_to_json
from .operator import BaseConvRFSearchOp
logger = MMLogger.get_current_instance()
from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp # noqa
from .utils import get_single_padding, write_to_json
@HOOKS.register_module()
@ -82,7 +81,7 @@ class RFSearchHook(Hook):
search/fixed_single_branch/fixed_multi_branch
"""
if self.verbose:
logger.info('RFSearch init begin.')
print_log('RFSearch init begin.', 'current')
if self.mode == 'search':
if self.config['structure']:
self.set_model(model, search_op='Conv2d')
@ -95,19 +94,19 @@ class RFSearchHook(Hook):
else:
raise NotImplementedError
if self.verbose:
logger.info('RFSearch init end.')
print_log('RFSearch init end.', 'current')
def after_train_epoch(self, runner):
"""Performs a dilation searching step after one training epoch."""
if self.by_epoch and self.mode == 'search':
self.step(runner.model, runner.work_dir)
def after_train_iter(self, runner):
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""Performs a dilation searching step after one training iteration."""
if not self.by_epoch and self.mode == 'search':
self.step(runner.model, runner.work_dir)
def step(self, model: nn.Module, work_dir: str):
def step(self, model: nn.Module, work_dir: str) -> None:
"""Performs a dilation searching step.
Args:
@ -132,7 +131,7 @@ class RFSearchHook(Hook):
),
)
def estimate_and_expand(self, model: nn.Module):
def estimate_and_expand(self, model: nn.Module) -> None:
"""estimate and search for RFConvOp.
Args:
@ -146,7 +145,7 @@ class RFSearchHook(Hook):
def wrap_model(self,
model: nn.Module,
search_op: str = 'Conv2d',
prefix: str = ''):
prefix: str = '') -> None:
"""wrap model to support searchable conv op.
Args:
@ -176,8 +175,9 @@ class RFSearchHook(Hook):
module, self.config['search'], self.verbose)
moduleWrap = moduleWrap.to(module.weight.device)
if self.verbose:
logger.info('Wrap model %s to %s.' %
(str(module), str(moduleWrap)))
print_log(
'Wrap model %s to %s.' %
(str(module), str(moduleWrap)), 'current')
setattr(model, name, moduleWrap)
elif not isinstance(module, BaseConvRFSearchOp):
self.wrap_model(module, search_op, fullname)
@ -186,7 +186,7 @@ class RFSearchHook(Hook):
model: nn.Module,
search_op: str = 'Conv2d',
init_rates: Optional[int] = None,
prefix: str = ''):
prefix: str = '') -> None:
"""set model based on config.
Args:
@ -231,8 +231,9 @@ class RFSearchHook(Hook):
self.config['structure'][fullname][1]))
setattr(model, name, module)
if self.verbose:
logger.info(
print_log(
'Set module %s dilation as: [%d %d]' %
(fullname, module.dilation[0], module.dilation[1]))
(fullname, module.dilation[0], module.dilation[1]),
'current')
elif not isinstance(module, BaseConvRFSearchOp):
self.set_model(module, search_op, init_rates, fullname)

View File

@ -440,7 +440,7 @@ def imcrop(
img (ndarray): Image to be cropped.
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
scale (float, optional): Scale ratio of bboxes, the default value
1.0 means no padding.
1.0 means no scaling.
pad_fill (Number | list[Number]): Value to be filled for padding.
Default: None, which means no padding.

View File

@ -1,17 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Tests the rfsearch with runners.
CommandLine:
pytest tests/test_runner/test_hooks.py
xdoctest tests/test_hooks.py zero
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mmcv.cnn.rfsearch import Conv2dRFSearchOp, RFSearchHook
from tests.test_runner.test_hooks import _build_demo_runner
def test_rfsearchhook():
@ -114,20 +105,6 @@ def test_rfsearchhook():
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
# 1. test step() with mode of search
loader = DataLoader(torch.ones((1, 1, 1, 1)))
runner = _build_demo_runner()
runner.model = model
runner.register_hook(rfsearchhook_search)
runner.run([loader], [('train', 1)])
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 3)]
# 2. test init_model() with mode of fixed_single_branch
model = Model()
rfsearchhook_fixed_single_branch.init_model(model)
@ -139,19 +116,6 @@ def test_rfsearchhook():
assert model.conv2.dilation == (2, 2)
assert model.conv3.dilation == (1, 1)
# 2. test step() with mode of fixed_single_branch
runner = _build_demo_runner()
runner.model = model
runner.register_hook(rfsearchhook_fixed_single_branch)
runner.run([loader], [('train', 1)])
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert not isinstance(model.conv2, Conv2dRFSearchOp)
assert not isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv1.dilation == (1, 1)
assert model.conv2.dilation == (2, 2)
assert model.conv3.dilation == (1, 1)
# 3. test init_model() with mode of fixed_multi_branch
model = Model()
rfsearchhook_fixed_multi_branch.init_model(model)
@ -162,16 +126,3 @@ def test_rfsearchhook():
assert isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
# 3. test step() with mode of fixed_single_branch
runner = _build_demo_runner()
runner.model = model
runner.register_hook(rfsearchhook_fixed_multi_branch)
runner.run([loader], [('train', 1)])
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 2)]