mmengine/examples/segmentation/train.ipynb

386 lines
13 KiB
Plaintext
Raw Normal View History

2023-08-03 15:27:58 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
2023-08-07 22:53:57 +08:00
"# Train a Segmentation Model\n",
2023-08-03 15:27:58 +08:00
"\n",
"This segmentation task example will be divided into the following steps:\n",
"\n",
"- [Download Camvid Dataset](#download-camvid-dataset)\n",
"- [Implement Camvid Dataset](#implement-the-camvid-dataset)\n",
"- [Implement a Segmentation Model](#implement-the-segmentation-model)\n",
"- [Train with Runner](#training-with-runner)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download Camvid Dataset\n",
"\n",
"First, you should download the Camvid dataset from OpenDataLab:\n",
"\n",
"```bash\n",
"# https://opendatalab.com/CamVid\n",
"# Configure install\n",
"pip install opendatalab\n",
"# Upgraded version\n",
"pip install -U opendatalab\n",
"# Login\n",
"odl login\n",
"# Download this dataset\n",
"mkdir data\n",
"odl get CamVid -d data\n",
"# Preprocess data in Linux. You should extract the files to data manually in\n",
"# Windows\n",
"tar -xzvf data/CamVid/raw/CamVid.tar.gz.00 -C ./data\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Implement the Camvid Dataset\n",
"\n",
"We have implemented the CamVid class here, which inherits from VisionDataset. Within this class, we have overridden the `__getitem__` and `__len__` methods to ensure that each index returns a dict of images and labels. Additionally, we have implemented the color_to_class dictionary to map the mask's color to the class index.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import numpy as np\n",
"from torchvision.datasets import VisionDataset\n",
"from PIL import Image\n",
"import csv\n",
"\n",
"\n",
"def create_palette(csv_filepath):\n",
" color_to_class = {}\n",
" with open(csv_filepath, newline='') as csvfile:\n",
" reader = csv.DictReader(csvfile)\n",
" for idx, row in enumerate(reader):\n",
" r, g, b = int(row['r']), int(row['g']), int(row['b'])\n",
" color_to_class[(r, g, b)] = idx\n",
" return color_to_class\n",
"\n",
"class CamVid(VisionDataset):\n",
"\n",
" def __init__(self,\n",
" root,\n",
" img_folder,\n",
" mask_folder,\n",
" transform=None,\n",
" target_transform=None):\n",
" super().__init__(\n",
" root, transform=transform, target_transform=target_transform)\n",
" self.img_folder = img_folder\n",
" self.mask_folder = mask_folder\n",
" self.images = list(\n",
" sorted(os.listdir(os.path.join(self.root, img_folder))))\n",
" self.masks = list(\n",
" sorted(os.listdir(os.path.join(self.root, mask_folder))))\n",
" self.color_to_class = create_palette(\n",
" os.path.join(self.root, 'class_dict.csv'))\n",
"\n",
" def __getitem__(self, index):\n",
" img_path = os.path.join(self.root, self.img_folder, self.images[index])\n",
" mask_path = os.path.join(self.root, self.mask_folder,\n",
" self.masks[index])\n",
"\n",
" img = Image.open(img_path).convert('RGB')\n",
" mask = Image.open(mask_path).convert('RGB') # Convert to RGB\n",
"\n",
" if self.transform is not None:\n",
" img = self.transform(img)\n",
"\n",
" # Convert the RGB values to class indices\n",
" mask = np.array(mask)\n",
" mask = mask[:, :, 0] * 65536 + mask[:, :, 1] * 256 + mask[:, :, 2]\n",
" labels = np.zeros_like(mask, dtype=np.int64)\n",
" for color, class_index in self.color_to_class.items():\n",
" rgb = color[0] * 65536 + color[1] * 256 + color[2]\n",
" labels[mask == rgb] = class_index\n",
"\n",
" if self.target_transform is not None:\n",
" labels = self.target_transform(labels)\n",
" data_samples = dict(\n",
" labels=labels, img_path=img_path, mask_path=mask_path)\n",
" return img, data_samples\n",
"\n",
" def __len__(self):\n",
" return len(self.images)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We utilize the Camvid dataset to create the `train_dataloader` and `val_dataloader`, which serve as the data loaders for training and validation in the subsequent Runner."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision.transforms as transforms\n",
"\n",
"norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
"transform = transforms.Compose(\n",
" [transforms.ToTensor(),\n",
" transforms.Normalize(**norm_cfg)])\n",
"\n",
"target_transform = transforms.Lambda(\n",
" lambda x: torch.tensor(np.array(x), dtype=torch.long))\n",
"\n",
"train_set = CamVid(\n",
" 'data/CamVid',\n",
" img_folder='train',\n",
" mask_folder='train_labels',\n",
" transform=transform,\n",
" target_transform=target_transform)\n",
"\n",
"valid_set = CamVid(\n",
" 'data/CamVid',\n",
" img_folder='val',\n",
" mask_folder='val_labels',\n",
" transform=transform,\n",
" target_transform=target_transform)\n",
"\n",
"train_dataloader = dict(\n",
" batch_size=3,\n",
" dataset=train_set,\n",
" sampler=dict(type='DefaultSampler', shuffle=True),\n",
" collate_fn=dict(type='default_collate'))\n",
"\n",
"val_dataloader = dict(\n",
" batch_size=3,\n",
" dataset=valid_set,\n",
" sampler=dict(type='DefaultSampler', shuffle=False),\n",
" collate_fn=dict(type='default_collate'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Implement the Segmentation Model\n",
"\n",
"The provided code defines a model class named `MMDeeplabV3`. This class is derived from `BaseModel` and incorporates the segmentation model of the DeepLabV3 architecture. It overrides the `forward` method to handle both input images and labels and supports computing losses and returning predictions in both training and prediction modes.\n",
"\n",
"For additional information about `BaseModel`, you can refer to the [Model tutorial](../tutorials/model.md)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mmengine.model import BaseModel\n",
"from torchvision.models.segmentation import deeplabv3_resnet50\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"class MMDeeplabV3(BaseModel):\n",
"\n",
" def __init__(self, num_classes):\n",
" super().__init__()\n",
" self.deeplab = deeplabv3_resnet50(num_classes=num_classes)\n",
"\n",
" def forward(self, imgs, data_samples=None, mode='tensor'):\n",
" x = self.deeplab(imgs)['out']\n",
" if mode == 'loss':\n",
" return {'loss': F.cross_entropy(x, data_samples['labels'])}\n",
" elif mode == 'predict':\n",
" return x, data_samples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training with Runner\n",
"\n",
"Before training with the Runner, we need to implement the IoU (Intersection over Union) metric to evaluate the model's performance."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mmengine.evaluator import BaseMetric\n",
"\n",
"class IoU(BaseMetric):\n",
"\n",
" def process(self, data_batch, data_samples):\n",
" preds, labels = data_samples[0], data_samples[1]['labels']\n",
" preds = torch.argmax(preds, dim=1)\n",
" intersect = (labels == preds).sum()\n",
" union = (torch.logical_or(preds, labels)).sum()\n",
" iou = (intersect / union).cpu()\n",
" self.results.append(\n",
" dict(batch_size=len(labels), iou=iou * len(labels)))\n",
"\n",
" def compute_metrics(self, results):\n",
" total_iou = sum(result['iou'] for result in self.results)\n",
" num_samples = sum(result['batch_size'] for result in self.results)\n",
" return dict(iou=total_iou / num_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Implementing a visualization hook is also important to facilitate easier comparison between predictions and labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mmengine.hooks import Hook\n",
"import shutil\n",
"import cv2\n",
"import os.path as osp\n",
"\n",
"\n",
"class SegVisHook(Hook):\n",
"\n",
" def __init__(self, data_root, vis_num=1) -> None:\n",
" super().__init__()\n",
" self.vis_num = vis_num\n",
" self.palette = create_palette(osp.join(data_root, 'class_dict.csv'))\n",
"\n",
" def after_val_iter(self,\n",
" runner,\n",
" batch_idx: int,\n",
" data_batch=None,\n",
" outputs=None) -> None:\n",
" if batch_idx > self.vis_num:\n",
" return\n",
" preds, data_samples = outputs\n",
" img_paths = data_samples['img_path']\n",
" mask_paths = data_samples['mask_path']\n",
" _, C, H, W = preds.shape\n",
" preds = torch.argmax(preds, dim=1)\n",
" for idx, (pred, img_path,\n",
" mask_path) in enumerate(zip(preds, img_paths, mask_paths)):\n",
" pred_mask = np.zeros((H, W, 3), dtype=np.uint8)\n",
" runner.visualizer.set_image(pred_mask)\n",
" for color, class_id in self.palette.items():\n",
" runner.visualizer.draw_binary_masks(\n",
" pred == class_id,\n",
" colors=[color],\n",
" alphas=1.0,\n",
" )\n",
" # Convert RGB to BGR\n",
" pred_mask = runner.visualizer.get_image()[..., ::-1]\n",
" saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx))\n",
" os.makedirs(saved_dir, exist_ok=True)\n",
"\n",
" shutil.copyfile(img_path,\n",
" osp.join(saved_dir, osp.basename(img_path)))\n",
" shutil.copyfile(mask_path,\n",
" osp.join(saved_dir, osp.basename(mask_path)))\n",
" cv2.imwrite(\n",
" osp.join(saved_dir, f'pred_{osp.basename(img_path)}'),\n",
" pred_mask)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finnaly, just train the model with Runner!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.optim import AdamW\n",
"from mmengine.optim import AmpOptimWrapper\n",
"from mmengine.runner import Runner\n",
"\n",
"\n",
"num_classes = 32 # Modify to actual number of categories.\n",
"\n",
"runner = Runner(\n",
" model=MMDeeplabV3(num_classes),\n",
" work_dir='./work_dir',\n",
" train_dataloader=train_dataloader,\n",
" optim_wrapper=dict(\n",
" type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)),\n",
" train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10),\n",
" val_dataloader=val_dataloader,\n",
" val_cfg=dict(),\n",
" val_evaluator=dict(type=IoU),\n",
" custom_hooks=[SegVisHook('data/CamVid')],\n",
" default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)),\n",
")\n",
"runner.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finnaly, you can check the training results in the folder `./work_dir/{timestamp}/vis_data`.\n",
"\n",
"<table class=\"docutils\">\n",
"<thead>\n",
"<tr>\n",
" <th>image</th>\n",
" <th>prediction</th>\n",
" <th>label</th>\n",
"</tr>\n",
"<tr>\n",
" <th><img src=\"https://github.com/open-mmlab/mmengine/assets/57566630/de70c138-fb8e-402c-9497-574b01725b6c\" width=\"200\"></th>\n",
" <th><img src=\"https://github.com/open-mmlab/mmengine/assets/57566630/ea9221e7-48ca-4515-8815-56b5ff091f53\" width=\"200\"></th>\n",
" <th><img src=\"https://github.com/open-mmlab/mmengine/assets/57566630/dcb2324f-a2df-4e5c-a038-df896dde2471\" width=\"200\"></th>\n",
"</tr>\n",
"</thead>\n",
"</table>"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py310torch20",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}