[Refactor] Use `mmengine` instead of `mmcv` and refactor some transforms. (#986)
* [Refactor] Refactor the rest data transforms and fix some docstring. * Use `mmengine` instead of `mmcv` * Refactor KFold dataset tools * Fix docstring according to commentspull/917/merge
parent
5665b8349a
commit
b7d0d521eb
|
@ -1,5 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dataset import BaseDataset, force_full_init
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
|
||||
|
@ -13,7 +16,8 @@ class KFoldDataset:
|
|||
and use the fold left to do validation.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`BaseDataset`): The dataset to be divided.
|
||||
dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be
|
||||
divided
|
||||
fold (int): The fold used to do validation. Defaults to 0.
|
||||
num_splits (int): The number of all folds. Defaults to 5.
|
||||
test_mode (bool): Use the training dataset or validation dataset.
|
||||
|
@ -28,38 +32,145 @@ class KFoldDataset:
|
|||
num_splits=5,
|
||||
test_mode=False,
|
||||
seed=None):
|
||||
self.dataset = dataset
|
||||
self.CLASSES = dataset.CLASSES
|
||||
self.test_mode = test_mode
|
||||
self.num_splits = num_splits
|
||||
if isinstance(dataset, dict):
|
||||
self.dataset = DATASETS.build(dataset)
|
||||
# Init the dataset wrapper lazily according to the dataset setting.
|
||||
lazy_init = dataset.get('lazy_init', False)
|
||||
elif isinstance(dataset, BaseDataset):
|
||||
self.dataset = dataset
|
||||
else:
|
||||
raise TypeError(f'Unsupported dataset type {type(dataset)}.')
|
||||
|
||||
length = len(dataset)
|
||||
indices = list(range(length))
|
||||
if isinstance(seed, int):
|
||||
rng = np.random.default_rng(seed)
|
||||
self._metainfo = getattr(self.dataset, 'metainfo', {})
|
||||
self.fold = fold
|
||||
self.num_splits = num_splits
|
||||
self.test_mode = test_mode
|
||||
self.seed = seed
|
||||
|
||||
self._fully_initialized = False
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
@property
|
||||
def metainfo(self) -> dict:
|
||||
"""Get the meta information of ``self.dataset``.
|
||||
|
||||
Returns:
|
||||
dict: Meta information of the dataset.
|
||||
"""
|
||||
# Prevent `self._metainfo` from being modified by outside.
|
||||
return copy.deepcopy(self._metainfo)
|
||||
|
||||
def full_init(self):
|
||||
"""fully initialize the dataset."""
|
||||
if self._fully_initialized:
|
||||
return
|
||||
|
||||
self.dataset.full_init()
|
||||
ori_len = len(self.dataset)
|
||||
indices = list(range(ori_len))
|
||||
if self.seed is not None:
|
||||
rng = np.random.default_rng(self.seed)
|
||||
rng.shuffle(indices)
|
||||
|
||||
test_start = length * fold // num_splits
|
||||
test_end = length * (fold + 1) // num_splits
|
||||
if test_mode:
|
||||
self.indices = indices[test_start:test_end]
|
||||
test_start = ori_len * self.fold // self.num_splits
|
||||
test_end = ori_len * (self.fold + 1) // self.num_splits
|
||||
if self.test_mode:
|
||||
indices = indices[test_start:test_end]
|
||||
else:
|
||||
self.indices = indices[:test_start] + indices[test_end:]
|
||||
indices = indices[:test_start] + indices[test_end:]
|
||||
|
||||
def get_cat_ids(self, idx):
|
||||
return self.dataset.get_cat_ids(self.indices[idx])
|
||||
self._ori_indices = indices
|
||||
self.dataset = self.dataset.get_subset(indices)
|
||||
|
||||
def get_gt_labels(self):
|
||||
dataset_gt_labels = self.dataset.get_gt_labels()
|
||||
gt_labels = np.array([dataset_gt_labels[idx] for idx in self.indices])
|
||||
return gt_labels
|
||||
self._fully_initialized = True
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset[self.indices[idx]]
|
||||
@force_full_init
|
||||
def _get_ori_dataset_idx(self, idx: int) -> int:
|
||||
"""Convert global idx to local index.
|
||||
|
||||
Args:
|
||||
idx (int): Global index of ``KFoldDataset``.
|
||||
|
||||
Returns:
|
||||
int: The original index in the whole dataset.
|
||||
"""
|
||||
return self._ori_indices[idx]
|
||||
|
||||
@force_full_init
|
||||
def get_data_info(self, idx: int) -> dict:
|
||||
"""Get annotation by index.
|
||||
|
||||
Args:
|
||||
idx (int): Global index of ``KFoldDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The idx-th annotation of the datasets.
|
||||
"""
|
||||
return self.dataset.get_data_info(idx)
|
||||
|
||||
@force_full_init
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
return len(self.dataset)
|
||||
|
||||
def evaluate(self, *args, **kwargs):
|
||||
kwargs['indices'] = self.indices
|
||||
return self.dataset.evaluate(*args, **kwargs)
|
||||
@force_full_init
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset[idx]
|
||||
|
||||
@force_full_init
|
||||
def get_cat_ids(self, idx):
|
||||
return self.dataset.get_cat_ids(idx)
|
||||
|
||||
@force_full_init
|
||||
def get_gt_labels(self):
|
||||
return self.dataset.get_gt_labels()
|
||||
|
||||
@property
|
||||
def CLASSES(self):
|
||||
"""Return all categories names."""
|
||||
return self._metainfo.get('classes', None)
|
||||
|
||||
@property
|
||||
def class_to_idx(self):
|
||||
"""Map mapping class name to class index.
|
||||
|
||||
Returns:
|
||||
dict: mapping from class name to class index.
|
||||
"""
|
||||
|
||||
return {cat: i for i, cat in enumerate(self.CLASSES)}
|
||||
|
||||
def __repr__(self):
|
||||
"""Print the basic information of the dataset.
|
||||
|
||||
Returns:
|
||||
str: Formatted string.
|
||||
"""
|
||||
head = 'Dataset ' + self.__class__.__name__
|
||||
body = []
|
||||
type_ = 'test' if self.test_mode else 'training'
|
||||
body.append(f'Type: \t{type_}')
|
||||
body.append(f'Seed: \t{self.seed}')
|
||||
|
||||
def ordinal(n):
|
||||
# Copy from https://codegolf.stackexchange.com/a/74047
|
||||
suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4]
|
||||
return f'{n}{suffix}'
|
||||
|
||||
body.append(
|
||||
f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold')
|
||||
if self._fully_initialized:
|
||||
body.append(f'Number of samples: \t{self.__len__()}')
|
||||
else:
|
||||
body.append("Haven't been initialized")
|
||||
|
||||
if self.CLASSES is not None:
|
||||
body.append(f'Number of categories: \t{len(self.CLASSES)}')
|
||||
else:
|
||||
body.append('The `CLASSES` meta info is not set.')
|
||||
|
||||
body.append(
|
||||
f'Original dataset type:\t{self.dataset.__class__.__name__}')
|
||||
|
||||
lines = [head] + [' ' * 4 + line for line in body]
|
||||
return '\n'.join(lines)
|
||||
|
|
|
@ -3,21 +3,17 @@ from .auto_augment import (AutoAugment, AutoContrast, Brightness,
|
|||
ColorTransform, Contrast, Cutout, Equalize, Invert,
|
||||
Posterize, RandAugment, Rotate, Sharpness, Shear,
|
||||
Solarize, SolarizeAdd, Translate)
|
||||
from .compose import Compose
|
||||
from .formatting import (Collect, ImageToTensor, PackClsInputs, ToNumpy, ToPIL,
|
||||
ToTensor, Transpose, to_tensor)
|
||||
from .formatting import Collect, PackClsInputs, ToNumpy, ToPIL, Transpose
|
||||
from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop,
|
||||
EfficientNetRandomCrop, Lighting, Normalize, Pad,
|
||||
RandomCrop, RandomErasing, RandomGrayscale,
|
||||
RandomResizedCrop)
|
||||
EfficientNetRandomCrop, Lighting, RandomCrop,
|
||||
RandomErasing, RandomResizedCrop)
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
|
||||
'Transpose', 'Collect', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
|
||||
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
|
||||
'ToPIL', 'ToNumpy', 'Transpose', 'Collect', 'RandomCrop',
|
||||
'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert',
|
||||
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
|
||||
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
|
||||
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'Pad',
|
||||
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing',
|
||||
'PackClsInputs', 'Albumentations', 'EfficientNetRandomCrop',
|
||||
'EfficientNetCenterCrop'
|
||||
]
|
||||
|
|
|
@ -8,11 +8,11 @@ from typing import List, Optional, Sequence, Tuple, Union
|
|||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv import BaseTransform, RandomChoice
|
||||
from mmcv.transforms import Compose
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
from mmengine import is_list_of, is_seq_of
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from .compose import Compose
|
||||
|
||||
|
||||
def merge_hparams(policy: dict, hparams: dict) -> dict:
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Compose(object):
|
||||
"""Compose a data pipeline with a sequence of transforms.
|
||||
|
||||
Args:
|
||||
transforms (list[dict | callable]):
|
||||
Either config dicts of transforms or transform objects.
|
||||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
assert isinstance(transforms, Sequence)
|
||||
self.transforms = []
|
||||
for transform in transforms:
|
||||
if isinstance(transform, dict):
|
||||
transform = TRANSFORMS.build(transform)
|
||||
self.transforms.append(transform)
|
||||
elif callable(transform):
|
||||
self.transforms.append(transform)
|
||||
else:
|
||||
raise TypeError('transform must be callable or a dict, but got'
|
||||
f' {type(transform)}')
|
||||
|
||||
def __call__(self, data):
|
||||
for t in self.transforms:
|
||||
data = t(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += f'\n {t}'
|
||||
format_string += '\n)'
|
||||
return format_string
|
|
@ -2,11 +2,10 @@
|
|||
import warnings
|
||||
from collections.abc import Sequence
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmengine.utils import is_str
|
||||
from PIL import Image
|
||||
|
||||
from mmcls.registry import TRANSFORMS
|
||||
|
@ -23,7 +22,7 @@ def to_tensor(data):
|
|||
return data
|
||||
elif isinstance(data, np.ndarray):
|
||||
return torch.from_numpy(data)
|
||||
elif isinstance(data, Sequence) and not mmcv.is_str(data):
|
||||
elif isinstance(data, Sequence) and not is_str(data):
|
||||
return torch.tensor(data)
|
||||
elif isinstance(data, int):
|
||||
return torch.LongTensor([data])
|
||||
|
@ -40,30 +39,36 @@ def to_tensor(data):
|
|||
class PackClsInputs(BaseTransform):
|
||||
"""Pack the inputs data for the classification.
|
||||
|
||||
The ``img_meta`` item is always populated. The contents of the
|
||||
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
|
||||
**Required Keys:**
|
||||
|
||||
- ``sample_idx``: id of the image sample
|
||||
- img
|
||||
- gt_label (optional)
|
||||
- ``*meta_keys`` (optional)
|
||||
|
||||
- ``img_path``: path to the image file
|
||||
**Deleted Keys:**
|
||||
|
||||
- ``ori_shape``: original shape of the image as a tuple (H, W).
|
||||
All keys in the dict.
|
||||
|
||||
- ``img_shape``: shape of the image input to the network as a tuple
|
||||
(H, W). Note that images may be zero padded on the bottom/right
|
||||
if the batch tensor is larger than this shape.
|
||||
**Added Keys:**
|
||||
|
||||
- ``scale_factor``: a float indicating the preprocessing scale
|
||||
|
||||
- ``flip``: a boolean indicating if image flip transform was used
|
||||
|
||||
- ``flip_direction``: the flipping direction
|
||||
- inputs (:obj:`torch.Tensor`): The forward data of models.
|
||||
- data_sample (:obj:`~mmcls.structures.ClsDataSample`): The annotation info
|
||||
of the sample.
|
||||
|
||||
Args:
|
||||
meta_keys (Sequence[str], optional): The meta keys to saved in the
|
||||
meta_keys (Sequence[str]): The meta keys to be saved in the
|
||||
``metainfo`` of the packed ``data_sample``.
|
||||
Default: ``('sample_idx', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'flip', 'flip_direction')``
|
||||
Defaults to a tuple includes keys:
|
||||
|
||||
- ``sample_idx``: The id of the image sample.
|
||||
- ``img_path``: The path to the image file.
|
||||
- ``ori_shape``: The original shape of the image as a tuple (H, W).
|
||||
- ``img_shape``: The shape of the image after the pipeline as a
|
||||
tuple (H, W).
|
||||
- ``scale_factor``: The scale factor between the resized image and
|
||||
the original image.
|
||||
- ``flip``: A boolean indicating if image flip transform was used.
|
||||
- ``flip_direction``: The flipping direction.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -72,17 +77,7 @@ class PackClsInputs(BaseTransform):
|
|||
self.meta_keys = meta_keys
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Method to pack the input data.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from the data pipeline.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
|
||||
- 'data_sample' (obj:`ClsDataSample`): The annotation info of the
|
||||
sample.
|
||||
"""
|
||||
"""Method to pack the input data."""
|
||||
packed_results = dict()
|
||||
if 'img' in results:
|
||||
img = results['img']
|
||||
|
@ -115,49 +110,28 @@ class PackClsInputs(BaseTransform):
|
|||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ToTensor(object):
|
||||
"""Convert objects of various python types to :obj:`torch.Tensor`."""
|
||||
class Transpose(BaseTransform):
|
||||
"""Transpose numpy array.
|
||||
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
**Required Keys:**
|
||||
|
||||
def __call__(self, results):
|
||||
for key in self.keys:
|
||||
results[key] = to_tensor(results[key])
|
||||
return results
|
||||
- ``*keys``
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
**Modified Keys:**
|
||||
|
||||
- ``*keys``
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ImageToTensor(object):
|
||||
"""Convert objects :obj:`PIL.Image` to :obj:`torch.Tensor`."""
|
||||
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
|
||||
def __call__(self, results):
|
||||
for key in self.keys:
|
||||
img = results[key]
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
results[key] = to_tensor(img.transpose(2, 0, 1))
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Transpose(object):
|
||||
"""matrix transpose."""
|
||||
Args:
|
||||
keys (List[str]): The fields to convert to tensor.
|
||||
order (List[int]): The output dimensions order.
|
||||
"""
|
||||
|
||||
def __init__(self, keys, order):
|
||||
self.keys = keys
|
||||
self.order = order
|
||||
|
||||
def __call__(self, results):
|
||||
def transform(self, results):
|
||||
"""Method to transpose array."""
|
||||
for key in self.keys:
|
||||
results[key] = results[key].transpose(self.order)
|
||||
return results
|
||||
|
@ -168,114 +142,80 @@ class Transpose(object):
|
|||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ToPIL(object):
|
||||
"""Convert tensor to :obj:`PIL.Image`."""
|
||||
class ToPIL(BaseTransform):
|
||||
"""Convert the image from OpenCV format to :obj:`PIL.Image.Image`.
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
**Required Keys:**
|
||||
|
||||
def __call__(self, results):
|
||||
- img
|
||||
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
"""
|
||||
|
||||
def transform(self, results):
|
||||
"""Method to convert images to :obj:`PIL.Image.Image`."""
|
||||
results['img'] = Image.fromarray(results['img'])
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ToNumpy(object):
|
||||
"""Convert tensor to :obj:`np.ndarray`."""
|
||||
class ToNumpy(BaseTransform):
|
||||
"""Convert object to :obj:`numpy.ndarray`.
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
**Required Keys:**
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = np.array(results['img'], dtype=np.float32)
|
||||
- ``*keys**
|
||||
|
||||
**Modified Keys:**
|
||||
|
||||
- ``*keys**
|
||||
|
||||
Args:
|
||||
dtype (str, optional): The dtype of the converted numpy array.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, keys, dtype=None):
|
||||
self.keys = keys
|
||||
self.dtype = dtype
|
||||
|
||||
def transform(self, results):
|
||||
"""Method to convert object to :obj:`numpy.ndarray`."""
|
||||
for key in self.keys:
|
||||
results[key] = np.array(results[key], dtype=self.dtype)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
f'(keys={self.keys}, dtype={self.dtype})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Collect(object):
|
||||
"""Collect data from the loader relevant to the specific task.
|
||||
class Collect(BaseTransform):
|
||||
"""Collect and only reserve the specified fields.
|
||||
|
||||
This is usually the last stage of the data loader pipeline. Typically keys
|
||||
is set to some subset of "img" and "gt_label".
|
||||
**Required Keys:**
|
||||
|
||||
- ``*keys``
|
||||
|
||||
**Deleted Keys:**
|
||||
|
||||
All keys except those in the argument ``*keys``.
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): Keys of results to be collected in ``data``.
|
||||
meta_keys (Sequence[str], optional): Meta keys to be converted to
|
||||
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
|
||||
Default: ('filename', 'ori_shape', 'img_shape', 'flip',
|
||||
'flip_direction', 'img_norm_cfg')
|
||||
|
||||
Returns:
|
||||
dict: The result dict contains the following keys
|
||||
|
||||
- keys in ``self.keys``
|
||||
- ``img_metas`` if available
|
||||
keys (Sequence[str]): The keys of the fields to be collected.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
keys,
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||
'img_shape', 'flip', 'flip_direction',
|
||||
'img_norm_cfg')):
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
self.meta_keys = meta_keys
|
||||
|
||||
def __call__(self, results):
|
||||
def transform(self, results):
|
||||
data = {}
|
||||
img_meta = {}
|
||||
for key in self.meta_keys:
|
||||
if key in results:
|
||||
img_meta[key] = results[key]
|
||||
data['img_metas'] = DC(img_meta, cpu_only=True)
|
||||
for key in self.keys:
|
||||
data[key] = results[key]
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
f'(keys={self.keys}, meta_keys={self.meta_keys})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class WrapFieldsToLists(object):
|
||||
"""Wrap fields of the data dictionary into lists for evaluation.
|
||||
|
||||
This class can be used as a last step of a test or validation
|
||||
pipeline for single image evaluation or inference.
|
||||
|
||||
Example:
|
||||
>>> test_pipeline = [
|
||||
>>> dict(type='LoadImageFromFile'),
|
||||
>>> dict(type='Normalize',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True),
|
||||
>>> dict(type='ImageToTensor', keys=['img']),
|
||||
>>> dict(type='Collect', keys=['img']),
|
||||
>>> dict(type='WrapIntoLists')
|
||||
>>> ]
|
||||
"""
|
||||
|
||||
def __call__(self, results):
|
||||
# Wrap dict fields into lists
|
||||
for key, val in results.items():
|
||||
results[key] = [val]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}()'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ToHalf(object):
|
||||
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
|
||||
def __call__(self, results):
|
||||
for k in self.keys:
|
||||
if isinstance(results[k], torch.Tensor):
|
||||
results[k] = results[k].to(torch.half)
|
||||
else:
|
||||
results[k] = results[k].astype(np.float16)
|
||||
return results
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
import inspect
|
||||
import math
|
||||
import numbers
|
||||
import random
|
||||
from numbers import Number
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
|
@ -24,25 +23,25 @@ except ImportError:
|
|||
class RandomCrop(BaseTransform):
|
||||
"""Crop the given Image at a random location.
|
||||
|
||||
Required Keys:
|
||||
**Required Keys:**
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
||||
Args:
|
||||
crop_size (sequence or int): Desired output size of the crop. If
|
||||
crop_size (int | Sequence): Desired output size of the crop. If
|
||||
crop_size is an int instead of sequence like (h, w), a square crop
|
||||
(crop_size, crop_size) is made.
|
||||
padding (int or sequence, optional): Optional padding on each border
|
||||
padding (int | Sequence, optional): Optional padding on each border
|
||||
of the image. If a sequence of length 4 is provided, it is used to
|
||||
pad left, top, right, bottom borders respectively. If a sequence
|
||||
of length 2 is provided, it is used to pad left/right, top/bottom
|
||||
borders, respectively. Default: None, which means no padding.
|
||||
pad_if_needed (boolean): It will pad the image if smaller than the
|
||||
pad_if_needed (bool): It will pad the image if smaller than the
|
||||
desired size to avoid raising an exception. Since cropping is done
|
||||
after padding, the padding seems to be done at a random offset.
|
||||
Default: False.
|
||||
|
@ -52,17 +51,17 @@ class RandomCrop(BaseTransform):
|
|||
padding_mode (str): Type of padding. Defaults to "constant". Should
|
||||
be one of the following:
|
||||
|
||||
- constant: Pads with a constant value, this value is specified \
|
||||
with pad_val.
|
||||
- edge: pads with the last value at the edge of the image.
|
||||
- reflect: Pads with reflection of image without repeating the \
|
||||
last value on the edge. For example, padding [1, 2, 3, 4] \
|
||||
with 2 elements on both sides in reflect mode will result \
|
||||
in [3, 2, 1, 2, 3, 4, 3, 2].
|
||||
- symmetric: Pads with reflection of image repeating the last \
|
||||
value on the edge. For example, padding [1, 2, 3, 4] with \
|
||||
2 elements on both sides in symmetric mode will result in \
|
||||
[2, 1, 1, 2, 3, 4, 4, 3].
|
||||
- ``constant``: Pads with a constant value, this value is specified
|
||||
with pad_val.
|
||||
- ``edge``: pads with the last value at the edge of the image.
|
||||
- ``reflect``: Pads with reflection of image without repeating the
|
||||
last value on the edge. For example, padding [1, 2, 3, 4]
|
||||
with 2 elements on both sides in reflect mode will result
|
||||
in [3, 2, 1, 2, 3, 4, 3, 2].
|
||||
- ``symmetric``: Pads with reflection of image repeating the last
|
||||
value on the edge. For example, padding [1, 2, 3, 4] with
|
||||
2 elements on both sides in symmetric mode will result in
|
||||
[2, 1, 1, 2, 3, 4, 4, 3].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -70,7 +69,7 @@ class RandomCrop(BaseTransform):
|
|||
padding: Optional[Union[Sequence, int]] = None,
|
||||
pad_if_needed: bool = False,
|
||||
pad_val: Union[Number, Sequence[Number]] = 0,
|
||||
padding_mode: str = 'constant') -> None:
|
||||
padding_mode: str = 'constant'):
|
||||
if isinstance(crop_size, Sequence):
|
||||
assert len(crop_size) == 2
|
||||
assert crop_size[0] > 0 and crop_size[1] > 0
|
||||
|
@ -170,11 +169,11 @@ class RandomResizedCrop(BaseTransform):
|
|||
random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio
|
||||
is made. This crop is finally resized to given size.
|
||||
|
||||
Required Keys:
|
||||
**Required Keys:**
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
@ -194,7 +193,7 @@ class RandomResizedCrop(BaseTransform):
|
|||
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
|
||||
'bilinear'.
|
||||
backend (str): The image resize backend type, accepted values are
|
||||
`cv2` and `pillow`. Defaults to `cv2`.
|
||||
'cv2' and 'pillow'. Defaults to 'cv2'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -320,11 +319,11 @@ class RandomResizedCrop(BaseTransform):
|
|||
class EfficientNetRandomCrop(RandomResizedCrop):
|
||||
"""EfficientNet style RandomResizedCrop.
|
||||
|
||||
Required Keys:
|
||||
**Required Keys:**
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
@ -347,7 +346,7 @@ class EfficientNetRandomCrop(RandomResizedCrop):
|
|||
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Defaults to
|
||||
'bicubic'.
|
||||
backend (str): The image resize backend type, accepted values are
|
||||
`cv2` and `pillow`. Defaults to `cv2`.
|
||||
'cv2' and 'pillow'. Defaults to 'cv2'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -435,54 +434,18 @@ class EfficientNetRandomCrop(RandomResizedCrop):
|
|||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomGrayscale(object):
|
||||
"""Randomly convert image to grayscale with a probability of gray_prob.
|
||||
|
||||
Args:
|
||||
gray_prob (float): Probability that image should be converted to
|
||||
grayscale. Default: 0.1.
|
||||
|
||||
Returns:
|
||||
ndarray: Image after randomly grayscale transform.
|
||||
|
||||
Notes:
|
||||
- If input image is 1 channel: grayscale version is 1 channel.
|
||||
- If input image is 3 channel: grayscale version is 3 channel
|
||||
with r == g == b.
|
||||
"""
|
||||
|
||||
def __init__(self, gray_prob=0.1):
|
||||
self.gray_prob = gray_prob
|
||||
|
||||
def __call__(self, results):
|
||||
"""
|
||||
Args:
|
||||
img (ndarray): Image to be converted to grayscale.
|
||||
|
||||
Returns:
|
||||
ndarray: Randomly grayscaled image.
|
||||
"""
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
num_output_channels = img.shape[2]
|
||||
if random.random() < self.gray_prob:
|
||||
if num_output_channels > 1:
|
||||
img = mmcv.rgb2gray(img)[:, :, None]
|
||||
results[key] = np.dstack(
|
||||
[img for _ in range(num_output_channels)])
|
||||
return results
|
||||
results[key] = img
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(gray_prob={self.gray_prob})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomErasing(BaseTransform):
|
||||
"""Randomly selects a rectangle region in an image and erase pixels.
|
||||
|
||||
**Required Keys:**
|
||||
|
||||
- img
|
||||
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
|
||||
Args:
|
||||
erase_prob (float): Probability that image will be randomly erased.
|
||||
Default: 0.5
|
||||
|
@ -632,62 +595,19 @@ class RandomErasing(BaseTransform):
|
|||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Pad(object):
|
||||
"""Pad images.
|
||||
|
||||
Args:
|
||||
size (tuple[int] | None): Expected padding size (h, w). Conflicts with
|
||||
pad_to_square. Defaults to None.
|
||||
pad_to_square (bool): Pad any image to square shape. Defaults to False.
|
||||
pad_val (Number | Sequence[Number]): Values to be filled in padding
|
||||
areas when padding_mode is 'constant'. Defaults to 0.
|
||||
padding_mode (str): Type of padding. Should be: constant, edge,
|
||||
reflect or symmetric. Defaults to "constant".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
size=None,
|
||||
pad_to_square=False,
|
||||
pad_val=0,
|
||||
padding_mode='constant'):
|
||||
assert (size is None) ^ (pad_to_square is False), \
|
||||
'Only one of [size, pad_to_square] should be given, ' \
|
||||
f'but get {(size is not None) + (pad_to_square is not False)}'
|
||||
self.size = size
|
||||
self.pad_to_square = pad_to_square
|
||||
self.pad_val = pad_val
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
def __call__(self, results):
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = results[key]
|
||||
if self.pad_to_square:
|
||||
target_size = tuple(
|
||||
max(img.shape[0], img.shape[1]) for _ in range(2))
|
||||
else:
|
||||
target_size = self.size
|
||||
img = mmcv.impad(
|
||||
img,
|
||||
shape=target_size,
|
||||
pad_val=self.pad_val,
|
||||
padding_mode=self.padding_mode)
|
||||
results[key] = img
|
||||
results['img_shape'] = img.shape
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(size={self.size}, '
|
||||
repr_str += f'(pad_val={self.pad_val}, '
|
||||
repr_str += f'padding_mode={self.padding_mode})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class EfficientNetCenterCrop(BaseTransform):
|
||||
"""EfficientNet style center crop.
|
||||
|
||||
**Required Keys:**
|
||||
|
||||
- img
|
||||
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
||||
Args:
|
||||
crop_size (int): Expected size after cropping with the format
|
||||
of (h, w).
|
||||
|
@ -781,16 +701,16 @@ class EfficientNetCenterCrop(BaseTransform):
|
|||
class ResizeEdge(BaseTransform):
|
||||
"""Resize images along the specified edge.
|
||||
|
||||
Required Keys:
|
||||
**Required Keys:**
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
||||
Added Keys:
|
||||
**Added Keys:**
|
||||
|
||||
- scale
|
||||
- scale_factor
|
||||
|
@ -877,38 +797,6 @@ class ResizeEdge(BaseTransform):
|
|||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Normalize(object):
|
||||
"""Normalize the image.
|
||||
|
||||
Args:
|
||||
mean (sequence): Mean values of 3 channels.
|
||||
std (sequence): Std values of 3 channels.
|
||||
to_rgb (bool): Whether to convert the image from BGR to RGB,
|
||||
default is true.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, to_rgb=True):
|
||||
self.mean = np.array(mean, dtype=np.float32)
|
||||
self.std = np.array(std, dtype=np.float32)
|
||||
self.to_rgb = to_rgb
|
||||
|
||||
def __call__(self, results):
|
||||
for key in results.get('img_fields', ['img']):
|
||||
results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
|
||||
self.to_rgb)
|
||||
results['img_norm_cfg'] = dict(
|
||||
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(mean={list(self.mean)}, '
|
||||
repr_str += f'std={list(self.std)}, '
|
||||
repr_str += f'to_rgb={self.to_rgb})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ColorJitter(BaseTransform):
|
||||
"""Randomly change the brightness, contrast and saturation of an image.
|
||||
|
@ -917,11 +805,11 @@ class ColorJitter(BaseTransform):
|
|||
https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py
|
||||
Licensed under the BSD 3-Clause License.
|
||||
|
||||
Required Keys:
|
||||
**Required Keys:**
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
|
||||
|
@ -1043,11 +931,11 @@ class ColorJitter(BaseTransform):
|
|||
class Lighting(BaseTransform):
|
||||
"""Adjust images lighting using AlexNet-style PCA jitter.
|
||||
|
||||
Required Keys:
|
||||
**Required Keys:**
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
|
||||
|
@ -1122,11 +1010,11 @@ class Lighting(BaseTransform):
|
|||
class Albumentations(BaseTransform):
|
||||
"""Wrapper to use augmentation from albumentations library.
|
||||
|
||||
Required Keys:
|
||||
**Required Keys:**
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
**Modified Keys:**
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
|
@ -1134,37 +1022,43 @@ class Albumentations(BaseTransform):
|
|||
Adds custom transformations from albumentations library.
|
||||
More details can be found in
|
||||
`Albumentations <https://albumentations.readthedocs.io>`_.
|
||||
An example of ``transforms`` is as followed:
|
||||
|
||||
.. code-block::
|
||||
[
|
||||
dict(
|
||||
type='ShiftScaleRotate',
|
||||
shift_limit=0.0625,
|
||||
scale_limit=0.0,
|
||||
rotate_limit=0,
|
||||
interpolation=1,
|
||||
p=0.5),
|
||||
dict(
|
||||
type='RandomBrightnessContrast',
|
||||
brightness_limit=[0.1, 0.3],
|
||||
contrast_limit=[0.1, 0.3],
|
||||
p=0.2),
|
||||
dict(type='ChannelShuffle', p=0.1),
|
||||
dict(
|
||||
type='OneOf',
|
||||
transforms=[
|
||||
dict(type='Blur', blur_limit=3, p=1.0),
|
||||
dict(type='MedianBlur', blur_limit=3, p=1.0)
|
||||
],
|
||||
p=0.1),
|
||||
]
|
||||
|
||||
Args:
|
||||
transforms (List[Dict]): List of albumentations transform configs.
|
||||
keymap (Optional[Dict]): Mapping of mmcls to albumentations fields,
|
||||
in format {'input key':'albumentation-style key'}. Defaults to
|
||||
None.
|
||||
|
||||
Example:
|
||||
>>> import mmcv
|
||||
>>> from mmcls.datasets import Albumentations
|
||||
>>> transforms = [
|
||||
... dict(
|
||||
... type='ShiftScaleRotate',
|
||||
... shift_limit=0.0625,
|
||||
... scale_limit=0.0,
|
||||
... rotate_limit=0,
|
||||
... interpolation=1,
|
||||
... p=0.5),
|
||||
... dict(
|
||||
... type='RandomBrightnessContrast',
|
||||
... brightness_limit=[0.1, 0.3],
|
||||
... contrast_limit=[0.1, 0.3],
|
||||
... p=0.2),
|
||||
... dict(type='ChannelShuffle', p=0.1),
|
||||
... dict(
|
||||
... type='OneOf',
|
||||
... transforms=[
|
||||
... dict(type='Blur', blur_limit=3, p=1.0),
|
||||
... dict(type='MedianBlur', blur_limit=3, p=1.0)
|
||||
... ],
|
||||
... p=0.1),
|
||||
... ]
|
||||
>>> albu = Albumentations(transforms)
|
||||
>>> data = {'img': mmcv.imread('./demo/demo.JPEG')}
|
||||
>>> data = albu(data)
|
||||
>>> print(data['img'].shape)
|
||||
(375, 500, 3)
|
||||
"""
|
||||
|
||||
def __init__(self, transforms: List[Dict], keymap: Optional[Dict] = None):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved
|
||||
from mmcv.utils import is_seq_of
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.utils import is_seq_of
|
||||
|
||||
from mmcls.registry import HOOKS
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import itertools
|
|||
import logging
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -14,6 +13,7 @@ from mmengine.hooks import Hook
|
|||
from mmengine.logging import print_log
|
||||
from mmengine.model import is_model_wrapper
|
||||
from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop, Runner
|
||||
from mmengine.utils import ProgressBar
|
||||
from torch.functional import Tensor
|
||||
from torch.nn import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
@ -120,7 +120,7 @@ def update_bn_stats(
|
|||
bn.momentum = 1.0
|
||||
# Average the BN stats for each BN layer over the batches
|
||||
if rank == 0:
|
||||
prog_bar = mmcv.ProgressBar(num_iter)
|
||||
prog_bar = ProgressBar(num_iter)
|
||||
|
||||
for data in itertools.islice(loader, num_iter):
|
||||
batch_inputs, data_samples = model.data_preprocessor(data, False)
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
|
||||
build_norm_layer)
|
||||
from mmcv.utils import digit_version
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
|
@ -8,8 +8,8 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from mmengine.registry import MODELS
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.model import BaseModule, Sequential
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils.se_layer import SELayer
|
||||
|
|
|
@ -5,9 +5,9 @@ import torch.utils.checkpoint as cp
|
|||
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
|
||||
build_norm_layer)
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.utils import constant_init
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
|
@ -8,9 +8,9 @@ import torch.nn as nn
|
|||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.utils import trunc_normal_
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import (ShiftWindowMSA, resize_pos_embed,
|
||||
|
|
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
|||
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmcv.cnn.bricks.transformer import PatchEmbed
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .base_backbone import BaseBackbone
|
||||
|
|
|
@ -4,7 +4,7 @@ import warnings
|
|||
from itertools import repeat
|
||||
|
||||
import torch
|
||||
from mmcv.utils import digit_version
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
|
||||
def is_tracing() -> bool:
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmcv
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils import is_tuple_of
|
||||
|
||||
from .make_divisible import make_divisible
|
||||
|
||||
|
@ -45,7 +45,7 @@ class SELayer(BaseModule):
|
|||
if isinstance(act_cfg, dict):
|
||||
act_cfg = (act_cfg, act_cfg)
|
||||
assert len(act_cfg) == 2
|
||||
assert mmcv.is_tuple_of(act_cfg, dict)
|
||||
assert is_tuple_of(act_cfg, dict)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
if squeeze_channels is None:
|
||||
squeeze_channels = make_divisible(channels // ratio, divisor)
|
||||
|
|
|
@ -3,10 +3,10 @@
|
|||
from numbers import Number
|
||||
from typing import Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import BaseDataElement, LabelData
|
||||
from mmengine.utils import is_str
|
||||
|
||||
|
||||
def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
|
||||
|
@ -31,7 +31,7 @@ def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
|
|||
|
||||
if isinstance(value, np.ndarray):
|
||||
value = torch.from_numpy(value).to(torch.long)
|
||||
elif isinstance(value, Sequence) and not mmcv.is_str(value):
|
||||
elif isinstance(value, Sequence) and not is_str(value):
|
||||
value = torch.tensor(value).to(torch.long)
|
||||
elif isinstance(value, int):
|
||||
value = torch.LongTensor([value])
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.utils import collect_env as collect_base_env
|
||||
from mmengine.utils import collect_env as collect_base_env
|
||||
from mmengine.utils import get_git_hash
|
||||
|
||||
import mmcls
|
||||
|
|
|
@ -3,52 +3,130 @@ import copy
|
|||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
from PIL import Image
|
||||
|
||||
from mmcls.datasets.transforms import PackClsInputs
|
||||
from mmcls.registry import TRANSFORMS
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestPackClsInputs(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Setup the model and optimizer which are used in every test method.
|
||||
|
||||
TestCase calls functions in this order: setUp() -> testMethod() ->
|
||||
tearDown() -> cleanUp()
|
||||
"""
|
||||
data_prefix = osp.join(osp.dirname(__file__), '../../data')
|
||||
img_path = osp.join(data_prefix, 'color.jpg')
|
||||
rng = np.random.RandomState(0)
|
||||
self.results1 = {
|
||||
def test_transform(self):
|
||||
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
|
||||
data = {
|
||||
'sample_idx': 1,
|
||||
'img_path': img_path,
|
||||
'ori_height': 300,
|
||||
'ori_width': 400,
|
||||
'height': 600,
|
||||
'width': 800,
|
||||
'scale_factor': 2.0,
|
||||
'ori_shape': (300, 400),
|
||||
'img_shape': (300, 400),
|
||||
'scale_factor': 1.0,
|
||||
'flip': False,
|
||||
'img': rng.rand(300, 400),
|
||||
'gt_label': rng.randint(3, )
|
||||
'img': mmcv.imread(img_path),
|
||||
'gt_label': 2,
|
||||
}
|
||||
self.meta_keys = ('sample_idx', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'flip')
|
||||
|
||||
def test_transform(self):
|
||||
transform = PackClsInputs(meta_keys=self.meta_keys)
|
||||
results = transform(copy.deepcopy(self.results1))
|
||||
cfg = dict(type='PackClsInputs')
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertIsInstance(results['inputs'], torch.Tensor)
|
||||
self.assertIn('data_sample', results)
|
||||
self.assertIsInstance(results['data_sample'], ClsDataSample)
|
||||
self.assertIn('flip', results['data_sample'].metainfo_keys())
|
||||
self.assertIsInstance(results['data_sample'].gt_label, LabelData)
|
||||
|
||||
data_sample = results['data_sample']
|
||||
self.assertIsInstance(data_sample.gt_label, LabelData)
|
||||
# Test grayscale image
|
||||
data['img'] = data['img'].mean(-1)
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertIsInstance(results['inputs'], torch.Tensor)
|
||||
self.assertEqual(results['inputs'].shape, (1, 300, 400))
|
||||
|
||||
# Test without `img` and `gt_label`
|
||||
del data['img']
|
||||
del data['gt_label']
|
||||
with self.assertWarnsRegex(Warning, 'Cannot get "img"'):
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertNotIn('gt_label', results['data_sample'])
|
||||
|
||||
def test_repr(self):
|
||||
transform = PackClsInputs(meta_keys=self.meta_keys)
|
||||
cfg = dict(type='PackClsInputs', meta_keys=['flip', 'img_shape'])
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
self.assertEqual(
|
||||
repr(transform), f'PackClsInputs(meta_keys={self.meta_keys})')
|
||||
repr(transform), "PackClsInputs(meta_keys=['flip', 'img_shape'])")
|
||||
|
||||
|
||||
class TestTranspose(unittest.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
cfg = dict(type='Transpose', keys=['img'], order=[2, 0, 1])
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
|
||||
data = {'img': np.random.randint(0, 256, (224, 224, 3), dtype='uint8')}
|
||||
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertEqual(results['img'].shape, (3, 224, 224))
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(type='Transpose', keys=['img'], order=(2, 0, 1))
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
self.assertEqual(
|
||||
repr(transform), "Transpose(keys=['img'], order=(2, 0, 1))")
|
||||
|
||||
|
||||
class TestToPIL(unittest.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
cfg = dict(type='ToPIL')
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
|
||||
data = {'img': np.random.randint(0, 256, (224, 224, 3), dtype='uint8')}
|
||||
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertIsInstance(results['img'], Image.Image)
|
||||
|
||||
|
||||
class TestToNumpy(unittest.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg')
|
||||
data = {
|
||||
'tensor': torch.tensor([1, 2, 3]),
|
||||
'Image': Image.open(img_path),
|
||||
}
|
||||
|
||||
cfg = dict(type='ToNumpy', keys=['tensor', 'Image'], dtype='uint8')
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertIsInstance(results['tensor'], np.ndarray)
|
||||
self.assertEqual(results['tensor'].dtype, 'uint8')
|
||||
self.assertIsInstance(results['Image'], np.ndarray)
|
||||
self.assertEqual(results['Image'].dtype, 'uint8')
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(type='ToNumpy', keys=['img'], dtype='uint8')
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
self.assertEqual(repr(transform), "ToNumpy(keys=['img'], dtype=uint8)")
|
||||
|
||||
|
||||
class TestCollect(unittest.TestCase):
|
||||
|
||||
def test_transform(self):
|
||||
data = {'img': [1, 2, 3], 'gt_label': 1}
|
||||
|
||||
cfg = dict(type='Collect', keys=['img'])
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
results = transform(copy.deepcopy(data))
|
||||
self.assertIn('img', results)
|
||||
self.assertNotIn('gt_label', results)
|
||||
|
||||
def test_repr(self):
|
||||
cfg = dict(type='Collect', keys=['img'])
|
||||
transform = TRANSFORMS.build(cfg)
|
||||
self.assertEqual(repr(transform), "Collect(keys=['img'])")
|
||||
|
|
|
@ -628,7 +628,7 @@ class TestLighting(TestCase):
|
|||
TRANSFORMS.build(cfg)
|
||||
|
||||
def test_transform(self):
|
||||
ori_img = np.random.randint(0, 256, (256, 256, 3), np.uint8)
|
||||
ori_img = np.ones((256, 256, 3), np.uint8) * 127
|
||||
results = dict(img=copy.deepcopy(ori_img))
|
||||
|
||||
# Test transform with non-img-keyword result
|
||||
|
|
|
@ -5,7 +5,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import CSPDarkNet, CSPResNet, CSPResNeXt
|
||||
from mmcls.models.backbones.cspnet import (CSPNet, DarknetBottleneck,
|
||||
|
|
|
@ -6,7 +6,7 @@ from copy import deepcopy
|
|||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
from mmengine.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import DistilledVisionTransformer
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
@ -42,7 +42,7 @@ class TestDeiT(TestCase):
|
|||
pretrain_pos_embed = model.pos_embed.clone().detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
save_checkpoint(model, checkpoint)
|
||||
save_checkpoint(model.state_dict(), checkpoint)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
|
|
|
@ -5,7 +5,7 @@ from copy import deepcopy
|
|||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
from mmengine.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import RepMLPNet
|
||||
|
||||
|
@ -163,7 +163,7 @@ class TestRepMLP(TestCase):
|
|||
cfg['deploy'] = True
|
||||
model_deploy = RepMLPNet(**cfg)
|
||||
model_deploy.eval()
|
||||
save_checkpoint(model, self.ckpt_path)
|
||||
save_checkpoint(model.state_dict(), self.ckpt_path)
|
||||
load_checkpoint(model_deploy, self.ckpt_path, strict=True)
|
||||
feats__ = model_deploy(imgs)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import tempfile
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
from mmengine.runner import load_checkpoint, save_checkpoint
|
||||
from torch import nn
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
@ -286,7 +286,7 @@ def test_repvgg_load():
|
|||
outputs = model(inputs)
|
||||
|
||||
model_deploy = RepVGG('A1', out_indices=(0, 1, 2, 3), deploy=True)
|
||||
save_checkpoint(model, ckpt_path)
|
||||
save_checkpoint(model.state_dict(), ckpt_path)
|
||||
load_checkpoint(model_deploy, ckpt_path, strict=True)
|
||||
|
||||
outputs_load = model_deploy(inputs)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import Res2Net
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import ResNet, ResNetV1c, ResNetV1d
|
||||
from mmcls.models.backbones.resnet import (BasicBlock, Bottleneck, ResLayer,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import ResNet_CIFAR
|
||||
|
||||
|
|
|
@ -7,8 +7,8 @@ from itertools import chain
|
|||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.runner import load_checkpoint, save_checkpoint
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import SwinTransformer
|
||||
from mmcls.models.backbones.swin_transformer import SwinBlock
|
||||
|
@ -90,7 +90,7 @@ class TestSwinTransformer(TestCase):
|
|||
tmpdir = tempfile.gettempdir()
|
||||
# Save v3 checkpoints
|
||||
checkpoint_v2 = os.path.join(tmpdir, 'v3.pth')
|
||||
save_checkpoint(model, checkpoint_v2)
|
||||
save_checkpoint(model.state_dict(), checkpoint_v2)
|
||||
# Save v1 checkpoints
|
||||
setattr(model, 'norm', model.norm3)
|
||||
setattr(model.stages[0].blocks[1].attn, 'attn_mask',
|
||||
|
@ -98,7 +98,7 @@ class TestSwinTransformer(TestCase):
|
|||
model._version = 1
|
||||
del model.norm3
|
||||
checkpoint_v1 = os.path.join(tmpdir, 'v1.pth')
|
||||
save_checkpoint(model, checkpoint_v1)
|
||||
save_checkpoint(model.state_dict(), checkpoint_v1)
|
||||
|
||||
# test load v1 checkpoint
|
||||
cfg = deepcopy(self.cfg)
|
||||
|
|
|
@ -6,7 +6,7 @@ from copy import deepcopy
|
|||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
from mmengine.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import T2T_ViT
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
@ -71,7 +71,7 @@ class TestT2TViT(TestCase):
|
|||
pretrain_pos_embed = model.pos_embed.clone().detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
save_checkpoint(model, checkpoint)
|
||||
save_checkpoint(model.state_dict(), checkpoint)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
|
|
|
@ -5,7 +5,7 @@ from itertools import chain
|
|||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
from torch import nn
|
||||
|
||||
from mmcls.models.backbones import VAN
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import VGG
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from copy import deepcopy
|
|||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
from mmengine.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import VisionTransformer
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
@ -97,7 +97,7 @@ class TestVisionTransformer(TestCase):
|
|||
pretrain_pos_embed = model.pos_embed.clone().detach()
|
||||
tmpdir = tempfile.gettempdir()
|
||||
checkpoint = os.path.join(tmpdir, 'test.pth')
|
||||
save_checkpoint(model, checkpoint)
|
||||
save_checkpoint(model.state_dict(), checkpoint)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.utils import digit_version
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
from mmcls.models.utils import channel_shuffle, is_tracing, make_divisible
|
||||
|
||||
|
|
|
@ -3,24 +3,16 @@ import argparse
|
|||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv import Config, DictAction
|
||||
from mmcv.runner import get_dist_info, init_dist
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.dist import sync_random_seed
|
||||
from mmengine.fileio import dump, load
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.runner import Runner, find_latest_checkpoint
|
||||
|
||||
from mmcls import __version__
|
||||
from mmcls.datasets import build_dataset
|
||||
from mmcls.models import build_classifier
|
||||
from mmcls.utils import (collect_env, get_root_logger, init_random_seed,
|
||||
load_json_log, set_random_seed, train_model)
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
TEST_METRICS = ('precision', 'recall', 'f1_score', 'support', 'mAP', 'CP',
|
||||
'CR', 'CF1', 'OP', 'OR', 'OF1', 'accuracy')
|
||||
EXP_INFO_FILE = 'kfold_exp.json'
|
||||
|
||||
prog_description = """K-Fold cross-validation.
|
||||
|
||||
|
@ -28,10 +20,7 @@ To start a 5-fold cross-validation experiment:
|
|||
python tools/kfold-cross-valid.py $CONFIG --num-splits 5
|
||||
|
||||
To resume a 5-fold cross-validation from an interrupted experiment:
|
||||
python tools/kfold-cross-valid.py $CONFIG --num-splits 5 --resume-from work_dirs/fold2/latest.pth
|
||||
|
||||
To summarize a 5-fold cross-validation:
|
||||
python tools/kfold-cross-valid.py $CONFIG --num-splits 5 --summary
|
||||
python tools/kfold-cross-valid.py $CONFIG --num-splits 5 --resume
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
|
@ -41,47 +30,34 @@ def parse_args():
|
|||
description=prog_description)
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument(
|
||||
'--num-splits', type=int, help='The number of all folds.')
|
||||
'--num-splits',
|
||||
type=int,
|
||||
help='The number of all folds.',
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
'--fold',
|
||||
type=int,
|
||||
help='The fold used to do validation. '
|
||||
'If specify, only do an experiment of the specified fold.')
|
||||
parser.add_argument(
|
||||
'--summary',
|
||||
action='store_true',
|
||||
help='Summarize the k-fold cross-validation results.')
|
||||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
||||
parser.add_argument(
|
||||
'--resume-from', help='the checkpoint file to resume from')
|
||||
'--resume',
|
||||
action='store_true',
|
||||
help='Resume the previous experiment.')
|
||||
parser.add_argument(
|
||||
'--amp',
|
||||
action='store_true',
|
||||
help='enable automatic-mixed-precision training')
|
||||
parser.add_argument(
|
||||
'--no-validate',
|
||||
action='store_true',
|
||||
help='whether not to evaluate the checkpoint during training')
|
||||
group_gpus = parser.add_mutually_exclusive_group()
|
||||
group_gpus.add_argument('--device', help='device used for training')
|
||||
group_gpus.add_argument(
|
||||
'--gpus',
|
||||
type=int,
|
||||
help='(Deprecated, please use --gpu-id) number of gpus to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
group_gpus.add_argument(
|
||||
'--gpu-ids',
|
||||
type=int,
|
||||
nargs='+',
|
||||
help='(Deprecated, please use --gpu-id) ids of gpus to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
group_gpus.add_argument(
|
||||
'--gpu-id',
|
||||
type=int,
|
||||
default=0,
|
||||
help='id of gpu to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
||||
parser.add_argument(
|
||||
'--deterministic',
|
||||
'--auto-scale-lr',
|
||||
action='store_true',
|
||||
help='whether to set deterministic options for CUDNN backend.')
|
||||
help='whether to auto scale the learning rate according to the '
|
||||
'actual batch size and the original batch size.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
|
@ -105,195 +81,14 @@ def parse_args():
|
|||
return args
|
||||
|
||||
|
||||
def copy_config(old_cfg):
|
||||
"""deepcopy a Config object."""
|
||||
new_cfg = Config()
|
||||
_cfg_dict = copy.deepcopy(old_cfg._cfg_dict)
|
||||
_filename = copy.deepcopy(old_cfg._filename)
|
||||
_text = copy.deepcopy(old_cfg._text)
|
||||
super(Config, new_cfg).__setattr__('_cfg_dict', _cfg_dict)
|
||||
super(Config, new_cfg).__setattr__('_filename', _filename)
|
||||
super(Config, new_cfg).__setattr__('_text', _text)
|
||||
return new_cfg
|
||||
def merge_args(cfg, args):
|
||||
"""Merge CLI arguments to config."""
|
||||
if args.no_validate:
|
||||
cfg.val_cfg = None
|
||||
cfg.val_dataloader = None
|
||||
cfg.val_evaluator = None
|
||||
|
||||
|
||||
def train_single_fold(args, cfg, fold, distributed, seed):
|
||||
# create the work_dir for the fold
|
||||
work_dir = osp.join(cfg.work_dir, f'fold{fold}')
|
||||
cfg.work_dir = work_dir
|
||||
|
||||
# create work_dir
|
||||
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
||||
|
||||
# wrap the dataset cfg
|
||||
train_dataset = dict(
|
||||
type='KFoldDataset',
|
||||
fold=fold,
|
||||
dataset=cfg.data.train,
|
||||
num_splits=args.num_splits,
|
||||
seed=seed,
|
||||
)
|
||||
val_dataset = dict(
|
||||
type='KFoldDataset',
|
||||
fold=fold,
|
||||
# Use the same dataset with training.
|
||||
dataset=copy.deepcopy(cfg.data.train),
|
||||
num_splits=args.num_splits,
|
||||
seed=seed,
|
||||
test_mode=True,
|
||||
)
|
||||
val_dataset['dataset']['pipeline'] = cfg.data.val.pipeline
|
||||
cfg.data.train = train_dataset
|
||||
cfg.data.val = val_dataset
|
||||
cfg.data.test = val_dataset
|
||||
|
||||
# dump config
|
||||
stem, suffix = osp.basename(args.config).rsplit('.', 1)
|
||||
cfg.dump(osp.join(cfg.work_dir, f'{stem}_fold{fold}.{suffix}'))
|
||||
# init the logger before other steps
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
|
||||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
||||
|
||||
# init the meta dict to record some important information such as
|
||||
# environment info and seed, which will be logged
|
||||
meta = dict()
|
||||
# log env info
|
||||
env_info_dict = collect_env()
|
||||
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
|
||||
dash_line = '-' * 60 + '\n'
|
||||
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
|
||||
dash_line)
|
||||
meta['env_info'] = env_info
|
||||
|
||||
# log some basic info
|
||||
logger.info(f'Distributed training: {distributed}')
|
||||
logger.info(f'Config:\n{cfg.pretty_text}')
|
||||
logger.info(
|
||||
f'-------- Cross-validation: [{fold+1}/{args.num_splits}] -------- ')
|
||||
|
||||
# set random seeds
|
||||
# Use different seed in different folds
|
||||
logger.info(f'Set random seed to {seed + fold}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
set_random_seed(seed + fold, deterministic=args.deterministic)
|
||||
cfg.seed = seed + fold
|
||||
meta['seed'] = seed + fold
|
||||
|
||||
model = build_classifier(cfg.model)
|
||||
model.init_weights()
|
||||
|
||||
datasets = [build_dataset(cfg.data.train)]
|
||||
if len(cfg.workflow) == 2:
|
||||
val_dataset = copy.deepcopy(cfg.data.val)
|
||||
val_dataset.pipeline = cfg.data.train.pipeline
|
||||
datasets.append(build_dataset(val_dataset))
|
||||
meta.update(
|
||||
dict(
|
||||
mmcls_version=__version__,
|
||||
config=cfg.pretty_text,
|
||||
CLASSES=datasets[0].CLASSES,
|
||||
kfold=dict(fold=fold, num_splits=args.num_splits)))
|
||||
# add an attribute for visualization convenience
|
||||
train_model(
|
||||
model,
|
||||
datasets,
|
||||
cfg,
|
||||
distributed=distributed,
|
||||
validate=(not args.no_validate),
|
||||
timestamp=timestamp,
|
||||
device='cpu' if args.device == 'cpu' else 'cuda',
|
||||
meta=meta)
|
||||
|
||||
|
||||
def summary(args, cfg):
|
||||
summary = dict()
|
||||
for fold in range(args.num_splits):
|
||||
work_dir = Path(cfg.work_dir) / f'fold{fold}'
|
||||
|
||||
# Find the latest training log
|
||||
log_files = list(work_dir.glob('*.log.json'))
|
||||
if len(log_files) == 0:
|
||||
continue
|
||||
log_file = sorted(log_files)[-1]
|
||||
|
||||
date = datetime.fromtimestamp(log_file.lstat().st_mtime)
|
||||
summary[fold] = {'date': date.strftime('%Y-%m-%d %H:%M:%S')}
|
||||
|
||||
# Find the latest eval log
|
||||
json_log = load_json_log(log_file)
|
||||
epochs = sorted(list(json_log.keys()))
|
||||
eval_log = {}
|
||||
|
||||
def is_metric_key(key):
|
||||
for metric in TEST_METRICS:
|
||||
if metric in key:
|
||||
return True
|
||||
return False
|
||||
|
||||
for epoch in epochs[::-1]:
|
||||
if any(is_metric_key(k) for k in json_log[epoch].keys()):
|
||||
eval_log = json_log[epoch]
|
||||
break
|
||||
|
||||
summary[fold]['epoch'] = epoch
|
||||
summary[fold]['metric'] = {
|
||||
k: v[0] # the value is a list with only one item.
|
||||
for k, v in eval_log.items() if is_metric_key(k)
|
||||
}
|
||||
show_summary(args, summary)
|
||||
|
||||
|
||||
def show_summary(args, summary_data):
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
except ImportError:
|
||||
raise ImportError('Please run `pip install rich` to install '
|
||||
'package `rich` to draw the table.')
|
||||
|
||||
console = Console()
|
||||
table = Table(title=f'{args.num_splits}-fold Cross-validation Summary')
|
||||
table.add_column('Fold')
|
||||
metrics = summary_data[0]['metric'].keys()
|
||||
for metric in metrics:
|
||||
table.add_column(metric)
|
||||
table.add_column('Epoch')
|
||||
table.add_column('Date')
|
||||
|
||||
for fold in range(args.num_splits):
|
||||
row = [f'{fold+1}']
|
||||
if fold not in summary_data:
|
||||
table.add_row(*row)
|
||||
continue
|
||||
for metric in metrics:
|
||||
metric_value = summary_data[fold]['metric'].get(metric, '')
|
||||
|
||||
def format_value(value):
|
||||
if isinstance(value, float):
|
||||
return f'{value:.2f}'
|
||||
if isinstance(value, (list, tuple)):
|
||||
return str([format_value(i) for i in value])
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
row.append(format_value(metric_value))
|
||||
row.append(str(summary_data[fold]['epoch']))
|
||||
row.append(summary_data[fold]['date'])
|
||||
table.add_row(*row)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
cfg.launcher = args.launcher
|
||||
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
|
@ -304,53 +99,119 @@ def main():
|
|||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
if args.summary:
|
||||
summary(args, cfg)
|
||||
return
|
||||
# enable automatic-mixed-precision training
|
||||
if args.amp is True:
|
||||
optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
|
||||
assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
|
||||
'`--amp` is not supported custom optimizer wrapper type ' \
|
||||
f'`{optim_wrapper}.'
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
|
||||
|
||||
# enable auto scale learning rate
|
||||
if args.auto_scale_lr:
|
||||
cfg.auto_scale_lr.enable = True
|
||||
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def train_single_fold(cfg, num_splits, fold, resume_ckpt=None):
|
||||
root_dir = cfg.work_dir
|
||||
cfg.work_dir = osp.join(root_dir, f'fold{fold}')
|
||||
if resume_ckpt is not None:
|
||||
cfg.resume = True
|
||||
cfg.load_from = resume_ckpt
|
||||
dataset = cfg.train_dataloader.dataset
|
||||
|
||||
# wrap the dataset cfg
|
||||
def wrap_dataset(dataset, test_mode):
|
||||
return dict(
|
||||
type='KFoldDataset',
|
||||
dataset=dataset,
|
||||
fold=fold,
|
||||
num_splits=num_splits,
|
||||
seed=cfg.seed,
|
||||
test_mode=test_mode,
|
||||
)
|
||||
|
||||
train_dataset = copy.deepcopy(dataset)
|
||||
cfg.train_dataloader.dataset = wrap_dataset(train_dataset, False)
|
||||
|
||||
if cfg.val_dataloader is not None:
|
||||
if 'pipeline' not in cfg.val_dataloader.dataset:
|
||||
raise ValueError(
|
||||
'Cannot find `pipeline` in the validation dataset. '
|
||||
"If you are using dataset wrapper, please don't use this "
|
||||
'tool to act kfold cross validation. '
|
||||
'Please write config files manually.')
|
||||
val_dataset = copy.deepcopy(dataset)
|
||||
val_dataset['pipeline'] = cfg.val_dataloader.dataset.pipeline
|
||||
cfg.val_dataloader.dataset = wrap_dataset(val_dataset, True)
|
||||
if cfg.test_dataloader is not None:
|
||||
if 'pipeline' not in cfg.test_dataloader.dataset:
|
||||
raise ValueError(
|
||||
'Cannot find `pipeline` in the test dataset. '
|
||||
"If you are using dataset wrapper, please don't use this "
|
||||
'tool to act kfold cross validation. '
|
||||
'Please write config files manually.')
|
||||
test_dataset = copy.deepcopy(dataset)
|
||||
test_dataset['pipeline'] = cfg.test_dataloader.dataset.pipeline
|
||||
cfg.test_dataloader.dataset = wrap_dataset(test_dataset, True)
|
||||
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.logger.info(
|
||||
f'----------- Cross-validation: [{fold+1}/{num_splits}] ----------- ')
|
||||
runner.logger.info(f'Train dataset: \n{runner.train_dataloader.dataset}')
|
||||
|
||||
class SaveInfoHook(Hook):
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
try:
|
||||
last_ckpt = find_latest_checkpoint(cfg.work_dir)
|
||||
exp_info = dict(
|
||||
fold=fold, last_ckpt=last_ckpt, seed=runner.seed)
|
||||
dump(exp_info, osp.join(root_dir, EXP_INFO_FILE))
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
runner.register_hook(SaveInfoHook(), 'LOWEST')
|
||||
|
||||
# start training
|
||||
runner.train()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmcls into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
# merge cli arguments to config
|
||||
cfg = merge_args(cfg, args)
|
||||
|
||||
# set preprocess configs to model
|
||||
cfg.model.setdefault('data_preprocessor', cfg.get('preprocess_cfg', {}))
|
||||
|
||||
# set the unify random seed
|
||||
cfg.seed = args.seed or sync_random_seed()
|
||||
|
||||
# resume from the previous experiment
|
||||
if args.resume_from is not None:
|
||||
cfg.resume_from = args.resume_from
|
||||
resume_kfold = torch.load(cfg.resume_from).get('meta',
|
||||
{}).get('kfold', None)
|
||||
if resume_kfold is None:
|
||||
raise RuntimeError(
|
||||
'No "meta" key in checkpoints or no "kfold" in the meta dict. '
|
||||
'Please check if the resume checkpoint from a k-fold '
|
||||
'cross-valid experiment.')
|
||||
resume_fold = resume_kfold['fold']
|
||||
assert args.num_splits == resume_kfold['num_splits']
|
||||
if args.resume:
|
||||
experiment_info = load(osp.join(cfg.work_dir, EXP_INFO_FILE))
|
||||
resume_fold = experiment_info['fold']
|
||||
cfg.seed = experiment_info['seed']
|
||||
resume_ckpt = experiment_info.get('last_ckpt', None)
|
||||
else:
|
||||
resume_fold = 0
|
||||
|
||||
if args.gpus is not None:
|
||||
cfg.gpu_ids = range(1)
|
||||
warnings.warn('`--gpus` is deprecated because we only support '
|
||||
'single GPU mode in non-distributed training. '
|
||||
'Use `gpus=1` now.')
|
||||
if args.gpu_ids is not None:
|
||||
cfg.gpu_ids = args.gpu_ids[0:1]
|
||||
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
|
||||
'Because we only support single GPU mode in '
|
||||
'non-distributed training. Use the first GPU '
|
||||
'in `gpu_ids` now.')
|
||||
if args.gpus is None and args.gpu_ids is None:
|
||||
cfg.gpu_ids = [args.gpu_id]
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if args.launcher == 'none':
|
||||
distributed = False
|
||||
else:
|
||||
distributed = True
|
||||
init_dist(args.launcher, **cfg.dist_params)
|
||||
_, world_size = get_dist_info()
|
||||
cfg.gpu_ids = range(world_size)
|
||||
|
||||
# init a unified random seed
|
||||
seed = init_random_seed(args.seed)
|
||||
|
||||
# create work_dir
|
||||
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
||||
resume_ckpt = None
|
||||
|
||||
if args.fold is not None:
|
||||
folds = [args.fold]
|
||||
|
@ -358,13 +219,9 @@ def main():
|
|||
folds = range(resume_fold, args.num_splits)
|
||||
|
||||
for fold in folds:
|
||||
cfg_ = copy_config(cfg)
|
||||
if fold != resume_fold:
|
||||
cfg_.resume_from = None
|
||||
train_single_fold(args, cfg_, fold, distributed, seed)
|
||||
|
||||
if args.fold is None:
|
||||
summary(args, cfg)
|
||||
cfg_ = copy.deepcopy(cfg)
|
||||
train_single_fold(cfg_, args.num_splits, fold, resume_ckpt)
|
||||
resume_ckpt = None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue