mirror of
https://github.com/YifanXu74/MQ-Det.git
synced 2025-06-03 15:03:07 +08:00
135 lines
4.2 KiB
Python
135 lines
4.2 KiB
Python
import os
|
|
|
|
import torch
|
|
import torch.utils.data
|
|
from PIL import Image
|
|
import sys
|
|
|
|
if sys.version_info[0] == 2:
|
|
import xml.etree.cElementTree as ET
|
|
else:
|
|
import xml.etree.ElementTree as ET
|
|
|
|
|
|
from maskrcnn_benchmark.structures.bounding_box import BoxList
|
|
|
|
|
|
class PascalVOCDataset(torch.utils.data.Dataset):
|
|
|
|
CLASSES = (
|
|
"__background__ ",
|
|
"aeroplane",
|
|
"bicycle",
|
|
"bird",
|
|
"boat",
|
|
"bottle",
|
|
"bus",
|
|
"car",
|
|
"cat",
|
|
"chair",
|
|
"cow",
|
|
"diningtable",
|
|
"dog",
|
|
"horse",
|
|
"motorbike",
|
|
"person",
|
|
"pottedplant",
|
|
"sheep",
|
|
"sofa",
|
|
"train",
|
|
"tvmonitor",
|
|
)
|
|
|
|
def __init__(self, data_dir, split, use_difficult=False, transforms=None):
|
|
self.root = data_dir
|
|
self.image_set = split
|
|
self.keep_difficult = use_difficult
|
|
self.transforms = transforms
|
|
|
|
self._annopath = os.path.join(self.root, "Annotations", "%s.xml")
|
|
self._imgpath = os.path.join(self.root, "JPEGImages", "%s.jpg")
|
|
self._imgsetpath = os.path.join(self.root, "ImageSets", "Main", "%s.txt")
|
|
|
|
with open(self._imgsetpath % self.image_set) as f:
|
|
self.ids = f.readlines()
|
|
self.ids = [x.strip("\n") for x in self.ids]
|
|
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
|
|
|
|
cls = PascalVOCDataset.CLASSES
|
|
self.class_to_ind = dict(zip(cls, range(len(cls))))
|
|
|
|
def __getitem__(self, index):
|
|
img_id = self.ids[index]
|
|
img = Image.open(self._imgpath % img_id).convert("RGB")
|
|
|
|
target = self.get_groundtruth(index)
|
|
target = target.clip_to_image(remove_empty=True)
|
|
|
|
if self.transforms is not None:
|
|
img, target = self.transforms(img, target)
|
|
|
|
return img, target, index
|
|
|
|
def __len__(self):
|
|
return len(self.ids)
|
|
|
|
def get_groundtruth(self, index):
|
|
img_id = self.ids[index]
|
|
anno = ET.parse(self._annopath % img_id).getroot()
|
|
anno = self._preprocess_annotation(anno)
|
|
|
|
height, width = anno["im_info"]
|
|
target = BoxList(anno["boxes"], (width, height), mode="xyxy")
|
|
target.add_field("labels", anno["labels"])
|
|
target.add_field("difficult", anno["difficult"])
|
|
return target
|
|
|
|
def _preprocess_annotation(self, target):
|
|
boxes = []
|
|
gt_classes = []
|
|
difficult_boxes = []
|
|
TO_REMOVE = 1
|
|
|
|
for obj in target.iter("object"):
|
|
difficult = int(obj.find("difficult").text) == 1
|
|
if not self.keep_difficult and difficult:
|
|
continue
|
|
name = obj.find("name").text.lower().strip()
|
|
bb = obj.find("bndbox")
|
|
# Make pixel indexes 0-based
|
|
# Refer to "https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/pascal_voc.py#L208-L211"
|
|
box = [
|
|
bb.find("xmin").text,
|
|
bb.find("ymin").text,
|
|
bb.find("xmax").text,
|
|
bb.find("ymax").text,
|
|
]
|
|
bndbox = tuple(
|
|
map(lambda x: x - TO_REMOVE, list(map(int, box)))
|
|
)
|
|
|
|
boxes.append(bndbox)
|
|
gt_classes.append(self.class_to_ind[name])
|
|
difficult_boxes.append(difficult)
|
|
|
|
size = target.find("size")
|
|
im_info = tuple(map(int, (size.find("height").text, size.find("width").text)))
|
|
|
|
res = {
|
|
"boxes": torch.tensor(boxes, dtype=torch.float32),
|
|
"labels": torch.tensor(gt_classes),
|
|
"difficult": torch.tensor(difficult_boxes),
|
|
"im_info": im_info,
|
|
}
|
|
return res
|
|
|
|
def get_img_info(self, index):
|
|
img_id = self.ids[index]
|
|
anno = ET.parse(self._annopath % img_id).getroot()
|
|
size = anno.find("size")
|
|
im_info = tuple(map(int, (size.find("height").text, size.find("width").text)))
|
|
return {"height": im_info[0], "width": im_info[1]}
|
|
|
|
def map_class_id_to_class_name(self, class_id):
|
|
return PascalVOCDataset.CLASSES[class_id]
|