{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Train a Segmentation Model\n", "\n", "This segmentation task example will be divided into the following steps:\n", "\n", "- [Download Camvid Dataset](#download-camvid-dataset)\n", "- [Implement Camvid Dataset](#implement-the-camvid-dataset)\n", "- [Implement a Segmentation Model](#implement-the-segmentation-model)\n", "- [Train with Runner](#training-with-runner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download Camvid Dataset\n", "\n", "First, you should download the Camvid dataset from OpenDataLab:\n", "\n", "```bash\n", "# https://opendatalab.com/CamVid\n", "# Configure install\n", "pip install opendatalab\n", "# Upgraded version\n", "pip install -U opendatalab\n", "# Login\n", "odl login\n", "# Download this dataset\n", "mkdir data\n", "odl get CamVid -d data\n", "# Preprocess data in Linux. You should extract the files to data manually in\n", "# Windows\n", "tar -xzvf data/CamVid/raw/CamVid.tar.gz.00 -C ./data\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Implement the Camvid Dataset\n", "\n", "We have implemented the CamVid class here, which inherits from VisionDataset. Within this class, we have overridden the `__getitem__` and `__len__` methods to ensure that each index returns a dict of images and labels. Additionally, we have implemented the color_to_class dictionary to map the mask's color to the class index.\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "from torchvision.datasets import VisionDataset\n", "from PIL import Image\n", "import csv\n", "\n", "\n", "def create_palette(csv_filepath):\n", " color_to_class = {}\n", " with open(csv_filepath, newline='') as csvfile:\n", " reader = csv.DictReader(csvfile)\n", " for idx, row in enumerate(reader):\n", " r, g, b = int(row['r']), int(row['g']), int(row['b'])\n", " color_to_class[(r, g, b)] = idx\n", " return color_to_class\n", "\n", "class CamVid(VisionDataset):\n", "\n", " def __init__(self,\n", " root,\n", " img_folder,\n", " mask_folder,\n", " transform=None,\n", " target_transform=None):\n", " super().__init__(\n", " root, transform=transform, target_transform=target_transform)\n", " self.img_folder = img_folder\n", " self.mask_folder = mask_folder\n", " self.images = list(\n", " sorted(os.listdir(os.path.join(self.root, img_folder))))\n", " self.masks = list(\n", " sorted(os.listdir(os.path.join(self.root, mask_folder))))\n", " self.color_to_class = create_palette(\n", " os.path.join(self.root, 'class_dict.csv'))\n", "\n", " def __getitem__(self, index):\n", " img_path = os.path.join(self.root, self.img_folder, self.images[index])\n", " mask_path = os.path.join(self.root, self.mask_folder,\n", " self.masks[index])\n", "\n", " img = Image.open(img_path).convert('RGB')\n", " mask = Image.open(mask_path).convert('RGB') # Convert to RGB\n", "\n", " if self.transform is not None:\n", " img = self.transform(img)\n", "\n", " # Convert the RGB values to class indices\n", " mask = np.array(mask)\n", " mask = mask[:, :, 0] * 65536 + mask[:, :, 1] * 256 + mask[:, :, 2]\n", " labels = np.zeros_like(mask, dtype=np.int64)\n", " for color, class_index in self.color_to_class.items():\n", " rgb = color[0] * 65536 + color[1] * 256 + color[2]\n", " labels[mask == rgb] = class_index\n", "\n", " if self.target_transform is not None:\n", " labels = self.target_transform(labels)\n", " data_samples = dict(\n", " labels=labels, img_path=img_path, mask_path=mask_path)\n", " return img, data_samples\n", "\n", " def __len__(self):\n", " return len(self.images)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We utilize the Camvid dataset to create the `train_dataloader` and `val_dataloader`, which serve as the data loaders for training and validation in the subsequent Runner." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchvision.transforms as transforms\n", "\n", "norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", "transform = transforms.Compose(\n", " [transforms.ToTensor(),\n", " transforms.Normalize(**norm_cfg)])\n", "\n", "target_transform = transforms.Lambda(\n", " lambda x: torch.tensor(np.array(x), dtype=torch.long))\n", "\n", "train_set = CamVid(\n", " 'data/CamVid',\n", " img_folder='train',\n", " mask_folder='train_labels',\n", " transform=transform,\n", " target_transform=target_transform)\n", "\n", "valid_set = CamVid(\n", " 'data/CamVid',\n", " img_folder='val',\n", " mask_folder='val_labels',\n", " transform=transform,\n", " target_transform=target_transform)\n", "\n", "train_dataloader = dict(\n", " batch_size=3,\n", " dataset=train_set,\n", " sampler=dict(type='DefaultSampler', shuffle=True),\n", " collate_fn=dict(type='default_collate'))\n", "\n", "val_dataloader = dict(\n", " batch_size=3,\n", " dataset=valid_set,\n", " sampler=dict(type='DefaultSampler', shuffle=False),\n", " collate_fn=dict(type='default_collate'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Implement the Segmentation Model\n", "\n", "The provided code defines a model class named `MMDeeplabV3`. This class is derived from `BaseModel` and incorporates the segmentation model of the DeepLabV3 architecture. It overrides the `forward` method to handle both input images and labels and supports computing losses and returning predictions in both training and prediction modes.\n", "\n", "For additional information about `BaseModel`, you can refer to the [Model tutorial](../tutorials/model.md)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from mmengine.model import BaseModel\n", "from torchvision.models.segmentation import deeplabv3_resnet50\n", "import torch.nn.functional as F\n", "\n", "\n", "class MMDeeplabV3(BaseModel):\n", "\n", " def __init__(self, num_classes):\n", " super().__init__()\n", " self.deeplab = deeplabv3_resnet50(num_classes=num_classes)\n", "\n", " def forward(self, imgs, data_samples=None, mode='tensor'):\n", " x = self.deeplab(imgs)['out']\n", " if mode == 'loss':\n", " return {'loss': F.cross_entropy(x, data_samples['labels'])}\n", " elif mode == 'predict':\n", " return x, data_samples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training with Runner\n", "\n", "Before training with the Runner, we need to implement the IoU (Intersection over Union) metric to evaluate the model's performance." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from mmengine.evaluator import BaseMetric\n", "\n", "class IoU(BaseMetric):\n", "\n", " def process(self, data_batch, data_samples):\n", " preds, labels = data_samples[0], data_samples[1]['labels']\n", " preds = torch.argmax(preds, dim=1)\n", " intersect = (labels == preds).sum()\n", " union = (torch.logical_or(preds, labels)).sum()\n", " iou = (intersect / union).cpu()\n", " self.results.append(\n", " dict(batch_size=len(labels), iou=iou * len(labels)))\n", "\n", " def compute_metrics(self, results):\n", " total_iou = sum(result['iou'] for result in self.results)\n", " num_samples = sum(result['batch_size'] for result in self.results)\n", " return dict(iou=total_iou / num_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementing a visualization hook is also important to facilitate easier comparison between predictions and labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from mmengine.hooks import Hook\n", "import shutil\n", "import cv2\n", "import os.path as osp\n", "\n", "\n", "class SegVisHook(Hook):\n", "\n", " def __init__(self, data_root, vis_num=1) -> None:\n", " super().__init__()\n", " self.vis_num = vis_num\n", " self.palette = create_palette(osp.join(data_root, 'class_dict.csv'))\n", "\n", " def after_val_iter(self,\n", " runner,\n", " batch_idx: int,\n", " data_batch=None,\n", " outputs=None) -> None:\n", " if batch_idx > self.vis_num:\n", " return\n", " preds, data_samples = outputs\n", " img_paths = data_samples['img_path']\n", " mask_paths = data_samples['mask_path']\n", " _, C, H, W = preds.shape\n", " preds = torch.argmax(preds, dim=1)\n", " for idx, (pred, img_path,\n", " mask_path) in enumerate(zip(preds, img_paths, mask_paths)):\n", " pred_mask = np.zeros((H, W, 3), dtype=np.uint8)\n", " runner.visualizer.set_image(pred_mask)\n", " for color, class_id in self.palette.items():\n", " runner.visualizer.draw_binary_masks(\n", " pred == class_id,\n", " colors=[color],\n", " alphas=1.0,\n", " )\n", " # Convert RGB to BGR\n", " pred_mask = runner.visualizer.get_image()[..., ::-1]\n", " saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx))\n", " os.makedirs(saved_dir, exist_ok=True)\n", "\n", " shutil.copyfile(img_path,\n", " osp.join(saved_dir, osp.basename(img_path)))\n", " shutil.copyfile(mask_path,\n", " osp.join(saved_dir, osp.basename(mask_path)))\n", " cv2.imwrite(\n", " osp.join(saved_dir, f'pred_{osp.basename(img_path)}'),\n", " pred_mask)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finnaly, just train the model with Runner!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.optim import AdamW\n", "from mmengine.optim import AmpOptimWrapper\n", "from mmengine.runner import Runner\n", "\n", "\n", "num_classes = 32 # Modify to actual number of categories.\n", "\n", "runner = Runner(\n", " model=MMDeeplabV3(num_classes),\n", " work_dir='./work_dir',\n", " train_dataloader=train_dataloader,\n", " optim_wrapper=dict(\n", " type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)),\n", " train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10),\n", " val_dataloader=val_dataloader,\n", " val_cfg=dict(),\n", " val_evaluator=dict(type=IoU),\n", " custom_hooks=[SegVisHook('data/CamVid')],\n", " default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)),\n", ")\n", "runner.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finnaly, you can check the training results in the folder `./work_dir/{timestamp}/vis_data`.\n", "\n", "<table class=\"docutils\">\n", "<thead>\n", "<tr>\n", " <th>image</th>\n", " <th>prediction</th>\n", " <th>label</th>\n", "</tr>\n", "<tr>\n", " <th><img src=\"https://github.com/open-mmlab/mmengine/assets/57566630/de70c138-fb8e-402c-9497-574b01725b6c\" width=\"200\"></th>\n", " <th><img src=\"https://github.com/open-mmlab/mmengine/assets/57566630/ea9221e7-48ca-4515-8815-56b5ff091f53\" width=\"200\"></th>\n", " <th><img src=\"https://github.com/open-mmlab/mmengine/assets/57566630/dcb2324f-a2df-4e5c-a038-df896dde2471\" width=\"200\"></th>\n", "</tr>\n", "</thead>\n", "</table>" ] } ], "metadata": { "kernelspec": { "display_name": "py310torch20", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }