[Docs] Refine add-datasets and add-transform docs (#436)

* update add datasets doc

* refactor add transform

* refine

* add `step` to subtitle

* fix typo

* update link
pull/444/head
Yixiao Fang 2022-08-31 19:20:49 +08:00 committed by GitHub
parent 09d4e0d1b9
commit 3c2fe162b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 125 additions and 110 deletions

View File

@ -1,117 +1,83 @@
# Add Datasets
In this tutorial, we introduce the basic steps to create your customized dataset:
In this tutorial, we introduce the basic steps to create your customized dataset. Before learning to create your customized datasets, it is recommended to learn the basic concept of datasets in file [datasets.md](datasets.md).
- [Add Datasets](#add-datasets)
- [An example of customized dataset](#an-example-of-customized-dataset)
- [Creating the `DataSource`](#creating-the-datasource)
- [Creating the `Dataset`](#creating-the-dataset)
- [Modify config file](#modify-config-file)
- [Step 1: Creating the Dataset](#step-1-creating-the-dataset)
- [Step 2: Add NewDataset to \_\_init\_\_py](#step-2-add-newdataset-to-__init__py)
- [Step 3: Modify the config file](#step-3-modify-the-config-file)
If your algorithm does not need any customized dataset, you can use these off-the-shelf datasets under [datasets](../../mmselfsup/datasets). But to use these existing datasets, you have to convert your dataset to existing dataset format.
If your algorithm does not need any customized dataset, you can use these off-the-shelf datasets under [datasets directory](mmselfsup.datasets). But to use these existing datasets, you have to convert your dataset to existing dataset format.
## An example of customized dataset
As for image pretraining, it is recommended to follow the format of MMClassification.
Assuming the format of your dataset's annotation file is:
## Step 1: Creating the Dataset
```text
000001.jpg 0
000002.jpg 1
```
To write a new dataset, you need to implement:
- `DataSource`: inherited from `BaseDataSource` and responsible for loading the annotation files and reading images.
- `Dataset`: inherited from `BaseDataset` and responsible for applying transformation to images and packing these images.
## Creating the `DataSource`
Assume the name of your `DataSource` is `NewDataSource`, you can create a file, named `new_data_source.py` under `mmselfsup/datasets/data_sources` and implement `NewDataSource` in it.
```python
import mmcv
import numpy as np
from ..builder import DATASOURCES
from .base import BaseDataSource
@DATASOURCES.register_module()
class NewDataSource(BaseDataSource):
def load_annotations(self):
assert isinstance(self.ann_file, str)
data_infos = []
# writing your code here.
return data_infos
```
Then, add `NewDataSource` in `mmselfsup/dataset/data_sources/__init__.py`.
```python
from .base import BaseDataSource
...
from .new_data_source import NewDataSource
__all__ = [
'BaseDataSource', ..., 'NewDataSource'
]
```
## Creating the `Dataset`
You could implement a new dataset class, inherited from `CustomDataset` from MMClassification for image pretraining.
Assume the name of your `Dataset` is `NewDataset`, you can create a file, named `new_dataset.py` under `mmselfsup/datasets` and implement `NewDataset` in it.
```python
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.utils import build_from_cfg
from torchvision.transforms import Compose
from typing import List, Optional, Union
from .base import BaseDataset
from .builder import DATASETS, PIPELINES, build_datasource
from .utils import to_numpy
from mmcls.datasets import CustomDataset
from mmselfsup.registry import DATASETS
@DATASETS.register_module()
class NewDataset(BaseDataset):
class NewDataset(CustomDataset):
def __init__(self, data_source, num_views, pipelines, prefetch=False):
# writing your code here
def __getitem__(self, idx):
# writing your code here
return dict(img=img)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
def __init__(self,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: str = '',
data_prefix: Union[str, dict] = '',
**kwargs) -> None:
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
**kwargs)
def load_data_list(self) -> List[dict]:
# Rewrite load_data_list() to satisfy your specific requirement.
# The returned data_list could include any information you need from
# data or transforms.
# writing your code here
return data_list
def evaluate(self, results, logger=None):
return NotImplemented
```
Then, add `NewDataset` in `mmselfsup/dataset/__init__.py`.
## Step 2: Add NewDataset to \_\_init\_\_py
Then, add `NewDataset` in `mmselfsup/dataset/__init__.py`. If it is not imported, the `NewDataset` will not be registered successfully.
```python
from .base import BaseDataset
...
from .new_dataset import NewDataset
__all__ = [
'BaseDataset', ..., 'NewDataset'
..., 'NewDataset'
]
```
## Modify config file
## Step 3: Modify the config file
To use `NewDataset`, you can modify the config as the following:
```python
train=dict(
train_dataloader = dict(
...
dataset=dict(
type='NewDataset',
data_source=dict(
type='NewDataSource',
),
num_views=[2],
pipelines=[train_pipeline],
prefetch=prefetch,
))
data_root=your_data_root,
ann_file=your_data_root,
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))
```

View File

@ -1,21 +1,26 @@
# Add Transforms
In this tutorial, we introduce the basic steps to create your customized transforms. Before learning to create your customized transforms, it is recommended to learn the basic concept of transforms in file [transforms.md](transforms.md).
- [Add Transforms](#add-transforms)
- [Overview of `Pipeline`](#overview-of-pipeline)
- [Creating new augmentations in `Pipeline`](#creating-new-augmentations-in-pipeline)
- [Overview of Pipeline](#overview-of-pipeline)
- [Creating a new transform in Pipeline](#creating-a-new-transform-in-pipeline)
- [Step 1: Creating the transform](#step-1-creating-the-transform)
- [Step 2: Add NewTransform to \_\_init\_\_py](#step-2-add-newtransform-to-__init__py)
- [Step 3: Modify the config file](#step-3-modify-the-config-file)
## Overview of `Pipeline`
## Overview of Pipeline
`DataSource` and `Pipeline` are two important components in `Dataset`. We have introduced `DataSource` in [add_new_dataset](./1_new_dataset.md). And the `Pipeline` is responsible for applying a series of data augmentations to images, such as random flip.
`Pipeline` is an important component in `Dataset`, which is responsible for applying a series of data augmentations to images, such as `RandomResizedCrop`, `RandomFlip`, etc.
Here is a config example of `Pipeline` for `SimCLR` training:
```python
train_pipeline = [
dict(type='RandomResizedCrop', size=224),
dict(type='RandomHorizontalFlip'),
view_pipeline = [
dict(type='RandomResizedCrop', size=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomAppliedTrans',
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
@ -24,36 +29,70 @@ train_pipeline = [
saturation=0.8,
hue=0.2)
],
p=0.8),
dict(type='RandomGrayscale', p=0.2),
dict(type='GaussianBlur', sigma_min=0.1, sigma_max=2.0, p=0.5)
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
]
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='MultiView', num_views=2, transforms=[view_pipeline]),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]
```
Every augmentation in the `Pipeline` receives an image as input and outputs an augmented image.
Every augmentation in the `Pipeline` receives a `dict` as input and outputs a `dict` containing the augmented image and other related information.
## Creating new augmentations in `Pipeline`
## Creating a new transform in Pipeline
1.Write a new transformation function in [transforms.py](../../mmselfsup/datasets/pipelines/transforms.py) and overwrite the `__call__` function, which takes a `Pillow` image as input:
Here are the steps to create a new transform.
### Step 1: Creating the transform
Write a new transform in [processing.py](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/mmselfsup/datasets/transforms/processing.py) and overwrite the `transform` function, which takes a `dict` as input:
```python
@PIPELINES.register_module()
class MyTransform(object):
@TRANSFORMS.register_module()
class NewTransform(BaseTransform):
"""Docstring for transform.
"""
def __call__(self, img):
# apply transforms on img
return img
def transform(self, results: dict) -> dict:
# apply transform
return results
```
2.Use it in config files. We reuse the config file shown above and add `MyTransform` to it.
**Note:** For the implementation of transforms, you could apply functions in [mmcv](https://github.com/open-mmlab/mmcv/tree/dev-2.x/mmcv/image).
### Step 2: Add NewTransform to \_\_init\_\_py
Then, add the transform to [\_\_init\_\_.py](https://github.com/open-mmlab/mmselfsup/blob/1.x/mmselfsup/datasets/transforms/__init__.py).
```python
train_pipeline = [
dict(type='RandomResizedCrop', size=224),
dict(type='RandomHorizontalFlip'),
dict(type='MyTransform'),
...
from .processing import NewTransform, ...
__all__ = [
..., 'NewTransform'
]
```
### Step 3: Modify the config file
To use `NewTransform`, you can modify the config as the following:
```python
view_pipeline = [
dict(type='RandomResizedCrop', size=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5),
# add `NewTransform`
dict(type='NewTransform'),
dict(
type='RandomAppliedTrans',
type='RandomApply',
transforms=[
dict(
type='ColorJitter',
@ -62,8 +101,18 @@ train_pipeline = [
saturation=0.8,
hue=0.2)
],
p=0.8),
dict(type='RandomGrayscale', p=0.2),
dict(type='GaussianBlur', sigma_min=0.1, sigma_max=2.0, p=0.5)
prob=0.8),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
]
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='MultiView', num_views=2, transforms=[view_pipeline]),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]
```