import argparse import csv import os import os.path as osp import shutil import cv2 import numpy as np import torch import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image from torch.optim import AdamW from torchvision.datasets import VisionDataset from torchvision.models.segmentation import deeplabv3_resnet50 from mmengine.dist import master_only from mmengine.evaluator import BaseMetric from mmengine.hooks import Hook from mmengine.model import BaseModel from mmengine.optim import AmpOptimWrapper from mmengine.runner import Runner def create_palette(csv_filepath): color_to_class = {} with open(csv_filepath, newline='') as csvfile: reader = csv.DictReader(csvfile) for idx, row in enumerate(reader): r, g, b = int(row['r']), int(row['g']), int(row['b']) color_to_class[(r, g, b)] = idx return color_to_class class CamVid(VisionDataset): def __init__(self, root, img_folder, mask_folder, transform=None, target_transform=None): super().__init__( root, transform=transform, target_transform=target_transform) self.img_folder = img_folder self.mask_folder = mask_folder self.images = list( sorted(os.listdir(os.path.join(self.root, img_folder)))) self.masks = list( sorted(os.listdir(os.path.join(self.root, mask_folder)))) self.color_to_class = create_palette( os.path.join(self.root, 'class_dict.csv')) def __getitem__(self, index): img_path = os.path.join(self.root, self.img_folder, self.images[index]) mask_path = os.path.join(self.root, self.mask_folder, self.masks[index]) img = Image.open(img_path).convert('RGB') mask = Image.open(mask_path).convert('RGB') # Convert to RGB if self.transform is not None: img = self.transform(img) # Convert the RGB values to class indices mask = np.array(mask) mask = mask[:, :, 0] * 65536 + mask[:, :, 1] * 256 + mask[:, :, 2] labels = np.zeros_like(mask, dtype=np.int64) for color, class_index in self.color_to_class.items(): rgb = color[0] * 65536 + color[1] * 256 + color[2] labels[mask == rgb] = class_index if self.target_transform is not None: labels = self.target_transform(labels) data_samples = dict( labels=labels, img_path=img_path, mask_path=mask_path) return img, data_samples def __len__(self): return len(self.images) class MMDeeplabV3(BaseModel): def __init__(self, num_classes): super().__init__() self.deeplab = deeplabv3_resnet50(num_classes=num_classes) def forward(self, imgs, data_samples=None, mode='tensor'): x = self.deeplab(imgs)['out'] if mode == 'loss': return {'loss': F.cross_entropy(x, data_samples['labels'])} elif mode == 'predict': return x, data_samples class IoU(BaseMetric): def process(self, data_batch, data_samples): preds, labels = data_samples[0], data_samples[1]['labels'] preds = torch.argmax(preds, dim=1) intersect = (labels == preds).sum() union = (torch.logical_or(preds, labels)).sum() iou = (intersect / union).cpu() self.results.append( dict(batch_size=len(labels), iou=iou * len(labels))) def compute_metrics(self, results): total_iou = sum(result['iou'] for result in self.results) num_samples = sum(result['batch_size'] for result in self.results) return dict(iou=total_iou / num_samples) class SegVisHook(Hook): def __init__(self, data_root, vis_num=1) -> None: super().__init__() self.vis_num = vis_num self.palette = create_palette(osp.join(data_root, 'class_dict.csv')) @master_only def after_val_iter(self, runner, batch_idx: int, data_batch=None, outputs=None) -> None: if batch_idx > self.vis_num: return preds, data_samples = outputs img_paths = data_samples['img_path'] mask_paths = data_samples['mask_path'] _, C, H, W = preds.shape preds = torch.argmax(preds, dim=1) for idx, (pred, img_path, mask_path) in enumerate(zip(preds, img_paths, mask_paths)): pred_mask = np.zeros((H, W, 3), dtype=np.uint8) runner.visualizer.set_image(pred_mask) for color, class_id in self.palette.items(): runner.visualizer.draw_binary_masks( pred == class_id, colors=[color], alphas=1.0, ) # Convert RGB to BGR pred_mask = runner.visualizer.get_image()[..., ::-1] saved_dir = osp.join(runner.log_dir, 'vis_data', str(idx)) os.makedirs(saved_dir, exist_ok=True) shutil.copyfile(img_path, osp.join(saved_dir, osp.basename(img_path))) shutil.copyfile(mask_path, osp.join(saved_dir, osp.basename(mask_path))) cv2.imwrite( osp.join(saved_dir, f'pred_{osp.basename(img_path)}'), pred_mask) def parse_args(): parser = argparse.ArgumentParser(description='Distributed Training') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() return args def main(): args = parse_args() num_classes = 32 # Modify to actual number of categories. norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(**norm_cfg)]) target_transform = transforms.Lambda( lambda x: torch.tensor(np.array(x), dtype=torch.long)) train_set = CamVid( 'data/CamVid', img_folder='train', mask_folder='train_labels', transform=transform, target_transform=target_transform) valid_set = CamVid( 'data/CamVid', img_folder='val', mask_folder='val_labels', transform=transform, target_transform=target_transform) train_dataloader = dict( batch_size=3, dataset=train_set, sampler=dict(type='DefaultSampler', shuffle=True), collate_fn=dict(type='default_collate')) val_dataloader = dict( batch_size=3, dataset=valid_set, sampler=dict(type='DefaultSampler', shuffle=False), collate_fn=dict(type='default_collate')) runner = Runner( model=MMDeeplabV3(num_classes), work_dir='./work_dir', train_dataloader=train_dataloader, optim_wrapper=dict( type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)), train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10), val_dataloader=val_dataloader, val_cfg=dict(), val_evaluator=dict(type=IoU), launcher=args.launcher, custom_hooks=[SegVisHook('data/CamVid')], default_hooks=dict(checkpoint=dict(type='CheckpointHook', interval=1)), ) runner.train() if __name__ == '__main__': main()