mirror of https://github.com/YifanXu74/MQ-Det.git
58 lines
2.1 KiB
Python
58 lines
2.1 KiB
Python
import torch
|
|
import maskrcnn_benchmark.utils.dist as dist
|
|
|
|
|
|
def normalized_positive_map(positive_map):
|
|
positive_map = positive_map.float()
|
|
positive_map_num_pos = positive_map.sum(2)
|
|
positive_map_num_pos[positive_map_num_pos == 0] = 1e-6
|
|
positive_map = positive_map / positive_map_num_pos.unsqueeze(-1)
|
|
return positive_map
|
|
|
|
|
|
def pad_tensor_given_dim_length(tensor, dim, length, padding_value=0, batch_first=True):
|
|
new_size = list(tensor.size()[:dim]) + [length] + list(tensor.size()[dim + 1:])
|
|
out_tensor = tensor.data.new(*new_size).fill_(padding_value)
|
|
if batch_first:
|
|
out_tensor[:, :tensor.size(1), ...] = tensor
|
|
else:
|
|
out_tensor[:tensor.size(0), ...] = tensor
|
|
return out_tensor
|
|
|
|
|
|
def pad_random_negative_tensor_given_length(positive_tensor, negative_padding_tensor, length=None):
|
|
assert positive_tensor.shape[0] + negative_padding_tensor.shape[0] == length
|
|
return torch.cat((positive_tensor, negative_padding_tensor), dim=0)
|
|
|
|
|
|
def gather_tensors(tensor):
|
|
"""
|
|
Performs all_gather operation on the provided tensors.
|
|
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
"""
|
|
if not dist.is_dist_avail_and_initialized():
|
|
return torch.stack([tensor], dim=0)
|
|
|
|
total = dist.get_world_size()
|
|
rank = torch.distributed.get_rank()
|
|
# gathered_normalized_img_emb = [torch.zeros_like(normalized_img_emb) for _ in range(total)]
|
|
# torch.distributed.all_gather(gathered_normalized_img_emb, normalized_img_emb)
|
|
|
|
tensors_gather = [
|
|
torch.zeros_like(tensor)
|
|
for _ in range(total)
|
|
]
|
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
|
|
# need to do this to restore propagation of the gradients
|
|
tensors_gather[rank] = tensor
|
|
output = torch.stack(tensors_gather, dim=0)
|
|
return output
|
|
|
|
|
|
def convert_to_roi_format(boxes):
|
|
concat_boxes = boxes.bbox
|
|
device, dtype = concat_boxes.device, concat_boxes.dtype
|
|
ids = torch.full((len(boxes), 1), 0, dtype=dtype, device=device)
|
|
rois = torch.cat([ids, concat_boxes], dim=1)
|
|
return rois |