Merge remote-tracking branch 'origin/dev'

pull/953/head v0.23.2
mzr1996 2022-07-28 14:15:52 +08:00
commit 71ef7bae85
19 changed files with 220 additions and 98 deletions

44
.github/workflows/test-mim.yml vendored 100644
View File

@ -0,0 +1,44 @@
name: test-mim
on:
push:
paths:
- 'model-index.yml'
- 'configs/**'
pull_request:
paths:
- 'model-index.yml'
- 'configs/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build_cpu:
runs-on: ubuntu-18.04
strategy:
matrix:
python-version: [3.7]
torch: [1.8.0]
include:
- torch: 1.8.0
torch_version: torch1.8
torchvision: 0.9.0
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install openmim
run: pip install openmim
- name: Build and install
run: rm -rf .eggs && mim install -e .
- name: test commands of mim
run: mim search mmcls

View File

@ -298,44 +298,6 @@ Models:
Top 5 Accuracy: 93.80
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth
Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py
- Name: wide-resnet50_3rdparty_8xb32_in1k
Metadata:
FLOPs: 11440000000 # 11.44G
Parameters: 68880000 # 68.88M
Training Techniques:
- SGD with Momentum
- Weight Decay
In Collection: ResNet
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.48
Top 5 Accuracy: 94.08
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet50_3rdparty_8xb32_in1k_20220304-66678344.pth
Config: configs/resnet/wide-resnet50_8xb32_in1k.py
Converted From:
Weights: https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
- Name: wide-resnet101_3rdparty_8xb32_in1k
Metadata:
FLOPs: 22810000000 # 22.81G
Parameters: 126890000 # 126.89M
Training Techniques:
- SGD with Momentum
- Weight Decay
In Collection: ResNet
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.84
Top 5 Accuracy: 94.28
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/wide-resnet101_3rdparty_8xb32_in1k_20220304-8d5f9d61.pth
Config: configs/resnet/wide-resnet101_8xb32_in1k.py
Converted From:
Weights: https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth
Code: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
- Name: resnetv1c50_8xb32_in1k
Metadata:
FLOPs: 4360000000

View File

@ -4,7 +4,7 @@ ARG CUDNN="7"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
ARG MMCV="1.4.2"
ARG MMCLS="0.23.1"
ARG MMCLS="0.23.2"
ENV PYTHONUNBUFFERED TRUE

View File

