mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Add a segmentation example (#1282)
This commit is contained in:
parent
d772ad0962
commit
5c5ec8b168
288
docs/en/examples/train_seg.md
Normal file
288
docs/en/examples/train_seg.md
Normal file
@ -0,0 +1,288 @@
|
||||
# Train a Segmentation Model
|
||||
|
||||
This segmentation task example will be divided into the following steps:
|
||||
|
||||
- [Download Camvid Dataset](#download-camvid-dataset)
|
||||
- [Implement Camvid Dataset](#implement-the-camvid-dataset)
|
||||
- [Implement a Segmentation Model](#implement-the-segmentation-model)
|
||||
- [Train with Runner](#training-with-runner)
|
||||
|
||||
```{note}
|
||||
You can also experience the notebook example [here](https://colab.research.google.com/github/open-mmlab/mmengine/blob/main/examples/segmentation/train.ipynb).
|
||||
```
|
||||
|
||||
## Download Camvid Dataset
|
||||
|
||||
First, you should download the Camvid dataset from OpenDataLab:
|
||||
|
||||
```bash
|
||||
# https://opendatalab.com/CamVid
|
||||
# Configure install
|
||||
pip install opendatalab
|
||||
# Upgraded version
|
||||
pip install -U opendatalab
|
||||
# Login
|
||||
odl login
|
||||
# Download this dataset
|
||||
mkdir data
|
||||
odl get CamVid -d data
|
||||
# Preprocess data in Linux. You should extract the files to data manually in
|
||||
# Windows
|
||||
tar -xzvf data/CamVid/raw/CamVid.tar.gz.00 -C ./data
|
||||
```
|
||||
|
||||
## Implement the Camvid Dataset
|
||||
|
||||
We have implemented the CamVid class here, which inherits from VisionDataset. Within this class, we have overridden the `__getitem__` and `__len__` methods to ensure that each index returns a dict of images and labels. Additionally, we have implemented the color_to_class dictionary to map the mask's color to the class index.
|
||||
|
||||
```python
|
||||
import os
|
||||
import numpy as np
|
||||
from torchvision.datasets import VisionDataset
|
||||
from PIL import Image
|
||||
import csv
|
||||
|
||||
|
||||
def create_palette(csv_filepath):
|
||||
color_to_class = {}
|
||||
with open(csv_filepath, newline='') as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for idx, row in enumerate(reader):
|
||||
r, g, b = int(row['r']), int(row['g']), int(row['b'])
|
||||
color_to_class[(r, g, b)] = idx
|
||||
return color_to_class
|
||||
|
||||
class CamVid(VisionDataset):
|
||||
|
||||
def __init__(self,
|
||||
root,
|
||||
img_folder,
|
||||
mask_folder,
|
||||
transform=None,
|
||||
target_transform=None):
|
||||
super().__init__(
|
||||
root, transform=transform, target_transform=target_transform)
|
||||
self.img_folder = img_folder
|
||||
self.mask_folder = mask_folder
|
||||
self.images = list(
|
||||
sorted(os.listdir(os.path.join(self.root, img_folder))))
|
||||
self.masks = list(
|
||||
sorted(os.listdir(os.path.join(self.root, mask_folder))))
|
||||
self.color_to_class = create_palette(
|
||||
os.path.join(self.root, 'class_dict.csv'))
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = os.path.join(self.root, self.img_folder, self.images[index])
|
||||
mask_path = os.path.join(self.root, self.mask_folder,
|
||||
self.masks[index])
|
||||
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
mask = Image.open(mask_path).convert('RGB') # Convert to RGB
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
# Convert the RGB values to class indices
|
||||
mask = np.array(mask)
|
||||
mask = mask[:, :, 0] * 65536 + mask[:, :, 1] * 256 + mask[:, :, 2]
|
||||
labels = np.zeros_like(mask, dtype=np.int64)
|
||||
for color, class_index in self.color_to_class.items():
|
||||
rgb = color[0] * 65536 + color[1] * 256 + color[2]
|
||||
labels[mask == rgb] = class_index
|
||||
|
||||
if self.target_transform is not None:
|
||||
labels = self.target_transform(labels)
|
||||
data_samples = dict(
|
||||
labels=labels, img_path=img_path, mask_path=mask_path)
|
||||
return img, data_samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
```
|
||||
|
||||
We utilize the Camvid dataset to create the `train_dataloader` and `val_dataloader`, which serve as the data loaders for training and validation in the subsequent Runner.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize(**norm_cfg)])
|
||||
|
||||
target_transform = transforms.Lambda(
|
||||
lambda x: torch.tensor(np.array(x), dtype=torch.long))
|
||||
|
||||
train_set = CamVid(
|
||||
'data/CamVid',
|
||||
img_folder='train',
|
||||
mask_folder='train_labels',
|
||||
transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
valid_set = CamVid(
|
||||
'data/CamVid',
|
||||
img_folder='val',
|
||||
mask_folder='val_labels',
|
||||
transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=3,
|
||||
dataset=train_set,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=3,
|
||||
dataset=valid_set,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
collate_fn=dict(type='default_collate'))
|
||||
```
|
||||
|
||||
## Implement the Segmentation Model
|
||||
|
||||
The provided code defines a model class named `MMDeeplabV3`. This class is derived from `BaseModel` and incorporates the segmentation model of the DeepLabV3 architecture. It overrides the `forward` method to handle both input images and labels and supports computing losses and returning predictions in both training and prediction modes.
|
||||
|
||||
For additional information about `BaseModel`, you can refer to the [Model tutorial](../tutorials/model.md).
|
||||
|
||||
```python
|
||||
from mmengine.model import BaseModel
|
||||
from torchvision.models.segmentation import deeplabv3_resnet50
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MMDeeplabV3(BaseModel):
|
||||
|
||||
def __init__(self, num_classes):
|
||||
super().__init__()
|
||||
self.deeplab = deeplabv3_resnet50(num_classes=num_classes)
|
||||
|
||||
def forward(self, imgs, data_samples=None, mode='tensor'):
|
||||
x = self.deeplab(imgs)['out']
|
||||
if mode == 'loss':
|
||||
return {'loss': F.cross_entropy(x, data_samples['labels'])}
|
||||
elif mode == 'predict':
|
||||
return x, data_samples
|
||||
```
|
||||
|
||||
## Training with Runner
|
||||
|
||||
Before training with the Runner, we need to implement the IoU (Intersection over Union) metric to evaluate the model's performance.
|
||||
|
||||
```python
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
class IoU(BaseMetric):
|
||||
|
||||
def process(self, data_batch, data_samples):
|
||||
preds, labels = data_samples[0], data_samples[1]['labels']
|
||||
preds = torch.argmax(preds, dim=1)
|
||||
intersect = (labels == preds).sum()
|
||||
union = (torch.logical_or(preds, labels)).sum()
|
||||
iou = (intersect / union).cpu()
|
||||
self.results.append(
|
||||
dict(batch_size=len(labels), iou=iou * len(labels)))
|
||||
|
||||
def compute_metrics(self, results):
|
||||
total_iou = sum(result['iou'] for result in self.results)
|
||||
num_samples = sum(result['batch_size'] for result in self.results)
|
||||
return dict(iou=total_iou / num_samples)
|
||||
```
|
||||
|
||||
Implementing a visualization hook is also important to facilitate easier comparison between predictions and labels.
|
||||
|
||||
```python
|
||||
from mmengine.hooks import Hook
|
||||
import shutil
|
||||
import cv2
|
||||
import os.path as osp
|
||||
|
||||
|
||||
class SegVisHook(Hook):
|
||||
|
||||
def __init__(self, data_root, vis_num=1) -> None:
|
||||
super().__init__()
|
||||
self.vis_num = vis_num
|
||||
self.palette = create_palette(osp.join(data_root, 'class_dict.csv'))
|
||||
|
||||
def after_val_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch=None,
|
||||
outputs=None) -> None:
|
||||
if batch_idx > self.vis_num:
|
||||
return
|
||||
preds, data_samples = outputs
|
||||
img_paths = data_samples['img_path']
|
||||
mask_paths = data_samples['mask_path']
|
||||
_, C, H, W = preds.shape
|
||||
preds = torch.argmax(preds, dim=1)
|
||||
for idx, (pred, img_path,
|
||||
mask_path) in enumerate(zip(preds, img_paths, mask_paths)):
|
||||
pred_mask = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
runner.visualizer.set_image(pred_mask)
|
||||
for color, class_id in self.palette.items():
|
||||
runner.visualizer.draw_binary_masks(
|
||||
pred == class_id,
|
||||
colors=[color],
|
||||
alphas=1.0,
|
||||
)
|
||||
# Convert RGB to BGR
|
||||
pred_mask = runner.visualizer.get_image()[..., ::-1]
|
||||
saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx))
|
||||
os.makedirs(saved_dir, exist_ok=True)
|
||||
|
||||
shutil.copyfile(img_path,
|
||||
osp.join(saved_dir, osp.basename(img_path)))
|
||||
shutil.copyfile(mask_path,
|
||||
osp.join(saved_dir, osp.basename(mask_path)))
|
||||
cv2.imwrite(
|
||||
osp.join(saved_dir, f'pred_{osp.basename(img_path)}'),
|
||||
pred_mask)
|
||||
```
|
||||
|
||||
Finnaly, just train the model with Runner!
|
||||
|
||||
```python
|
||||
from torch.optim import AdamW
|
||||
from mmengine.optim import AmpOptimWrapper
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
num_classes = 32 # Modify to actual number of categories.
|
||||
|
||||
runner = Runner(
|
||||
model=MMDeeplabV3(num_classes),
|
||||
work_dir='./work_dir',
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=dict(
|
||||
type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10),
|
||||
val_dataloader=val_dataloader,
|
||||
val_cfg=dict(),
|
||||
val_evaluator=dict(type=IoU),
|
||||
custom_hooks=[SegVisHook('data/CamVid')],
|
||||
default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)),
|
||||
)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
Finnaly, you can check the training results in the folder `./work_dir/{timestamp}/vis_data`.
|
||||
|
||||
<table class="docutils">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>image</th>
|
||||
<th>prediction</th>
|
||||
<th>label</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th><img src="https://github.com/open-mmlab/mmengine/assets/57566630/de70c138-fb8e-402c-9497-574b01725b6c" width="200"></th>
|
||||
<th><img src="https://github.com/open-mmlab/mmengine/assets/57566630/ea9221e7-48ca-4515-8815-56b5ff091f53" width="200"></th>
|
||||
<th><img src="https://github.com/open-mmlab/mmengine/assets/57566630/dcb2324f-a2df-4e5c-a038-df896dde2471" width="200"></th>
|
||||
</tr>
|
||||
</thead>
|
||||
</table>
|
@ -15,6 +15,7 @@ You can switch between Chinese and English documents in the lower-left corner of
|
||||
:caption: Examples
|
||||
|
||||
examples/train_a_gan.md
|
||||
examples/train_seg.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
290
docs/zh_cn/examples/train_seg.md
Normal file
290
docs/zh_cn/examples/train_seg.md
Normal file
@ -0,0 +1,290 @@
|
||||
# 训练一个语义分割模型
|
||||
|
||||
语义分割的样例大体可以分成四个步骤:
|
||||
|
||||
- [下载 Camvid 数据集](#下载-camvid-数据集)
|
||||
- [实现 Camvid 数据类](#实现-camvid-数据类)
|
||||
- [实现语义分割模型](#实现语义分割模型)
|
||||
- [使用 Runner 训练模型](#使用-runner-训练模型)
|
||||
|
||||
```{note}
|
||||
如果你更喜欢 notebook 风格的样例,也可以在[此处](https://colab.research.google.com/github/open-mmlab/mmengine/blob/main/examples/segmentation/train.ipynb) 体验。
|
||||
```
|
||||
|
||||
## 下载 Camvid 数据集
|
||||
|
||||
首先,从 opendatalab 下载 Camvid 数据集:
|
||||
|
||||
```bash
|
||||
# https://opendatalab.com/CamVid
|
||||
# Configure install
|
||||
pip install opendatalab
|
||||
# Upgraded version
|
||||
pip install -U opendatalab
|
||||
# Login
|
||||
odl login
|
||||
# Download this dataset
|
||||
mkdir data
|
||||
odl get CamVid -d data
|
||||
# Preprocess data in Linux. You should extract the files to data manually in
|
||||
# Windows
|
||||
tar -xzvf data/CamVid/raw/CamVid.tar.gz.00 -C ./data
|
||||
```
|
||||
|
||||
## 实现 Camvid 数据类
|
||||
|
||||
实现继承自 VisionDataset 的 CamVid 数据类。在这个类中,我们重写了`__getitem__`和`__len__`方法,以确保每个索引返回一个包含图像和标签的字典。此外,我们还实现了color_to_class字典,将 mask 的颜色映射到类别索引。
|
||||
|
||||
```python
|
||||
import os
|
||||
import numpy as np
|
||||
from torchvision.datasets import VisionDataset
|
||||
from PIL import Image
|
||||
import csv
|
||||
|
||||
|
||||
def create_palette(csv_filepath):
|
||||
color_to_class = {}
|
||||
with open(csv_filepath, newline='') as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for idx, row in enumerate(reader):
|
||||
r, g, b = int(row['r']), int(row['g']), int(row['b'])
|
||||
color_to_class[(r, g, b)] = idx
|
||||
return color_to_class
|
||||
|
||||
class CamVid(VisionDataset):
|
||||
|
||||
def __init__(self,
|
||||
root,
|
||||
img_folder,
|
||||
mask_folder,
|
||||
transform=None,
|
||||
target_transform=None):
|
||||
super().__init__(
|
||||
root, transform=transform, target_transform=target_transform)
|
||||
self.img_folder = img_folder
|
||||
self.mask_folder = mask_folder
|
||||
self.images = list(
|
||||
sorted(os.listdir(os.path.join(self.root, img_folder))))
|
||||
self.masks = list(
|
||||
sorted(os.listdir(os.path.join(self.root, mask_folder))))
|
||||
self.color_to_class = create_palette(
|
||||
os.path.join(self.root, 'class_dict.csv'))
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = os.path.join(self.root, self.img_folder, self.images[index])
|
||||
mask_path = os.path.join(self.root, self.mask_folder,
|
||||
self.masks[index])
|
||||
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
mask = Image.open(mask_path).convert('RGB') # Convert to RGB
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
# Convert the RGB values to class indices
|
||||
mask = np.array(mask)
|
||||
mask = mask[:, :, 0] * 65536 + mask[:, :, 1] * 256 + mask[:, :, 2]
|
||||
labels = np.zeros_like(mask, dtype=np.int64)
|
||||
for color, class_index in self.color_to_class.items():
|
||||
rgb = color[0] * 65536 + color[1] * 256 + color[2]
|
||||
labels[mask == rgb] = class_index
|
||||
|
||||
if self.target_transform is not None:
|
||||
labels = self.target_transform(labels)
|
||||
data_samples = dict(
|
||||
labels=labels, img_path=img_path, mask_path=mask_path)
|
||||
return img, data_samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
```
|
||||
|
||||
基于 CamVid 数据类,选择相应的数据增强方式,构建 train_dataloader 和 val_dataloader,供后续 runner 使用
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize(**norm_cfg)])
|
||||
|
||||
target_transform = transforms.Lambda(
|
||||
lambda x: torch.tensor(np.array(x), dtype=torch.long))
|
||||
|
||||
train_set = CamVid(
|
||||
'data/CamVid',
|
||||
img_folder='train',
|
||||
mask_folder='train_labels',
|
||||
transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
valid_set = CamVid(
|
||||
'data/CamVid',
|
||||
img_folder='val',
|
||||
mask_folder='val_labels',
|
||||
transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=3,
|
||||
dataset=train_set,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=3,
|
||||
dataset=valid_set,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
collate_fn=dict(type='default_collate'))
|
||||
```
|
||||
|
||||
## 实现语义分割模型
|
||||
|
||||
定义一个名为`MMDeeplabV3`的模型类。该类继承自`BaseModel`,并集成了DeepLabV3架构的分割模型。`MMDeeplabV3` 重写了`forward`方法,以处理输入图像和标签,并支持在训练和预测模式下计算损失和返回预测结果。
|
||||
|
||||
关于`BaseModel`的更多信息,请参考[模型教程](../tutorials/model.md)。
|
||||
|
||||
```python
|
||||
from mmengine.model import BaseModel
|
||||
from torchvision.models.segmentation import deeplabv3_resnet50
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MMDeeplabV3(BaseModel):
|
||||
|
||||
def __init__(self, num_classes):
|
||||
super().__init__()
|
||||
self.deeplab = deeplabv3_resnet50()
|
||||
self.deeplab.classifier[4] = torch.nn.Conv2d(
|
||||
256, num_classes, kernel_size=(1, 1), stride=(1, 1))
|
||||
|
||||
def forward(self, imgs, data_samples=None, mode='tensor'):
|
||||
x = self.deeplab(imgs)['out']
|
||||
if mode == 'loss':
|
||||
return {'loss': F.cross_entropy(x, data_samples['labels'])}
|
||||
elif mode == 'predict':
|
||||
return x, data_samples
|
||||
```
|
||||
|
||||
## 使用 Runner 训练模型
|
||||
|
||||
在使用 Runner 进行训练之前,我们需要实现 IoU(交并比)指标来评估模型的性能。
|
||||
|
||||
```python
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
class IoU(BaseMetric):
|
||||
|
||||
def process(self, data_batch, data_samples):
|
||||
preds, labels = data_samples[0], data_samples[1]['labels']
|
||||
preds = torch.argmax(preds, dim=1)
|
||||
intersect = (labels == preds).sum()
|
||||
union = (torch.logical_or(preds, labels)).sum()
|
||||
iou = (intersect / union).cpu()
|
||||
self.results.append(
|
||||
dict(batch_size=len(labels), iou=iou * len(labels)))
|
||||
|
||||
def compute_metrics(self, results):
|
||||
total_iou = sum(result['iou'] for result in self.results)
|
||||
num_samples = sum(result['batch_size'] for result in self.results)
|
||||
return dict(iou=total_iou / num_samples)
|
||||
```
|
||||
|
||||
实现可视化钩子(Hook)也很重要,它可以便于更轻松地比较模型预测的好坏。
|
||||
|
||||
```python
|
||||
from mmengine.hooks import Hook
|
||||
import shutil
|
||||
import cv2
|
||||
import os.path as osp
|
||||
|
||||
|
||||
class SegVisHook(Hook):
|
||||
|
||||
def __init__(self, data_root, vis_num=1) -> None:
|
||||
super().__init__()
|
||||
self.vis_num = vis_num
|
||||
self.palette = create_palette(osp.join(data_root, 'class_dict.csv'))
|
||||
|
||||
def after_val_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch=None,
|
||||
outputs=None) -> None:
|
||||
if batch_idx > self.vis_num:
|
||||
return
|
||||
preds, data_samples = outputs
|
||||
img_paths = data_samples['img_path']
|
||||
mask_paths = data_samples['mask_path']
|
||||
_, C, H, W = preds.shape
|
||||
preds = torch.argmax(preds, dim=1)
|
||||
for idx, (pred, img_path,
|
||||
mask_path) in enumerate(zip(preds, img_paths, mask_paths)):
|
||||
pred_mask = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
runner.visualizer.set_image(pred_mask)
|
||||
for color, class_id in self.palette.items():
|
||||
runner.visualizer.draw_binary_masks(
|
||||
pred == class_id,
|
||||
colors=[color],
|
||||
alphas=1.0,
|
||||
)
|
||||
# Convert RGB to BGR
|
||||
pred_mask = runner.visualizer.get_image()[..., ::-1]
|
||||
saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx))
|
||||
os.makedirs(saved_dir, exist_ok=True)
|
||||
|
||||
shutil.copyfile(img_path,
|
||||
osp.join(saved_dir, osp.basename(img_path)))
|
||||
shutil.copyfile(mask_path,
|
||||
osp.join(saved_dir, osp.basename(mask_path)))
|
||||
cv2.imwrite(
|
||||
osp.join(saved_dir, f'pred_{osp.basename(img_path)}'),
|
||||
pred_mask)
|
||||
```
|
||||
|
||||
准备完毕,让我们用 Runner 开始训练吧!
|
||||
|
||||
```python
|
||||
from torch.optim import AdamW
|
||||
from mmengine.optim import AmpOptimWrapper
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
num_classes = 32 # Modify to actual number of categories.
|
||||
|
||||
runner = Runner(
|
||||
model=MMDeeplabV3(num_classes),
|
||||
work_dir='./work_dir',
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=dict(
|
||||
type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10),
|
||||
val_dataloader=val_dataloader,
|
||||
val_cfg=dict(),
|
||||
val_evaluator=dict(type=IoU),
|
||||
custom_hooks=[SegVisHook('data/CamVid')],
|
||||
default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)),
|
||||
)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
训练完成后,你可以在 `./work_dir/{timestamp}/vis_data` 文件夹中找到可视化结果,如下图所示:
|
||||
|
||||
<table class="docutils">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>原图</th>
|
||||
<th>预测结果</th>
|
||||
<th>标签</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th><img src="https://github.com/open-mmlab/mmengine/assets/57566630/de70c138-fb8e-402c-9497-574b01725b6c" width="200"></th>
|
||||
<th><img src="https://github.com/open-mmlab/mmengine/assets/57566630/ea9221e7-48ca-4515-8815-56b5ff091f53" width="200"></th>
|
||||
<th><img src="https://github.com/open-mmlab/mmengine/assets/57566630/dcb2324f-a2df-4e5c-a038-df896dde2471" width="200"></th>
|
||||
</tr>
|
||||
</thead>
|
||||
</table>
|
@ -15,6 +15,7 @@
|
||||
:caption: 示例
|
||||
|
||||
examples/train_a_gan.md
|
||||
examples/train_seg.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
35
examples/segmentation/README.md
Normal file
35
examples/segmentation/README.md
Normal file
@ -0,0 +1,35 @@
|
||||
# Train a Segmentation Model
|
||||
|
||||
## Download Camvid Dataset
|
||||
|
||||
First, you should get the collated Camvid dataset on OpenDataLab to use for the segmentation training example. The official download steps are shown below.
|
||||
|
||||
```bash
|
||||
# https://opendatalab.com/CamVid
|
||||
# Configure install
|
||||
pip install opendatalab
|
||||
# Upgraded version
|
||||
pip install -U opendatalab
|
||||
# Login
|
||||
odl login
|
||||
# Download this dataset
|
||||
mkdir data
|
||||
odl get CamVid -d data
|
||||
# Preprocess data in Linux. You should extract the files to data manually in
|
||||
# Windows
|
||||
tar -xzvf data/CamVid/raw/CamVid.tar.gz.00 -C ./data
|
||||
```
|
||||
|
||||
## Run the example
|
||||
|
||||
Single device training
|
||||
|
||||
```bash
|
||||
python examples/segmentation/segmentation_training.py
|
||||
```
|
||||
|
||||
Distributed data parallel training
|
||||
|
||||
```bash
|
||||
tochrun -nnodes 1 -nproc_per_node 8 examples/segmentation/segmentation_training.py --launcher pytorch
|
||||
```
|
385
examples/segmentation/train.ipynb
Normal file
385
examples/segmentation/train.ipynb
Normal file
@ -0,0 +1,385 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Segmentation Task Example\n",
|
||||
"\n",
|
||||
"This segmentation task example will be divided into the following steps:\n",
|
||||
"\n",
|
||||
"- [Download Camvid Dataset](#download-camvid-dataset)\n",
|
||||
"- [Implement Camvid Dataset](#implement-the-camvid-dataset)\n",
|
||||
"- [Implement a Segmentation Model](#implement-the-segmentation-model)\n",
|
||||
"- [Train with Runner](#training-with-runner)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Download Camvid Dataset\n",
|
||||
"\n",
|
||||
"First, you should download the Camvid dataset from OpenDataLab:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"# https://opendatalab.com/CamVid\n",
|
||||
"# Configure install\n",
|
||||
"pip install opendatalab\n",
|
||||
"# Upgraded version\n",
|
||||
"pip install -U opendatalab\n",
|
||||
"# Login\n",
|
||||
"odl login\n",
|
||||
"# Download this dataset\n",
|
||||
"mkdir data\n",
|
||||
"odl get CamVid -d data\n",
|
||||
"# Preprocess data in Linux. You should extract the files to data manually in\n",
|
||||
"# Windows\n",
|
||||
"tar -xzvf data/CamVid/raw/CamVid.tar.gz.00 -C ./data\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Implement the Camvid Dataset\n",
|
||||
"\n",
|
||||
"We have implemented the CamVid class here, which inherits from VisionDataset. Within this class, we have overridden the `__getitem__` and `__len__` methods to ensure that each index returns a dict of images and labels. Additionally, we have implemented the color_to_class dictionary to map the mask's color to the class index.\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"from torchvision.datasets import VisionDataset\n",
|
||||
"from PIL import Image\n",
|
||||
"import csv\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_palette(csv_filepath):\n",
|
||||
" color_to_class = {}\n",
|
||||
" with open(csv_filepath, newline='') as csvfile:\n",
|
||||
" reader = csv.DictReader(csvfile)\n",
|
||||
" for idx, row in enumerate(reader):\n",
|
||||
" r, g, b = int(row['r']), int(row['g']), int(row['b'])\n",
|
||||
" color_to_class[(r, g, b)] = idx\n",
|
||||
" return color_to_class\n",
|
||||
"\n",
|
||||
"class CamVid(VisionDataset):\n",
|
||||
"\n",
|
||||
" def __init__(self,\n",
|
||||
" root,\n",
|
||||
" img_folder,\n",
|
||||
" mask_folder,\n",
|
||||
" transform=None,\n",
|
||||
" target_transform=None):\n",
|
||||
" super().__init__(\n",
|
||||
" root, transform=transform, target_transform=target_transform)\n",
|
||||
" self.img_folder = img_folder\n",
|
||||
" self.mask_folder = mask_folder\n",
|
||||
" self.images = list(\n",
|
||||
" sorted(os.listdir(os.path.join(self.root, img_folder))))\n",
|
||||
" self.masks = list(\n",
|
||||
" sorted(os.listdir(os.path.join(self.root, mask_folder))))\n",
|
||||
" self.color_to_class = create_palette(\n",
|
||||
" os.path.join(self.root, 'class_dict.csv'))\n",
|
||||
"\n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" img_path = os.path.join(self.root, self.img_folder, self.images[index])\n",
|
||||
" mask_path = os.path.join(self.root, self.mask_folder,\n",
|
||||
" self.masks[index])\n",
|
||||
"\n",
|
||||
" img = Image.open(img_path).convert('RGB')\n",
|
||||
" mask = Image.open(mask_path).convert('RGB') # Convert to RGB\n",
|
||||
"\n",
|
||||
" if self.transform is not None:\n",
|
||||
" img = self.transform(img)\n",
|
||||
"\n",
|
||||
" # Convert the RGB values to class indices\n",
|
||||
" mask = np.array(mask)\n",
|
||||
" mask = mask[:, :, 0] * 65536 + mask[:, :, 1] * 256 + mask[:, :, 2]\n",
|
||||
" labels = np.zeros_like(mask, dtype=np.int64)\n",
|
||||
" for color, class_index in self.color_to_class.items():\n",
|
||||
" rgb = color[0] * 65536 + color[1] * 256 + color[2]\n",
|
||||
" labels[mask == rgb] = class_index\n",
|
||||
"\n",
|
||||
" if self.target_transform is not None:\n",
|
||||
" labels = self.target_transform(labels)\n",
|
||||
" data_samples = dict(\n",
|
||||
" labels=labels, img_path=img_path, mask_path=mask_path)\n",
|
||||
" return img, data_samples\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.images)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We utilize the Camvid dataset to create the `train_dataloader` and `val_dataloader`, which serve as the data loaders for training and validation in the subsequent Runner."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"\n",
|
||||
"norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
|
||||
"transform = transforms.Compose(\n",
|
||||
" [transforms.ToTensor(),\n",
|
||||
" transforms.Normalize(**norm_cfg)])\n",
|
||||
"\n",
|
||||
"target_transform = transforms.Lambda(\n",
|
||||
" lambda x: torch.tensor(np.array(x), dtype=torch.long))\n",
|
||||
"\n",
|
||||
"train_set = CamVid(\n",
|
||||
" 'data/CamVid',\n",
|
||||
" img_folder='train',\n",
|
||||
" mask_folder='train_labels',\n",
|
||||
" transform=transform,\n",
|
||||
" target_transform=target_transform)\n",
|
||||
"\n",
|
||||
"valid_set = CamVid(\n",
|
||||
" 'data/CamVid',\n",
|
||||
" img_folder='val',\n",
|
||||
" mask_folder='val_labels',\n",
|
||||
" transform=transform,\n",
|
||||
" target_transform=target_transform)\n",
|
||||
"\n",
|
||||
"train_dataloader = dict(\n",
|
||||
" batch_size=3,\n",
|
||||
" dataset=train_set,\n",
|
||||
" sampler=dict(type='DefaultSampler', shuffle=True),\n",
|
||||
" collate_fn=dict(type='default_collate'))\n",
|
||||
"\n",
|
||||
"val_dataloader = dict(\n",
|
||||
" batch_size=3,\n",
|
||||
" dataset=valid_set,\n",
|
||||
" sampler=dict(type='DefaultSampler', shuffle=False),\n",
|
||||
" collate_fn=dict(type='default_collate'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Implement the Segmentation Model\n",
|
||||
"\n",
|
||||
"The provided code defines a model class named `MMDeeplabV3`. This class is derived from `BaseModel` and incorporates the segmentation model of the DeepLabV3 architecture. It overrides the `forward` method to handle both input images and labels and supports computing losses and returning predictions in both training and prediction modes.\n",
|
||||
"\n",
|
||||
"For additional information about `BaseModel`, you can refer to the [Model tutorial](../tutorials/model.md)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mmengine.model import BaseModel\n",
|
||||
"from torchvision.models.segmentation import deeplabv3_resnet50\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class MMDeeplabV3(BaseModel):\n",
|
||||
"\n",
|
||||
" def __init__(self, num_classes):\n",
|
||||
" super().__init__()\n",
|
||||
" self.deeplab = deeplabv3_resnet50(num_classes=num_classes)\n",
|
||||
"\n",
|
||||
" def forward(self, imgs, data_samples=None, mode='tensor'):\n",
|
||||
" x = self.deeplab(imgs)['out']\n",
|
||||
" if mode == 'loss':\n",
|
||||
" return {'loss': F.cross_entropy(x, data_samples['labels'])}\n",
|
||||
" elif mode == 'predict':\n",
|
||||
" return x, data_samples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Training with Runner\n",
|
||||
"\n",
|
||||
"Before training with the Runner, we need to implement the IoU (Intersection over Union) metric to evaluate the model's performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mmengine.evaluator import BaseMetric\n",
|
||||
"\n",
|
||||
"class IoU(BaseMetric):\n",
|
||||
"\n",
|
||||
" def process(self, data_batch, data_samples):\n",
|
||||
" preds, labels = data_samples[0], data_samples[1]['labels']\n",
|
||||
" preds = torch.argmax(preds, dim=1)\n",
|
||||
" intersect = (labels == preds).sum()\n",
|
||||
" union = (torch.logical_or(preds, labels)).sum()\n",
|
||||
" iou = (intersect / union).cpu()\n",
|
||||
" self.results.append(\n",
|
||||
" dict(batch_size=len(labels), iou=iou * len(labels)))\n",
|
||||
"\n",
|
||||
" def compute_metrics(self, results):\n",
|
||||
" total_iou = sum(result['iou'] for result in self.results)\n",
|
||||
" num_samples = sum(result['batch_size'] for result in self.results)\n",
|
||||
" return dict(iou=total_iou / num_samples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Implementing a visualization hook is also important to facilitate easier comparison between predictions and labels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from mmengine.hooks import Hook\n",
|
||||
"import shutil\n",
|
||||
"import cv2\n",
|
||||
"import os.path as osp\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class SegVisHook(Hook):\n",
|
||||
"\n",
|
||||
" def __init__(self, data_root, vis_num=1) -> None:\n",
|
||||
" super().__init__()\n",
|
||||
" self.vis_num = vis_num\n",
|
||||
" self.palette = create_palette(osp.join(data_root, 'class_dict.csv'))\n",
|
||||
"\n",
|
||||
" def after_val_iter(self,\n",
|
||||
" runner,\n",
|
||||
" batch_idx: int,\n",
|
||||
" data_batch=None,\n",
|
||||
" outputs=None) -> None:\n",
|
||||
" if batch_idx > self.vis_num:\n",
|
||||
" return\n",
|
||||
" preds, data_samples = outputs\n",
|
||||
" img_paths = data_samples['img_path']\n",
|
||||
" mask_paths = data_samples['mask_path']\n",
|
||||
" _, C, H, W = preds.shape\n",
|
||||
" preds = torch.argmax(preds, dim=1)\n",
|
||||
" for idx, (pred, img_path,\n",
|
||||
" mask_path) in enumerate(zip(preds, img_paths, mask_paths)):\n",
|
||||
" pred_mask = np.zeros((H, W, 3), dtype=np.uint8)\n",
|
||||
" runner.visualizer.set_image(pred_mask)\n",
|
||||
" for color, class_id in self.palette.items():\n",
|
||||
" runner.visualizer.draw_binary_masks(\n",
|
||||
" pred == class_id,\n",
|
||||
" colors=[color],\n",
|
||||
" alphas=1.0,\n",
|
||||
" )\n",
|
||||
" # Convert RGB to BGR\n",
|
||||
" pred_mask = runner.visualizer.get_image()[..., ::-1]\n",
|
||||
" saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx))\n",
|
||||
" os.makedirs(saved_dir, exist_ok=True)\n",
|
||||
"\n",
|
||||
" shutil.copyfile(img_path,\n",
|
||||
" osp.join(saved_dir, osp.basename(img_path)))\n",
|
||||
" shutil.copyfile(mask_path,\n",
|
||||
" osp.join(saved_dir, osp.basename(mask_path)))\n",
|
||||
" cv2.imwrite(\n",
|
||||
" osp.join(saved_dir, f'pred_{osp.basename(img_path)}'),\n",
|
||||
" pred_mask)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finnaly, just train the model with Runner!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.optim import AdamW\n",
|
||||
"from mmengine.optim import AmpOptimWrapper\n",
|
||||
"from mmengine.runner import Runner\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"num_classes = 32 # Modify to actual number of categories.\n",
|
||||
"\n",
|
||||
"runner = Runner(\n",
|
||||
" model=MMDeeplabV3(num_classes),\n",
|
||||
" work_dir='./work_dir',\n",
|
||||
" train_dataloader=train_dataloader,\n",
|
||||
" optim_wrapper=dict(\n",
|
||||
" type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)),\n",
|
||||
" train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10),\n",
|
||||
" val_dataloader=val_dataloader,\n",
|
||||
" val_cfg=dict(),\n",
|
||||
" val_evaluator=dict(type=IoU),\n",
|
||||
" custom_hooks=[SegVisHook('data/CamVid')],\n",
|
||||
" default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)),\n",
|
||||
")\n",
|
||||
"runner.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finnaly, you can check the training results in the folder `./work_dir/{timestamp}/vis_data`.\n",
|
||||
"\n",
|
||||
"<table class=\"docutils\">\n",
|
||||
"<thead>\n",
|
||||
"<tr>\n",
|
||||
" <th>image</th>\n",
|
||||
" <th>prediction</th>\n",
|
||||
" <th>label</th>\n",
|
||||
"</tr>\n",
|
||||
"<tr>\n",
|
||||
" <th><img src=\"https://github.com/open-mmlab/mmengine/assets/57566630/de70c138-fb8e-402c-9497-574b01725b6c\" width=\"200\"></th>\n",
|
||||
" <th><img src=\"https://github.com/open-mmlab/mmengine/assets/57566630/ea9221e7-48ca-4515-8815-56b5ff091f53\" width=\"200\"></th>\n",
|
||||
" <th><img src=\"https://github.com/open-mmlab/mmengine/assets/57566630/dcb2324f-a2df-4e5c-a038-df896dde2471\" width=\"200\"></th>\n",
|
||||
"</tr>\n",
|
||||
"</thead>\n",
|
||||
"</table>"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "py310torch20",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
227
examples/segmentation/train.py
Normal file
227
examples/segmentation/train.py
Normal file
@ -0,0 +1,227 @@
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torch.optim import AdamW
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.models.segmentation import deeplabv3_resnet50
|
||||
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.optim import AmpOptimWrapper
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
def create_palette(csv_filepath):
|
||||
color_to_class = {}
|
||||
with open(csv_filepath, newline='') as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for idx, row in enumerate(reader):
|
||||
r, g, b = int(row['r']), int(row['g']), int(row['b'])
|
||||
color_to_class[(r, g, b)] = idx
|
||||
return color_to_class
|
||||
|
||||
|
||||
class CamVid(VisionDataset):
|
||||
|
||||
def __init__(self,
|
||||
root,
|
||||
img_folder,
|
||||
mask_folder,
|
||||
transform=None,
|
||||
target_transform=None):
|
||||
super().__init__(
|
||||
root, transform=transform, target_transform=target_transform)
|
||||
self.img_folder = img_folder
|
||||
self.mask_folder = mask_folder
|
||||
self.images = list(
|
||||
sorted(os.listdir(os.path.join(self.root, img_folder))))
|
||||
self.masks = list(
|
||||
sorted(os.listdir(os.path.join(self.root, mask_folder))))
|
||||
self.color_to_class = create_palette(
|
||||
os.path.join(self.root, 'class_dict.csv'))
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = os.path.join(self.root, self.img_folder, self.images[index])
|
||||
mask_path = os.path.join(self.root, self.mask_folder,
|
||||
self.masks[index])
|
||||
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
mask = Image.open(mask_path).convert('RGB') # Convert to RGB
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
# Convert the RGB values to class indices
|
||||
mask = np.array(mask)
|
||||
mask = mask[:, :, 0] * 65536 + mask[:, :, 1] * 256 + mask[:, :, 2]
|
||||
labels = np.zeros_like(mask, dtype=np.int64)
|
||||
for color, class_index in self.color_to_class.items():
|
||||
rgb = color[0] * 65536 + color[1] * 256 + color[2]
|
||||
labels[mask == rgb] = class_index
|
||||
|
||||
if self.target_transform is not None:
|
||||
labels = self.target_transform(labels)
|
||||
data_samples = dict(
|
||||
labels=labels, img_path=img_path, mask_path=mask_path)
|
||||
return img, data_samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
|
||||
class MMDeeplabV3(BaseModel):
|
||||
|
||||
def __init__(self, num_classes):
|
||||
super().__init__()
|
||||
self.deeplab = deeplabv3_resnet50(num_classes=num_classes)
|
||||
|
||||
def forward(self, imgs, data_samples=None, mode='tensor'):
|
||||
x = self.deeplab(imgs)['out']
|
||||
if mode == 'loss':
|
||||
return {'loss': F.cross_entropy(x, data_samples['labels'])}
|
||||
elif mode == 'predict':
|
||||
return x, data_samples
|
||||
|
||||
|
||||
class IoU(BaseMetric):
|
||||
|
||||
def process(self, data_batch, data_samples):
|
||||
preds, labels = data_samples[0], data_samples[1]['labels']
|
||||
preds = torch.argmax(preds, dim=1)
|
||||
intersect = (labels == preds).sum()
|
||||
union = (torch.logical_or(preds, labels)).sum()
|
||||
iou = (intersect / union).cpu()
|
||||
self.results.append(
|
||||
dict(batch_size=len(labels), iou=iou * len(labels)))
|
||||
|
||||
def compute_metrics(self, results):
|
||||
total_iou = sum(result['iou'] for result in self.results)
|
||||
num_samples = sum(result['batch_size'] for result in self.results)
|
||||
return dict(iou=total_iou / num_samples)
|
||||
|
||||
|
||||
class SegVisHook(Hook):
|
||||
|
||||
def __init__(self, data_root, vis_num=1) -> None:
|
||||
super().__init__()
|
||||
self.vis_num = vis_num
|
||||
self.palette = create_palette(osp.join(data_root, 'class_dict.csv'))
|
||||
|
||||
@master_only
|
||||
def after_val_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch=None,
|
||||
outputs=None) -> None:
|
||||
if batch_idx > self.vis_num:
|
||||
return
|
||||
|
||||
preds, data_samples = outputs
|
||||
img_paths = data_samples['img_path']
|
||||
mask_paths = data_samples['mask_path']
|
||||
_, C, H, W = preds.shape
|
||||
preds = torch.argmax(preds, dim=1)
|
||||
for idx, (pred, img_path,
|
||||
mask_path) in enumerate(zip(preds, img_paths, mask_paths)):
|
||||
pred_mask = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
runner.visualizer.set_image(pred_mask)
|
||||
for color, class_id in self.palette.items():
|
||||
runner.visualizer.draw_binary_masks(
|
||||
pred == class_id,
|
||||
colors=[color],
|
||||
alphas=1.0,
|
||||
)
|
||||
# Convert RGB to BGR
|
||||
pred_mask = runner.visualizer.get_image()[..., ::-1]
|
||||
saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx))
|
||||
os.makedirs(saved_dir, exist_ok=True)
|
||||
|
||||
shutil.copyfile(img_path,
|
||||
osp.join(saved_dir, osp.basename(img_path)))
|
||||
shutil.copyfile(mask_path,
|
||||
osp.join(saved_dir, osp.basename(mask_path)))
|
||||
cv2.imwrite(
|
||||
osp.join(saved_dir, f'pred_{osp.basename(img_path)}'),
|
||||
pred_mask)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Distributed Training')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
num_classes = 32 # Modify to actual number of categories.
|
||||
norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize(**norm_cfg)])
|
||||
|
||||
target_transform = transforms.Lambda(
|
||||
lambda x: torch.tensor(np.array(x), dtype=torch.long))
|
||||
|
||||
train_set = CamVid(
|
||||
'data/CamVid',
|
||||
img_folder='train',
|
||||
mask_folder='train_labels',
|
||||
transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
valid_set = CamVid(
|
||||
'data/CamVid',
|
||||
img_folder='val',
|
||||
mask_folder='val_labels',
|
||||
transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=3,
|
||||
dataset=train_set,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'))
|
||||
val_dataloader = dict(
|
||||
batch_size=3,
|
||||
dataset=valid_set,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
collate_fn=dict(type='default_collate'))
|
||||
|
||||
runner = Runner(
|
||||
model=MMDeeplabV3(num_classes),
|
||||
work_dir='./work_dir',
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=dict(
|
||||
type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10),
|
||||
val_dataloader=val_dataloader,
|
||||
val_cfg=dict(),
|
||||
val_evaluator=dict(type=IoU),
|
||||
launcher=args.launcher,
|
||||
custom_hooks=[SegVisHook('data/CamVid')],
|
||||
default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)),
|
||||
)
|
||||
runner.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user