Add files via upload

main
RE-OWOD 2022-01-04 17:27:22 +08:00 committed by GitHub
parent 39a272f3d3
commit fc1065333b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 653 additions and 0 deletions

View File

@ -0,0 +1,103 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import numpy as np
import os
import tempfile
import unittest
import pycocotools.mask as mask_util
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets.coco import convert_to_coco_dict, load_coco_json
from detectron2.structures import BoxMode
def make_mask():
"""
Makes a donut shaped binary mask.
"""
H = 100
W = 100
mask = np.zeros([H, W], dtype=np.uint8)
for x in range(W):
for y in range(H):
d = np.linalg.norm(np.array([W, H]) / 2 - np.array([x, y]))
if d > 10 and d < 20:
mask[y, x] = 1
return mask
def uncompressed_rle(mask):
l = mask.flatten(order="F").tolist()
counts = []
p = False
cnt = 0
for i in l:
if i == p:
cnt += 1
else:
counts.append(cnt)
p = i
cnt = 1
counts.append(cnt)
return {"counts": counts, "size": [mask.shape[0], mask.shape[1]]}
def make_dataset_dicts(mask, compressed: bool = True):
"""
Returns a list of dicts that represents a single COCO data point for
object detection. The single instance given by `mask` is represented by
RLE, either compressed or uncompressed.
"""
record = {}
record["file_name"] = "test"
record["image_id"] = 0
record["height"] = mask.shape[0]
record["width"] = mask.shape[1]
y, x = np.nonzero(mask)
if compressed:
segmentation = mask_util.encode(np.asarray(mask, order="F"))
else:
segmentation = uncompressed_rle(mask)
min_x = np.min(x)
max_x = np.max(x)
min_y = np.min(y)
max_y = np.max(y)
obj = {
"bbox": [min_x, min_y, max_x, max_y],
"bbox_mode": BoxMode.XYXY_ABS,
"category_id": 0,
"iscrowd": 0,
"segmentation": segmentation,
}
record["annotations"] = [obj]
return [record]
class TestRLEToJson(unittest.TestCase):
def test(self):
# Make a dummy dataset.
mask = make_mask()
DatasetCatalog.register("test_dataset", lambda: make_dataset_dicts(mask))
MetadataCatalog.get("test_dataset").set(thing_classes=["test_label"])
# Dump to json.
json_dict = convert_to_coco_dict("test_dataset")
with tempfile.TemporaryDirectory() as tmpdir:
json_file_name = os.path.join(tmpdir, "test.json")
with open(json_file_name, "w") as f:
json.dump(json_dict, f)
# Load from json.
dicts = load_coco_json(json_file_name, "")
# Check the loaded mask matches the original.
anno = dicts[0]["annotations"][0]
loaded_mask = mask_util.decode(anno["segmentation"])
self.assertTrue(np.array_equal(loaded_mask, mask))
def test_uncompressed_RLE(self):
mask = make_mask()
rle = mask_util.encode(np.asarray(mask, order="F"))
uncompressed = uncompressed_rle(mask)
compressed = mask_util.frPyObjects(uncompressed, *rle["size"])
self.assertEqual(rle, compressed)

View File