@ -1,5 +1,15 @@
# Changelog
## v0.23.2(28/7/2022)
### New Features
- Support MPS device. ([#894](https://github.com/open-mmlab/mmclassification/pull/894))
### Bug Fixes
- Fix a bug in Albu which caused crashing. ([#918](https://github.com/open-mmlab/mmclassification/pull/918))
## v0.23.1(2/6/2022)
### New Features

View File

@ -17,8 +17,8 @@ and make sure you fill in all required information in the template.
| MMClassification version | MMCV version |
| :----------------------: | :--------------------: |
| dev | mmcv>=1.5.0, \<1.6.0 |
| 0.23.1 (master) | mmcv>=1.4.2, \<1.6.0 |
| dev | mmcv>=1.6.0, \<1.7.0 |
| 0.23.2 (master) | mmcv>=1.4.2, \<1.7.0 |
| 0.22.1 | mmcv>=1.4.2, \<1.6.0 |
| 0.21.0 | mmcv>=1.4.2, \<=1.5.0 |
| 0.20.1 | mmcv>=1.4.2, \<=1.5.0 |

View File

@ -15,8 +15,8 @@
| MMClassification version | MMCV version |
| :----------------------: | :--------------------: |
| dev | mmcv>=1.5.0, \<1.6.0 |
| 0.23.1 (master) | mmcv>=1.4.2, \<1.6.0 |
| dev | mmcv>=1.6.0, \<1.7.0 |
| 0.23.2 (master) | mmcv>=1.4.2, \<1.6.0 |
| 0.22.1 | mmcv>=1.4.2, \<1.6.0 |
| 0.21.0 | mmcv>=1.4.2, \<=1.5.0 |
| 0.20.1 | mmcv>=1.4.2, \<=1.5.0 |

View File

@ -48,7 +48,7 @@ def digit_version(version_str: str, length: int = 4):
mmcv_minimum_version = '1.4.2'
mmcv_maximum_version = '1.6.0'
mmcv_maximum_version = '1.7.0'
mmcv_version = digit_version(mmcv.__version__)

View File

@ -5,13 +5,13 @@ import warnings
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
build_optimizer, build_runner, get_dist_info)
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import get_root_logger
from mmcls.utils import (get_root_logger, wrap_distributed_model,
wrap_non_distributed_model)
def init_random_seed(seed=None, device='cuda'):
@ -128,27 +128,15 @@ def train_model(model,
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
model = wrap_distributed_model(
model,
cfg.device,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if device == 'cpu':
warnings.warn(
'The argument `device` is deprecated. To use cpu to train, '
'please refers to https://mmclassification.readthedocs.io/en'
'/latest/getting_started.html#train-a-model')
model = model.cpu()
elif device == 'ipu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if not model.device_ids:
from mmcv import __version__, digit_version
assert digit_version(__version__) >= (1, 4, 4), \
'To train with CPU, please confirm your mmcv version ' \
'is not lower than v1.4.4'
model = wrap_non_distributed_model(
model, cfg.device, device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import math
import random
@ -1117,19 +1118,23 @@ class Albu(object):
return updated_dict
def __call__(self, results):
# backup gt_label in case Albu modify it.
_gt_label = copy.deepcopy(results.get('gt_label', None))
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)
# process aug
results = self.aug(**results)
if 'gt_labels' in results:
if isinstance(results['gt_labels'], list):
results['gt_labels'] = np.array(results['gt_labels'])
results['gt_labels'] = results['gt_labels'].astype(np.int64)
# back to the original format
results = self.mapper(results, self.keymap_back)
if _gt_label is not None:
# recover backup gt_label
results.update({'gt_label': _gt_label})
# update final shape
if self.update_pad_shape:
results['pad_shape'] = results['img'].shape

View File

@ -1,8 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .device import auto_select_device
from .distribution import wrap_distributed_model, wrap_non_distributed_model
from .logger import get_root_logger, load_json_log
from .setup_env import setup_multi_processes
__all__ = [
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes'
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes',
'wrap_non_distributed_model', 'wrap_distributed_model',
'auto_select_device'
]

View File

@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from mmcv.utils import digit_version
def auto_select_device() -> str:
mmcv_version = digit_version(mmcv.__version__)
if mmcv_version >= digit_version('1.6.0'):
from mmcv.device import get_device
return get_device()
elif torch.cuda.is_available():
return 'cuda'
else:
return 'cpu'

View File

@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
"""Wrap module in non-distributed environment by device type.
- For CUDA, wrap as :obj:`mmcv.parallel.MMDataParallel`.
- For MPS, wrap as :obj:`mmcv.device.mps.MPSDataParallel`.
- For CPU & IPU, not wrap the model.
Args:
model(:class:`nn.Module`): model to be parallelized.
device(str): device type, cuda, cpu or mlu. Defaults to cuda.
dim(int): Dimension used to scatter the data. Defaults to 0.
Returns:
model(nn.Module): the model to be parallelized.
"""
if device == 'cuda':
from mmcv.parallel import MMDataParallel
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
elif device == 'cpu':
model = model.cpu()
elif device == 'ipu':
model = model.cpu()
elif device == 'mps':
from mmcv.device import mps
model = mps.MPSDataParallel(model.to('mps'), dim=dim, *args, **kwargs)
else:
raise RuntimeError(f'Unavailable device "{device}"')
return model
def wrap_distributed_model(model, device='cuda', *args, **kwargs):
"""Build DistributedDataParallel module by device type.
- For CUDA, wrap as :obj:`mmcv.parallel.MMDistributedDataParallel`.
- Other device types are not supported by now.
Args:
model(:class:`nn.Module`): module to be parallelized.
device(str): device type, mlu or cuda.
Returns:
model(:class:`nn.Module`): the module to be parallelized
References:
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
if device == 'cuda':
from mmcv.parallel import MMDistributedDataParallel
model = MMDistributedDataParallel(model.cuda(), *args, **kwargs)
else:
raise RuntimeError(f'Unavailable device "{device}"')
return model

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved
__version__ = '0.23.1'
__version__ = '0.23.2'
def parse_version_info(version_str):

View File

@ -1 +1 @@
mmcv-full>=1.4.2,<1.6.0
mmcv-full>=1.4.2,<1.7.0

View File

@ -1268,14 +1268,25 @@ def test_lighting():
def test_albu_transform():
results = dict(
img_prefix=osp.join(osp.dirname(__file__), '../../data'),
img_info=dict(filename='color.jpg'))
img_info=dict(filename='color.jpg'),
gt_label=np.array(1))
# Define simple pipeline
load = dict(type='LoadImageFromFile')
load = build_from_cfg(load, PIPELINES)
albu_transform = dict(
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
type='Albu',
transforms=[
dict(type='ChannelShuffle', p=1),
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=1)
])
albu_transform = build_from_cfg(albu_transform, PIPELINES)
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
@ -1287,3 +1298,4 @@ def test_albu_transform():
results = normalize(results)
assert results['img'].dtype == np.float32
assert results['gt_label'].shape == np.array(1).shape

View File

@ -60,7 +60,8 @@ def test_timm_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size((1, 192))
# Disable the test since TIMM's behavior changes between 0.5.4 and 0.5.5
# assert feat[0].shape == torch.Size((1, 197, 192))
def test_timm_backbone_features_only():

View File

@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import patch
import mmcv
from mmcls.utils import auto_select_device
class TestAutoSelectDevice(TestCase):
@patch.object(mmcv, '__version__', '1.6.0')
@patch('mmcv.device.get_device', create=True)
def test_mmcv(self, mock):
auto_select_device()
mock.assert_called_once()
@patch.object(mmcv, '__version__', '1.5.0')
@patch('torch.cuda.is_available', return_value=True)
def test_cuda(self, mock):
device = auto_select_device()
self.assertEqual(device, 'cuda')
@patch.object(mmcv, '__version__', '1.5.0')
@patch('torch.cuda.is_available', return_value=False)
def test_cpu(self, mock):
device = auto_select_device()
self.assertEqual(device, 'cpu')

View File

@ -8,14 +8,15 @@ import mmcv
import numpy as np
import torch
from mmcv import DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmcls.apis import multi_gpu_test, single_gpu_test
from mmcls.datasets import build_dataloader, build_dataset
from mmcls.models import build_classifier
from mmcls.utils import get_root_logger, setup_multi_processes
from mmcls.utils import (auto_select_device, get_root_logger,
setup_multi_processes, wrap_distributed_model,
wrap_non_distributed_model)
def parse_args():
@ -92,11 +93,7 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--device',
choices=['cpu', 'cuda', 'ipu'],
default='cuda',
help='device used for testing')
parser.add_argument('--device', help='device used for testing')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
@ -130,6 +127,7 @@ def main():
'in `gpu_ids` now.')
else:
cfg.gpu_ids = [args.gpu_id]
cfg.device = args.device or auto_select_device()
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
@ -144,7 +142,7 @@ def main():
# The default loader config
loader_cfg = dict(
# cfg.gpus will be ignored if distributed
num_gpus=1 if args.device == 'ipu' else len(cfg.gpu_ids),
num_gpus=1 if cfg.device == 'ipu' else len(cfg.gpu_ids),
dist=distributed,
round_up=True,
)
@ -182,29 +180,24 @@ def main():
CLASSES = ImageNet.CLASSES
if not distributed:
if args.device == 'cpu':
model = model.cpu()
elif args.device == 'ipu':
model = wrap_non_distributed_model(
model, device=cfg.device, device_ids=cfg.gpu_ids)
if cfg.device == 'ipu':
from mmcv.device.ipu import cfg2options, ipu_model_wrapper
opts = cfg2options(cfg.runner.get('options_cfg', {}))
if fp16_cfg is not None:
model.half()
model = ipu_model_wrapper(model, opts, fp16_cfg=fp16_cfg)
data_loader.init(opts['inference'])
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if not model.device_ids:
assert mmcv.digit_version(mmcv.__version__) >= (1, 4, 4), \
'To test with CPU, please confirm your mmcv version ' \
'is not lower than v1.4.4'
model.CLASSES = CLASSES
show_kwargs = {} if args.show_options is None else args.show_options
show_kwargs = args.show_options or {}
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
**show_kwargs)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
model = wrap_distributed_model(
model,
device=cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)

View File

@ -16,7 +16,8 @@ from mmcls import __version__
from mmcls.apis import init_random_seed, set_random_seed, train_model
from mmcls.datasets import build_dataset
from mmcls.models import build_classifier
from mmcls.utils import collect_env, get_root_logger, setup_multi_processes
from mmcls.utils import (auto_select_device, collect_env, get_root_logger,
setup_multi_processes)
def parse_args():
@ -162,7 +163,8 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
seed = init_random_seed(args.seed)
cfg.device = args.device or auto_select_device()
seed = init_random_seed(args.seed, device=cfg.device)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
@ -195,7 +197,7 @@ def main():
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
device=args.device,
device=cfg.device,
meta=meta)