# Train a Segmentation Model This segmentation task example will be divided into the following steps: - [Download Camvid Dataset](#download-camvid-dataset) - [Implement Camvid Dataset](#implement-the-camvid-dataset) - [Implement a Segmentation Model](#implement-the-segmentation-model) - [Train with Runner](#training-with-runner) ```{note} You can also experience the notebook example [here](https://colab.research.google.com/github/open-mmlab/mmengine/blob/main/examples/segmentation/train.ipynb). ``` ## Download Camvid Dataset First, you should download the Camvid dataset from OpenDataLab: ```bash # https://opendatalab.com/CamVid # Configure install pip install opendatalab # Upgraded version pip install -U opendatalab # Login odl login # Download this dataset mkdir data odl get CamVid -d data # Preprocess data in Linux. You should extract the files to data manually in # Windows tar -xzvf data/CamVid/raw/CamVid.tar.gz.00 -C ./data ``` ## Implement the Camvid Dataset We have implemented the CamVid class here, which inherits from VisionDataset. Within this class, we have overridden the `__getitem__` and `__len__` methods to ensure that each index returns a dict of images and labels. Additionally, we have implemented the color_to_class dictionary to map the mask's color to the class index. ```python import os import numpy as np from torchvision.datasets import VisionDataset from PIL import Image import csv def create_palette(csv_filepath): color_to_class = {} with open(csv_filepath, newline='') as csvfile: reader = csv.DictReader(csvfile) for idx, row in enumerate(reader): r, g, b = int(row['r']), int(row['g']), int(row['b']) color_to_class[(r, g, b)] = idx return color_to_class class CamVid(VisionDataset): def __init__(self, root, img_folder, mask_folder, transform=None, target_transform=None): super().__init__( root, transform=transform, target_transform=target_transform) self.img_folder = img_folder self.mask_folder = mask_folder self.images = list( sorted(os.listdir(os.path.join(self.root, img_folder)))) self.masks = list( sorted(os.listdir(os.path.join(self.root, mask_folder)))) self.color_to_class = create_palette( os.path.join(self.root, 'class_dict.csv')) def __getitem__(self, index): img_path = os.path.join(self.root, self.img_folder, self.images[index]) mask_path = os.path.join(self.root, self.mask_folder, self.masks[index]) img = Image.open(img_path).convert('RGB') mask = Image.open(mask_path).convert('RGB') # Convert to RGB if self.transform is not None: img = self.transform(img) # Convert the RGB values to class indices mask = np.array(mask) mask = mask[:, :, 0] * 65536 + mask[:, :, 1] * 256 + mask[:, :, 2] labels = np.zeros_like(mask, dtype=np.int64) for color, class_index in self.color_to_class.items(): rgb = color[0] * 65536 + color[1] * 256 + color[2] labels[mask == rgb] = class_index if self.target_transform is not None: labels = self.target_transform(labels) data_samples = dict( labels=labels, img_path=img_path, mask_path=mask_path) return img, data_samples def __len__(self): return len(self.images) ``` We utilize the Camvid dataset to create the `train_dataloader` and `val_dataloader`, which serve as the data loaders for training and validation in the subsequent Runner. ```python import torch import torchvision.transforms as transforms norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(**norm_cfg)]) target_transform = transforms.Lambda( lambda x: torch.tensor(np.array(x), dtype=torch.long)) train_set = CamVid( 'data/CamVid', img_folder='train', mask_folder='train_labels', transform=transform, target_transform=target_transform) valid_set = CamVid( 'data/CamVid', img_folder='val', mask_folder='val_labels', transform=transform, target_transform=target_transform) train_dataloader = dict( batch_size=3, dataset=train_set, sampler=dict(type='DefaultSampler', shuffle=True), collate_fn=dict(type='default_collate')) val_dataloader = dict( batch_size=3, dataset=valid_set, sampler=dict(type='DefaultSampler', shuffle=False), collate_fn=dict(type='default_collate')) ``` ## Implement the Segmentation Model The provided code defines a model class named `MMDeeplabV3`. This class is derived from `BaseModel` and incorporates the segmentation model of the DeepLabV3 architecture. It overrides the `forward` method to handle both input images and labels and supports computing losses and returning predictions in both training and prediction modes. For additional information about `BaseModel`, you can refer to the [Model tutorial](../tutorials/model.md). ```python from mmengine.model import BaseModel from torchvision.models.segmentation import deeplabv3_resnet50 import torch.nn.functional as F class MMDeeplabV3(BaseModel): def __init__(self, num_classes): super().__init__() self.deeplab = deeplabv3_resnet50(num_classes=num_classes) def forward(self, imgs, data_samples=None, mode='tensor'): x = self.deeplab(imgs)['out'] if mode == 'loss': return {'loss': F.cross_entropy(x, data_samples['labels'])} elif mode == 'predict': return x, data_samples ``` ## Training with Runner Before training with the Runner, we need to implement the IoU (Intersection over Union) metric to evaluate the model's performance. ```python from mmengine.evaluator import BaseMetric class IoU(BaseMetric): def process(self, data_batch, data_samples): preds, labels = data_samples[0], data_samples[1]['labels'] preds = torch.argmax(preds, dim=1) intersect = (labels == preds).sum() union = (torch.logical_or(preds, labels)).sum() iou = (intersect / union).cpu() self.results.append( dict(batch_size=len(labels), iou=iou * len(labels))) def compute_metrics(self, results): total_iou = sum(result['iou'] for result in self.results) num_samples = sum(result['batch_size'] for result in self.results) return dict(iou=total_iou / num_samples) ``` Implementing a visualization hook is also important to facilitate easier comparison between predictions and labels. ```python from mmengine.hooks import Hook import shutil import cv2 import os.path as osp class SegVisHook(Hook): def __init__(self, data_root, vis_num=1) -> None: super().__init__() self.vis_num = vis_num self.palette = create_palette(osp.join(data_root, 'class_dict.csv')) def after_val_iter(self, runner, batch_idx: int, data_batch=None, outputs=None) -> None: if batch_idx > self.vis_num: return preds, data_samples = outputs img_paths = data_samples['img_path'] mask_paths = data_samples['mask_path'] _, C, H, W = preds.shape preds = torch.argmax(preds, dim=1) for idx, (pred, img_path, mask_path) in enumerate(zip(preds, img_paths, mask_paths)): pred_mask = np.zeros((H, W, 3), dtype=np.uint8) runner.visualizer.set_image(pred_mask) for color, class_id in self.palette.items(): runner.visualizer.draw_binary_masks( pred == class_id, colors=[color], alphas=1.0, ) # Convert RGB to BGR pred_mask = runner.visualizer.get_image()[..., ::-1] saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx)) os.makedirs(saved_dir, exist_ok=True) shutil.copyfile(img_path, osp.join(saved_dir, osp.basename(img_path))) shutil.copyfile(mask_path, osp.join(saved_dir, osp.basename(mask_path))) cv2.imwrite( osp.join(saved_dir, f'pred_{osp.basename(img_path)}'), pred_mask) ``` Finnaly, just train the model with Runner! ```python from torch.optim import AdamW from mmengine.optim import AmpOptimWrapper from mmengine.runner import Runner num_classes = 32 # Modify to actual number of categories. runner = Runner( model=MMDeeplabV3(num_classes), work_dir='./work_dir', train_dataloader=train_dataloader, optim_wrapper=dict( type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)), train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10), val_dataloader=val_dataloader, val_cfg=dict(), val_evaluator=dict(type=IoU), custom_hooks=[SegVisHook('data/CamVid')], default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)), ) runner.train() ``` Finnaly, you can check the training results in the folder `./work_dir/{timestamp}/vis_data`.
image | prediction | label |
---|---|---|