@ -0,0 +1,120 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib
import copy
import io
import json
import numpy as np
import os
import tempfile
import unittest
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from detectron2.evaluation.fast_eval_api import COCOeval_opt
class TestCOCOeval(unittest.TestCase):
def test(self):
# A small set of images/categories from COCO val
# fmt: off
detections = [{"image_id": 139, "category_id": 1, "bbox": [417.3332824707031, 159.27003479003906, 47.66064453125, 143.00193786621094], "score": 0.9949821829795837, "segmentation": {"size": [426, 640], "counts": "Tc`52W=3N0N4aNN^E7]:4XE1g:8kDMT;U100000001O1gE[Nk8h1dFiNY9Z1aFkN]9g2J3NdN`FlN`9S1cFRN07]9g1bFoM6;X9c1cFoM=8R9g1bFQN>3U9Y30O01OO1O001N2O1N1O4L4L5UNoE3V:CVF6Q:@YF9l9@ZF<k9[O`F=];HYnX2"}}, {"image_id": 139, "category_id": 1, "bbox": [383.5909118652344, 172.0777587890625, 17.959075927734375, 36.94813537597656], "score": 0.7685421705245972, "segmentation": {"size": [426, 640], "counts": "lZP5m0Z<300O100O100000001O00]OlC0T<OnCOT<OnCNX<JnC2bQT3"}}, {"image_id": 139, "category_id": 1, "bbox": [457.8359069824219, 158.88027954101562, 9.89764404296875, 8.771820068359375], "score": 0.07092753797769547, "segmentation": {"size": [426, 640], "counts": "bSo54T=2N2O1001O006ImiW2"}}] # noqa
gt_annotations = {"categories": [{"supercategory": "person", "id": 1, "name": "person"}, {"supercategory": "furniture", "id": 65, "name": "bed"}], "images": [{"license": 4, "file_name": "000000000285.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000000285.jpg", "height": 640, "width": 586, "date_captured": "2013-11-18 13:09:47", "flickr_url": "http://farm8.staticflickr.com/7434/9138147604_c6225224b8_z.jpg", "id": 285}, {"license": 2, "file_name": "000000000139.jpg", "coco_url": "http://images.cocodataset.org/val2017/000000000139.jpg", "height": 426, "width": 640, "date_captured": "2013-11-21 01:34:01", "flickr_url": "http://farm9.staticflickr.com/8035/8024364858_9c41dc1666_z.jpg", "id": 139}], "annotations": [{"segmentation": [[428.19, 219.47, 430.94, 209.57, 430.39, 210.12, 421.32, 216.17, 412.8, 217.27, 413.9, 214.24, 422.42, 211.22, 429.29, 201.6, 430.67, 181.8, 430.12, 175.2, 427.09, 168.06, 426.27, 164.21, 430.94, 159.26, 440.29, 157.61, 446.06, 163.93, 448.53, 168.06, 448.53, 173.01, 449.08, 174.93, 454.03, 185.1, 455.41, 188.4, 458.43, 195.0, 460.08, 210.94, 462.28, 226.61, 460.91, 233.76, 454.31, 234.04, 460.08, 256.85, 462.56, 268.13, 465.58, 290.67, 465.85, 293.14, 463.38, 295.62, 452.66, 295.34, 448.26, 294.52, 443.59, 282.7, 446.06, 235.14, 446.34, 230.19, 438.09, 232.39, 438.09, 221.67, 434.24, 221.12, 427.09, 219.74]], "area": 2913.1103999999987, "iscrowd": 0, "image_id": 139, "bbox": [412.8, 157.61, 53.05, 138.01], "category_id": 1, "id": 230831}, {"segmentation": [[384.98, 206.58, 384.43, 199.98, 385.25, 193.66, 385.25, 190.08, 387.18, 185.13, 387.18, 182.93, 386.08, 181.01, 385.25, 178.81, 385.25, 175.79, 388.0, 172.76, 394.88, 172.21, 398.72, 173.31, 399.27, 176.06, 399.55, 183.48, 397.9, 185.68, 395.15, 188.98, 396.8, 193.38, 398.45, 194.48, 399.0, 205.75, 395.43, 207.95, 388.83, 206.03]], "area": 435.1449499999997, "iscrowd": 0, "image_id": 139, "bbox": [384.43, 172.21, 15.12, 35.74], "category_id": 1, "id": 233201}]} # noqa
# fmt: on
# Test a small dataset for typical COCO format
experiments = {"full": (detections, gt_annotations, {})}
# Test what happens if the list of detections or ground truth annotations is empty
experiments["empty_dt"] = ([], gt_annotations, {})
gt = copy.deepcopy(gt_annotations)
gt["annotations"] = []
experiments["empty_gt"] = (detections, gt, {})
# Test changing parameter settings
experiments["no_categories"] = (detections, gt_annotations, {"useCats": 0})
experiments["no_ious"] = (detections, gt_annotations, {"iouThrs": []})
experiments["no_rec_thrs"] = (detections, gt_annotations, {"recThrs": []})
experiments["no_max_dets"] = (detections, gt_annotations, {"maxDets": []})
experiments["one_max_det"] = (detections, gt_annotations, {"maxDets": [1]})
experiments["no_area"] = (detections, gt_annotations, {"areaRng": [], "areaRngLbl": []})
# Test what happens if one omits different fields from the annotation structure
annotation_fields = [
"id",
"image_id",
"category_id",
"score",
"area",
"iscrowd",
"ignore",
"bbox",
"segmentation",
]
for a in annotation_fields:
gt = copy.deepcopy(gt_annotations)
for g in gt["annotations"]:
if a in g:
del g[a]
dt = copy.deepcopy(detections)
for d in dt:
if a in d:
del d[a]
experiments["omit_gt_" + a] = (detections, gt, {})
experiments["omit_dt_" + a] = (dt, gt_annotations, {})
# Compare precision/recall for original COCO PythonAPI to custom optimized one
for name, (dt, gt, params) in experiments.items():
# Dump to json.
try:
with tempfile.TemporaryDirectory() as tmpdir:
json_file_name = os.path.join(tmpdir, "gt_" + name + ".json")
with open(json_file_name, "w") as f:
json.dump(gt, f)
with contextlib.redirect_stdout(io.StringIO()):
coco_api = COCO(json_file_name)
except Exception:
pass
for iou_type in ["bbox", "segm", "keypoints"]:
# Run original COCOeval PythonAPI
api_exception = None
try:
with contextlib.redirect_stdout(io.StringIO()):
coco_dt = coco_api.loadRes(dt)
coco_eval = COCOeval(coco_api, coco_dt, iou_type)
for p, v in params.items():
setattr(coco_eval.params, p, v)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
except Exception as ex:
api_exception = ex
# Run optimized COCOeval_opt API
opt_exception = None
try:
with contextlib.redirect_stdout(io.StringIO()):
coco_dt = coco_api.loadRes(dt)
coco_eval_opt = COCOeval_opt(coco_api, coco_dt, iou_type)
for p, v in params.items():
setattr(coco_eval_opt.params, p, v)
coco_eval_opt.evaluate()
coco_eval_opt.accumulate()
coco_eval_opt.summarize()
except Exception as ex:
opt_exception = ex
if api_exception is not None and opt_exception is not None:
# Original API and optimized API should throw the same exception if annotation
# format is bad
api_error = "" if api_exception is None else type(api_exception).__name__
opt_error = "" if opt_exception is None else type(opt_exception).__name__
msg = "%s: comparing COCO APIs, '%s' != '%s'" % (name, api_error, opt_error)
self.assertTrue(api_error == opt_error, msg=msg)
else:
# Original API and optimized API should produce the same precision/recalls
for k in ["precision", "recall"]:
diff = np.abs(coco_eval.eval[k] - coco_eval_opt.eval[k])
abs_diff = np.max(diff) if diff.size > 0 else 0.0
msg = "%s: comparing COCO APIs, %s differs by %f" % (name, k, abs_diff)
self.assertTrue(abs_diff < 1e-4, msg=msg)

