mirror of
https://github.com/YifanXu74/MQ-Det.git
synced 2025-06-03 15:03:07 +08:00
96 lines
3.9 KiB
Python
96 lines
3.9 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||
|
import torch
|
||
|
from maskrcnn_benchmark.structures.image_list import to_image_list
|
||
|
|
||
|
import pdb
|
||
|
class BatchCollator(object):
|
||
|
"""
|
||
|
From a list of samples from the dataset,
|
||
|
returns the batched images and targets.
|
||
|
This should be passed to the DataLoader
|
||
|
"""
|
||
|
|
||
|
def __init__(self, size_divisible=0):
|
||
|
self.size_divisible = size_divisible
|
||
|
|
||
|
def __call__(self, batch):
|
||
|
transposed_batch = list(zip(*batch))
|
||
|
|
||
|
images = to_image_list(transposed_batch[0], self.size_divisible)
|
||
|
targets = transposed_batch[1]
|
||
|
img_ids = transposed_batch[2]
|
||
|
positive_map = None
|
||
|
positive_map_eval = None
|
||
|
greenlight_map = None
|
||
|
|
||
|
path = transposed_batch[-1] # debug
|
||
|
|
||
|
if isinstance(targets[0], dict):
|
||
|
return images, targets, img_ids, positive_map, positive_map_eval
|
||
|
|
||
|
if "greenlight_map" in transposed_batch[1][0].fields():
|
||
|
greenlight_map = torch.stack([i.get_field("greenlight_map") for i in transposed_batch[1]], dim = 0)
|
||
|
|
||
|
if "positive_map" in transposed_batch[1][0].fields():
|
||
|
# we batch the positive maps here
|
||
|
# Since in general each batch element will have a different number of boxes,
|
||
|
# we collapse a single batch dimension to avoid padding. This is sufficient for our purposes.
|
||
|
max_len = max([v.get_field("positive_map").shape[1] for v in transposed_batch[1]])
|
||
|
nb_boxes = sum([v.get_field("positive_map").shape[0] for v in transposed_batch[1]])
|
||
|
batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool)
|
||
|
cur_count = 0
|
||
|
for v in transposed_batch[1]:
|
||
|
cur_pos = v.get_field("positive_map")
|
||
|
batched_pos_map[cur_count: cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos
|
||
|
cur_count += len(cur_pos)
|
||
|
|
||
|
assert cur_count == len(batched_pos_map)
|
||
|
positive_map = batched_pos_map.float()
|
||
|
|
||
|
|
||
|
if "positive_map_eval" in transposed_batch[1][0].fields():
|
||
|
# we batch the positive maps here
|
||
|
# Since in general each batch element will have a different number of boxes,
|
||
|
# we collapse a single batch dimension to avoid padding. This is sufficient for our purposes.
|
||
|
max_len = max([v.get_field("positive_map_eval").shape[1] for v in transposed_batch[1]])
|
||
|
nb_boxes = sum([v.get_field("positive_map_eval").shape[0] for v in transposed_batch[1]])
|
||
|
batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool)
|
||
|
cur_count = 0
|
||
|
for v in transposed_batch[1]:
|
||
|
cur_pos = v.get_field("positive_map_eval")
|
||
|
batched_pos_map[cur_count: cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos
|
||
|
cur_count += len(cur_pos)
|
||
|
|
||
|
assert cur_count == len(batched_pos_map)
|
||
|
# assert batched_pos_map.sum().item() == sum([v["positive_map"].sum().item() for v in batch[1]])
|
||
|
positive_map_eval = batched_pos_map.float()
|
||
|
|
||
|
|
||
|
return images, targets, img_ids, positive_map, positive_map_eval, greenlight_map, path # debug: delete path
|
||
|
|
||
|
|
||
|
class BBoxAugCollator(object):
|
||
|
"""
|
||
|
From a list of samples from the dataset,
|
||
|
returns the images and targets.
|
||
|
Images should be converted to batched images in `im_detect_bbox_aug`
|
||
|
"""
|
||
|
|
||
|
def __call__(self, batch):
|
||
|
# return list(zip(*batch))
|
||
|
transposed_batch = list(zip(*batch))
|
||
|
|
||
|
images = transposed_batch[0]
|
||
|
targets = transposed_batch[1]
|
||
|
img_ids = transposed_batch[2]
|
||
|
positive_map = None
|
||
|
positive_map_eval = None
|
||
|
|
||
|
if isinstance(targets[0], dict):
|
||
|
return images, targets, img_ids, positive_map, positive_map_eval
|
||
|
|
||
|
return images, targets, img_ids, positive_map, positive_map_eval
|
||
|
|
||
|
|
||
|
|