mirror of
https://github.com/UX-Decoder/DINOv.git
synced 2025-06-03 14:59:16 +08:00
226 lines
9.7 KiB
Python
226 lines
9.7 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
import copy
|
||
|
import logging
|
||
|
import numpy as np
|
||
|
from typing import List, Optional, Union
|
||
|
import torch
|
||
|
|
||
|
from detectron2.config import configurable
|
||
|
|
||
|
from detectron2.data import detection_utils as utils
|
||
|
from detectron2.data import transforms as T
|
||
|
from detectron2.structures import BitMasks, Instances, Boxes
|
||
|
from pycocotools import mask
|
||
|
import os
|
||
|
import random
|
||
|
from tqdm import tqdm
|
||
|
# import lvis
|
||
|
|
||
|
"""
|
||
|
This file contains the default mapping that's applied to "dataset dicts".
|
||
|
"""
|
||
|
|
||
|
__all__ = ["LVISInferenceMapperWithGT"]
|
||
|
|
||
|
class LVISInferenceMapperWithGT:
|
||
|
"""
|
||
|
A callable which takes a dataset dict in Detectron2 Dataset format,
|
||
|
and map it into a format used by the model.
|
||
|
|
||
|
This is the default callable to be used to map your dataset dict into training data.
|
||
|
You may need to follow it to implement your own one for customized logic,
|
||
|
such as a different way to read or transform images.
|
||
|
See :doc:`/tutorials/data_loading` for details.
|
||
|
|
||
|
The callable currently does the following:
|
||
|
|
||
|
1. Read the image from "file_name"
|
||
|
2. Applies cropping/geometric transforms to the image and annotations
|
||
|
3. Prepare data and annotations to Tensor and :class:`Instances`
|
||
|
"""
|
||
|
|
||
|
@configurable
|
||
|
def __init__(
|
||
|
self,
|
||
|
is_train: bool,
|
||
|
*,
|
||
|
augmentations: List[Union[T.Augmentation, T.Transform]],
|
||
|
image_format: str,
|
||
|
use_instance_mask: bool = False,
|
||
|
use_keypoint: bool = False,
|
||
|
instance_mask_format: str = "polygon",
|
||
|
keypoint_hflip_indices: Optional[np.ndarray] = None,
|
||
|
precomputed_proposal_topk: Optional[int] = None,
|
||
|
recompute_boxes: bool = False,
|
||
|
):
|
||
|
"""
|
||
|
NOTE: this interface is experimental.
|
||
|
|
||
|
Args:
|
||
|
is_train: whether it's used in training or inference
|
||
|
augmentations: a list of augmentations or deterministic transforms to apply
|
||
|
image_format: an image format supported by :func:`detection_utils.read_image`.
|
||
|
use_instance_mask: whether to process instance segmentation annotations, if available
|
||
|
use_keypoint: whether to process keypoint annotations if available
|
||
|
instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
|
||
|
masks into this format.
|
||
|
keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
|
||
|
precomputed_proposal_topk: if given, will load pre-computed
|
||
|
proposals from dataset_dict and keep the top k proposals for each image.
|
||
|
recompute_boxes: whether to overwrite bounding box annotations
|
||
|
by computing tight bounding boxes from instance mask annotations.
|
||
|
"""
|
||
|
if recompute_boxes:
|
||
|
assert use_instance_mask, "recompute_boxes requires instance masks"
|
||
|
# fmt: off
|
||
|
self.is_train = is_train
|
||
|
self.augmentations = T.AugmentationList(augmentations)
|
||
|
self.image_format = image_format
|
||
|
self.use_instance_mask = use_instance_mask
|
||
|
self.instance_mask_format = instance_mask_format
|
||
|
self.use_keypoint = use_keypoint
|
||
|
self.keypoint_hflip_indices = keypoint_hflip_indices
|
||
|
self.proposal_topk = precomputed_proposal_topk
|
||
|
self.recompute_boxes = recompute_boxes
|
||
|
# fmt: on
|
||
|
logger = logging.getLogger(__name__)
|
||
|
mode = "training" if is_train else "inference"
|
||
|
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
|
||
|
# ! HARD CODE HERE TO LOAD EMBEDDINGS
|
||
|
# self.embeddings = pre_load_embeddings(
|
||
|
# 'work_dirs/test_lvis/embedding_model_0021999.pth_model_0021999.pth_model_0021999.pth')
|
||
|
# if len(self.embeddings) != 1203:
|
||
|
# raise ValueError(f'len(self.embeddings) = {len(self.embeddings)}')
|
||
|
|
||
|
@classmethod
|
||
|
def from_config(cls, cfg, is_train: bool = True):
|
||
|
augs = utils.build_augmentation(cfg, is_train)
|
||
|
if cfg.INPUT.CROP.ENABLED and is_train:
|
||
|
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
|
||
|
recompute_boxes = cfg.MODEL.MASK_ON
|
||
|
else:
|
||
|
recompute_boxes = False
|
||
|
|
||
|
ret = {
|
||
|
"is_train": is_train,
|
||
|
"augmentations": augs,
|
||
|
"image_format": cfg.INPUT.FORMAT,
|
||
|
"use_instance_mask": cfg.MODEL.MASK_ON,
|
||
|
"instance_mask_format": cfg.INPUT.MASK_FORMAT,
|
||
|
"use_keypoint": cfg.MODEL.KEYPOINT_ON,
|
||
|
"recompute_boxes": recompute_boxes,
|
||
|
}
|
||
|
|
||
|
if cfg.MODEL.KEYPOINT_ON:
|
||
|
ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
||
|
|
||
|
if cfg.MODEL.LOAD_PROPOSALS:
|
||
|
ret["precomputed_proposal_topk"] = (
|
||
|
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
|
||
|
if is_train
|
||
|
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
|
||
|
)
|
||
|
return ret
|
||
|
|
||
|
def _transform_annotations(self, dataset_dict, transforms, image_shape):
|
||
|
# USER: Modify this if you want to keep them for some reason.
|
||
|
for anno in dataset_dict["annotations"]:
|
||
|
if not self.use_instance_mask:
|
||
|
anno.pop("segmentation", None)
|
||
|
if not self.use_keypoint:
|
||
|
anno.pop("keypoints", None)
|
||
|
|
||
|
# USER: Implement additional transformations if you have other types of data
|
||
|
annos = [
|
||
|
utils.transform_instance_annotations(
|
||
|
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
|
||
|
)
|
||
|
for obj in dataset_dict.pop("annotations")
|
||
|
if obj.get("iscrowd", 0) == 0
|
||
|
]
|
||
|
instances = utils.annotations_to_instances(
|
||
|
annos, image_shape, mask_format=self.instance_mask_format
|
||
|
)
|
||
|
|
||
|
# After transforms such as cropping are applied, the bounding box may no longer
|
||
|
# tightly bound the object. As an example, imagine a triangle object
|
||
|
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
|
||
|
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
|
||
|
# the intersection of original bounding box and the cropping box.
|
||
|
if self.recompute_boxes:
|
||
|
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
||
|
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
||
|
|
||
|
def __call__(self, dataset_dict):
|
||
|
"""
|
||
|
Args:
|
||
|
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
||
|
|
||
|
Returns:
|
||
|
dict: a format that builtin models in detectron2 accept
|
||
|
"""
|
||
|
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
||
|
# USER: Write your own image loading if it's not from a file
|
||
|
image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
|
||
|
utils.check_image_size(dataset_dict, image)
|
||
|
|
||
|
sem_seg_gt = None
|
||
|
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
||
|
transforms = self.augmentations(aug_input)
|
||
|
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
|
||
|
|
||
|
image_shape = image.shape[:2] # h, w
|
||
|
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
||
|
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
||
|
# Therefore it's important to use torch.Tensor.
|
||
|
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
||
|
|
||
|
assert len(dataset_dict['instance']) > 0
|
||
|
masks = []
|
||
|
classes = []
|
||
|
for inst, label in zip(dataset_dict['instance'], dataset_dict['labels']):
|
||
|
rle = mask.frPyObjects(inst, dataset_dict['height'], dataset_dict['width'])
|
||
|
m = mask.decode(rle)
|
||
|
# sometimes there are multiple binary map (corresponding to multiple segs)
|
||
|
m = np.sum(m, axis=2)
|
||
|
m = m.astype(np.uint8) # convert to np.uint8
|
||
|
masks.append(transforms.apply_segmentation(m[:, :, None])[:, :, 0])
|
||
|
classes.append(label)
|
||
|
|
||
|
instances = Instances(image_shape)
|
||
|
classes = np.array(classes)
|
||
|
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
||
|
if len(masks) == 0:
|
||
|
# Some image does not have annotation (all ignored)
|
||
|
instances.gt_masks = torch.zeros((0, image_shape[0], image_shape[1]))
|
||
|
instances.gt_boxes = Boxes(torch.zeros((0, 4)))
|
||
|
else:
|
||
|
masks = BitMasks(
|
||
|
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
||
|
)
|
||
|
instances.gt_masks = masks.tensor
|
||
|
instances.gt_boxes = masks.get_bounding_boxes()
|
||
|
dataset_dict["instances"] = instances
|
||
|
|
||
|
# ! HARD CODE HERE TO LOAD EMBEDDINGS
|
||
|
# example_num = 16
|
||
|
# input_query_label = []
|
||
|
#
|
||
|
# for cate_embeddings in self.embeddings:
|
||
|
# num = len(cate_embeddings)
|
||
|
# if num == 0:
|
||
|
# assert False, 'num == 0'
|
||
|
# selected_emd = []
|
||
|
# select_num = min(num, example_num)
|
||
|
# # randomly select embeddings
|
||
|
# selected_idx = random.sample(range(num), select_num)
|
||
|
# for idx in selected_idx:
|
||
|
# selected_emd.append(cate_embeddings[idx])
|
||
|
# selected_emd = torch.stack(selected_emd, 0)
|
||
|
# ave_emd = torch.mean(selected_emd, 0)
|
||
|
# input_query_label.append(ave_emd)
|
||
|
# input_query_label = torch.stack(input_query_label, 0) # (num_cates, 256)
|
||
|
# assert input_query_label.shape[0] == 1203, f'input_query_label.shape[0] = {input_query_label.shape[0]}'
|
||
|
# dataset_dict['input_query_label'] = input_query_label
|
||
|
|
||
|
return dataset_dict
|