View File

@ -0,0 +1,156 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import copy
import numpy as np
import os
import unittest
import pycocotools.mask as mask_util
from fvcore.common.file_io import PathManager
from detectron2.data import MetadataCatalog, detection_utils
from detectron2.data import transforms as T
from detectron2.structures import BitMasks, BoxMode
class TestTransformAnnotations(unittest.TestCase):
def test_transform_simple_annotation(self):
transforms = T.TransformList([T.HFlipTransform(400)])
anno = {
"bbox": np.asarray([10, 10, 200, 300]),
"bbox_mode": BoxMode.XYXY_ABS,
"category_id": 3,
"segmentation": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],
}
output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))
self.assertTrue(np.allclose(output["bbox"], [200, 10, 390, 300]))
self.assertEqual(len(output["segmentation"]), len(anno["segmentation"]))
self.assertTrue(np.allclose(output["segmentation"][0], [390, 10, 300, 100, 300, 10]))
detection_utils.annotations_to_instances([output, output], (400, 400))
def test_flip_keypoints(self):
transforms = T.TransformList([T.HFlipTransform(400)])
anno = {
"bbox": np.asarray([10, 10, 200, 300]),
"bbox_mode": BoxMode.XYXY_ABS,
"keypoints": np.random.rand(17, 3) * 50 + 15,
}
output = detection_utils.transform_instance_annotations(
copy.deepcopy(anno),
transforms,
(400, 400),
keypoint_hflip_indices=detection_utils.create_keypoint_hflip_indices(
["keypoints_coco_2017_train"]
),
)
# The first keypoint is nose
self.assertTrue(np.allclose(output["keypoints"][0, 0], 400 - anno["keypoints"][0, 0]))
# The last 16 keypoints are 8 left-right pairs
self.assertTrue(
np.allclose(
output["keypoints"][1:, 0].reshape(-1, 2)[:, ::-1],
400 - anno["keypoints"][1:, 0].reshape(-1, 2),
)
)
self.assertTrue(
np.allclose(
output["keypoints"][1:, 1:].reshape(-1, 2, 2)[:, ::-1, :],
anno["keypoints"][1:, 1:].reshape(-1, 2, 2),
)
)
def test_crop(self):
transforms = T.TransformList([T.CropTransform(300, 300, 10, 10)])
keypoints = np.random.rand(17, 3) * 50 + 15
keypoints[:, 2] = 2
anno = {
"bbox": np.asarray([10, 10, 200, 400]),
"bbox_mode": BoxMode.XYXY_ABS,
"keypoints": keypoints,
}
output = detection_utils.transform_instance_annotations(
copy.deepcopy(anno), transforms, (10, 10)
)
# box is shifted and cropped
self.assertTrue((output["bbox"] == np.asarray([0, 0, 0, 10])).all())
# keypoints are no longer visible
self.assertTrue((output["keypoints"][:, 2] == 0).all())
def test_transform_RLE(self):
transforms = T.TransformList([T.HFlipTransform(400)])
mask = np.zeros((300, 400), order="F").astype("uint8")
mask[:, :200] = 1
anno = {
"bbox": np.asarray([10, 10, 200, 300]),
"bbox_mode": BoxMode.XYXY_ABS,
"segmentation": mask_util.encode(mask[:, :, None])[0],
"category_id": 3,
}
output = detection_utils.transform_instance_annotations(
copy.deepcopy(anno), transforms, (300, 400)
)
mask = output["segmentation"]
self.assertTrue((mask[:, 200:] == 1).all())
self.assertTrue((mask[:, :200] == 0).all())
inst = detection_utils.annotations_to_instances(
[output, output], (400, 400), mask_format="bitmask"
)
self.assertTrue(isinstance(inst.gt_masks, BitMasks))
def test_transform_RLE_resize(self):
transforms = T.TransformList(
[T.HFlipTransform(400), T.ScaleTransform(300, 400, 400, 400, "bilinear")]
)
mask = np.zeros((300, 400), order="F").astype("uint8")
mask[:, :200] = 1
anno = {
"bbox": np.asarray([10, 10, 200, 300]),
"bbox_mode": BoxMode.XYXY_ABS,
"segmentation": mask_util.encode(mask[:, :, None])[0],
"category_id": 3,
}
output = detection_utils.transform_instance_annotations(
copy.deepcopy(anno), transforms, (400, 400)
)
inst = detection_utils.annotations_to_instances(
[output, output], (400, 400), mask_format="bitmask"
)
self.assertTrue(isinstance(inst.gt_masks, BitMasks))
def test_gen_crop(self):
instance = {"bbox": [10, 10, 100, 100], "bbox_mode": BoxMode.XYXY_ABS}
t = detection_utils.gen_crop_transform_with_instance((10, 10), (150, 150), instance)
# the box center must fall into the cropped region
self.assertTrue(t.x0 <= 55 <= t.x0 + t.w)
def test_gen_crop_outside_boxes(self):
instance = {"bbox": [10, 10, 100, 100], "bbox_mode": BoxMode.XYXY_ABS}
with self.assertRaises(AssertionError):
detection_utils.gen_crop_transform_with_instance((10, 10), (15, 15), instance)
def test_read_sem_seg(self):
cityscapes_dir = MetadataCatalog.get("cityscapes_fine_sem_seg_val").gt_dir
sem_seg_gt_path = os.path.join(
cityscapes_dir, "frankfurt", "frankfurt_000001_083852_gtFine_labelIds.png"
)
if not PathManager.exists(sem_seg_gt_path):
raise unittest.SkipTest(
"Semantic segmentation ground truth {} not found.".format(sem_seg_gt_path)
)
sem_seg = detection_utils.read_image(sem_seg_gt_path, "L")
self.assertEqual(sem_seg.ndim, 3)
self.assertEqual(sem_seg.shape[2], 1)
self.assertEqual(sem_seg.dtype, np.uint8)
self.assertEqual(sem_seg.max(), 32)
self.assertEqual(sem_seg.min(), 1)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,71 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
import unittest
from detectron2.data.transforms.transform import RotationTransform
class TestRotationTransform(unittest.TestCase):
def assertEqualsArrays(self, a1, a2):
self.assertTrue(np.allclose(a1, a2))
def randomData(self, h=5, w=5):
image = np.random.rand(h, w)
coords = np.array([[i, j] for j in range(h + 1) for i in range(w + 1)], dtype=float)
return image, coords, h, w
def test180(self):
image, coords, h, w = self.randomData(6, 6)
rot = RotationTransform(h, w, 180, expand=False, center=None)
self.assertEqualsArrays(rot.apply_image(image), image[::-1, ::-1])
rotated_coords = [[w - c[0], h - c[1]] for c in coords]
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords)
def test45_coords(self):
_, coords, h, w = self.randomData(4, 6)
rot = RotationTransform(h, w, 45, expand=False, center=None)
rotated_coords = [
[(x + y - (h + w) / 2) / np.sqrt(2) + w / 2, h / 2 + (y + (w - h) / 2 - x) / np.sqrt(2)]
for (x, y) in coords
]
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords)
def test90(self):
image, coords, h, w = self.randomData()
rot = RotationTransform(h, w, 90, expand=False, center=None)
self.assertEqualsArrays(rot.apply_image(image), image.T[::-1])
rotated_coords = [[c[1], w - c[0]] for c in coords]
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords)
def test90_expand(self): # non-square image
image, coords, h, w = self.randomData(h=5, w=8)
rot = RotationTransform(h, w, 90, expand=True, center=None)
self.assertEqualsArrays(rot.apply_image(image), image.T[::-1])
rotated_coords = [[c[1], w - c[0]] for c in coords]
self.assertEqualsArrays(rot.apply_coords(coords), rotated_coords)
def test_center_expand(self):
# center has no effect if expand=True because it only affects shifting
image, coords, h, w = self.randomData(h=5, w=8)
angle = np.random.randint(360)
rot1 = RotationTransform(h, w, angle, expand=True, center=None)
rot2 = RotationTransform(h, w, angle, expand=True, center=(0, 0))
rot3 = RotationTransform(h, w, angle, expand=True, center=(h, w))
rot4 = RotationTransform(h, w, angle, expand=True, center=(2, 5))
for r1 in [rot1, rot2, rot3, rot4]:
for r2 in [rot1, rot2, rot3, rot4]:
self.assertEqualsArrays(r1.apply_image(image), r2.apply_image(image))
self.assertEqualsArrays(r1.apply_coords(coords), r2.apply_coords(coords))
def test_inverse_transform(self):
image, coords, h, w = self.randomData(h=5, w=8)
rot = RotationTransform(h, w, 90, expand=True, center=None)
rot_image = rot.apply_image(image)
self.assertEqualsArrays(rot.inverse().apply_image(rot_image), image)
rot = RotationTransform(h, w, 65, expand=True, center=None)
rotated_coords = rot.apply_coords(coords)
self.assertEqualsArrays(rot.inverse().apply_coords(rotated_coords), coords)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,23 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
from torch.utils.data.sampler import SequentialSampler
from detectron2.data.samplers import GroupedBatchSampler
class TestGroupedBatchSampler(unittest.TestCase):
def test_missing_group_id(self):
sampler = SequentialSampler(list(range(100)))
group_ids = [1] * 100
samples = GroupedBatchSampler(sampler, group_ids, 2)
for mini_batch in samples:
self.assertEqual(len(mini_batch), 2)
def test_groups(self):
sampler = SequentialSampler(list(range(100)))
group_ids = [1, 0] * 50
samples = GroupedBatchSampler(sampler, group_ids, 2)
for mini_batch in samples:
self.assertEqual((mini_batch[0] + mini_batch[1]) % 2, 0)

