mirror of https://github.com/RE-OWOD/RE-OWOD
191 lines
6.7 KiB
Python
191 lines
6.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
|
|
import contextlib
|
|
import io
|
|
import numpy as np
|
|
import unittest
|
|
from collections import defaultdict
|
|
import torch
|
|
import tqdm
|
|
from fvcore.common.benchmark import benchmark
|
|
from fvcore.common.file_io import PathManager
|
|
from pycocotools.coco import COCO
|
|
from tabulate import tabulate
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.data import MetadataCatalog
|
|
from detectron2.layers.mask_ops import (
|
|
pad_masks,
|
|
paste_mask_in_image_old,
|
|
paste_masks_in_image,
|
|
scale_boxes,
|
|
)
|
|
from detectron2.structures import BitMasks, Boxes, BoxMode, PolygonMasks
|
|
from detectron2.structures.masks import polygons_to_bitmask
|
|
|
|
|
|
def iou_between_full_image_bit_masks(a, b):
|
|
intersect = (a & b).sum()
|
|
union = (a | b).sum()
|
|
return intersect / union
|
|
|
|
|
|
def rasterize_polygons_with_grid_sample(full_image_bit_mask, box, mask_size, threshold=0.5):
|
|
x0, y0, x1, y1 = box[0], box[1], box[2], box[3]
|
|
|
|
img_h, img_w = full_image_bit_mask.shape
|
|
|
|
mask_y = np.arange(0.0, mask_size) + 0.5 # mask y sample coords in [0.5, mask_size - 0.5]
|
|
mask_x = np.arange(0.0, mask_size) + 0.5 # mask x sample coords in [0.5, mask_size - 0.5]
|
|
mask_y = mask_y / mask_size * (y1 - y0) + y0
|
|
mask_x = mask_x / mask_size * (x1 - x0) + x0
|
|
|
|
mask_x = (mask_x - 0.5) / (img_w - 1) * 2 + -1
|
|
mask_y = (mask_y - 0.5) / (img_h - 1) * 2 + -1
|
|
gy, gx = torch.meshgrid(torch.from_numpy(mask_y), torch.from_numpy(mask_x))
|
|
ind = torch.stack([gx, gy], dim=-1).to(dtype=torch.float32)
|
|
|
|
full_image_bit_mask = torch.from_numpy(full_image_bit_mask)
|
|
mask = F.grid_sample(
|
|
full_image_bit_mask[None, None, :, :].to(dtype=torch.float32),
|
|
ind[None, :, :, :],
|
|
align_corners=True,
|
|
)
|
|
|
|
return mask[0, 0] >= threshold
|
|
|
|
|
|
class TestMaskCropPaste(unittest.TestCase):
|
|
def setUp(self):
|
|
json_file = MetadataCatalog.get("coco_2017_val_100").json_file
|
|
if not PathManager.isfile(json_file):
|
|
raise unittest.SkipTest("{} not found".format(json_file))
|
|
with contextlib.redirect_stdout(io.StringIO()):
|
|
json_file = PathManager.get_local_path(json_file)
|
|
self.coco = COCO(json_file)
|
|
|
|
def test_crop_paste_consistency(self):
|
|
"""
|
|
rasterize_polygons_within_box (used in training)
|
|
and
|
|
paste_masks_in_image (used in inference)
|
|
should be inverse operations to each other.
|
|
|
|
This function runs several implementation of the above two operations and prints
|
|
the reconstruction error.
|
|
"""
|
|
|
|
anns = self.coco.loadAnns(self.coco.getAnnIds(iscrowd=False)) # avoid crowd annotations
|
|
|
|
selected_anns = anns[:100]
|
|
|
|
ious = []
|
|
for ann in tqdm.tqdm(selected_anns):
|
|
results = self.process_annotation(ann)
|
|
ious.append([k[2] for k in results])
|
|
|
|
ious = np.array(ious)
|
|
mean_ious = ious.mean(axis=0)
|
|
table = []
|
|
res_dic = defaultdict(dict)
|
|
for row, iou in zip(results, mean_ious):
|
|
table.append((row[0], row[1], iou))
|
|
res_dic[row[0]][row[1]] = iou
|
|
print(tabulate(table, headers=["rasterize", "paste", "iou"], tablefmt="simple"))
|
|
# assert that the reconstruction is good:
|
|
self.assertTrue(res_dic["polygon"]["aligned"] > 0.94)
|
|
self.assertTrue(res_dic["roialign"]["aligned"] > 0.95)
|
|
|
|
def process_annotation(self, ann, mask_side_len=28):
|
|
# Parse annotation data
|
|
img_info = self.coco.loadImgs(ids=[ann["image_id"]])[0]
|
|
height, width = img_info["height"], img_info["width"]
|
|
gt_polygons = [np.array(p, dtype=np.float64) for p in ann["segmentation"]]
|
|
gt_bbox = BoxMode.convert(ann["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
|
|
gt_bit_mask = polygons_to_bitmask(gt_polygons, height, width)
|
|
|
|
# Run rasterize ..
|
|
torch_gt_bbox = torch.tensor(gt_bbox).to(dtype=torch.float32).reshape(-1, 4)
|
|
box_bitmasks = {
|
|
"polygon": PolygonMasks([gt_polygons]).crop_and_resize(torch_gt_bbox, mask_side_len)[0],
|
|
"gridsample": rasterize_polygons_with_grid_sample(gt_bit_mask, gt_bbox, mask_side_len),
|
|
"roialign": BitMasks(torch.from_numpy(gt_bit_mask[None, :, :])).crop_and_resize(
|
|
torch_gt_bbox, mask_side_len
|
|
)[0],
|
|
}
|
|
|
|
# Run paste ..
|
|
results = defaultdict(dict)
|
|
for k, box_bitmask in box_bitmasks.items():
|
|
padded_bitmask, scale = pad_masks(box_bitmask[None, :, :], 1)
|
|
scaled_boxes = scale_boxes(torch_gt_bbox, scale)
|
|
|
|
r = results[k]
|
|
r["old"] = paste_mask_in_image_old(
|
|
padded_bitmask[0], scaled_boxes[0], height, width, threshold=0.5
|
|
)
|
|
r["aligned"] = paste_masks_in_image(
|
|
box_bitmask[None, :, :], Boxes(torch_gt_bbox), (height, width)
|
|
)[0]
|
|
|
|
table = []
|
|
for rasterize_method, r in results.items():
|
|
for paste_method, mask in r.items():
|
|
mask = np.asarray(mask)
|
|
iou = iou_between_full_image_bit_masks(gt_bit_mask.astype("uint8"), mask)
|
|
table.append((rasterize_method, paste_method, iou))
|
|
return table
|
|
|
|
def test_polygon_area(self):
|
|
# Draw polygon boxes
|
|
for d in [5.0, 10.0, 1000.0]:
|
|
polygon = PolygonMasks([[[0, 0, 0, d, d, d, d, 0]]])
|
|
area = polygon.area()[0]
|
|
target = d ** 2
|
|
self.assertEqual(area, target)
|
|
|
|
# Draw polygon triangles
|
|
for d in [5.0, 10.0, 1000.0]:
|
|
polygon = PolygonMasks([[[0, 0, 0, d, d, d]]])
|
|
area = polygon.area()[0]
|
|
target = d ** 2 / 2
|
|
self.assertEqual(area, target)
|
|
|
|
|
|
def benchmark_paste():
|
|
S = 800
|
|
H, W = image_shape = (S, S)
|
|
N = 64
|
|
torch.manual_seed(42)
|
|
masks = torch.rand(N, 28, 28)
|
|
|
|
center = torch.rand(N, 2) * 600 + 100
|
|
wh = torch.clamp(torch.randn(N, 2) * 40 + 200, min=50)
|
|
x0y0 = torch.clamp(center - wh * 0.5, min=0.0)
|
|
x1y1 = torch.clamp(center + wh * 0.5, max=S)
|
|
boxes = Boxes(torch.cat([x0y0, x1y1], axis=1))
|
|
|
|
def func(device, n=3):
|
|
m = masks.to(device=device)
|
|
b = boxes.to(device=device)
|
|
|
|
def bench():
|
|
for _ in range(n):
|
|
paste_masks_in_image(m, b, image_shape)
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
|
|
return bench
|
|
|
|
specs = [{"device": torch.device("cpu"), "n": 3}]
|
|
if torch.cuda.is_available():
|
|
specs.append({"device": torch.device("cuda"), "n": 3})
|
|
|
|
benchmark(func, "paste_masks", specs, num_iters=10, warmup_iters=2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
benchmark_paste()
|
|
unittest.main()
|