mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[CI] Update circle-ci and github workflow. (#1018)
* Add deploy workflow. * [CI] Update circle-ci and github workflow. * Fix windows CI * Update unit tests to save memory
This commit is contained in:
parent
85b1eae7f1
commit
61e9d890a6
@ -31,7 +31,7 @@ jobs:
|
|||||||
name: Check docstring coverage
|
name: Check docstring coverage
|
||||||
command: |
|
command: |
|
||||||
pip install interrogate
|
pip install interrogate
|
||||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 60 mmcls
|
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 60 mmcls
|
||||||
build_cpu:
|
build_cpu:
|
||||||
parameters:
|
parameters:
|
||||||
# The python version must match available image tags in
|
# The python version must match available image tags in
|
||||||
@ -42,8 +42,6 @@ jobs:
|
|||||||
type: string
|
type: string
|
||||||
torchvision:
|
torchvision:
|
||||||
type: string
|
type: string
|
||||||
mmcv:
|
|
||||||
type: string
|
|
||||||
docker:
|
docker:
|
||||||
- image: cimg/python:<< parameters.python >>
|
- image: cimg/python:<< parameters.python >>
|
||||||
resource_class: large
|
resource_class: large
|
||||||
@ -57,31 +55,32 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Configure Python & pip
|
name: Configure Python & pip
|
||||||
command: |
|
command: |
|
||||||
python -m pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
python -m pip install wheel
|
pip install wheel
|
||||||
- run:
|
- run:
|
||||||
name: Install PyTorch
|
name: Install PyTorch
|
||||||
command: |
|
command: |
|
||||||
python -V
|
python -V
|
||||||
python -m pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
- run:
|
- run:
|
||||||
name: Install mmcls dependencies
|
name: Install mmcls dependencies
|
||||||
command: |
|
command: |
|
||||||
python -m pip install git+ssh://git@github.com/open-mmlab/mmengine.git@main
|
pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||||
python -m pip install << parameters.mmcv >>
|
pip install -U openmim
|
||||||
python -m pip install timm
|
mim install 'mmcv >= 2.0.0rc1'
|
||||||
python -m pip install -r requirements.txt
|
pip install timm
|
||||||
|
pip install -r requirements.txt
|
||||||
python -c 'import mmcv; print(mmcv.__version__)'
|
python -c 'import mmcv; print(mmcv.__version__)'
|
||||||
- run:
|
- run:
|
||||||
name: Build and install
|
name: Build and install
|
||||||
command: |
|
command: |
|
||||||
python -m pip install -e .
|
pip install -e .
|
||||||
- run:
|
- run:
|
||||||
name: Run unittests
|
name: Run unittests
|
||||||
command: |
|
command: |
|
||||||
python -m coverage run --branch --source mmcls -m pytest tests/
|
coverage run --branch --source mmcls -m pytest tests/
|
||||||
python -m coverage xml
|
coverage xml
|
||||||
python -m coverage report -m
|
coverage report -m
|
||||||
|
|
||||||
build_cuda:
|
build_cuda:
|
||||||
machine:
|
machine:
|
||||||
@ -96,15 +95,13 @@ jobs:
|
|||||||
cudnn:
|
cudnn:
|
||||||
type: integer
|
type: integer
|
||||||
default: 7
|
default: 7
|
||||||
mmcv:
|
|
||||||
type: string
|
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
# Cloning repos in VM since Docker doesn't have access to the private key
|
# Cloning repos in VM since Docker doesn't have access to the private key
|
||||||
name: Clone Repos
|
name: Clone Repos
|
||||||
command: |
|
command: |
|
||||||
git clone -b main --depth 1 ssh://git@github.com/open-mmlab/mmengine.git /home/circleci/mmengine
|
git clone -b main --depth 1 https://github.com/open-mmlab/mmengine.git /home/circleci/mmengine
|
||||||
- run:
|
- run:
|
||||||
name: Build Docker image
|
name: Build Docker image
|
||||||
command: |
|
command: |
|
||||||
@ -114,7 +111,8 @@ jobs:
|
|||||||
name: Install mmcls dependencies
|
name: Install mmcls dependencies
|
||||||
command: |
|
command: |
|
||||||
docker exec mmcls pip install -e /mmengine
|
docker exec mmcls pip install -e /mmengine
|
||||||
docker exec mmcls pip install << parameters.mmcv >>
|
docker exec mmcls pip install -U openmim
|
||||||
|
docker exec mmcls mim install 'mmcv >= 2.0.0rc1'
|
||||||
docker exec mmcls pip install -r requirements.txt
|
docker exec mmcls pip install -r requirements.txt
|
||||||
docker exec mmcls python -c 'import mmcv; print(mmcv.__version__)'
|
docker exec mmcls python -c 'import mmcv; print(mmcv.__version__)'
|
||||||
- run:
|
- run:
|
||||||
@ -124,7 +122,7 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Run unittests
|
name: Run unittests
|
||||||
command: |
|
command: |
|
||||||
docker exec mmcls python -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
|
docker exec mmcls python -m pytest tests/ -k 'not timm'
|
||||||
|
|
||||||
# Invoke jobs via workflows
|
# Invoke jobs via workflows
|
||||||
# See: https://circleci.com/docs/2.0/configuration-reference/#workflows
|
# See: https://circleci.com/docs/2.0/configuration-reference/#workflows
|
||||||
@ -138,6 +136,7 @@ workflows:
|
|||||||
branches:
|
branches:
|
||||||
ignore:
|
ignore:
|
||||||
- dev-1.x
|
- dev-1.x
|
||||||
|
- 1.x
|
||||||
pr_stage_test:
|
pr_stage_test:
|
||||||
when:
|
when:
|
||||||
not:
|
not:
|
||||||
@ -154,15 +153,13 @@ workflows:
|
|||||||
torch: 1.6.0
|
torch: 1.6.0
|
||||||
torchvision: 0.7.0
|
torchvision: 0.7.0
|
||||||
python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images
|
python: 3.6.9 # The lowest python 3.6.x version available on CircleCI images
|
||||||
mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cpu/torch1.6.0/mmcv_full-2.0.0rc0-cp36-cp36m-manylinux1_x86_64.whl
|
|
||||||
requires:
|
requires:
|
||||||
- lint
|
- lint
|
||||||
- build_cpu:
|
- build_cpu:
|
||||||
name: maximum_version_cpu
|
name: maximum_version_cpu
|
||||||
torch: 1.9.0 # TODO: Update the version after mmcv provides more pre-compiled packages.
|
torch: 1.12.1
|
||||||
torchvision: 0.10.0
|
torchvision: 0.13.1
|
||||||
python: 3.9.0
|
python: 3.9.0
|
||||||
mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cpu/torch1.9.0/mmcv_full-2.0.0rc0-cp39-cp39-manylinux1_x86_64.whl
|
|
||||||
requires:
|
requires:
|
||||||
- minimum_version_cpu
|
- minimum_version_cpu
|
||||||
- hold:
|
- hold:
|
||||||
@ -175,7 +172,6 @@ workflows:
|
|||||||
# Use double quotation mark to explicitly specify its type
|
# Use double quotation mark to explicitly specify its type
|
||||||
# as string instead of number
|
# as string instead of number
|
||||||
cuda: "10.2"
|
cuda: "10.2"
|
||||||
mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cu102/torch1.8.0/mmcv_full-2.0.0rc0-cp37-cp37m-manylinux1_x86_64.whl
|
|
||||||
requires:
|
requires:
|
||||||
- hold
|
- hold
|
||||||
merge_stage_test:
|
merge_stage_test:
|
||||||
@ -188,7 +184,6 @@ workflows:
|
|||||||
torch: 1.6.0
|
torch: 1.6.0
|
||||||
# Use double quotation mark to explicitly specify its type
|
# Use double quotation mark to explicitly specify its type
|
||||||
# as string instead of number
|
# as string instead of number
|
||||||
mmcv: https://download.openmmlab.com/mmcv/dev-2.x/cu101/torch1.6.0/mmcv_full-2.0.0rc0-cp37-cp37m-manylinux1_x86_64.whl
|
|
||||||
cuda: "10.1"
|
cuda: "10.1"
|
||||||
filters:
|
filters:
|
||||||
branches:
|
branches:
|
||||||
|
27
.github/workflows/lint.yml
vendored
Normal file
27
.github/workflows/lint.yml
vendored
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
name: lint
|
||||||
|
|
||||||
|
on: [push, pull_request]
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python 3.7
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: 3.7
|
||||||
|
- name: Install pre-commit hook
|
||||||
|
run: |
|
||||||
|
pip install pre-commit
|
||||||
|
pre-commit install
|
||||||
|
- name: Linting
|
||||||
|
run: pre-commit run --all-files
|
||||||
|
- name: Check docstring coverage
|
||||||
|
run: |
|
||||||
|
pip install interrogate
|
||||||
|
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-magic --ignore-regex "__repr__" --fail-under 60 mmcls
|
87
.github/workflows/pr_stage_test.yml
vendored
Normal file
87
.github/workflows/pr_stage_test.yml
vendored
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
name: pr_stage_test
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths-ignore:
|
||||||
|
- 'README.md'
|
||||||
|
- 'README_zh-CN.md'
|
||||||
|
- 'docs/**'
|
||||||
|
- 'demo/**'
|
||||||
|
- 'tools/**'
|
||||||
|
- 'configs/**'
|
||||||
|
- '.dev_scripts/**'
|
||||||
|
- '.circleci/**'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-18.04
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: [3.7]
|
||||||
|
include:
|
||||||
|
- torch: 1.8.1
|
||||||
|
torchvision: 0.9.1
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Upgrade pip
|
||||||
|
run: pip install pip --upgrade
|
||||||
|
- name: Install PyTorch
|
||||||
|
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
- name: Install mmcls dependencies
|
||||||
|
run: |
|
||||||
|
pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||||
|
pip install -U openmim
|
||||||
|
mim install 'mmcv >= 2.0.0rc1'
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Build and install
|
||||||
|
run: pip install -e .
|
||||||
|
- name: Run unittests and generate coverage report
|
||||||
|
run: |
|
||||||
|
coverage run --branch --source mmcls -m pytest tests/ -k 'not timm'
|
||||||
|
coverage xml
|
||||||
|
coverage report -m
|
||||||
|
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
|
||||||
|
- name: Upload coverage to Codecov
|
||||||
|
uses: codecov/codecov-action@v1.0.14
|
||||||
|
with:
|
||||||
|
file: ./coverage.xml
|
||||||
|
flags: unittests
|
||||||
|
env_vars: OS,PYTHON
|
||||||
|
name: codecov-umbrella
|
||||||
|
fail_ci_if_error: false
|
||||||
|
|
||||||
|
build_windows:
|
||||||
|
runs-on: windows-2022
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python: [3.7]
|
||||||
|
platform: [cu111]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Upgrade pip
|
||||||
|
run: pip install pip --upgrade
|
||||||
|
- name: Install PyTorch
|
||||||
|
run: pip install torch==1.8.2+${{matrix.platform}} torchvision==0.9.2+${{matrix.platform}} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
|
||||||
|
- name: Install mmcls dependencies
|
||||||
|
run: |
|
||||||
|
pip install git+https://github.com/open-mmlab/mmengine.git@main
|
||||||
|
pip install -U openmim
|
||||||
|
mim install 'mmcv >= 2.0.0rc1'
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Build and install
|
||||||
|
run: pip install -e .
|
||||||
|
- name: Run unittests
|
||||||
|
run: |
|
||||||
|
pytest tests/ -k 'not timm' --ignore tests/test_models/test_backbones
|
22
.github/workflows/publish-to-pypi.yml
vendored
Normal file
22
.github/workflows/publish-to-pypi.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
name: deploy
|
||||||
|
|
||||||
|
on: push
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-n-publish:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: startsWith(github.event.ref, 'refs/tags')
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python 3.7
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: 3.7
|
||||||
|
- name: Build MMClassification
|
||||||
|
run: |
|
||||||
|
pip install wheel
|
||||||
|
python setup.py sdist bdist_wheel
|
||||||
|
- name: Publish distribution to PyPI
|
||||||
|
run: |
|
||||||
|
pip install twine
|
||||||
|
twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
|
44
.github/workflows/test_mim.yml
vendored
Normal file
44
.github/workflows/test_mim.yml
vendored
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
name: test-mim
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- 'model-index.yml'
|
||||||
|
- 'configs/**'
|
||||||
|
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'model-index.yml'
|
||||||
|
- 'configs/**'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build_cpu:
|
||||||
|
runs-on: ubuntu-18.04
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: [3.7]
|
||||||
|
torch: [1.8.0]
|
||||||
|
include:
|
||||||
|
- torch: 1.8.0
|
||||||
|
torch_version: torch1.8
|
||||||
|
torchvision: 0.9.0
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Upgrade pip
|
||||||
|
run: pip install pip --upgrade
|
||||||
|
- name: Install PyTorch
|
||||||
|
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
- name: Install openmim
|
||||||
|
run: pip install openmim
|
||||||
|
- name: Build and install
|
||||||
|
run: mim install -e .
|
||||||
|
- name: test commands of mim
|
||||||
|
run: mim search mmcls
|
@ -504,7 +504,8 @@ class RandomErasing(BaseTransform):
|
|||||||
'aspect_range should be positive.'
|
'aspect_range should be positive.'
|
||||||
assert aspect_range[0] <= aspect_range[1], \
|
assert aspect_range[0] <= aspect_range[1], \
|
||||||
'In aspect_range (min, max), min should be smaller than max.'
|
'In aspect_range (min, max), min should be smaller than max.'
|
||||||
assert mode in ['const', 'rand']
|
assert mode in ['const', 'rand'], \
|
||||||
|
'Please select `mode` from ["const", "rand"].'
|
||||||
if isinstance(fill_color, Number):
|
if isinstance(fill_color, Number):
|
||||||
fill_color = [fill_color] * 3
|
fill_color = [fill_color] * 3
|
||||||
assert isinstance(fill_color, Sequence) and len(fill_color) == 3 \
|
assert isinstance(fill_color, Sequence) and len(fill_color) == 3 \
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import pickle
|
import pickle
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
@ -141,12 +142,12 @@ class TestCustomDataset(TestBaseDataset):
|
|||||||
self.assertEqual(dataset.CLASSES, ('a', 'b')) # auto infer classes
|
self.assertEqual(dataset.CLASSES, ('a', 'b')) # auto infer classes
|
||||||
self.assertGreaterEqual(
|
self.assertGreaterEqual(
|
||||||
dataset.get_data_info(0).items(), {
|
dataset.get_data_info(0).items(), {
|
||||||
'img_path': osp.join(ASSETS_ROOT, 'a/1.JPG'),
|
'img_path': osp.join(ASSETS_ROOT, 'a', '1.JPG'),
|
||||||
'gt_label': 0
|
'gt_label': 0
|
||||||
}.items())
|
}.items())
|
||||||
self.assertGreaterEqual(
|
self.assertGreaterEqual(
|
||||||
dataset.get_data_info(2).items(), {
|
dataset.get_data_info(2).items(), {
|
||||||
'img_path': osp.join(ASSETS_ROOT, 'b/subb/3.jpg'),
|
'img_path': osp.join(ASSETS_ROOT, 'b', 'subb', '3.jpg'),
|
||||||
'gt_label': 1
|
'gt_label': 1
|
||||||
}.items())
|
}.items())
|
||||||
|
|
||||||
@ -225,7 +226,7 @@ class TestCustomDataset(TestBaseDataset):
|
|||||||
self.assertEqual(len(dataset), 1)
|
self.assertEqual(len(dataset), 1)
|
||||||
self.assertGreaterEqual(
|
self.assertGreaterEqual(
|
||||||
dataset.get_data_info(0).items(), {
|
dataset.get_data_info(0).items(), {
|
||||||
'img_path': osp.join(ASSETS_ROOT, 'b/2.jpeg'),
|
'img_path': osp.join(ASSETS_ROOT, 'b', '2.jpeg'),
|
||||||
'gt_label': 1
|
'gt_label': 1
|
||||||
}.items())
|
}.items())
|
||||||
|
|
||||||
@ -631,12 +632,12 @@ class TestVOC(TestBaseDataset):
|
|||||||
# Test different backend
|
# Test different backend
|
||||||
cfg = {
|
cfg = {
|
||||||
**self.DEFAULT_ARGS, 'lazy_init': True,
|
**self.DEFAULT_ARGS, 'lazy_init': True,
|
||||||
'data_root': 's3:/openmmlab/voc'
|
'data_root': 's3://openmmlab/voc'
|
||||||
}
|
}
|
||||||
|
petrel_mock = MagicMock()
|
||||||
|
sys.modules['petrel_client'] = petrel_mock
|
||||||
dataset = dataset_class(**cfg)
|
dataset = dataset_class(**cfg)
|
||||||
dataset._check_integrity = MagicMock(return_value=False)
|
petrel_mock.client.Client.assert_called()
|
||||||
with self.assertRaisesRegex(FileNotFoundError, 's3:/openmmlab/voc'):
|
|
||||||
dataset.full_init()
|
|
||||||
|
|
||||||
def test_extra_repr(self):
|
def test_extra_repr(self):
|
||||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
import shutil
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from mmengine.logging import MMLogger
|
||||||
from mmengine.model import BaseDataPreprocessor, BaseModel
|
from mmengine.model import BaseDataPreprocessor, BaseModel
|
||||||
from mmengine.runner import Runner
|
from mmengine.runner import Runner
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
@ -115,7 +116,7 @@ class TestPreciseBNHookHook(TestCase):
|
|||||||
)
|
)
|
||||||
self.epoch_train_cfg = dict(by_epoch=True, max_epochs=1)
|
self.epoch_train_cfg = dict(by_epoch=True, max_epochs=1)
|
||||||
self.iter_train_cfg = dict(by_epoch=False, max_iters=5)
|
self.iter_train_cfg = dict(by_epoch=False, max_iters=5)
|
||||||
self.tmpdir = tempfile.mkdtemp()
|
self.tmpdir = tempfile.TemporaryDirectory()
|
||||||
self.preciseBN_cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
self.preciseBN_cfg = copy.deepcopy(self.DEFAULT_ARGS)
|
||||||
|
|
||||||
test_dataset = ExampleDataset()
|
test_dataset = ExampleDataset()
|
||||||
@ -125,7 +126,7 @@ class TestPreciseBNHookHook(TestCase):
|
|||||||
def test_construct(self):
|
def test_construct(self):
|
||||||
self.runner = Runner(
|
self.runner = Runner(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
work_dir=self.tmpdir,
|
work_dir=self.tmpdir.name,
|
||||||
train_dataloader=self.loader,
|
train_dataloader=self.loader,
|
||||||
train_cfg=self.epoch_train_cfg,
|
train_cfg=self.epoch_train_cfg,
|
||||||
log_level='WARNING',
|
log_level='WARNING',
|
||||||
@ -160,7 +161,7 @@ class TestPreciseBNHookHook(TestCase):
|
|||||||
self.preciseBN_cfg['priority'] = 'ABOVE_NORMAL'
|
self.preciseBN_cfg['priority'] = 'ABOVE_NORMAL'
|
||||||
self.runner = Runner(
|
self.runner = Runner(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
work_dir=self.tmpdir,
|
work_dir=self.tmpdir.name,
|
||||||
train_dataloader=self.loader,
|
train_dataloader=self.loader,
|
||||||
train_cfg=self.epoch_train_cfg,
|
train_cfg=self.epoch_train_cfg,
|
||||||
log_level='WARNING',
|
log_level='WARNING',
|
||||||
@ -176,7 +177,7 @@ class TestPreciseBNHookHook(TestCase):
|
|||||||
self.preciseBN_cfg['priority'] = 'ABOVE_NORMAL'
|
self.preciseBN_cfg['priority'] = 'ABOVE_NORMAL'
|
||||||
self.runner = Runner(
|
self.runner = Runner(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
work_dir=self.tmpdir,
|
work_dir=self.tmpdir.name,
|
||||||
train_dataloader=self.loader,
|
train_dataloader=self.loader,
|
||||||
train_cfg=self.epoch_train_cfg,
|
train_cfg=self.epoch_train_cfg,
|
||||||
log_level='WARNING',
|
log_level='WARNING',
|
||||||
@ -213,7 +214,7 @@ class TestPreciseBNHookHook(TestCase):
|
|||||||
self.loader = DataLoader(test_dataset, batch_size=2)
|
self.loader = DataLoader(test_dataset, batch_size=2)
|
||||||
self.runner = Runner(
|
self.runner = Runner(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
work_dir=self.tmpdir,
|
work_dir=self.tmpdir.name,
|
||||||
train_dataloader=self.loader,
|
train_dataloader=self.loader,
|
||||||
train_cfg=self.iter_train_cfg,
|
train_cfg=self.iter_train_cfg,
|
||||||
log_level='WARNING',
|
log_level='WARNING',
|
||||||
@ -226,4 +227,8 @@ class TestPreciseBNHookHook(TestCase):
|
|||||||
self.runner.train()
|
self.runner.train()
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
def tearDown(self) -> None:
|
||||||
shutil.rmtree(self.tmpdir)
|
# `FileHandler` should be closed in Windows, otherwise we cannot
|
||||||
|
# delete the temporary directory.
|
||||||
|
logging.shutdown()
|
||||||
|
MMLogger._instance_dict.clear()
|
||||||
|
self.tmpdir.cleanup()
|
||||||
|
@ -25,6 +25,7 @@ def check_norm_state(modules, train_state):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad() # To save memory
|
||||||
def test_conformer_backbone():
|
def test_conformer_backbone():
|
||||||
|
|
||||||
cfg_ori = dict(
|
cfg_ori = dict(
|
||||||
|
@ -18,6 +18,7 @@ def test_assertion():
|
|||||||
ConvMixer(out_indices=-100)
|
ConvMixer(out_indices=-100)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad() # To save memory
|
||||||
def test_convmixer():
|
def test_convmixer():
|
||||||
|
|
||||||
# Test forward
|
# Test forward
|
||||||
|
Loading…
x
Reference in New Issue
Block a user