fix ci (#284)
* fix ci for circle ci * fix bug in test_metafiles * add pr_stage_test for github ci * add multiple version * fix ut * fix lint * Temporarily skip dataset UT * update github ci * add github lint ci * install wheel * remove timm from requirements * install wheel when test on windows * fix error * fix bug * remove github windows ci * fix device error of arch_params when DsnasDDP * fix CRD dataset ut * fix scope error * rm test_cuda in workflows of github * [Doc] fix typos in en/usr_guides Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>pull/311/head^2
parent
6cd8c68d0f
commit
f98ac3416b
.circleci
configs/nas/mmcls/darts
docs/en/user_guides
mmrazor
datasets
models/task_modules/tracer
requirements
tests
test_datasets
test_transforms
test_models
test_algorithms
test_losses
test_mutators/test_classical_models
|
@ -26,7 +26,6 @@ jobs:
|
|||
command: |
|
||||
pip install interrogate
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 80 mmrazor
|
||||
|
||||
build_cpu:
|
||||
parameters:
|
||||
# The python version must match available image tags in
|
||||
|
@ -37,8 +36,6 @@ jobs:
|
|||
type: string
|
||||
torchvision:
|
||||
type: string
|
||||
mmcv:
|
||||
type: string
|
||||
docker:
|
||||
- image: cimg/python:<< parameters.python >>
|
||||
resource_class: large
|
||||
|
@ -58,20 +55,21 @@ jobs:
|
|||
name: Install PyTorch
|
||||
command: |
|
||||
python -V
|
||||
python -m pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- when:
|
||||
condition:
|
||||
equal: [ "3.9.0", << parameters.python >> ]
|
||||
equal: ["3.9.0", << parameters.python >>]
|
||||
steps:
|
||||
- run: pip install "protobuf <= 3.20.1" && sudo apt-get update && sudo apt-get -y install libprotobuf-dev protobuf-compiler cmake
|
||||
- run:
|
||||
name: Install mmrazor dependencies
|
||||
command: |
|
||||
python -m pip install git+ssh://git@github.com/open-mmlab/mmengine.git@main
|
||||
python -m pip install << parameters.mmcv >>
|
||||
python -m pip install git+ssh://git@github.com/open-mmlab/mmclassification.git@dev-1.x
|
||||
python -m pip install git+ssh://git@github.com/open-mmlab/mmdetection.git@dev-3.x
|
||||
python -m pip install git+ssh://git@github.com/open-mmlab/mmsegmentation.git@dev-1.x
|
||||
pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||
pip install -U openmim
|
||||
mim install 'mmcv >= 2.0.0rc1'
|
||||
pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x
|
||||
pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x
|
||||
pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x
|
||||
pip install -r requirements.txt
|
||||
- run:
|
||||
name: Build and install
|
||||
|
@ -80,10 +78,9 @@ jobs:
|
|||
- run:
|
||||
name: Run unittests
|
||||
command: |
|
||||
python -m coverage run --branch --source mmrazor -m pytest tests/
|
||||
python -m coverage xml
|
||||
python -m coverage report -m
|
||||
|
||||
coverage run --branch --source mmrazor -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
build_cuda:
|
||||
parameters:
|
||||
torch:
|
||||
|
@ -94,8 +91,6 @@ jobs:
|
|||
cudnn:
|
||||
type: integer
|
||||
default: 7
|
||||
mmcv:
|
||||
type: string
|
||||
machine:
|
||||
image: ubuntu-2004-cuda-11.4:202110-01
|
||||
# docker_layer_caching: true
|
||||
|
@ -103,13 +98,13 @@ jobs:
|
|||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
# CLoning repos in VM since Docker doesn't have access to the private key
|
||||
# Cloning repos in VM since Docker doesn't have access to the private key
|
||||
name: Clone Repos
|
||||
command: |
|
||||
git clone -b main --depth 1 ssh://git@github.com/open-mmlab/mmengine.git /home/circleci/mmengine
|
||||
git clone -b dev-3.x --depth 1 ssh://git@github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection
|
||||
git clone -b dev-1.x --depth 1 ssh://git@github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification
|
||||
git clone -b dev-1.x --depth 1 ssh://git@github.com/open-mmlab/mmsegmentation.git /home/circleci/mmsegmentation
|
||||
git clone -b main --depth 1 https://github.com/open-mmlab/mmengine.git /home/circleci/mmengine
|
||||
git clone -b dev-3.x --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection
|
||||
git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification
|
||||
git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmsegmentation.git /home/circleci/mmsegmentation
|
||||
- run:
|
||||
name: Build Docker image
|
||||
command: |
|
||||
|
@ -117,10 +112,10 @@ jobs:
|
|||
docker run --gpus all -t -d -v /home/circleci/project:/mmrazor -v /home/circleci/mmengine:/mmengine -v /home/circleci/mmdetection:/mmdetection -v /home/circleci/mmclassification:/mmclassification -v /home/circleci/mmsegmentation:/mmsegmentation -w /mmrazor --name mmrazor mmrazor:gpu
|
||||
- run:
|
||||
name: Install mmrazor dependencies
|
||||
# pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch${{matrix.torch_version}}/index.html
|
||||
command: |
|
||||
docker exec mmrazor pip install -e /mmengine
|
||||
docker exec mmrazor pip install << parameters.mmcv >>
|
||||
docker exec mmrazor pip install -U openmim
|
||||
docker exec mmrazor mim install 'mmcv >= 2.0.0rc1'
|
||||
docker exec mmrazor pip install -e /mmdetection
|
||||
docker exec mmrazor pip install -e /mmclassification
|
||||
docker exec mmrazor pip install -e /mmsegmentation
|
||||
|
@ -132,7 +127,7 @@ jobs:
|
|||
- run:
|
||||
name: Run unittests
|
||||
command: |
|
||||
docker exec mmrazor python -m pytest tests/
|
||||
docker exec mmrazor pytest tests/
|
||||
|
||||
workflows:
|
||||
pr_stage_lint:
|
||||
|
@ -144,10 +139,10 @@ workflows:
|
|||
branches:
|
||||
ignore:
|
||||
- dev-1.x
|
||||
- 1.x
|
||||
pr_stage_test:
|
||||
when:
|
||||
not:
|
||||
<< pipeline.parameters.lint_only >>
|
||||
not: << pipeline.parameters.lint_only >>
|
||||
jobs:
|
||||
- lint:
|
||||
name: lint
|
||||
|
@ -159,16 +154,14 @@ workflows:
|
|||
name: minimum_version_cpu
|
||||
torch: 1.6.0
|
||||
torchvision: 0.7.0
|
||||
python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images
|
||||
mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cpu/torch1.6.0/mmcv_full-2.0.0rc1-cp36-cp36m-manylinux1_x86_64.whl
|
||||
python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images
|
||||
requires:
|
||||
- lint
|
||||
- build_cpu:
|
||||
name: maximum_version_cpu
|
||||
torch: 1.9.0
|
||||
torchvision: 0.10.0
|
||||
torch: 1.12.1
|
||||
torchvision: 0.13.1
|
||||
python: 3.9.0
|
||||
mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cpu/torch1.9.0/mmcv_full-2.0.0rc1-cp39-cp39-manylinux1_x86_64.whl
|
||||
requires:
|
||||
- minimum_version_cpu
|
||||
- hold:
|
||||
|
@ -181,20 +174,17 @@ workflows:
|
|||
# Use double quotation mark to explicitly specify its type
|
||||
# as string instead of number
|
||||
cuda: "10.2"
|
||||
mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cu102/torch1.8.0/mmcv_full-2.0.0rc1-cp37-cp37m-manylinux1_x86_64.whl
|
||||
requires:
|
||||
- hold
|
||||
merge_stage_test:
|
||||
when:
|
||||
not:
|
||||
<< pipeline.parameters.lint_only >>
|
||||
not: << pipeline.parameters.lint_only >>
|
||||
jobs:
|
||||
- build_cuda:
|
||||
name: minimum_version_gpu
|
||||
torch: 1.6.0
|
||||
# Use double quotation mark to explicitly specify its type
|
||||
# as string instead of number
|
||||
mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cu101/torch1.6.0/mmcv_full-2.0.0rc1-cp37-cp37m-manylinux1_x86_64.whl
|
||||
cuda: "10.1"
|
||||
filters:
|
||||
branches:
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
name: build
|
||||
|
||||
on:
|
||||
push:
|
||||
paths-ignore:
|
||||
- "README.md"
|
||||
- "README_zh-CN.md"
|
||||
- "model-index.yml"
|
||||
- "configs/**"
|
||||
- "docs/**"
|
||||
- ".dev_scripts/**"
|
||||
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- "README.md"
|
||||
- "README_zh-CN.md"
|
||||
- "docs/**"
|
||||
- "demo/**"
|
||||
- ".dev_scripts/**"
|
||||
- ".circleci/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test_linux:
|
||||
runs-on: ubuntu-18.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0]
|
||||
include:
|
||||
- torch: 1.6.0
|
||||
torch_version: 1.6
|
||||
torchvision: 0.7.0
|
||||
- torch: 1.7.0
|
||||
torch_version: 1.7
|
||||
torchvision: 0.8.1
|
||||
- torch: 1.7.0
|
||||
torch_version: 1.7
|
||||
torchvision: 0.8.1
|
||||
python-version: 3.8
|
||||
- torch: 1.8.0
|
||||
torch_version: 1.8
|
||||
torchvision: 0.9.0
|
||||
- torch: 1.8.0
|
||||
torch_version: 1.8
|
||||
torchvision: 0.9.0
|
||||
python-version: 3.8
|
||||
- torch: 1.9.0
|
||||
torch_version: 1.9
|
||||
torchvision: 0.10.0
|
||||
- torch: 1.9.0
|
||||
torch_version: 1.9
|
||||
torchvision: 0.10.0
|
||||
python-version: 3.8
|
||||
- torch: 1.10.0
|
||||
torch_version: 1.10
|
||||
torchvision: 0.11.0
|
||||
- torch: 1.10.0
|
||||
torch_version: 1.10
|
||||
torchvision: 0.11.0
|
||||
python-version: 3.8
|
||||
- torch: 1.11.0
|
||||
torch_version: 1.11
|
||||
torchvision: 0.12.0
|
||||
- torch: 1.11.0
|
||||
torch_version: 1.11
|
||||
torchvision: 0.12.0
|
||||
python-version: 3.8
|
||||
- torch: 1.12.0
|
||||
torch_version: 1.12
|
||||
torchvision: 0.13.0
|
||||
- torch: 1.12.0
|
||||
torch_version: 1.12
|
||||
torchvision: 0.13.0
|
||||
python-version: 3.8
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: |
|
||||
pip install pip --upgrade
|
||||
pip install wheel
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install MMEngine
|
||||
run: pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||
- name: Install MMCV
|
||||
run: |
|
||||
pip install -U openmim
|
||||
mim install 'mmcv >= 2.0.0rc1'
|
||||
- name: Install MMCls
|
||||
run: pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x
|
||||
- name: Install MMDet
|
||||
run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x
|
||||
- name: Install MMSeg
|
||||
run: pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x
|
||||
- name: Install other dependencies
|
||||
run: pip install -r requirements.txt
|
||||
- name: Build and install
|
||||
run: rm -rf .eggs && pip install -e .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmrazor -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
# Upload coverage report for python3.8 && pytorch1.12.0 cpu
|
||||
- name: Upload coverage to Codecov
|
||||
if: ${{matrix.torch == '1.12.0' && matrix.python-version == '3.8'}}
|
||||
uses: codecov/codecov-action@v2
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
env_vars: OS,PYTHON
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
|
||||
# test_windows:
|
||||
# runs-on: ${{ matrix.os }}
|
||||
# strategy:
|
||||
# matrix:
|
||||
# os: [windows-2022]
|
||||
# python: [3.7]
|
||||
# platform: [cpu]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v2
|
||||
# - name: Set up Python ${{ matrix.python-version }}
|
||||
# uses: actions/setup-python@v2
|
||||
# with:
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
# - name: Upgrade pip
|
||||
# run: |
|
||||
# pip install pip --upgrade
|
||||
# pip install wheel
|
||||
# - name: Install lmdb
|
||||
# run: pip install lmdb
|
||||
# - name: Install PyTorch
|
||||
# run: pip install torch==1.8.1+${{matrix.platform}} torchvision==0.9.1+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
|
||||
# - name: Install mmrazor dependencies
|
||||
# run: |
|
||||
# pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||
# pip install -U openmim
|
||||
# mim install 'mmcv >= 2.0.0rc1'
|
||||
# pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x
|
||||
# pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x
|
||||
# pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x
|
||||
# pip install -r requirements.txt
|
||||
# - name: Build and install
|
||||
# run: |
|
||||
# pip install -e .
|
||||
# - name: Run unittests and generate coverage report
|
||||
# run: |
|
||||
# pytest tests/
|
|
@ -0,0 +1,27 @@
|
|||
name: lint
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
- name: Install pre-commit hook
|
||||
run: |
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
- name: Linting
|
||||
run: pre-commit run --all-files
|
||||
- name: Check docstring coverage
|
||||
run: |
|
||||
pip install interrogate
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmrazor
|
|
@ -24,5 +24,5 @@ Models:
|
|||
Metrics:
|
||||
Top 1 Accuracy: 97.32
|
||||
Top 5 Accuracy: 99.94
|
||||
Config: configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py
|
||||
Config: configs/nas/darts/darts_subnet_1xb96_cifar10_2.0.py
|
||||
Weights: https://download.openmmlab.com/mmrazor/v0.1/nas/darts/darts_subnetnet_1xb96_cifar10/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921.pth
|
||||
|
|
|
@ -86,7 +86,7 @@ For example, the default `_channel_cfg_paths` is set in the config below.
|
|||
|
||||
```Python
|
||||
python ./tools/train.py \
|
||||
configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M \
|
||||
configs/pruning/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py \
|
||||
--work-dir your_work_dir
|
||||
```
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ class CRDDataset:
|
|||
# e.g. [2, 3, 5].
|
||||
num_classes: int = self.num_classes # type: ignore
|
||||
if num_classes is None:
|
||||
num_classes = len(self.dataset.CLASSES)
|
||||
num_classes = max(self.dataset.get_gt_labels()) + 1
|
||||
|
||||
if not self.dataset.test_mode: # type: ignore
|
||||
# Parse info.
|
||||
|
|
|
@ -118,13 +118,20 @@ def parse_cat(tracer, grad_fn, module2name, param2module, cur_path,
|
|||
>>> # ``out`` is obtained by concatenating two tensors
|
||||
"""
|
||||
parents = grad_fn.next_functions
|
||||
concat_id = '_'.join([str(id(p)) for p in parents])
|
||||
concat_id_list = [str(id(p)) for p in parents]
|
||||
concat_id_list.sort()
|
||||
concat_id = '_'.join(concat_id_list)
|
||||
name = f'concat_{concat_id}'
|
||||
|
||||
visited[name] = True
|
||||
sub_path_lists = list()
|
||||
for i, parent in enumerate(parents):
|
||||
for _, parent in enumerate(parents):
|
||||
sub_path_list = PathList()
|
||||
tracer.backward_trace(parent, module2name, param2module, Path(),
|
||||
sub_path_list, visited, shared_module)
|
||||
sub_path_lists.append(sub_path_list)
|
||||
cur_path.append(PathConcatNode('CatNode', sub_path_lists))
|
||||
cur_path.append(PathConcatNode(name, sub_path_lists))
|
||||
|
||||
result_paths.append(copy.deepcopy(cur_path))
|
||||
cur_path.pop(-1)
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
albumentations>=0.3.2
|
||||
scipy
|
||||
timm
|
||||
# timm
|
||||
|
|
|
@ -6,6 +6,7 @@ import tempfile
|
|||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
from mmcls.registry import DATASETS as CLS_DATASETS
|
||||
|
||||
from mmrazor.registry import DATASETS
|
||||
from mmrazor.utils import register_all_modules
|
||||
|
@ -15,7 +16,8 @@ ASSETS_ROOT = osp.abspath(osp.join(osp.dirname(__file__), '../data/dataset'))
|
|||
|
||||
|
||||
class Test_CRD_CIFAR10(TestCase):
|
||||
DATASET_TYPE = 'CRD_CIFAR10'
|
||||
ORI_DATASET_TYPE = 'CIFAR10'
|
||||
DATASET_TYPE = 'CRDDataset'
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
|
@ -24,10 +26,11 @@ class Test_CRD_CIFAR10(TestCase):
|
|||
tmpdir = tempfile.TemporaryDirectory()
|
||||
cls.tmpdir = tmpdir
|
||||
data_prefix = tmpdir.name
|
||||
cls.DEFAULT_ARGS = dict(
|
||||
cls.ORI_DEFAULT_ARGS = dict(
|
||||
data_prefix=data_prefix, pipeline=[], test_mode=False)
|
||||
cls.DEFAULT_ARGS = dict(neg_num=1, percent=0.5)
|
||||
|
||||
dataset_class = DATASETS.get(cls.DATASET_TYPE)
|
||||
dataset_class = CLS_DATASETS.get(cls.ORI_DATASET_TYPE)
|
||||
base_folder = osp.join(data_prefix, dataset_class.base_folder)
|
||||
os.mkdir(base_folder)
|
||||
|
||||
|
@ -65,25 +68,16 @@ class Test_CRD_CIFAR10(TestCase):
|
|||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# Test overriding metainfo by `metainfo` argument
|
||||
cfg = {**self.DEFAULT_ARGS, 'metainfo': {'classes': ('bus', 'car')}}
|
||||
ori_cfg = {
|
||||
**self.ORI_DEFAULT_ARGS, 'metainfo': {
|
||||
'classes': ('bus', 'car')
|
||||
},
|
||||
'type': self.ORI_DATASET_TYPE,
|
||||
'_scope_': 'mmcls'
|
||||
}
|
||||
cfg = {'dataset': ori_cfg, **self.DEFAULT_ARGS}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
|
||||
# Test overriding metainfo by `classes` argument
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
|
||||
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
|
||||
|
||||
# Test invalid classes
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
|
||||
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
|
||||
dataset_class(**cfg)
|
||||
self.assertEqual(dataset.dataset.CLASSES, ('bus', 'car'))
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
@ -91,4 +85,4 @@ class Test_CRD_CIFAR10(TestCase):
|
|||
|
||||
|
||||
class Test_CRD_CIFAR100(Test_CRD_CIFAR10):
|
||||
DATASET_TYPE = 'CRD_CIFAR100'
|
||||
ORI_DATASET_TYPE = 'CIFAR100'
|
||||
|
|
|
@ -6,7 +6,7 @@ import unittest
|
|||
import numpy as np
|
||||
import torch
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmengine.data import LabelData
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmrazor.datasets.transforms import PackCRDClsInputs
|
||||
|
||||
|
@ -34,7 +34,7 @@ class TestPackClsInputs(unittest.TestCase):
|
|||
'img': rng.rand(300, 400),
|
||||
'gt_label': rng.randint(3, ),
|
||||
# TODO.
|
||||
'contrast_sample_idxs': rng.randint()
|
||||
'contrast_sample_idxs': rng.randint(3, )
|
||||
}
|
||||
self.meta_keys = ('sample_idx', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'flip')
|
||||
|
@ -44,13 +44,13 @@ class TestPackClsInputs(unittest.TestCase):
|
|||
results = transform(copy.deepcopy(self.results1))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertIsInstance(results['inputs'], torch.Tensor)
|
||||
self.assertIn('data_sample', results)
|
||||
self.assertIsInstance(results['data_sample'], ClsDataSample)
|
||||
self.assertIn('data_samples', results)
|
||||
self.assertIsInstance(results['data_samples'], ClsDataSample)
|
||||
|
||||
data_sample = results['data_sample']
|
||||
data_sample = results['data_samples']
|
||||
self.assertIsInstance(data_sample.gt_label, LabelData)
|
||||
|
||||
def test_repr(self):
|
||||
transform = PackCRDClsInputs(meta_keys=self.meta_keys)
|
||||
self.assertEqual(
|
||||
repr(transform), f'PackClsInputs(meta_keys={self.meta_keys})')
|
||||
repr(transform), f'PackCRDClsInputs(meta_keys={self.meta_keys})')
|
||||
|
|
|
@ -18,7 +18,8 @@ MUTATOR_TYPE = Union[torch.nn.Module, Dict]
|
|||
DISTILLER_TYPE = Union[torch.nn.Module, Dict]
|
||||
|
||||
ARCHITECTURE_CFG = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
_scope_='mmcls',
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='MobileNetV2', widen_factor=1.5),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
|
|
|
@ -170,19 +170,17 @@ class TestDsnasDDP(TestDsnas):
|
|||
os.environ['MASTER_PORT'] = '12345'
|
||||
|
||||
# initialize the process group
|
||||
if torch.cuda.is_available():
|
||||
backend = 'nccl'
|
||||
cls.device = 'cuda'
|
||||
else:
|
||||
backend = 'gloo'
|
||||
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
|
||||
dist.init_process_group(backend, rank=0, world_size=1)
|
||||
|
||||
def prepare_model(self, device_ids=None) -> Dsnas:
|
||||
model = ToyDiffModule().to(self.device)
|
||||
mutator = DiffModuleMutator().to(self.device)
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
model = ToyDiffModule()
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
|
||||
algo = Dsnas(model, mutator)
|
||||
algo = Dsnas(model, mutator).to(self.device)
|
||||
|
||||
return DsnasDDP(
|
||||
module=algo, find_unused_parameters=True, device_ids=device_ids)
|
||||
|
@ -199,24 +197,19 @@ class TestDsnasDDP(TestDsnas):
|
|||
|
||||
@patch('mmengine.logging.message_hub.MessageHub.get_info')
|
||||
def test_dsnasddp_train_step(self, mock_get_info) -> None:
|
||||
model = ToyDiffModule()
|
||||
mutator = DiffModuleMutator()
|
||||
mutator.prepare_from_supernet(model)
|
||||
ddp_model = self.prepare_model()
|
||||
mock_get_info.return_value = 2
|
||||
|
||||
algo = Dsnas(model, mutator)
|
||||
ddp_model = DsnasDDP(module=algo, find_unused_parameters=True)
|
||||
data = self._prepare_fake_data()
|
||||
optim_wrapper = build_optim_wrapper(ddp_model, self.OPTIM_WRAPPER_CFG)
|
||||
loss = ddp_model.train_step(data, optim_wrapper)
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
|
||||
algo = Dsnas(model, mutator)
|
||||
ddp_model = DsnasDDP(module=algo, find_unused_parameters=True)
|
||||
ddp_model = self.prepare_model()
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
architecture=OptimWrapper(SGD(model.parameters(), lr=0.1)),
|
||||
mutator=OptimWrapper(SGD(model.parameters(), lr=0.01)))
|
||||
architecture=OptimWrapper(SGD(ddp_model.parameters(), lr=0.1)),
|
||||
mutator=OptimWrapper(SGD(ddp_model.parameters(), lr=0.01)))
|
||||
loss = ddp_model.train_step(data, optim_wrapper_dict)
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
|
|
|
@ -15,7 +15,8 @@ from mmengine.optim import build_optim_wrapper
|
|||
from mmrazor.models.algorithms import SlimmableNetwork, SlimmableNetworkDDP
|
||||
|
||||
MODEL_CFG = dict(
|
||||
type='mmcls.ImageClassifier',
|
||||
_scope_='mmcls',
|
||||
type='ImageClassifier',
|
||||
backbone=dict(type='MobileNetV2', widen_factor=1.5),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
from mmrazor import digit_version
|
||||
from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, CRDLoss, DKDLoss,
|
||||
|
|
|
@ -15,6 +15,7 @@ from mmrazor.registry import MODELS
|
|||
from ..utils import load_and_merge_channel_cfgs
|
||||
|
||||
MODEL_CFG = dict(
|
||||
_scope_='mmcls',
|
||||
type='mmcls.ImageClassifier',
|
||||
backbone=dict(type='MobileNetV2', widen_factor=1.5),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
|
|
Loading…
Reference in New Issue