View File

@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import numpy as np
import unittest
from unittest import mock
from PIL import Image, ImageOps
from detectron2.config import get_cfg
from detectron2.data import detection_utils
from detectron2.data import transforms as T
from detectron2.utils.logger import setup_logger
logger = logging.getLogger(__name__)
class TestTransforms(unittest.TestCase):
def setUp(self):
setup_logger()
def test_apply_rotated_boxes(self):
np.random.seed(125)
cfg = get_cfg()
is_train = True
augs = detection_utils.build_augmentation(cfg, is_train)
image = np.random.rand(200, 300)
image, transforms = T.apply_augmentations(augs, image)
image_shape = image.shape[:2] # h, w
assert image_shape == (800, 1200)
annotation = {"bbox": [179, 97, 62, 40, -56]}
boxes = np.array([annotation["bbox"]], dtype=np.float64) # boxes.shape = (1, 5)
transformed_bbox = transforms.apply_rotated_box(boxes)[0]
expected_bbox = np.array([484, 388, 248, 160, 56], dtype=np.float64)
err_msg = "transformed_bbox = {}, expected {}".format(transformed_bbox, expected_bbox)
assert np.allclose(transformed_bbox, expected_bbox), err_msg
def test_apply_rotated_boxes_unequal_scaling_factor(self):
np.random.seed(125)
h, w = 400, 200
newh, neww = 800, 800
image = np.random.rand(h, w)
augs = []
augs.append(T.Resize(shape=(newh, neww)))
image, transforms = T.apply_augmentations(augs, image)
image_shape = image.shape[:2] # h, w
assert image_shape == (newh, neww)
boxes = np.array(
[
[150, 100, 40, 20, 0],
[150, 100, 40, 20, 30],
[150, 100, 40, 20, 90],
[150, 100, 40, 20, -90],
],
dtype=np.float64,
)
transformed_boxes = transforms.apply_rotated_box(boxes)
expected_bboxes = np.array(
[
[600, 200, 160, 40, 0],
[600, 200, 144.22205102, 52.91502622, 49.10660535],
[600, 200, 80, 80, 90],
[600, 200, 80, 80, -90],
],
dtype=np.float64,
)
err_msg = "transformed_boxes = {}, expected {}".format(transformed_boxes, expected_bboxes)
assert np.allclose(transformed_boxes, expected_bboxes), err_msg
def test_print_augmentation(self):
t = T.RandomCrop("relative", (100, 100))
self.assertEqual(str(t), "RandomCrop(crop_type='relative', crop_size=(100, 100))")
t0 = T.RandomFlip(prob=0.5)
self.assertEqual(str(t0), "RandomFlip(prob=0.5)")
t1 = T.RandomFlip()
self.assertEqual(str(t1), "RandomFlip()")
t = T.AugmentationList([t0, t1])
self.assertEqual(str(t), f"AugmentationList[{t0}, {t1}]")
def test_random_apply_prob_out_of_range_check(self):
test_probabilities = {0.0: True, 0.5: True, 1.0: True, -0.01: False, 1.01: False}
for given_probability, is_valid in test_probabilities.items():
if not is_valid:
self.assertRaises(AssertionError, T.RandomApply, None, prob=given_probability)
else:
T.RandomApply(T.NoOpTransform(), prob=given_probability)
def test_random_apply_wrapping_aug_probability_occured_evaluation(self):
transform_mock = mock.MagicMock(name="MockTransform", spec=T.Augmentation)
image_mock = mock.MagicMock(name="MockImage")
random_apply = T.RandomApply(transform_mock, prob=0.001)
with mock.patch.object(random_apply, "_rand_range", return_value=0.0001):
transform = random_apply.get_transform(image_mock)
transform_mock.get_transform.assert_called_once_with(image_mock)
self.assertIsNot(transform, transform_mock)
def test_random_apply_wrapping_std_transform_probability_occured_evaluation(self):
transform_mock = mock.MagicMock(name="MockTransform", spec=T.Transform)
image_mock = mock.MagicMock(name="MockImage")
random_apply = T.RandomApply(transform_mock, prob=0.001)
with mock.patch.object(random_apply, "_rand_range", return_value=0.0001):
transform = random_apply.get_transform(image_mock)
self.assertIs(transform, transform_mock)
def test_random_apply_probability_not_occured_evaluation(self):
transform_mock = mock.MagicMock(name="MockTransform", spec=T.Augmentation)
image_mock = mock.MagicMock(name="MockImage")
random_apply = T.RandomApply(transform_mock, prob=0.001)
with mock.patch.object(random_apply, "_rand_range", return_value=0.9):
transform = random_apply.get_transform(image_mock)
transform_mock.get_transform.assert_not_called()
self.assertIsInstance(transform, T.NoOpTransform)
def test_augmentation_input_args(self):
input_shape = (100, 100)
output_shape = (50, 50)
# define two augmentations with different args
class TG1(T.Augmentation):
def get_transform(self, image, sem_seg):
return T.ResizeTransform(
input_shape[0], input_shape[1], output_shape[0], output_shape[1]
)
class TG2(T.Augmentation):
def get_transform(self, image):
assert image.shape[:2] == output_shape # check that TG1 is applied
return T.HFlipTransform(output_shape[1])
image = np.random.rand(*input_shape).astype("float32")
sem_seg = (np.random.rand(*input_shape) < 0.5).astype("uint8")
inputs = T.AugInput(image, sem_seg=sem_seg) # provide two args
tfms = inputs.apply_augmentations([TG1(), TG2()])
self.assertIsInstance(tfms[0], T.ResizeTransform)
self.assertIsInstance(tfms[1], T.HFlipTransform)
self.assertTrue(inputs.image.shape[:2] == output_shape)
self.assertTrue(inputs.sem_seg.shape[:2] == output_shape)
class TG3(T.Augmentation):
def get_transform(self, image, nonexist):
pass
with self.assertRaises(AttributeError):
inputs.apply_augmentations([TG3()])
def test_augmentation_list(self):
input_shape = (100, 100)
image = np.random.rand(*input_shape).astype("float32")
sem_seg = (np.random.rand(*input_shape) < 0.5).astype("uint8")
inputs = T.AugInput(image, sem_seg=sem_seg) # provide two args
augs = T.AugmentationList([T.RandomFlip(), T.Resize(20)])
_ = T.AugmentationList([augs, T.Resize(30)])(inputs)
# 3 in latest fvcore (flattened transformlist), 2 in older
# self.assertEqual(len(tfms), 3)
def test_color_transforms(self):
rand_img = np.random.random((100, 100, 3)) * 255
rand_img = rand_img.astype("uint8")
# Test no-op
noop_transform = T.ColorTransform(lambda img: img)
self.assertTrue(np.array_equal(rand_img, noop_transform.apply_image(rand_img)))
# Test a ImageOps operation
magnitude = np.random.randint(0, 256)
solarize_transform = T.PILColorTransform(lambda img: ImageOps.solarize(img, magnitude))
expected_img = ImageOps.solarize(Image.fromarray(rand_img), magnitude)
self.assertTrue(np.array_equal(expected_img, solarize_transform.apply_image(rand_img)))