[Fix] Fixed evolution search pending (#330)
* fixed evolution search pending * update mmcv_maximum_version * remove timm in CI * remove timm in requirements * add es unittest * update mmcv versionpull/325/head
parent
8b57a07b5e
commit
12455f2470
|
@ -75,9 +75,9 @@ jobs:
|
|||
- name: Install unittest dependencies
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
- name: Install timm
|
||||
run: |
|
||||
pip install timm
|
||||
# - name: Install timm
|
||||
# run: |
|
||||
# pip install timm
|
||||
- name: Build and install
|
||||
run: |
|
||||
rm -rf .eggs
|
||||
|
@ -125,9 +125,9 @@ jobs:
|
|||
- name: Install unittest dependencies
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
- name: Install timm
|
||||
run: |
|
||||
pip install timm
|
||||
# - name: Install timm
|
||||
# run: |
|
||||
# pip install timm
|
||||
- name: Build and install
|
||||
run: |
|
||||
pip install -e . -U
|
||||
|
|
|
@ -48,7 +48,7 @@ def digit_version(version_str: str, length: int = 4):
|
|||
|
||||
|
||||
mmcv_minimum_version = '1.3.13'
|
||||
mmcv_maximum_version = '1.6.0'
|
||||
mmcv_maximum_version = '1.8.0'
|
||||
mmcv_version = digit_version(mmcv.__version__)
|
||||
|
||||
|
||||
|
|
|
@ -56,6 +56,11 @@ class EvolutionSearcher():
|
|||
|
||||
if not hasattr(algorithm, 'module'):
|
||||
raise NotImplementedError('Do not support searching with cpu.')
|
||||
if num_mutation + num_crossover > candidate_pool_size:
|
||||
raise ValueError(
|
||||
f'The sum of num_mutation({num_mutation}) and '
|
||||
f'num_crossover({num_crossover}) should not be '
|
||||
f'greater than candidate_pool_size({candidate_pool_size}).')
|
||||
self.algorithm = algorithm.module
|
||||
self.algorithm_for_test = algorithm
|
||||
self.dataloader = dataloader
|
||||
|
@ -210,7 +215,6 @@ class EvolutionSearcher():
|
|||
self.logger.info(
|
||||
f'Epoch:[{epoch + 1}/{self.max_epoch}], top1_score: '
|
||||
f'{list(self.top_k_candidates_with_score.keys())[0]}')
|
||||
broadcast_object_list(self.candidate_pool)
|
||||
|
||||
if rank == 0:
|
||||
final_subnet_dict = list(
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
albumentations>=0.3.2
|
||||
mmdet
|
||||
mmsegmentation
|
||||
timm
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmrazor.core.builder import SEARCHERS
|
||||
|
||||
|
||||
def collate_fn(data_batch):
|
||||
return data_batch
|
||||
|
||||
|
||||
class ToyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
class TestEvolutionSearcher(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.work_dir = tempfile.mkdtemp()
|
||||
self.searcher_cfg = dict(
|
||||
type='EvolutionSearcher',
|
||||
metrics='bbox',
|
||||
score_key='bbox_mAP',
|
||||
constraints=dict(flops=300 * 1e6),
|
||||
candidate_pool_size=50,
|
||||
candidate_top_k=10,
|
||||
max_epoch=20,
|
||||
num_mutation=20,
|
||||
num_crossover=20,
|
||||
)
|
||||
self.dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn)
|
||||
self.test_fn = MagicMock()
|
||||
self.logger = MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.work_dir)
|
||||
|
||||
@patch('mmrazor.models.algorithms.DetNAS')
|
||||
def test_init(self, mock_algo):
|
||||
cfg = copy.deepcopy(self.searcher_cfg)
|
||||
cfg['algorithm'] = mock_algo
|
||||
cfg['dataloader'] = self.dataloader
|
||||
cfg['test_fn'] = self.test_fn
|
||||
cfg['work_dir'] = self.work_dir
|
||||
cfg['logger'] = self.logger
|
||||
searcher = SEARCHERS.build(cfg)
|
||||
assert hasattr(searcher, 'algorithm')
|
||||
assert hasattr(searcher, 'logger')
|
||||
|
||||
cfg['num_mutation'] = 40
|
||||
with self.assertRaises(ValueError):
|
||||
searcher = SEARCHERS.build(cfg)
|
Loading…
Reference in New Issue