From c458829763e311b8232ec7afa68c8fb9e7cdb1b6 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 17 Jan 2023 23:33:13 +0800 Subject: [PATCH] Refine rfsearch and fix a typo --- docs/en/api/cnn.rst | 1 + docs/zh_cn/api/cnn.rst | 1 + mmcv/cnn/rfsearch/operator.py | 23 +++++----- mmcv/cnn/rfsearch/search.py | 33 +++++++------- mmcv/image/geometric.py | 2 +- tests/test_cnn/test_rfsearch/test_search.py | 49 --------------------- 6 files changed, 31 insertions(+), 78 deletions(-) diff --git a/docs/en/api/cnn.rst b/docs/en/api/cnn.rst index 5cbcb191e..022191f17 100644 --- a/docs/en/api/cnn.rst +++ b/docs/en/api/cnn.rst @@ -40,6 +40,7 @@ Module NonLocal3d Scale Swish + Conv2dRFSearchOp Build Function ---------------- diff --git a/docs/zh_cn/api/cnn.rst b/docs/zh_cn/api/cnn.rst index 5cbcb191e..022191f17 100644 --- a/docs/zh_cn/api/cnn.rst +++ b/docs/zh_cn/api/cnn.rst @@ -40,6 +40,7 @@ Module NonLocal3d Scale Swish + Conv2dRFSearchOp Build Function ---------------- diff --git a/mmcv/cnn/rfsearch/operator.py b/mmcv/cnn/rfsearch/operator.py index 3d3416f59..2fa45abb0 100644 --- a/mmcv/cnn/rfsearch/operator.py +++ b/mmcv/cnn/rfsearch/operator.py @@ -4,14 +4,12 @@ import copy import numpy as np import torch import torch.nn as nn -from mmengine.logging import MMLogger +from mmengine.logging import print_log from mmengine.model import BaseModule from torch import Tensor from .utils import expand_rates, get_single_padding -logger = MMLogger.get_current_instance() - class BaseConvRFSearchOp(BaseModule): """Based class of ConvRFSearchOp. @@ -84,7 +82,7 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp): self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches)) if self.verbose: - logger.info(f'Expand as {self.dilation_rates}') + print_log(f'Expand as {self.dilation_rates}', 'current') nn.init.constant_(self.branch_weights, global_config['init_alphas']) def forward(self, input: Tensor) -> Tensor: @@ -118,13 +116,14 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp): output += outputs[i] return output - def estimate_rates(self): + def estimate_rates(self) -> None: """Estimate new dilation rate based on trained branch_weights.""" norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) if self.verbose: - logger.info('Estimate dilation {} with weight {}.'.format( - self.dilation_rates, - norm_w.detach().cpu().numpy().tolist())) + print_log( + 'Estimate dilation {} with weight {}.'.format( + self.dilation_rates, + norm_w.detach().cpu().numpy().tolist()), 'current') sum0, sum1, w_sum = 0, 0, 0 for i in range(len(self.dilation_rates)): @@ -143,9 +142,9 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp): self.op_layer.padding = self.get_padding(self.op_layer.dilation) self.dilation_rates = [tuple(estimated)] if self.verbose: - logger.info(f'Estimate as {tuple(estimated)}') + print_log(f'Estimate as {tuple(estimated)}', 'current') - def expand_rates(self): + def expand_rates(self) -> None: """Expand dilation rate.""" dilation = self.op_layer.dilation dilation_rates = expand_rates(dilation, self.global_config) @@ -158,11 +157,11 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp): self.dilation_rates = copy.deepcopy(dilation_rates) if self.verbose: - logger.info(f'Expand as {self.dilation_rates}') + print_log(f'Expand as {self.dilation_rates}', 'current') nn.init.constant_(self.branch_weights, self.global_config['init_alphas']) - def get_padding(self, dilation): + def get_padding(self, dilation) -> tuple: padding = (get_single_padding(self.op_layer.kernel_size[0], self.op_layer.stride[0], dilation[0]), get_single_padding(self.op_layer.kernel_size[1], diff --git a/mmcv/cnn/rfsearch/search.py b/mmcv/cnn/rfsearch/search.py index d54021a0c..f4add4b23 100644 --- a/mmcv/cnn/rfsearch/search.py +++ b/mmcv/cnn/rfsearch/search.py @@ -3,15 +3,14 @@ import os from typing import Dict, Optional import mmengine +import torch # noqa import torch.nn as nn from mmengine.hooks import Hook -from mmengine.logging import MMLogger +from mmengine.logging import print_log from mmengine.registry import HOOKS -from mmcv.cnn.rfsearch.utils import get_single_padding, write_to_json -from .operator import BaseConvRFSearchOp - -logger = MMLogger.get_current_instance() +from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp # noqa +from .utils import get_single_padding, write_to_json @HOOKS.register_module() @@ -82,7 +81,7 @@ class RFSearchHook(Hook): search/fixed_single_branch/fixed_multi_branch """ if self.verbose: - logger.info('RFSearch init begin.') + print_log('RFSearch init begin.', 'current') if self.mode == 'search': if self.config['structure']: self.set_model(model, search_op='Conv2d') @@ -95,19 +94,19 @@ class RFSearchHook(Hook): else: raise NotImplementedError if self.verbose: - logger.info('RFSearch init end.') + print_log('RFSearch init end.', 'current') def after_train_epoch(self, runner): """Performs a dilation searching step after one training epoch.""" if self.by_epoch and self.mode == 'search': self.step(runner.model, runner.work_dir) - def after_train_iter(self, runner): + def after_train_iter(self, runner, batch_idx, data_batch, outputs): """Performs a dilation searching step after one training iteration.""" if not self.by_epoch and self.mode == 'search': self.step(runner.model, runner.work_dir) - def step(self, model: nn.Module, work_dir: str): + def step(self, model: nn.Module, work_dir: str) -> None: """Performs a dilation searching step. Args: @@ -132,7 +131,7 @@ class RFSearchHook(Hook): ), ) - def estimate_and_expand(self, model: nn.Module): + def estimate_and_expand(self, model: nn.Module) -> None: """estimate and search for RFConvOp. Args: @@ -146,7 +145,7 @@ class RFSearchHook(Hook): def wrap_model(self, model: nn.Module, search_op: str = 'Conv2d', - prefix: str = ''): + prefix: str = '') -> None: """wrap model to support searchable conv op. Args: @@ -176,8 +175,9 @@ class RFSearchHook(Hook): module, self.config['search'], self.verbose) moduleWrap = moduleWrap.to(module.weight.device) if self.verbose: - logger.info('Wrap model %s to %s.' % - (str(module), str(moduleWrap))) + print_log( + 'Wrap model %s to %s.' % + (str(module), str(moduleWrap)), 'current') setattr(model, name, moduleWrap) elif not isinstance(module, BaseConvRFSearchOp): self.wrap_model(module, search_op, fullname) @@ -186,7 +186,7 @@ class RFSearchHook(Hook): model: nn.Module, search_op: str = 'Conv2d', init_rates: Optional[int] = None, - prefix: str = ''): + prefix: str = '') -> None: """set model based on config. Args: @@ -231,8 +231,9 @@ class RFSearchHook(Hook): self.config['structure'][fullname][1])) setattr(model, name, module) if self.verbose: - logger.info( + print_log( 'Set module %s dilation as: [%d %d]' % - (fullname, module.dilation[0], module.dilation[1])) + (fullname, module.dilation[0], module.dilation[1]), + 'current') elif not isinstance(module, BaseConvRFSearchOp): self.set_model(module, search_op, init_rates, fullname) diff --git a/mmcv/image/geometric.py b/mmcv/image/geometric.py index 88fb63693..f35299bf9 100644 --- a/mmcv/image/geometric.py +++ b/mmcv/image/geometric.py @@ -440,7 +440,7 @@ def imcrop( img (ndarray): Image to be cropped. bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes. scale (float, optional): Scale ratio of bboxes, the default value - 1.0 means no padding. + 1.0 means no scaling. pad_fill (Number | list[Number]): Value to be filled for padding. Default: None, which means no padding. diff --git a/tests/test_cnn/test_rfsearch/test_search.py b/tests/test_cnn/test_rfsearch/test_search.py index 182134981..5614e3c1c 100644 --- a/tests/test_cnn/test_rfsearch/test_search.py +++ b/tests/test_cnn/test_rfsearch/test_search.py @@ -1,17 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -"""Tests the rfsearch with runners. -CommandLine: - pytest tests/test_runner/test_hooks.py - xdoctest tests/test_hooks.py zero -""" - -import torch import torch.nn as nn -from torch.utils.data import DataLoader from mmcv.cnn.rfsearch import Conv2dRFSearchOp, RFSearchHook -from tests.test_runner.test_hooks import _build_demo_runner def test_rfsearchhook(): @@ -114,20 +105,6 @@ def test_rfsearchhook(): assert model.conv2.dilation_rates == [(1, 1), (3, 3)] assert model.conv3.dilation_rates == [(1, 1), (1, 2)] - # 1. test step() with mode of search - loader = DataLoader(torch.ones((1, 1, 1, 1))) - runner = _build_demo_runner() - runner.model = model - runner.register_hook(rfsearchhook_search) - runner.run([loader], [('train', 1)]) - - test_skip_layer() - assert not isinstance(model.conv1, Conv2dRFSearchOp) - assert isinstance(model.conv2, Conv2dRFSearchOp) - assert isinstance(model.conv3, Conv2dRFSearchOp) - assert model.conv2.dilation_rates == [(1, 1), (3, 3)] - assert model.conv3.dilation_rates == [(1, 1), (1, 3)] - # 2. test init_model() with mode of fixed_single_branch model = Model() rfsearchhook_fixed_single_branch.init_model(model) @@ -139,19 +116,6 @@ def test_rfsearchhook(): assert model.conv2.dilation == (2, 2) assert model.conv3.dilation == (1, 1) - # 2. test step() with mode of fixed_single_branch - runner = _build_demo_runner() - runner.model = model - runner.register_hook(rfsearchhook_fixed_single_branch) - runner.run([loader], [('train', 1)]) - - assert not isinstance(model.conv1, Conv2dRFSearchOp) - assert not isinstance(model.conv2, Conv2dRFSearchOp) - assert not isinstance(model.conv3, Conv2dRFSearchOp) - assert model.conv1.dilation == (1, 1) - assert model.conv2.dilation == (2, 2) - assert model.conv3.dilation == (1, 1) - # 3. test init_model() with mode of fixed_multi_branch model = Model() rfsearchhook_fixed_multi_branch.init_model(model) @@ -162,16 +126,3 @@ def test_rfsearchhook(): assert isinstance(model.conv3, Conv2dRFSearchOp) assert model.conv2.dilation_rates == [(1, 1), (3, 3)] assert model.conv3.dilation_rates == [(1, 1), (1, 2)] - - # 3. test step() with mode of fixed_single_branch - runner = _build_demo_runner() - runner.model = model - runner.register_hook(rfsearchhook_fixed_multi_branch) - runner.run([loader], [('train', 1)]) - - test_skip_layer() - assert not isinstance(model.conv1, Conv2dRFSearchOp) - assert isinstance(model.conv2, Conv2dRFSearchOp) - assert isinstance(model.conv3, Conv2dRFSearchOp) - assert model.conv2.dilation_rates == [(1, 1), (3, 3)] - assert model.conv3.dilation_rates == [(1, 1), (1, 2)]