[Refactor] Use `mmengine` instead of `mmcv` and refactor some transforms. (#986)

* [Refactor] Refactor the rest data transforms and fix some docstring.

* Use `mmengine` instead of `mmcv`

* Refactor KFold dataset tools

* Fix docstring according to comments
pull/917/merge
Ma Zerun 2022-08-24 15:59:02 +08:00 committed by GitHub
parent 5665b8349a
commit b7d0d521eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 598 additions and 763 deletions

View File

@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import numpy as np
from mmengine.dataset import BaseDataset, force_full_init
from mmcls.registry import DATASETS
@ -13,7 +16,8 @@ class KFoldDataset:
and use the fold left to do validation.
Args:
dataset (:obj:`BaseDataset`): The dataset to be divided.
dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be
divided
fold (int): The fold used to do validation. Defaults to 0.
num_splits (int): The number of all folds. Defaults to 5.
test_mode (bool): Use the training dataset or validation dataset.
@ -28,38 +32,145 @@ class KFoldDataset:
num_splits=5,
test_mode=False,
seed=None):
self.dataset = dataset
self.CLASSES = dataset.CLASSES
self.test_mode = test_mode
self.num_splits = num_splits
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
# Init the dataset wrapper lazily according to the dataset setting.
lazy_init = dataset.get('lazy_init', False)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(f'Unsupported dataset type {type(dataset)}.')
length = len(dataset)
indices = list(range(length))
if isinstance(seed, int):
rng = np.random.default_rng(seed)
self._metainfo = getattr(self.dataset, 'metainfo', {})
self.fold = fold
self.num_splits = num_splits
self.test_mode = test_mode
self.seed = seed
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def metainfo(self) -> dict:
"""Get the meta information of ``self.dataset``.
Returns:
dict: Meta information of the dataset.
"""
# Prevent `self._metainfo` from being modified by outside.
return copy.deepcopy(self._metainfo)
def full_init(self):
"""fully initialize the dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
ori_len = len(self.dataset)
indices = list(range(ori_len))
if self.seed is not None:
rng = np.random.default_rng(self.seed)
rng.shuffle(indices)
test_start = length * fold // num_splits
test_end = length * (fold + 1) // num_splits
if test_mode:
self.indices = indices[test_start:test_end]
test_start = ori_len * self.fold // self.num_splits
test_end = ori_len * (self.fold + 1) // self.num_splits
if self.test_mode:
indices = indices[test_start:test_end]
else:
self.indices = indices[:test_start] + indices[test_end:]
indices = indices[:test_start] + indices[test_end:]
def get_cat_ids(self, idx):
return self.dataset.get_cat_ids(self.indices[idx])
self._ori_indices = indices
self.dataset = self.dataset.get_subset(indices)
def get_gt_labels(self):
dataset_gt_labels = self.dataset.get_gt_labels()
gt_labels = np.array([dataset_gt_labels[idx] for idx in self.indices])
return gt_labels
self._fully_initialized = True
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
@force_full_init
def _get_ori_dataset_idx(self, idx: int) -> int:
"""Convert global idx to local index.
Args:
idx (int): Global index of ``KFoldDataset``.
Returns:
int: The original index in the whole dataset.
"""
return self._ori_indices[idx]
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``KFoldDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
return self.dataset.get_data_info(idx)
@force_full_init
def __len__(self):
return len(self.indices)
return len(self.dataset)
def evaluate(self, *args, **kwargs):
kwargs['indices'] = self.indices
return self.dataset.evaluate(*args, **kwargs)
@force_full_init
def __getitem__(self, idx):
return self.dataset[idx]
@force_full_init
def get_cat_ids(self, idx):
return self.dataset.get_cat_ids(idx)
@force_full_init
def get_gt_labels(self):
return self.dataset.get_gt_labels()
@property
def CLASSES(self):
"""Return all categories names."""
return self._metainfo.get('classes', None)
@property
def class_to_idx(self):
"""Map mapping class name to class index.
Returns:
dict: mapping from class name to class index.
"""
return {cat: i for i, cat in enumerate(self.CLASSES)}
def __repr__(self):
"""Print the basic information of the dataset.
Returns:
str: Formatted string.
"""
head = 'Dataset ' + self.__class__.__name__
body = []
type_ = 'test' if self.test_mode else 'training'
body.append(f'Type: \t{type_}')
body.append(f'Seed: \t{self.seed}')
def ordinal(n):
# Copy from https://codegolf.stackexchange.com/a/74047
suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4]
return f'{n}{suffix}'
body.append(
f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold')
if self._fully_initialized:
body.append(f'Number of samples: \t{self.__len__()}')
else:
body.append("Haven't been initialized")
if self.CLASSES is not None:
body.append(f'Number of categories: \t{len(self.CLASSES)}')
else:
body.append('The `CLASSES` meta info is not set.')
body.append(
f'Original dataset type:\t{self.dataset.__class__.__name__}')
lines = [head] + [' ' * 4 + line for line in body]
return '\n'.join(lines)

View File

@ -3,21 +3,17 @@ from .auto_augment import (AutoAugment, AutoContrast, Brightness,
ColorTransform, Contrast, Cutout, Equalize, Invert,
Posterize, RandAugment, Rotate, Sharpness, Shear,
Solarize, SolarizeAdd, Translate)
from .compose import Compose
from .formatting import (Collect, ImageToTensor, PackClsInputs, ToNumpy, ToPIL,
ToTensor, Transpose, to_tensor)
from .formatting import Collect, PackClsInputs, ToNumpy, ToPIL, Transpose
from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop,
EfficientNetRandomCrop, Lighting, Normalize, Pad,
RandomCrop, RandomErasing, RandomGrayscale,
RandomResizedCrop)
EfficientNetRandomCrop, Lighting, RandomCrop,
RandomErasing, RandomResizedCrop)
__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
'Transpose', 'Collect', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
'ToPIL', 'ToNumpy', 'Transpose', 'Collect', 'RandomCrop',
'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert',
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'Pad',
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
'PackClsInputs', 'Albumentations', 'EfficientNetRandomCrop',
'EfficientNetCenterCrop'
]

View File

@ -8,11 +8,11 @@ from typing import List, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
from mmcv import BaseTransform, RandomChoice
from mmcv.transforms import Compose
from mmcv.transforms.utils import cache_randomness
from mmengine import is_list_of, is_seq_of
from mmcls.registry import TRANSFORMS
from .compose import Compose
def merge_hparams(policy: dict, hparams: dict) -> dict:

View File

@ -1,41 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections.abc import Sequence
from mmcls.registry import TRANSFORMS
@TRANSFORMS.register_module()
class Compose(object):
"""Compose a data pipeline with a sequence of transforms.
Args:
transforms (list[dict | callable]):
Either config dicts of transforms or transform objects.
"""
def __init__(self, transforms):
assert isinstance(transforms, Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = TRANSFORMS.build(transform)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict, but got'
f' {type(transform)}')
def __call__(self, data):
for t in self.transforms:
data = t(data)
if data is None:
return None
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += f'\n {t}'
format_string += '\n)'
return format_string

View File

@ -2,11 +2,10 @@
import warnings
from collections.abc import Sequence
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from mmcv.transforms.base import BaseTransform
from mmcv.transforms import BaseTransform
from mmengine.utils import is_str
from PIL import Image
from mmcls.registry import TRANSFORMS
@ -23,7 +22,7 @@ def to_tensor(data):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
elif isinstance(data, Sequence) and not is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
@ -40,30 +39,36 @@ def to_tensor(data):
class PackClsInputs(BaseTransform):
"""Pack the inputs data for the classification.
The ``img_meta`` item is always populated. The contents of the
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
**Required Keys:**
- ``sample_idx``: id of the image sample
- img
- gt_label (optional)
- ``*meta_keys`` (optional)
- ``img_path``: path to the image file
**Deleted Keys:**
- ``ori_shape``: original shape of the image as a tuple (H, W).
All keys in the dict.
- ``img_shape``: shape of the image input to the network as a tuple
(H, W). Note that images may be zero padded on the bottom/right
if the batch tensor is larger than this shape.
**Added Keys:**
- ``scale_factor``: a float indicating the preprocessing scale
- ``flip``: a boolean indicating if image flip transform was used
- ``flip_direction``: the flipping direction
- inputs (:obj:`torch.Tensor`): The forward data of models.
- data_sample (:obj:`~mmcls.structures.ClsDataSample`): The annotation info
of the sample.
Args:
meta_keys (Sequence[str], optional): The meta keys to saved in the
meta_keys (Sequence[str]): The meta keys to be saved in the
``metainfo`` of the packed ``data_sample``.
Default: ``('sample_idx', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')``
Defaults to a tuple includes keys:
- ``sample_idx``: The id of the image sample.
- ``img_path``: The path to the image file.
- ``ori_shape``: The original shape of the image as a tuple (H, W).
- ``img_shape``: The shape of the image after the pipeline as a
tuple (H, W).
- ``scale_factor``: The scale factor between the resized image and
the original image.
- ``flip``: A boolean indicating if image flip transform was used.
- ``flip_direction``: The flipping direction.
"""
def __init__(self,
@ -72,17 +77,7 @@ class PackClsInputs(BaseTransform):
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict:
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
- 'data_sample' (obj:`ClsDataSample`): The annotation info of the
sample.
"""
"""Method to pack the input data."""
packed_results = dict()
if 'img' in results:
img = results['img']
@ -115,49 +110,28 @@ class PackClsInputs(BaseTransform):
@TRANSFORMS.register_module()
class ToTensor(object):
"""Convert objects of various python types to :obj:`torch.Tensor`."""
class Transpose(BaseTransform):
"""Transpose numpy array.
def __init__(self, keys):
self.keys = keys
**Required Keys:**
def __call__(self, results):
for key in self.keys:
results[key] = to_tensor(results[key])
return results
- ``*keys``
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
**Modified Keys:**
- ``*keys``
@TRANSFORMS.register_module()
class ImageToTensor(object):
"""Convert objects :obj:`PIL.Image` to :obj:`torch.Tensor`."""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = to_tensor(img.transpose(2, 0, 1))
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@TRANSFORMS.register_module()
class Transpose(object):
"""matrix transpose."""
Args:
keys (List[str]): The fields to convert to tensor.
order (List[int]): The output dimensions order.
"""
def __init__(self, keys, order):
self.keys = keys
self.order = order
def __call__(self, results):
def transform(self, results):
"""Method to transpose array."""
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
@ -168,114 +142,80 @@ class Transpose(object):
@TRANSFORMS.register_module()
class ToPIL(object):
"""Convert tensor to :obj:`PIL.Image`."""
class ToPIL(BaseTransform):
"""Convert the image from OpenCV format to :obj:`PIL.Image.Image`.
def __init__(self):
pass
**Required Keys:**
def __call__(self, results):
- img
**Modified Keys:**
- img
"""
def transform(self, results):
"""Method to convert images to :obj:`PIL.Image.Image`."""
results['img'] = Image.fromarray(results['img'])
return results
@TRANSFORMS.register_module()
class ToNumpy(object):
"""Convert tensor to :obj:`np.ndarray`."""
class ToNumpy(BaseTransform):
"""Convert object to :obj:`numpy.ndarray`.
def __init__(self):
pass
**Required Keys:**
def __call__(self, results):
results['img'] = np.array(results['img'], dtype=np.float32)
- ``*keys**
**Modified Keys:**
- ``*keys**
Args:
dtype (str, optional): The dtype of the converted numpy array.
Defaults to None.
"""
def __init__(self, keys, dtype=None):
self.keys = keys
self.dtype = dtype
def transform(self, results):
"""Method to convert object to :obj:`numpy.ndarray`."""
for key in self.keys:
results[key] = np.array(results[key], dtype=self.dtype)
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, dtype={self.dtype})'
@TRANSFORMS.register_module()
class Collect(object):
"""Collect data from the loader relevant to the specific task.
class Collect(BaseTransform):
"""Collect and only reserve the specified fields.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img" and "gt_label".
**Required Keys:**
- ``*keys``
**Deleted Keys:**
All keys except those in the argument ``*keys``.
Args:
keys (Sequence[str]): Keys of results to be collected in ``data``.
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ('filename', 'ori_shape', 'img_shape', 'flip',
'flip_direction', 'img_norm_cfg')
Returns:
dict: The result dict contains the following keys
- keys in ``self.keys``
- ``img_metas`` if available
keys (Sequence[str]): The keys of the fields to be collected.
"""
def __init__(self,
keys,
meta_keys=('filename', 'ori_filename', 'ori_shape',
'img_shape', 'flip', 'flip_direction',
'img_norm_cfg')):
def __init__(self, keys):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
def transform(self, results):
data = {}
img_meta = {}
for key in self.meta_keys:
if key in results:
img_meta[key] = results[key]
data['img_metas'] = DC(img_meta, cpu_only=True)
for key in self.keys:
data[key] = results[key]
return data
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, meta_keys={self.meta_keys})'
@TRANSFORMS.register_module()
class WrapFieldsToLists(object):
"""Wrap fields of the data dictionary into lists for evaluation.
This class can be used as a last step of a test or validation
pipeline for single image evaluation or inference.
Example:
>>> test_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
>>> dict(type='ImageToTensor', keys=['img']),
>>> dict(type='Collect', keys=['img']),
>>> dict(type='WrapIntoLists')
>>> ]
"""
def __call__(self, results):
# Wrap dict fields into lists
for key, val in results.items():
results[key] = [val]
return results
def __repr__(self):
return f'{self.__class__.__name__}()'
@TRANSFORMS.register_module()
class ToHalf(object):
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
for k in self.keys:
if isinstance(results[k], torch.Tensor):
results[k] = results[k].to(torch.half)
else:
results[k] = results[k].astype(np.float16)
return results
return self.__class__.__name__ + f'(keys={self.keys})'

View File

@ -2,7 +2,6 @@
import inspect
import math
import numbers
import random
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union
@ -24,25 +23,25 @@ except ImportError:
class RandomCrop(BaseTransform):
"""Crop the given Image at a random location.
Required Keys:
**Required Keys:**
- img
Modified Keys:
**Modified Keys:**
- img
- img_shape
Args:
crop_size (sequence or int): Desired output size of the crop. If
crop_size (int | Sequence): Desired output size of the crop. If
crop_size is an int instead of sequence like (h, w), a square crop
(crop_size, crop_size) is made.
padding (int or sequence, optional): Optional padding on each border
padding (int | Sequence, optional): Optional padding on each border
of the image. If a sequence of length 4 is provided, it is used to
pad left, top, right, bottom borders respectively. If a sequence
of length 2 is provided, it is used to pad left/right, top/bottom
borders, respectively. Default: None, which means no padding.
pad_if_needed (boolean): It will pad the image if smaller than the
pad_if_needed (bool): It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset.
Default: False.
@ -52,17 +51,17 @@ class RandomCrop(BaseTransform):
padding_mode (str): Type of padding. Defaults to "constant". Should
be one of the following:
- constant: Pads with a constant value, this value is specified \
with pad_val.
- edge: pads with the last value at the edge of the image.
- reflect: Pads with reflection of image without repeating the \
last value on the edge. For example, padding [1, 2, 3, 4] \
with 2 elements on both sides in reflect mode will result \
in [3, 2, 1, 2, 3, 4, 3, 2].
- symmetric: Pads with reflection of image repeating the last \
value on the edge. For example, padding [1, 2, 3, 4] with \
2 elements on both sides in symmetric mode will result in \
[2, 1, 1, 2, 3, 4, 4, 3].
- ``constant``: Pads with a constant value, this value is specified
with pad_val.
- ``edge``: pads with the last value at the edge of the image.
- ``reflect``: Pads with reflection of image without repeating the
last value on the edge. For example, padding [1, 2, 3, 4]
with 2 elements on both sides in reflect mode will result
in [3, 2, 1, 2, 3, 4, 3, 2].
- ``symmetric``: Pads with reflection of image repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with
2 elements on both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3].
"""
def __init__(self,
@ -70,7 +69,7 @@ class RandomCrop(BaseTransform):
padding: Optional[Union[Sequence, int]] = None,
pad_if_needed: bool = False,
pad_val: Union[Number, Sequence[Number]] = 0,
padding_mode: str = 'constant') -> None:
padding_mode: str = 'constant'):
if isinstance(crop_size, Sequence):
assert len(crop_size) == 2
assert crop_size[0] > 0 and crop_size[1] > 0
@ -170,11 +169,11 @@ class RandomResizedCrop(BaseTransform):
random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio
is made. This crop is finally resized to given size.
Required Keys:
**Required Keys:**
- img
Modified Keys:
**Modified Keys:**
- img
- img_shape
@ -194,7 +193,7 @@ class RandomResizedCrop(BaseTransform):
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
'bilinear'.
backend (str): The image resize backend type, accepted values are
`cv2` and `pillow`. Defaults to `cv2`.
'cv2' and 'pillow'. Defaults to 'cv2'.
"""
def __init__(self,
@ -320,11 +319,11 @@ class RandomResizedCrop(BaseTransform):
class EfficientNetRandomCrop(RandomResizedCrop):
"""EfficientNet style RandomResizedCrop.
Required Keys:
**Required Keys:**
- img
Modified Keys:
**Modified Keys:**
- img
- img_shape
@ -347,7 +346,7 @@ class EfficientNetRandomCrop(RandomResizedCrop):
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
'bicubic'.
backend (str): The image resize backend type, accepted values are
`cv2` and `pillow`. Defaults to `cv2`.
'cv2' and 'pillow'. Defaults to 'cv2'.
"""
def __init__(self,
@ -435,54 +434,18 @@ class EfficientNetRandomCrop(RandomResizedCrop):
return repr_str
@TRANSFORMS.register_module()
class RandomGrayscale(object):
"""Randomly convert image to grayscale with a probability of gray_prob.
Args:
gray_prob (float): Probability that image should be converted to
grayscale. Default: 0.1.
Returns:
ndarray: Image after randomly grayscale transform.
Notes:
- If input image is 1 channel: grayscale version is 1 channel.
- If input image is 3 channel: grayscale version is 3 channel
with r == g == b.
"""
def __init__(self, gray_prob=0.1):
self.gray_prob = gray_prob
def __call__(self, results):
"""
Args:
img (ndarray): Image to be converted to grayscale.
Returns:
ndarray: Randomly grayscaled image.
"""
for key in results.get('img_fields', ['img']):
img = results[key]
num_output_channels = img.shape[2]
if random.random() < self.gray_prob:
if num_output_channels > 1:
img = mmcv.rgb2gray(img)[:, :, None]
results[key] = np.dstack(
[img for _ in range(num_output_channels)])
return results
results[key] = img
return results
def __repr__(self):
return self.__class__.__name__ + f'(gray_prob={self.gray_prob})'
@TRANSFORMS.register_module()
class RandomErasing(BaseTransform):
"""Randomly selects a rectangle region in an image and erase pixels.
**Required Keys:**
- img
**Modified Keys:**
- img
Args:
erase_prob (float): Probability that image will be randomly erased.
Default: 0.5
@ -632,62 +595,19 @@ class RandomErasing(BaseTransform):
return repr_str
@TRANSFORMS.register_module()
class Pad(object):
"""Pad images.
Args:
size (tuple[int] | None): Expected padding size (h, w). Conflicts with
pad_to_square. Defaults to None.
pad_to_square (bool): Pad any image to square shape. Defaults to False.
pad_val (Number | Sequence[Number]): Values to be filled in padding
areas when padding_mode is 'constant'. Defaults to 0.
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Defaults to "constant".
"""
def __init__(self,
size=None,
pad_to_square=False,
pad_val=0,
padding_mode='constant'):
assert (size is None) ^ (pad_to_square is False), \
'Only one of [size, pad_to_square] should be given, ' \
f'but get {(size is not None) + (pad_to_square is not False)}'
self.size = size
self.pad_to_square = pad_to_square
self.pad_val = pad_val
self.padding_mode = padding_mode
def __call__(self, results):
for key in results.get('img_fields', ['img']):
img = results[key]
if self.pad_to_square:
target_size = tuple(
max(img.shape[0], img.shape[1]) for _ in range(2))
else:
target_size = self.size
img = mmcv.impad(
img,
shape=target_size,
pad_val=self.pad_val,
padding_mode=self.padding_mode)
results[key] = img
results['img_shape'] = img.shape
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, '
repr_str += f'(pad_val={self.pad_val}, '
repr_str += f'padding_mode={self.padding_mode})'
return repr_str
@TRANSFORMS.register_module()
class EfficientNetCenterCrop(BaseTransform):
"""EfficientNet style center crop.
**Required Keys:**
- img
**Modified Keys:**
- img
- img_shape
Args:
crop_size (int): Expected size after cropping with the format
of (h, w).
@ -781,16 +701,16 @@ class EfficientNetCenterCrop(BaseTransform):
class ResizeEdge(BaseTransform):
"""Resize images along the specified edge.
Required Keys:
**Required Keys:**
- img
Modified Keys:
**Modified Keys:**
- img
- img_shape
Added Keys:
**Added Keys:**
- scale
- scale_factor
@ -877,38 +797,6 @@ class ResizeEdge(BaseTransform):
return repr_str
@TRANSFORMS.register_module()
class Normalize(object):
"""Normalize the image.
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
"""
def __init__(self, mean, std, to_rgb=True):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def __call__(self, results):
for key in results.get('img_fields', ['img']):
results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
self.to_rgb)
results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(mean={list(self.mean)}, '
repr_str += f'std={list(self.std)}, '
repr_str += f'to_rgb={self.to_rgb})'
return repr_str
@TRANSFORMS.register_module()
class ColorJitter(BaseTransform):
"""Randomly change the brightness, contrast and saturation of an image.
@ -917,11 +805,11 @@ class ColorJitter(BaseTransform):
https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py
Licensed under the BSD 3-Clause License.
Required Keys:
**Required Keys:**
- img
Modified Keys:
**Modified Keys:**
- img
@ -1043,11 +931,11 @@ class ColorJitter(BaseTransform):
class Lighting(BaseTransform):
"""Adjust images lighting using AlexNet-style PCA jitter.
Required Keys:
**Required Keys:**
- img
Modified Keys:
**Modified Keys:**
- img
@ -1122,11 +1010,11 @@ class Lighting(BaseTransform):
class Albumentations(BaseTransform):
"""Wrapper to use augmentation from albumentations library.
Required Keys:
**Required Keys:**
- img
Modified Keys:
**Modified Keys:**
- img
- img_shape
@ -1134,37 +1022,43 @@ class Albumentations(BaseTransform):
Adds custom transformations from albumentations library.
More details can be found in
`Albumentations <https://albumentations.readthedocs.io>`_.
An example of ``transforms`` is as followed:
.. code-block::
[
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
Args:
transforms (List[Dict]): List of albumentations transform configs.
keymap (Optional[Dict]): Mapping of mmcls to albumentations fields,
in format {'input key':'albumentation-style key'}. Defaults to
None.
Example:
>>> import mmcv
>>> from mmcls.datasets import Albumentations
>>> transforms = [
... dict(
... type='ShiftScaleRotate',
... shift_limit=0.0625,
... scale_limit=0.0,
... rotate_limit=0,
... interpolation=1,
... p=0.5),
... dict(
... type='RandomBrightnessContrast',
... brightness_limit=[0.1, 0.3],
... contrast_limit=[0.1, 0.3],
... p=0.2),
... dict(type='ChannelShuffle', p=0.1),
... dict(
... type='OneOf',
... transforms=[
... dict(type='Blur', blur_limit=3, p=1.0),
... dict(type='MedianBlur', blur_limit=3, p=1.0)
... ],
... p=0.1),
... ]
>>> albu = Albumentations(transforms)
>>> data = {'img': mmcv.imread('./demo/demo.JPEG')}
>>> data = albu(data)
>>> print(data['img'].shape)
(375, 500, 3)
"""
def __init__(self, transforms: List[Dict], keymap: Optional[Dict] = None):

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved
from mmcv.utils import is_seq_of
from mmengine.hooks import Hook
from mmengine.utils import is_seq_of
from mmcls.registry import HOOKS

View File

@ -6,7 +6,6 @@ import itertools
import logging
from typing import List, Optional, Sequence, Union
import mmcv
import mmengine
import torch
import torch.nn as nn
@ -14,6 +13,7 @@ from mmengine.hooks import Hook
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop, Runner
from mmengine.utils import ProgressBar
from torch.functional import Tensor
from torch.nn import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
@ -120,7 +120,7 @@ def update_bn_stats(
bn.momentum = 1.0
# Average the BN stats for each BN layer over the batches
if rank == 0:
prog_bar = mmcv.ProgressBar(num_iter)
prog_bar = ProgressBar(num_iter)
for data in itertools.islice(loader, num_iter):
batch_inputs, data_samples = model.data_preprocessor(data, False)

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
build_norm_layer)
from mmcv.utils import digit_version
from mmengine.utils import digit_version
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -8,8 +8,8 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule, ModuleList, Sequential
from mmengine.registry import MODELS
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -3,8 +3,8 @@ import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.model import BaseModule, Sequential
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from ..utils.se_layer import SELayer

View File

@ -5,9 +5,9 @@ import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmcv.cnn.bricks import DropPath
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.model import BaseModule
from mmengine.model.utils import constant_init
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -8,9 +8,9 @@ import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.model import BaseModule, ModuleList
from mmengine.model.utils import trunc_normal_
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from ..utils import (ShiftWindowMSA, resize_pos_embed,

View File

@ -4,8 +4,8 @@ import torch.nn as nn
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.model import BaseModule, ModuleList
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.registry import MODELS
from .base_backbone import BaseBackbone

View File

@ -4,7 +4,7 @@ import warnings
from itertools import repeat
import torch
from mmcv.utils import digit_version
from mmengine.utils import digit_version
def is_tracing() -> bool:

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmengine.utils import is_tuple_of
from .make_divisible import make_divisible
@ -45,7 +45,7 @@ class SELayer(BaseModule):
if isinstance(act_cfg, dict):
act_cfg = (act_cfg, act_cfg)
assert len(act_cfg) == 2
assert mmcv.is_tuple_of(act_cfg, dict)
assert is_tuple_of(act_cfg, dict)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
if squeeze_channels is None:
squeeze_channels = make_divisible(channels // ratio, divisor)

View File

@ -3,10 +3,10 @@
from numbers import Number
from typing import Sequence, Union
import mmcv
import numpy as np
import torch
from mmengine.data import BaseDataElement, LabelData
from mmengine.utils import is_str
def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
@ -31,7 +31,7 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
if isinstance(value, np.ndarray):
value = torch.from_numpy(value).to(torch.long)
elif isinstance(value, Sequence) and not mmcv.is_str(value):
elif isinstance(value, Sequence) and not is_str(value):
value = torch.tensor(value).to(torch.long)
elif isinstance(value, int):
value = torch.LongTensor([value])

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import collect_env as collect_base_env
from mmengine.utils import collect_env as collect_base_env
from mmengine.utils import get_git_hash
import mmcls

View File

@ -3,52 +3,130 @@ import copy
import os.path as osp
import unittest
import mmcv
import numpy as np
import torch
from mmengine.data import LabelData
from PIL import Image
from mmcls.datasets.transforms import PackClsInputs
from mmcls.registry import TRANSFORMS
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules
register_all_modules()
class TestPackClsInputs(unittest.TestCase):
def setUp(self):
"""Setup the model and optimizer which are used in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
data_prefix = osp.join(osp.dirname(__file__), '../../data')
img_path = osp.join(data_prefix, 'color.jpg')
rng = np.random.RandomState(0)
self.results1 = {
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
data = {
'sample_idx': 1,
'img_path': img_path,
'ori_height': 300,
'ori_width': 400,
'height': 600,
'width': 800,
'scale_factor': 2.0,
'ori_shape': (300, 400),
'img_shape': (300, 400),
'scale_factor': 1.0,
'flip': False,
'img': rng.rand(300, 400),
'gt_label': rng.randint(3, )
'img': mmcv.imread(img_path),
'gt_label': 2,
}
self.meta_keys = ('sample_idx', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip')
def test_transform(self):
transform = PackClsInputs(meta_keys=self.meta_keys)
results = transform(copy.deepcopy(self.results1))
cfg = dict(type='PackClsInputs')
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertIn('data_sample', results)
self.assertIsInstance(results['data_sample'], ClsDataSample)
self.assertIn('flip', results['data_sample'].metainfo_keys())
self.assertIsInstance(results['data_sample'].gt_label, LabelData)
data_sample = results['data_sample']
self.assertIsInstance(data_sample.gt_label, LabelData)
# Test grayscale image
data['img'] = data['img'].mean(-1)
results = transform(copy.deepcopy(data))
self.assertIn('inputs', results)
self.assertIsInstance(results['inputs'], torch.Tensor)
self.assertEqual(results['inputs'].shape, (1, 300, 400))
# Test without `img` and `gt_label`
del data['img']
del data['gt_label']
with self.assertWarnsRegex(Warning, 'Cannot get "img"'):
results = transform(copy.deepcopy(data))
self.assertNotIn('gt_label', results['data_sample'])
def test_repr(self):
transform = PackClsInputs(meta_keys=self.meta_keys)
cfg = dict(type='PackClsInputs', meta_keys=['flip', 'img_shape'])
transform = TRANSFORMS.build(cfg)
self.assertEqual(
repr(transform), f'PackClsInputs(meta_keys={self.meta_keys})')
repr(transform), "PackClsInputs(meta_keys=['flip', 'img_shape'])")
class TestTranspose(unittest.TestCase):
def test_transform(self):
cfg = dict(type='Transpose', keys=['img'], order=[2, 0, 1])
transform = TRANSFORMS.build(cfg)
data = {'img': np.random.randint(0, 256, (224, 224, 3), dtype='uint8')}
results = transform(copy.deepcopy(data))
self.assertEqual(results['img'].shape, (3, 224, 224))
def test_repr(self):
cfg = dict(type='Transpose', keys=['img'], order=(2, 0, 1))
transform = TRANSFORMS.build(cfg)
self.assertEqual(
repr(transform), "Transpose(keys=['img'], order=(2, 0, 1))")
class TestToPIL(unittest.TestCase):
def test_transform(self):
cfg = dict(type='ToPIL')
transform = TRANSFORMS.build(cfg)
data = {'img': np.random.randint(0, 256, (224, 224, 3), dtype='uint8')}
results = transform(copy.deepcopy(data))
self.assertIsInstance(results['img'], Image.Image)
class TestToNumpy(unittest.TestCase):
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
data = {
'tensor': torch.tensor([1, 2, 3]),
'Image': Image.open(img_path),
}
cfg = dict(type='ToNumpy', keys=['tensor', 'Image'], dtype='uint8')
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIsInstance(results['tensor'], np.ndarray)
self.assertEqual(results['tensor'].dtype, 'uint8')
self.assertIsInstance(results['Image'], np.ndarray)
self.assertEqual(results['Image'].dtype, 'uint8')
def test_repr(self):
cfg = dict(type='ToNumpy', keys=['img'], dtype='uint8')
transform = TRANSFORMS.build(cfg)
self.assertEqual(repr(transform), "ToNumpy(keys=['img'], dtype=uint8)")
class TestCollect(unittest.TestCase):
def test_transform(self):
data = {'img': [1, 2, 3], 'gt_label': 1}
cfg = dict(type='Collect', keys=['img'])
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
self.assertIn('img', results)
self.assertNotIn('gt_label', results)
def test_repr(self):
cfg = dict(type='Collect', keys=['img'])
transform = TRANSFORMS.build(cfg)
self.assertEqual(repr(transform), "Collect(keys=['img'])")

View File

@ -628,7 +628,7 @@ class TestLighting(TestCase):
TRANSFORMS.build(cfg)
def test_transform(self):
ori_img = np.random.randint(0, 256, (256, 256, 3), np.uint8)
ori_img = np.ones((256, 256, 3), np.uint8) * 127
results = dict(img=copy.deepcopy(ori_img))
# Test transform with non-img-keyword result

View File

@ -5,7 +5,7 @@ from unittest import TestCase
import torch
from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import CSPDarkNet, CSPResNet, CSPResNeXt
from mmcls.models.backbones.cspnet import (CSPNet, DarknetBottleneck,

View File

@ -6,7 +6,7 @@ from copy import deepcopy
from unittest import TestCase
import torch
from mmcv.runner import load_checkpoint, save_checkpoint
from mmengine.runner import load_checkpoint, save_checkpoint
from mmcls.models.backbones import DistilledVisionTransformer
from .utils import timm_resize_pos_embed
@ -42,7 +42,7 @@ class TestDeiT(TestCase):
pretrain_pos_embed = model.pos_embed.clone().detach()
tmpdir = tempfile.gettempdir()
checkpoint = os.path.join(tmpdir, 'test.pth')
save_checkpoint(model, checkpoint)
save_checkpoint(model.state_dict(), checkpoint)
cfg = deepcopy(self.cfg)
model = DistilledVisionTransformer(**cfg)
load_checkpoint(model, checkpoint, strict=True)

View File

@ -5,7 +5,7 @@ from copy import deepcopy
from unittest import TestCase
import torch
from mmcv.runner import load_checkpoint, save_checkpoint
from mmengine.runner import load_checkpoint, save_checkpoint
from mmcls.models.backbones import RepMLPNet
@ -163,7 +163,7 @@ class TestRepMLP(TestCase):
cfg['deploy'] = True
model_deploy = RepMLPNet(**cfg)
model_deploy.eval()
save_checkpoint(model, self.ckpt_path)
save_checkpoint(model.state_dict(), self.ckpt_path)
load_checkpoint(model_deploy, self.ckpt_path, strict=True)
feats__ = model_deploy(imgs)

View File

@ -4,7 +4,7 @@ import tempfile
import pytest
import torch
from mmcv.runner import load_checkpoint, save_checkpoint
from mmengine.runner import load_checkpoint, save_checkpoint
from torch import nn
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
@ -286,7 +286,7 @@ def test_repvgg_load():
outputs = model(inputs)
model_deploy = RepVGG('A1', out_indices=(0, 1, 2, 3), deploy=True)
save_checkpoint(model, ckpt_path)
save_checkpoint(model.state_dict(), ckpt_path)
load_checkpoint(model_deploy, ckpt_path, strict=True)
outputs_load = model_deploy(inputs)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import Res2Net

View File

@ -3,7 +3,7 @@ import pytest
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import ResNet, ResNetV1c, ResNetV1d
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import ResNet_CIFAR

View File

@ -7,8 +7,8 @@ from itertools import chain
from unittest import TestCase
import torch
from mmcv.runner import load_checkpoint, save_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.runner import load_checkpoint, save_checkpoint
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import SwinTransformer
from mmcls.models.backbones.swin_transformer import SwinBlock
@ -90,7 +90,7 @@ class TestSwinTransformer(TestCase):
tmpdir = tempfile.gettempdir()
# Save v3 checkpoints
checkpoint_v2 = os.path.join(tmpdir, 'v3.pth')
save_checkpoint(model, checkpoint_v2)
save_checkpoint(model.state_dict(), checkpoint_v2)
# Save v1 checkpoints
setattr(model, 'norm', model.norm3)
setattr(model.stages[0].blocks[1].attn, 'attn_mask',
@ -98,7 +98,7 @@ class TestSwinTransformer(TestCase):
model._version = 1
del model.norm3
checkpoint_v1 = os.path.join(tmpdir, 'v1.pth')
save_checkpoint(model, checkpoint_v1)
save_checkpoint(model.state_dict(), checkpoint_v1)
# test load v1 checkpoint
cfg = deepcopy(self.cfg)

View File

@ -6,7 +6,7 @@ from copy import deepcopy
from unittest import TestCase
import torch
from mmcv.runner import load_checkpoint, save_checkpoint
from mmengine.runner import load_checkpoint, save_checkpoint
from mmcls.models.backbones import T2T_ViT
from .utils import timm_resize_pos_embed
@ -71,7 +71,7 @@ class TestT2TViT(TestCase):
pretrain_pos_embed = model.pos_embed.clone().detach()
tmpdir = tempfile.gettempdir()
checkpoint = os.path.join(tmpdir, 'test.pth')
save_checkpoint(model, checkpoint)
save_checkpoint(model.state_dict(), checkpoint)
cfg = deepcopy(self.cfg)
model = T2T_ViT(**cfg)
load_checkpoint(model, checkpoint, strict=True)

View File

@ -5,7 +5,7 @@ from itertools import chain
from unittest import TestCase
import torch
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.parrots_wrapper import _BatchNorm
from torch import nn
from mmcls.models.backbones import VAN

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import VGG

View File

@ -6,7 +6,7 @@ from copy import deepcopy
from unittest import TestCase
import torch
from mmcv.runner import load_checkpoint, save_checkpoint
from mmengine.runner import load_checkpoint, save_checkpoint
from mmcls.models.backbones import VisionTransformer
from .utils import timm_resize_pos_embed
@ -97,7 +97,7 @@ class TestVisionTransformer(TestCase):
pretrain_pos_embed = model.pos_embed.clone().detach()
tmpdir = tempfile.gettempdir()
checkpoint = os.path.join(tmpdir, 'test.pth')
save_checkpoint(model, checkpoint)
save_checkpoint(model.state_dict(), checkpoint)
cfg = deepcopy(self.cfg)
model = VisionTransformer(**cfg)
load_checkpoint(model, checkpoint, strict=True)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils import digit_version
from mmengine.utils import digit_version
from mmcls.models.utils import channel_shuffle, is_tracing, make_divisible

View File

@ -3,24 +3,16 @@ import argparse
import copy
import os
import os.path as osp
import time
import warnings
from datetime import datetime
from pathlib import Path
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmengine.config import Config, DictAction
from mmengine.dist import sync_random_seed
from mmengine.fileio import dump, load
from mmengine.hooks import Hook
from mmengine.runner import Runner, find_latest_checkpoint
from mmcls import __version__
from mmcls.datasets import build_dataset
from mmcls.models import build_classifier
from mmcls.utils import (collect_env, get_root_logger, init_random_seed,
load_json_log, set_random_seed, train_model)
from mmcls.utils import register_all_modules
TEST_METRICS = ('precision', 'recall', 'f1_score', 'support', 'mAP', 'CP',
'CR', 'CF1', 'OP', 'OR', 'OF1', 'accuracy')
EXP_INFO_FILE = 'kfold_exp.json'
prog_description = """K-Fold cross-validation.
@ -28,10 +20,7 @@ To start a 5-fold cross-validation experiment:
python tools/kfold-cross-valid.py $CONFIG --num-splits 5
To resume a 5-fold cross-validation from an interrupted experiment:
python tools/kfold-cross-valid.py $CONFIG --num-splits 5 --resume-from work_dirs/fold2/latest.pth
To summarize a 5-fold cross-validation:
python tools/kfold-cross-valid.py $CONFIG --num-splits 5 --summary
python tools/kfold-cross-valid.py $CONFIG --num-splits 5 --resume
""" # noqa: E501
@ -41,47 +30,34 @@ def parse_args():
description=prog_description)
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--num-splits', type=int, help='The number of all folds.')
'--num-splits',
type=int,
help='The number of all folds.',
required=True)
parser.add_argument(
'--fold',
type=int,
help='The fold used to do validation. '
'If specify, only do an experiment of the specified fold.')
parser.add_argument(
'--summary',
action='store_true',
help='Summarize the k-fold cross-validation results.')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
'--resume',
action='store_true',
help='Resume the previous experiment.')
parser.add_argument(
'--amp',
action='store_true',
help='enable automatic-mixed-precision training')
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument('--device', help='device used for training')
group_gpus.add_argument(
'--gpus',
type=int,
help='(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
'--auto-scale-lr',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
help='whether to auto scale the learning rate according to the '
'actual batch size and the original batch size.')
parser.add_argument(
'--cfg-options',
nargs='+',
@ -105,195 +81,14 @@ def parse_args():
return args
def copy_config(old_cfg):
"""deepcopy a Config object."""
new_cfg = Config()
_cfg_dict = copy.deepcopy(old_cfg._cfg_dict)
_filename = copy.deepcopy(old_cfg._filename)
_text = copy.deepcopy(old_cfg._text)
super(Config, new_cfg).__setattr__('_cfg_dict', _cfg_dict)
super(Config, new_cfg).__setattr__('_filename', _filename)
super(Config, new_cfg).__setattr__('_text', _text)
return new_cfg
def merge_args(cfg, args):
"""Merge CLI arguments to config."""
if args.no_validate:
cfg.val_cfg = None
cfg.val_dataloader = None
cfg.val_evaluator = None
def train_single_fold(args, cfg, fold, distributed, seed):
# create the work_dir for the fold
work_dir = osp.join(cfg.work_dir, f'fold{fold}')
cfg.work_dir = work_dir
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# wrap the dataset cfg
train_dataset = dict(
type='KFoldDataset',
fold=fold,
dataset=cfg.data.train,
num_splits=args.num_splits,
seed=seed,
)
val_dataset = dict(
type='KFoldDataset',
fold=fold,
# Use the same dataset with training.
dataset=copy.deepcopy(cfg.data.train),
num_splits=args.num_splits,
seed=seed,
test_mode=True,
)
val_dataset['dataset']['pipeline'] = cfg.data.val.pipeline
cfg.data.train = train_dataset
cfg.data.val = val_dataset
cfg.data.test = val_dataset
# dump config
stem, suffix = osp.basename(args.config).rsplit('.', 1)
cfg.dump(osp.join(cfg.work_dir, f'{stem}_fold{fold}.{suffix}'))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
logger.info(
f'-------- Cross-validation: [{fold+1}/{args.num_splits}] -------- ')
# set random seeds
# Use different seed in different folds
logger.info(f'Set random seed to {seed + fold}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed + fold, deterministic=args.deterministic)
cfg.seed = seed + fold
meta['seed'] = seed + fold
model = build_classifier(cfg.model)
model.init_weights()
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset))
meta.update(
dict(
mmcls_version=__version__,
config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES,
kfold=dict(fold=fold, num_splits=args.num_splits)))
# add an attribute for visualization convenience
train_model(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
device='cpu' if args.device == 'cpu' else 'cuda',
meta=meta)
def summary(args, cfg):
summary = dict()
for fold in range(args.num_splits):
work_dir = Path(cfg.work_dir) / f'fold{fold}'
# Find the latest training log
log_files = list(work_dir.glob('*.log.json'))
if len(log_files) == 0:
continue
log_file = sorted(log_files)[-1]
date = datetime.fromtimestamp(log_file.lstat().st_mtime)
summary[fold] = {'date': date.strftime('%Y-%m-%d %H:%M:%S')}
# Find the latest eval log
json_log = load_json_log(log_file)
epochs = sorted(list(json_log.keys()))
eval_log = {}
def is_metric_key(key):
for metric in TEST_METRICS:
if metric in key:
return True
return False
for epoch in epochs[::-1]:
if any(is_metric_key(k) for k in json_log[epoch].keys()):
eval_log = json_log[epoch]
break
summary[fold]['epoch'] = epoch
summary[fold]['metric'] = {
k: v[0] # the value is a list with only one item.
for k, v in eval_log.items() if is_metric_key(k)
}
show_summary(args, summary)
def show_summary(args, summary_data):
try:
from rich.console import Console
from rich.table import Table
except ImportError:
raise ImportError('Please run `pip install rich` to install '
'package `rich` to draw the table.')
console = Console()
table = Table(title=f'{args.num_splits}-fold Cross-validation Summary')
table.add_column('Fold')
metrics = summary_data[0]['metric'].keys()
for metric in metrics:
table.add_column(metric)
table.add_column('Epoch')
table.add_column('Date')
for fold in range(args.num_splits):
row = [f'{fold+1}']
if fold not in summary_data:
table.add_row(*row)
continue
for metric in metrics:
metric_value = summary_data[fold]['metric'].get(metric, '')
def format_value(value):
if isinstance(value, float):
return f'{value:.2f}'
if isinstance(value, (list, tuple)):
return str([format_value(i) for i in value])
else:
return str(value)
row.append(format_value(metric_value))
row.append(str(summary_data[fold]['epoch']))
row.append(summary_data[fold]['date'])
table.add_row(*row)
console.print(table)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.launcher = args.launcher
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
@ -304,53 +99,119 @@ def main():
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.summary:
summary(args, cfg)
return
# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
'`--amp` is not supported custom optimizer wrapper type ' \
f'`{optim_wrapper}.'
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
# enable auto scale learning rate
if args.auto_scale_lr:
cfg.auto_scale_lr.enable = True
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
return cfg
def train_single_fold(cfg, num_splits, fold, resume_ckpt=None):
root_dir = cfg.work_dir
cfg.work_dir = osp.join(root_dir, f'fold{fold}')
if resume_ckpt is not None:
cfg.resume = True
cfg.load_from = resume_ckpt
dataset = cfg.train_dataloader.dataset
# wrap the dataset cfg
def wrap_dataset(dataset, test_mode):
return dict(
type='KFoldDataset',
dataset=dataset,
fold=fold,
num_splits=num_splits,
seed=cfg.seed,
test_mode=test_mode,
)
train_dataset = copy.deepcopy(dataset)
cfg.train_dataloader.dataset = wrap_dataset(train_dataset, False)
if cfg.val_dataloader is not None:
if 'pipeline' not in cfg.val_dataloader.dataset:
raise ValueError(
'Cannot find `pipeline` in the validation dataset. '
"If you are using dataset wrapper, please don't use this "
'tool to act kfold cross validation. '
'Please write config files manually.')
val_dataset = copy.deepcopy(dataset)
val_dataset['pipeline'] = cfg.val_dataloader.dataset.pipeline
cfg.val_dataloader.dataset = wrap_dataset(val_dataset, True)
if cfg.test_dataloader is not None:
if 'pipeline' not in cfg.test_dataloader.dataset:
raise ValueError(
'Cannot find `pipeline` in the test dataset. '
"If you are using dataset wrapper, please don't use this "
'tool to act kfold cross validation. '
'Please write config files manually.')
test_dataset = copy.deepcopy(dataset)
test_dataset['pipeline'] = cfg.test_dataloader.dataset.pipeline
cfg.test_dataloader.dataset = wrap_dataset(test_dataset, True)
# build the runner from config
runner = Runner.from_cfg(cfg)
runner.logger.info(
f'----------- Cross-validation: [{fold+1}/{num_splits}] ----------- ')
runner.logger.info(f'Train dataset: \n{runner.train_dataloader.dataset}')
class SaveInfoHook(Hook):
def after_train_epoch(self, runner):
try:
last_ckpt = find_latest_checkpoint(cfg.work_dir)
exp_info = dict(
fold=fold, last_ckpt=last_ckpt, seed=runner.seed)
dump(exp_info, osp.join(root_dir, EXP_INFO_FILE))
except OSError:
pass
runner.register_hook(SaveInfoHook(), 'LOWEST')
# start training
runner.train()
def main():
args = parse_args()
# register all modules in mmcls into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
# merge cli arguments to config
cfg = merge_args(cfg, args)
# set preprocess configs to model
cfg.model.setdefault('data_preprocessor', cfg.get('preprocess_cfg', {}))
# set the unify random seed
cfg.seed = args.seed or sync_random_seed()
# resume from the previous experiment
if args.resume_from is not None:
cfg.resume_from = args.resume_from
resume_kfold = torch.load(cfg.resume_from).get('meta',
{}).get('kfold', None)
if resume_kfold is None:
raise RuntimeError(
'No "meta" key in checkpoints or no "kfold" in the meta dict. '
'Please check if the resume checkpoint from a k-fold '
'cross-valid experiment.')
resume_fold = resume_kfold['fold']
assert args.num_splits == resume_kfold['num_splits']
if args.resume:
experiment_info = load(osp.join(cfg.work_dir, EXP_INFO_FILE))
resume_fold = experiment_info['fold']
cfg.seed = experiment_info['seed']
resume_ckpt = experiment_info.get('last_ckpt', None)
else:
resume_fold = 0
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
'single GPU mode in non-distributed training. '
'Use `gpus=1` now.')
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed training. Use the first GPU '
'in `gpu_ids` now.')
if args.gpus is None and args.gpu_ids is None:
cfg.gpu_ids = [args.gpu_id]
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# init a unified random seed
seed = init_random_seed(args.seed)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
resume_ckpt = None
if args.fold is not None:
folds = [args.fold]
@ -358,13 +219,9 @@ def main():
folds = range(resume_fold, args.num_splits)
for fold in folds:
cfg_ = copy_config(cfg)
if fold != resume_fold:
cfg_.resume_from = None
train_single_fold(args, cfg_, fold, distributed, seed)
if args.fold is None:
summary(args, cfg)
cfg_ = copy.deepcopy(cfg)
train_single_fold(cfg_, args.num_splits, fold, resume_ckpt)
resume_ckpt = None
if __name__ == '__main__':