Compare commits
3 Commits
58ab028f97
...
72b0da8bd7
Author | SHA1 | Date |
---|---|---|
|
72b0da8bd7 | |
|
d232912391 | |
|
abd7001bd9 |
|
@ -13,7 +13,8 @@ RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build
|
|||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install MMCV
|
||||
RUN pip install mmcv-full==1.3.16 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.1/index.html
|
||||
RUN pip install openmim
|
||||
RUN mim install mmcv-full
|
||||
|
||||
# Install MMClassification
|
||||
RUN conda clean --all
|
||||
|
|
|
@ -118,8 +118,8 @@ number).
|
|||
We provide a [Dockerfile](https://github.com/open-mmlab/mmclassification/blob/master/docker/Dockerfile) to build an image.
|
||||
|
||||
```shell
|
||||
# build an image with PyTorch 1.6.0, CUDA 10.1, CUDNN 7.
|
||||
docker build -f ./docker/Dockerfile --rm -t mmcls:torch1.6.0-cuda10.1-cudnn7 .
|
||||
# build an image with PyTorch 1.8.1, CUDA 10.2, CUDNN 7 and MMCV-full latest version released.
|
||||
docker build -f ./docker/Dockerfile --rm -t mmcls:latest .
|
||||
```
|
||||
|
||||
```{important}
|
||||
|
@ -129,7 +129,7 @@ Make sure you've installed the [nvidia-container-toolkit](https://docs.nvidia.co
|
|||
Run a container built from mmcls image with command:
|
||||
|
||||
```shell
|
||||
docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/workspace/mmclassification/data mmcls:torch1.6.0-cuda10.1-cudnn7 /bin/bash
|
||||
docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/workspace/mmclassification/data mmcls:latest /bin/bash
|
||||
```
|
||||
|
||||
## Using multiple MMClassification versions
|
||||
|
|
|
@ -110,8 +110,8 @@ pip install -e . # 或者 "python setup.py develop"
|
|||
MMClassification 提供 [Dockerfile](https://github.com/open-mmlab/mmclassification/blob/master/docker/Dockerfile) ,可以通过以下命令创建 docker 镜像。
|
||||
|
||||
```shell
|
||||
# 创建基于 PyTorch 1.6.0, CUDA 10.1, CUDNN 7 的镜像。
|
||||
docker build -f ./docker/Dockerfile --rm -t mmcls:torch1.6.0-cuda10.1-cudnn7 .
|
||||
# 创建基于 PyTorch 1.8.1, CUDA 10.2, CUDNN 7 以及最近版本的 MMCV-full 的镜像 。
|
||||
docker build -f ./docker/Dockerfile --rm -t mmcls:latest .
|
||||
```
|
||||
|
||||
```{important}
|
||||
|
@ -121,7 +121,7 @@ docker build -f ./docker/Dockerfile --rm -t mmcls:torch1.6.0-cuda10.1-cudnn7 .
|
|||
运行一个基于上述镜像的容器:
|
||||
|
||||
```shell
|
||||
docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/workspace/mmclassification/data mmcls:torch1.6.0-cuda10.1-cudnn7 /bin/bash
|
||||
docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/workspace/mmclassification/data mmcls:latest /bin/bash
|
||||
```
|
||||
|
||||
## 在多个 MMClassification 版本下进行开发
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .class_num_check_hook import ClassNumCheckHook
|
||||
|
||||
__all__ = ['ClassNumCheckHook']
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved
|
||||
from mmcv.runner import IterBasedRunner
|
||||
from mmcv.runner.hooks import HOOKS, Hook
|
||||
from mmcv.utils import is_seq_of
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ClassNumCheckHook(Hook):
|
||||
|
||||
def _check_head(self, runner, dataset):
|
||||
"""Check whether the `num_classes` in head matches the length of
|
||||
`CLASSES` in `dataset`.
|
||||
|
||||
Args:
|
||||
runner (obj:`EpochBasedRunner`, `IterBasedRunner`): runner object.
|
||||
dataset (obj: `BaseDataset`): the dataset to check.
|
||||
"""
|
||||
model = runner.model
|
||||
if dataset.CLASSES is None:
|
||||
runner.logger.warning(
|
||||
f'Please set `CLASSES` '
|
||||
f'in the {dataset.__class__.__name__} and'
|
||||
f'check if it is consistent with the `num_classes` '
|
||||
f'of head')
|
||||
else:
|
||||
assert is_seq_of(dataset.CLASSES, str), \
|
||||
(f'`CLASSES` in {dataset.__class__.__name__}'
|
||||
f'should be a tuple of str.')
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'num_classes'):
|
||||
assert module.num_classes == len(dataset.CLASSES), \
|
||||
(f'The `num_classes` ({module.num_classes}) in '
|
||||
f'{module.__class__.__name__} of '
|
||||
f'{model.__class__.__name__} does not matches '
|
||||
f'the length of `CLASSES` '
|
||||
f'{len(dataset.CLASSES)}) in '
|
||||
f'{dataset.__class__.__name__}')
|
||||
|
||||
def before_train_iter(self, runner):
|
||||
"""Check whether the training dataset is compatible with head.
|
||||
|
||||
Args:
|
||||
runner (obj: `IterBasedRunner`): Iter based Runner.
|
||||
"""
|
||||
if not isinstance(runner, IterBasedRunner):
|
||||
return
|
||||
self._check_head(runner, runner.data_loader._dataloader.dataset)
|
||||
|
||||
def before_val_iter(self, runner):
|
||||
"""Check whether the eval dataset is compatible with head.
|
||||
|
||||
Args:
|
||||
runner (obj:`IterBasedRunner`): Iter based Runner.
|
||||
"""
|
||||
if not isinstance(runner, IterBasedRunner):
|
||||
return
|
||||
self._check_head(runner, runner.data_loader._dataloader.dataset)
|
||||
|
||||
def before_train_epoch(self, runner):
|
||||
"""Check whether the training dataset is compatible with head.
|
||||
|
||||
Args:
|
||||
runner (obj:`EpochBasedRunner`): Epoch based Runner.
|
||||
"""
|
||||
self._check_head(runner, runner.data_loader.dataset)
|
||||
|
||||
def before_val_epoch(self, runner):
|
||||
"""Check whether the eval dataset is compatible with head.
|
||||
|
||||
Args:
|
||||
runner (obj:`EpochBasedRunner`): Epoch based Runner.
|
||||
"""
|
||||
self._check_head(runner, runner.data_loader.dataset)
|
16
setup.py
16
setup.py
|
@ -131,8 +131,20 @@ def add_mim_extension():
|
|||
|
||||
if mode == 'symlink':
|
||||
src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
|
||||
os.symlink(src_relpath, tar_path)
|
||||
elif mode == 'copy':
|
||||
try:
|
||||
os.symlink(src_relpath, tar_path)
|
||||
except OSError:
|
||||
# Creating a symbolic link on windows may raise an
|
||||
# `OSError: [WinError 1314]` due to privilege. If
|
||||
# the error happens, the src file will be copied
|
||||
mode = 'copy'
|
||||
warnings.warn(
|
||||
f'Failed to create a symbolic link for {src_relpath}, '
|
||||
f'and it will be copied to {tar_path}')
|
||||
else:
|
||||
continue
|
||||
|
||||
if mode == 'copy':
|
||||
if osp.isfile(src_path):
|
||||
shutil.copyfile(src_path, tar_path)
|
||||
elif osp.isdir(src_path):
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
import logging
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import mmcv.runner as mmcv_runner
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.runner import obj_from_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmcls.core.hook import ClassNumCheckHook
|
||||
from mmcls.models.heads.base_head import BaseHead
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __init__(self, CLASSES):
|
||||
self.CLASSES = CLASSES
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleHead(BaseHead):
|
||||
|
||||
def __init__(self, init_cfg=None):
|
||||
super(BaseHead, self).__init__(init_cfg)
|
||||
self.num_classes = 4
|
||||
|
||||
def forward_train(self, x, gt_label=None, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ExampleModel(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
self.head = ExampleHead()
|
||||
|
||||
def forward(self, img, img_metas, test_mode=False, **kwargs):
|
||||
return img
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('runner_type',
|
||||
['EpochBasedRunner', 'IterBasedRunner'])
|
||||
@pytest.mark.parametrize(
|
||||
'CLASSES', [None, ('A', 'B', 'C', 'D', 'E'), ('A', 'B', 'C', 'D')])
|
||||
def test_num_class_hook(runner_type, CLASSES):
|
||||
test_dataset = ExampleDataset(CLASSES)
|
||||
loader = DataLoader(test_dataset, batch_size=1)
|
||||
model = ExampleModel()
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
num_class_hook = ClassNumCheckHook()
|
||||
logger_mock = MagicMock(spec=logging.Logger)
|
||||
runner = getattr(mmcv_runner, runner_type)(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logger_mock,
|
||||
max_epochs=1)
|
||||
runner.register_hook(num_class_hook)
|
||||
if CLASSES is None:
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
logger_mock.warning.assert_called()
|
||||
elif len(CLASSES) != 4:
|
||||
with pytest.raises(AssertionError):
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
else:
|
||||
runner.run([loader], [('train', 1)], 1)
|
Loading…
Reference in New Issue