[Docs] Refine add-datasets and add-transform docs (#436)
* update add datasets doc * refactor add transform * refine * add `step` to subtitle * fix typo * update linkpull/444/head
parent
09d4e0d1b9
commit
3c2fe162b0
|
@ -1,117 +1,83 @@
|
|||
# Add Datasets
|
||||
|
||||
In this tutorial, we introduce the basic steps to create your customized dataset:
|
||||
In this tutorial, we introduce the basic steps to create your customized dataset. Before learning to create your customized datasets, it is recommended to learn the basic concept of datasets in file [datasets.md](datasets.md).
|
||||
|
||||
- [Add Datasets](#add-datasets)
|
||||
- [An example of customized dataset](#an-example-of-customized-dataset)
|
||||
- [Creating the `DataSource`](#creating-the-datasource)
|
||||
- [Creating the `Dataset`](#creating-the-dataset)
|
||||
- [Modify config file](#modify-config-file)
|
||||
- [Step 1: Creating the Dataset](#step-1-creating-the-dataset)
|
||||
- [Step 2: Add NewDataset to \_\_init\_\_py](#step-2-add-newdataset-to-__init__py)
|
||||
- [Step 3: Modify the config file](#step-3-modify-the-config-file)
|
||||
|
||||
If your algorithm does not need any customized dataset, you can use these off-the-shelf datasets under [datasets](../../mmselfsup/datasets). But to use these existing datasets, you have to convert your dataset to existing dataset format.
|
||||
If your algorithm does not need any customized dataset, you can use these off-the-shelf datasets under [datasets directory](mmselfsup.datasets). But to use these existing datasets, you have to convert your dataset to existing dataset format.
|
||||
|
||||
## An example of customized dataset
|
||||
As for image pretraining, it is recommended to follow the format of MMClassification.
|
||||
|
||||
Assuming the format of your dataset's annotation file is:
|
||||
## Step 1: Creating the Dataset
|
||||
|
||||
```text
|
||||
000001.jpg 0
|
||||
000002.jpg 1
|
||||
```
|
||||
|
||||
To write a new dataset, you need to implement:
|
||||
|
||||
- `DataSource`: inherited from `BaseDataSource` and responsible for loading the annotation files and reading images.
|
||||
- `Dataset`: inherited from `BaseDataset` and responsible for applying transformation to images and packing these images.
|
||||
|
||||
## Creating the `DataSource`
|
||||
|
||||
Assume the name of your `DataSource` is `NewDataSource`, you can create a file, named `new_data_source.py` under `mmselfsup/datasets/data_sources` and implement `NewDataSource` in it.
|
||||
|
||||
```python
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from ..builder import DATASOURCES
|
||||
from .base import BaseDataSource
|
||||
|
||||
|
||||
@DATASOURCES.register_module()
|
||||
class NewDataSource(BaseDataSource):
|
||||
|
||||
def load_annotations(self):
|
||||
|
||||
assert isinstance(self.ann_file, str)
|
||||
data_infos = []
|
||||
# writing your code here.
|
||||
return data_infos
|
||||
```
|
||||
|
||||
Then, add `NewDataSource` in `mmselfsup/dataset/data_sources/__init__.py`.
|
||||
|
||||
```python
|
||||
from .base import BaseDataSource
|
||||
...
|
||||
from .new_data_source import NewDataSource
|
||||
|
||||
__all__ = [
|
||||
'BaseDataSource', ..., 'NewDataSource'
|
||||
]
|
||||
```
|
||||
|
||||
## Creating the `Dataset`
|
||||
You could implement a new dataset class, inherited from `CustomDataset` from MMClassification for image pretraining.
|
||||
|
||||
Assume the name of your `Dataset` is `NewDataset`, you can create a file, named `new_dataset.py` under `mmselfsup/datasets` and implement `NewDataset` in it.
|
||||
|
||||
```python
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.utils import build_from_cfg
|
||||
from torchvision.transforms import Compose
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from .base import BaseDataset
|
||||
from .builder import DATASETS, PIPELINES, build_datasource
|
||||
from .utils import to_numpy
|
||||
from mmcls.datasets import CustomDataset
|
||||
|
||||
from mmselfsup.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class NewDataset(BaseDataset):
|
||||
class NewDataset(CustomDataset):
|
||||
|
||||
def __init__(self, data_source, num_views, pipelines, prefetch=False):
|
||||
# writing your code here
|
||||
def __getitem__(self, idx):
|
||||
# writing your code here
|
||||
return dict(img=img)
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: str = '',
|
||||
data_prefix: Union[str, dict] = '',
|
||||
**kwargs) -> None:
|
||||
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
|
||||
super().__init__(
|
||||
ann_file=ann_file,
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
# Rewrite load_data_list() to satisfy your specific requirement.
|
||||
# The returned data_list could include any information you need from
|
||||
# data or transforms.
|
||||
|
||||
# writing your code here
|
||||
return data_list
|
||||
|
||||
def evaluate(self, results, logger=None):
|
||||
return NotImplemented
|
||||
```
|
||||
|
||||
Then, add `NewDataset` in `mmselfsup/dataset/__init__.py`.
|
||||
## Step 2: Add NewDataset to \_\_init\_\_py
|
||||
|
||||
Then, add `NewDataset` in `mmselfsup/dataset/__init__.py`. If it is not imported, the `NewDataset` will not be registered successfully.
|
||||
|
||||
```python
|
||||
from .base import BaseDataset
|
||||
...
|
||||
from .new_dataset import NewDataset
|
||||
|
||||
__all__ = [
|
||||
'BaseDataset', ..., 'NewDataset'
|
||||
..., 'NewDataset'
|
||||
]
|
||||
```
|
||||
|
||||
## Modify config file
|
||||
## Step 3: Modify the config file
|
||||
|
||||
To use `NewDataset`, you can modify the config as the following:
|
||||
|
||||
```python
|
||||
train=dict(
|
||||
train_dataloader = dict(
|
||||
...
|
||||
dataset=dict(
|
||||
type='NewDataset',
|
||||
data_source=dict(
|
||||
type='NewDataSource',
|
||||
),
|
||||
num_views=[2],
|
||||
pipelines=[train_pipeline],
|
||||
prefetch=prefetch,
|
||||
))
|
||||
|
||||
data_root=your_data_root,
|
||||
ann_file=your_data_root,
|
||||
data_prefix=dict(img_path='train/'),
|
||||
pipeline=train_pipeline))
|
||||
```
|
||||
|
|
|
@ -1,21 +1,26 @@
|
|||
# Add Transforms
|
||||
|
||||
In this tutorial, we introduce the basic steps to create your customized transforms. Before learning to create your customized transforms, it is recommended to learn the basic concept of transforms in file [transforms.md](transforms.md).
|
||||
|
||||
- [Add Transforms](#add-transforms)
|
||||
- [Overview of `Pipeline`](#overview-of-pipeline)
|
||||
- [Creating new augmentations in `Pipeline`](#creating-new-augmentations-in-pipeline)
|
||||
- [Overview of Pipeline](#overview-of-pipeline)
|
||||
- [Creating a new transform in Pipeline](#creating-a-new-transform-in-pipeline)
|
||||
- [Step 1: Creating the transform](#step-1-creating-the-transform)
|
||||
- [Step 2: Add NewTransform to \_\_init\_\_py](#step-2-add-newtransform-to-__init__py)
|
||||
- [Step 3: Modify the config file](#step-3-modify-the-config-file)
|
||||
|
||||
## Overview of `Pipeline`
|
||||
## Overview of Pipeline
|
||||
|
||||
`DataSource` and `Pipeline` are two important components in `Dataset`. We have introduced `DataSource` in [add_new_dataset](./1_new_dataset.md). And the `Pipeline` is responsible for applying a series of data augmentations to images, such as random flip.
|
||||
`Pipeline` is an important component in `Dataset`, which is responsible for applying a series of data augmentations to images, such as `RandomResizedCrop`, `RandomFlip`, etc.
|
||||
|
||||
Here is a config example of `Pipeline` for `SimCLR` training:
|
||||
|
||||
```python
|
||||
train_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomHorizontalFlip'),
|
||||
view_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224, backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(
|
||||
type='RandomAppliedTrans',
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
|
@ -24,36 +29,70 @@ train_pipeline = [
|
|||
saturation=0.8,
|
||||
hue=0.2)
|
||||
],
|
||||
p=0.8),
|
||||
dict(type='RandomGrayscale', p=0.2),
|
||||
dict(type='GaussianBlur', sigma_min=0.1, sigma_max=2.0, p=0.5)
|
||||
prob=0.8),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='MultiView', num_views=2, transforms=[view_pipeline]),
|
||||
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||
]
|
||||
```
|
||||
|
||||
Every augmentation in the `Pipeline` receives an image as input and outputs an augmented image.
|
||||
Every augmentation in the `Pipeline` receives a `dict` as input and outputs a `dict` containing the augmented image and other related information.
|
||||
|
||||
## Creating new augmentations in `Pipeline`
|
||||
## Creating a new transform in Pipeline
|
||||
|
||||
1.Write a new transformation function in [transforms.py](../../mmselfsup/datasets/pipelines/transforms.py) and overwrite the `__call__` function, which takes a `Pillow` image as input:
|
||||
Here are the steps to create a new transform.
|
||||
|
||||
### Step 1: Creating the transform
|
||||
|
||||
Write a new transform in [processing.py](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/mmselfsup/datasets/transforms/processing.py) and overwrite the `transform` function, which takes a `dict` as input:
|
||||
|
||||
```python
|
||||
@PIPELINES.register_module()
|
||||
class MyTransform(object):
|
||||
@TRANSFORMS.register_module()
|
||||
class NewTransform(BaseTransform):
|
||||
"""Docstring for transform.
|
||||
"""
|
||||
|
||||
def __call__(self, img):
|
||||
# apply transforms on img
|
||||
return img
|
||||
def transform(self, results: dict) -> dict:
|
||||
# apply transform
|
||||
return results
|
||||
```
|
||||
|
||||
2.Use it in config files. We reuse the config file shown above and add `MyTransform` to it.
|
||||
**Note:** For the implementation of transforms, you could apply functions in [mmcv](https://github.com/open-mmlab/mmcv/tree/dev-2.x/mmcv/image).
|
||||
|
||||
### Step 2: Add NewTransform to \_\_init\_\_py
|
||||
|
||||
Then, add the transform to [\_\_init\_\_.py](https://github.com/open-mmlab/mmselfsup/blob/1.x/mmselfsup/datasets/transforms/__init__.py).
|
||||
|
||||
```python
|
||||
train_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomHorizontalFlip'),
|
||||
dict(type='MyTransform'),
|
||||
...
|
||||
from .processing import NewTransform, ...
|
||||
|
||||
__all__ = [
|
||||
..., 'NewTransform'
|
||||
]
|
||||
```
|
||||
|
||||
### Step 3: Modify the config file
|
||||
|
||||
To use `NewTransform`, you can modify the config as the following:
|
||||
|
||||
```python
|
||||
view_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224, backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
# add `NewTransform`
|
||||
dict(type='NewTransform'),
|
||||
dict(
|
||||
type='RandomAppliedTrans',
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
|
@ -62,8 +101,18 @@ train_pipeline = [
|
|||
saturation=0.8,
|
||||
hue=0.2)
|
||||
],
|
||||
p=0.8),
|
||||
dict(type='RandomGrayscale', p=0.2),
|
||||
dict(type='GaussianBlur', sigma_min=0.1, sigma_max=2.0, p=0.5)
|
||||
prob=0.8),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='MultiView', num_views=2, transforms=[view_pipeline]),
|
||||
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||
]
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue