[Fix] Fixed evolution search pending (#330)

* fixed evolution search pending

* update mmcv_maximum_version

* remove timm in CI

* remove timm in requirements

* add es unittest

* update mmcv version
pull/325/head
humu789 2022-10-25 17:59:33 +08:00 committed by GitHub
parent 8b57a07b5e
commit 12455f2470
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 82 additions and 9 deletions

View File

@ -75,9 +75,9 @@ jobs:
- name: Install unittest dependencies
run: |
pip install -r requirements.txt
- name: Install timm
run: |
pip install timm
# - name: Install timm
# run: |
# pip install timm
- name: Build and install
run: |
rm -rf .eggs
@ -125,9 +125,9 @@ jobs:
- name: Install unittest dependencies
run: |
pip install -r requirements.txt
- name: Install timm
run: |
pip install timm
# - name: Install timm
# run: |
# pip install timm
- name: Build and install
run: |
pip install -e . -U

View File

@ -48,7 +48,7 @@ def digit_version(version_str: str, length: int = 4):
mmcv_minimum_version = '1.3.13'
mmcv_maximum_version = '1.6.0'
mmcv_maximum_version = '1.8.0'
mmcv_version = digit_version(mmcv.__version__)

View File

@ -56,6 +56,11 @@ class EvolutionSearcher():
if not hasattr(algorithm, 'module'):
raise NotImplementedError('Do not support searching with cpu.')
if num_mutation + num_crossover > candidate_pool_size:
raise ValueError(
f'The sum of num_mutation({num_mutation}) and '
f'num_crossover({num_crossover}) should not be '
f'greater than candidate_pool_size({candidate_pool_size}).')
self.algorithm = algorithm.module
self.algorithm_for_test = algorithm
self.dataloader = dataloader
@ -210,7 +215,6 @@ class EvolutionSearcher():
self.logger.info(
f'Epoch:[{epoch + 1}/{self.max_epoch}], top1_score: '
f'{list(self.top_k_candidates_with_score.keys())[0]}')
broadcast_object_list(self.candidate_pool)
if rank == 0:
final_subnet_dict = list(

View File

@ -1,4 +1,3 @@
albumentations>=0.3.2
mmdet
mmsegmentation
timm

View File

@ -0,0 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import shutil
import tempfile
from unittest import TestCase
from unittest.mock import MagicMock, patch
import torch
from torch.utils.data import DataLoader, Dataset
from mmrazor.core.builder import SEARCHERS
def collate_fn(data_batch):
return data_batch
class ToyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
class TestEvolutionSearcher(TestCase):
def setUp(self):
self.work_dir = tempfile.mkdtemp()
self.searcher_cfg = dict(
type='EvolutionSearcher',
metrics='bbox',
score_key='bbox_mAP',
constraints=dict(flops=300 * 1e6),
candidate_pool_size=50,
candidate_top_k=10,
max_epoch=20,
num_mutation=20,
num_crossover=20,
)
self.dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn)
self.test_fn = MagicMock()
self.logger = MagicMock()
def tearDown(self):
shutil.rmtree(self.work_dir)
@patch('mmrazor.models.algorithms.DetNAS')
def test_init(self, mock_algo):
cfg = copy.deepcopy(self.searcher_cfg)
cfg['algorithm'] = mock_algo
cfg['dataloader'] = self.dataloader
cfg['test_fn'] = self.test_fn
cfg['work_dir'] = self.work_dir
cfg['logger'] = self.logger
searcher = SEARCHERS.build(cfg)
assert hasattr(searcher, 'algorithm')
assert hasattr(searcher, 'logger')
cfg['num_mutation'] = 40
with self.assertRaises(ValueError):
searcher = SEARCHERS.build(cfg)