mirror of
https://github.com/YifanXu74/MQ-Det.git
synced 2025-06-03 15:03:07 +08:00
270 lines
9.0 KiB
Python
270 lines
9.0 KiB
Python
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
import json
|
|
import os
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
import pycocotools.mask as mask_utils
|
|
import torchvision
|
|
from PIL import Image
|
|
|
|
# from .coco import ConvertCocoPolysToMask, make_coco_transforms
|
|
from .modulated_coco import ConvertCocoPolysToMask
|
|
|
|
|
|
def _isArrayLike(obj):
|
|
return hasattr(obj, "__iter__") and hasattr(obj, "__len__")
|
|
|
|
|
|
class LVIS:
|
|
def __init__(self, annotation_path=None):
|
|
"""Class for reading and visualizing annotations.
|
|
Args:
|
|
annotation_path (str): location of annotation file
|
|
"""
|
|
self.anns = {}
|
|
self.cats = {}
|
|
self.imgs = {}
|
|
self.img_ann_map = defaultdict(list)
|
|
self.cat_img_map = defaultdict(list)
|
|
self.dataset = {}
|
|
|
|
if annotation_path is not None:
|
|
print("Loading annotations.")
|
|
|
|
tic = time.time()
|
|
self.dataset = self._load_json(annotation_path)
|
|
print("Done (t={:0.2f}s)".format(time.time() - tic))
|
|
|
|
assert type(self.dataset) == dict, "Annotation file format {} not supported.".format(type(self.dataset))
|
|
self._create_index()
|
|
|
|
def _load_json(self, path):
|
|
with open(path, "r") as f:
|
|
return json.load(f)
|
|
|
|
def _create_index(self):
|
|
print("Creating index.")
|
|
|
|
self.img_ann_map = defaultdict(list)
|
|
self.cat_img_map = defaultdict(list)
|
|
|
|
self.anns = {}
|
|
self.cats = {}
|
|
self.imgs = {}
|
|
|
|
for ann in self.dataset["annotations"]:
|
|
self.img_ann_map[ann["image_id"]].append(ann)
|
|
self.anns[ann["id"]] = ann
|
|
|
|
for img in self.dataset["images"]:
|
|
self.imgs[img["id"]] = img
|
|
|
|
for cat in self.dataset["categories"]:
|
|
self.cats[cat["id"]] = cat
|
|
|
|
for ann in self.dataset["annotations"]:
|
|
self.cat_img_map[ann["category_id"]].append(ann["image_id"])
|
|
|
|
print("Index created.")
|
|
|
|
def get_ann_ids(self, img_ids=None, cat_ids=None, area_rng=None):
|
|
"""Get ann ids that satisfy given filter conditions.
|
|
Args:
|
|
img_ids (int array): get anns for given imgs
|
|
cat_ids (int array): get anns for given cats
|
|
area_rng (float array): get anns for a given area range. e.g [0, inf]
|
|
Returns:
|
|
ids (int array): integer array of ann ids
|
|
"""
|
|
if img_ids is not None:
|
|
img_ids = img_ids if _isArrayLike(img_ids) else [img_ids]
|
|
if cat_ids is not None:
|
|
cat_ids = cat_ids if _isArrayLike(cat_ids) else [cat_ids]
|
|
anns = []
|
|
if img_ids is not None:
|
|
for img_id in img_ids:
|
|
anns.extend(self.img_ann_map[img_id])
|
|
else:
|
|
anns = self.dataset["annotations"]
|
|
|
|
# return early if no more filtering required
|
|
if cat_ids is None and area_rng is None:
|
|
return [_ann["id"] for _ann in anns]
|
|
|
|
cat_ids = set(cat_ids)
|
|
|
|
if area_rng is None:
|
|
area_rng = [0, float("inf")]
|
|
|
|
ann_ids = [
|
|
_ann["id"]
|
|
for _ann in anns
|
|
if _ann["category_id"] in cat_ids and _ann["area"] > area_rng[0] and _ann["area"] < area_rng[1]
|
|
]
|
|
return ann_ids
|
|
|
|
def get_cat_ids(self):
|
|
"""Get all category ids.
|
|
Returns:
|
|
ids (int array): integer array of category ids
|
|
"""
|
|
return list(self.cats.keys())
|
|
|
|
def get_img_ids(self):
|
|
"""Get all img ids.
|
|
Returns:
|
|
ids (int array): integer array of image ids
|
|
"""
|
|
return list(self.imgs.keys())
|
|
|
|
def _load_helper(self, _dict, ids):
|
|
if ids is None:
|
|
return list(_dict.values())
|
|
elif _isArrayLike(ids):
|
|
return [_dict[id] for id in ids]
|
|
else:
|
|
return [_dict[ids]]
|
|
|
|
def load_anns(self, ids=None):
|
|
"""Load anns with the specified ids. If ids=None load all anns.
|
|
Args:
|
|
ids (int array): integer array of annotation ids
|
|
Returns:
|
|
anns (dict array) : loaded annotation objects
|
|
"""
|
|
return self._load_helper(self.anns, ids)
|
|
|
|
def load_cats(self, ids):
|
|
"""Load categories with the specified ids. If ids=None load all
|
|
categories.
|
|
Args:
|
|
ids (int array): integer array of category ids
|
|
Returns:
|
|
cats (dict array) : loaded category dicts
|
|
"""
|
|
return self._load_helper(self.cats, ids)
|
|
|
|
def load_imgs(self, ids):
|
|
"""Load categories with the specified ids. If ids=None load all images.
|
|
Args:
|
|
ids (int array): integer array of image ids
|
|
Returns:
|
|
imgs (dict array) : loaded image dicts
|
|
"""
|
|
return self._load_helper(self.imgs, ids)
|
|
|
|
def download(self, save_dir, img_ids=None):
|
|
"""Download images from mscoco.org server.
|
|
Args:
|
|
save_dir (str): dir to save downloaded images
|
|
img_ids (int array): img ids of images to download
|
|
"""
|
|
imgs = self.load_imgs(img_ids)
|
|
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
|
|
for img in imgs:
|
|
file_name = os.path.join(save_dir, img["file_name"])
|
|
if not os.path.exists(file_name):
|
|
from urllib.request import urlretrieve
|
|
|
|
urlretrieve(img["coco_url"], file_name)
|
|
|
|
def ann_to_rle(self, ann):
|
|
"""Convert annotation which can be polygons, uncompressed RLE to RLE.
|
|
Args:
|
|
ann (dict) : annotation object
|
|
Returns:
|
|
ann (rle)
|
|
"""
|
|
img_data = self.imgs[ann["image_id"]]
|
|
h, w = img_data["height"], img_data["width"]
|
|
segm = ann["segmentation"]
|
|
if isinstance(segm, list):
|
|
# polygon -- a single object might consist of multiple parts
|
|
# we merge all parts into one mask rle code
|
|
rles = mask_utils.frPyObjects(segm, h, w)
|
|
rle = mask_utils.merge(rles)
|
|
elif isinstance(segm["counts"], list):
|
|
# uncompressed RLE
|
|
rle = mask_utils.frPyObjects(segm, h, w)
|
|
else:
|
|
# rle
|
|
rle = ann["segmentation"]
|
|
return rle
|
|
|
|
def ann_to_mask(self, ann):
|
|
"""Convert annotation which can be polygons, uncompressed RLE, or RLE
|
|
to binary mask.
|
|
Args:
|
|
ann (dict) : annotation object
|
|
Returns:
|
|
binary mask (numpy 2D array)
|
|
"""
|
|
rle = self.ann_to_rle(ann)
|
|
return mask_utils.decode(rle)
|
|
|
|
|
|
class LvisDetectionBase(torchvision.datasets.VisionDataset):
|
|
def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
|
|
super(LvisDetectionBase, self).__init__(root, transforms, transform, target_transform)
|
|
self.lvis = LVIS(annFile)
|
|
self.ids = list(sorted(self.lvis.imgs.keys()))
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
Args:
|
|
index (int): Index
|
|
Returns:
|
|
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
|
|
"""
|
|
lvis = self.lvis
|
|
img_id = self.ids[index]
|
|
ann_ids = lvis.get_ann_ids(img_ids=img_id)
|
|
target = lvis.load_anns(ann_ids)
|
|
|
|
path = "/".join(self.lvis.load_imgs(img_id)[0]["coco_url"].split("/")[-2:])
|
|
|
|
img = Image.open(os.path.join(self.root, path)).convert("RGB")
|
|
if self.transforms is not None:
|
|
img, target = self.transforms(img, target)
|
|
|
|
return img, target
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.ids)
|
|
|
|
|
|
class LvisDetection(LvisDetectionBase):
|
|
def __init__(self, img_folder, ann_file, transforms, return_masks=False, cumtom_ids=None, **kwargs):
|
|
super(LvisDetection, self).__init__(img_folder, ann_file)
|
|
self.ann_file = ann_file
|
|
self._transforms = transforms
|
|
self.prepare = ConvertCocoPolysToMask(return_masks)
|
|
if cumtom_ids is not None:
|
|
self.ids = cumtom_ids
|
|
|
|
def __getitem__(self, idx):
|
|
img, target = super(LvisDetection, self).__getitem__(idx)
|
|
image_id = self.ids[idx]
|
|
target = {"image_id": image_id, "annotations": target}
|
|
img, target = self.prepare(img, target)
|
|
if self._transforms is not None:
|
|
img = self._transforms(img)
|
|
return img, target, idx
|
|
|
|
def get_raw_image(self, idx):
|
|
img, target = super(LvisDetection, self).__getitem__(idx)
|
|
return img
|
|
|
|
def categories(self):
|
|
id2cat = {c["id"]: c for c in self.lvis.dataset["categories"]}
|
|
all_cats = sorted(list(id2cat.keys()))
|
|
categories = {}
|
|
for l in list(all_cats):
|
|
categories[l] = id2cat[l]['name']
|
|
return categories |