# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import itertools import copy import logging import numpy as np import operator import pickle import torch.utils.data from fvcore.common.file_io import PathManager from tabulate import tabulate from termcolor import colored from detectron2.structures import BoxMode from detectron2.utils.comm import get_world_size from detectron2.utils.env import seed_all_rng from detectron2.utils.logger import log_first_n from .catalog import DatasetCatalog, MetadataCatalog from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset from .dataset_mapper import DatasetMapper from .detection_utils import check_metadata_consistency from .samplers import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler """ This file contains the default logic to build a dataloader for training or testing. """ __all__ = [ "build_batch_data_loader", "build_detection_train_loader", "build_detection_test_loader", "get_detection_dataset_dicts", "load_proposals_into_dataset", "print_instances_class_histogram", ] def filter_images_with_only_crowd_annotations(dataset_dicts): """ Filter out images with none annotations or only crowd annotations (i.e., images without non-crowd annotations). A common training-time preprocessing on COCO dataset. Args: dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. Returns: list[dict]: the same format, but filtered. """ num_before = len(dataset_dicts) def valid(anns): for ann in anns: if ann.get("iscrowd", 0) == 0: return True return False dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] num_after = len(dataset_dicts) logger = logging.getLogger(__name__) logger.info( "Removed {} images with no usable annotations. {} images left.".format( num_before - num_after, num_after ) ) return dataset_dicts def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image): """ Filter out images with too few number of keypoints. Args: dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. Returns: list[dict]: the same format as dataset_dicts, but filtered. """ num_before = len(dataset_dicts) def visible_keypoints_in_image(dic): # Each keypoints field has the format [x1, y1, v1, ...], where v is visibility annotations = dic["annotations"] return sum( (np.array(ann["keypoints"][2::3]) > 0).sum() for ann in annotations if "keypoints" in ann ) dataset_dicts = [ x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image ] num_after = len(dataset_dicts) logger = logging.getLogger(__name__) logger.info( "Removed {} images with fewer than {} keypoints.".format( num_before - num_after, min_keypoints_per_image ) ) return dataset_dicts def load_proposals_into_dataset(dataset_dicts, proposal_file): """ Load precomputed object proposals into the dataset. The proposal file should be a pickled dict with the following keys: - "ids": list[int] or list[str], the image ids - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores corresponding to the boxes. - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``. Args: dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. proposal_file (str): file path of pre-computed proposals, in pkl format. Returns: list[dict]: the same format as dataset_dicts, but added proposal field. """ logger = logging.getLogger(__name__) logger.info("Loading proposals from: {}".format(proposal_file)) with PathManager.open(proposal_file, "rb") as f: proposals = pickle.load(f, encoding="latin1") # Rename the key names in D1 proposal files rename_keys = {"indexes": "ids", "scores": "objectness_logits"} for key in rename_keys: if key in proposals: proposals[rename_keys[key]] = proposals.pop(key) # Fetch the indexes of all proposals that are in the dataset # Convert image_id to str since they could be int. img_ids = set({str(record["image_id"]) for record in dataset_dicts}) id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids} # Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS' bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS for record in dataset_dicts: # Get the index of the proposal i = id_to_index[str(record["image_id"])] boxes = proposals["boxes"][i] objectness_logits = proposals["objectness_logits"][i] # Sort the proposals in descending order of the scores inds = objectness_logits.argsort()[::-1] record["proposal_boxes"] = boxes[inds] record["proposal_objectness_logits"] = objectness_logits[inds] record["proposal_bbox_mode"] = bbox_mode return dataset_dicts def print_instances_class_histogram(dataset_dicts, class_names): """ Args: dataset_dicts (list[dict]): list of dataset dicts. class_names (list[str]): list of class names (zero-indexed). """ num_classes = len(class_names) hist_bins = np.arange(num_classes + 1) histogram = np.zeros((num_classes,), dtype=np.int) for entry in dataset_dicts: annos = entry["annotations"] classes = [x["category_id"] for x in annos if not x.get("iscrowd", 0)] histogram += np.histogram(classes, bins=hist_bins)[0] N_COLS = min(6, len(class_names) * 2) def short_name(x): # make long class names shorter. useful for lvis if len(x) > 13: return x[:11] + ".." return x data = list( itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]) ) total_num_instances = sum(data[1::2]) data.extend([None] * (N_COLS - (len(data) % N_COLS))) if num_classes > 1: data.extend(["total", total_num_instances]) data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) table = tabulate( data, headers=["category", "#instances"] * (N_COLS // 2), tablefmt="pipe", numalign="left", stralign="center", ) log_first_n( logging.INFO, "Distribution of instances among all {} categories:\n".format(num_classes) + colored(table, "cyan"), key="message", ) def get_detection_dataset_dicts( dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None, cfg=None ): """ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. Args: dataset_names (list[str]): a list of dataset names filter_empty (bool): whether to filter out images without instance annotations min_keypoints (int): filter out images with fewer keypoints than `min_keypoints`. Set to 0 to do nothing. proposal_files (list[str]): if given, a list of object proposal files that match each dataset in `dataset_names`. Returns: list[dict]: a list of dicts following the standard dataset dict format. """ assert len(dataset_names) dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] for dataset_name, dicts in zip(dataset_names, dataset_dicts): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) if proposal_files is not None: assert len(dataset_names) == len(proposal_files) # load precomputed proposals from proposal files dataset_dicts = [ load_proposals_into_dataset(dataset_i_dicts, proposal_file) for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) ] dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) has_instances = "annotations" in dataset_dicts[0] if filter_empty and has_instances: dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) if min_keypoints > 0 and has_instances: dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) d_name = dataset_names[0] if 'voc_coco' in d_name: if 'train' in d_name: dataset_dicts = remove_prev_class_and_unk_instances(cfg, dataset_dicts) elif 'test' in d_name: dataset_dicts = label_known_class_and_unknown(cfg, dataset_dicts) elif 'val' in d_name: dataset_dicts = label_known_class_and_unknown(cfg, dataset_dicts) elif 'ft' in d_name: dataset_dicts = remove_unknown_instances(cfg, dataset_dicts) if has_instances: try: class_names = MetadataCatalog.get(dataset_names[0]).thing_classes check_metadata_consistency("thing_classes", dataset_names) print_instances_class_histogram(dataset_dicts, class_names) except AttributeError: # class names are not available for this dataset pass assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names)) return dataset_dicts def remove_prev_class_and_unk_instances(cfg, dataset_dicts): # For training data. prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS valid_classes = range(prev_intro_cls, curr_intro_cls) logger = logging.getLogger(__name__) logger.info("Valid classes: " + str(valid_classes)) logger.info("Removing earlier seen class objects and the unknown objects...") for entry in copy.copy(dataset_dicts): annos = entry["annotations"] for annotation in copy.copy(annos): if annotation["category_id"] not in valid_classes: annos.remove(annotation) if len(annos) == 0: dataset_dicts.remove(entry) return dataset_dicts def remove_unknown_instances(cfg, dataset_dicts): # For finetune data. prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS valid_classes = range(0, prev_intro_cls+curr_intro_cls) logger = logging.getLogger(__name__) logger.info("Valid classes: " + str(valid_classes)) logger.info("Removing the unknown objects...") for entry in copy.copy(dataset_dicts): annos = entry["annotations"] for annotation in copy.copy(annos): if annotation["category_id"] not in valid_classes: annos.remove(annotation) if len(annos) == 0: dataset_dicts.remove(entry) return dataset_dicts def label_known_class_and_unknown(cfg, dataset_dicts): # For test and validation data. # Label known instances the corresponding label and unknown instances as unknown. prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS total_num_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES known_classes = range(0, prev_intro_cls+curr_intro_cls) logger = logging.getLogger(__name__) logger.info("Known classes: " + str(known_classes)) logger.info("Labelling known instances the corresponding label, and unknown instances as unknown...") for entry in dataset_dicts: annos = entry["annotations"] for annotation in annos: if annotation["category_id"] not in known_classes: annotation["category_id"] = total_num_class - 1 return dataset_dicts def build_batch_data_loader( dataset, sampler, total_batch_size, *, aspect_ratio_grouping=False, num_workers=0 ): """ Build a batched dataloader for training. Args: dataset (torch.utils.data.Dataset): map-style PyTorch dataset. Can be indexed. sampler (torch.utils.data.sampler.Sampler): a sampler that produces indices total_batch_size (int): total batch size across GPUs. aspect_ratio_grouping (bool): whether to group images with similar aspect ratio for efficiency. When enabled, it requires each element in dataset be a dict with keys "width" and "height". num_workers (int): number of parallel data loading workers Returns: iterable[list]. Length of each list is the batch size of the current GPU. Each element in the list comes from the dataset. """ world_size = get_world_size() assert ( total_batch_size > 0 and total_batch_size % world_size == 0 ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( total_batch_size, world_size ) batch_size = total_batch_size // world_size if aspect_ratio_grouping: data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler, num_workers=num_workers, batch_sampler=None, collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements worker_init_fn=worker_init_reset_seed, ) # yield individual mapped dict return AspectRatioGroupedDataset(data_loader, batch_size) else: batch_sampler = torch.utils.data.sampler.BatchSampler( sampler, batch_size, drop_last=True ) # drop_last so the batch always have the same size return torch.utils.data.DataLoader( dataset, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=trivial_batch_collator, worker_init_fn=worker_init_reset_seed, ) def build_detection_train_loader(cfg, mapper=None): """ A data loader is created by the following steps: 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts. 2. Coordinate a random shuffle order shared among all processes (all GPUs) 3. Each process spawn another few workers to process the dicts. Each worker will: * Map each metadata dict into another format to be consumed by the model. * Batch them by simply putting dicts into a list. The batched ``list[mapped_dict]`` is what this dataloader will yield. Args: cfg (CfgNode): the config mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. By default it will be ``DatasetMapper(cfg, True)``. Returns: an infinite iterator of training data """ dataset_dicts = get_detection_dataset_dicts( cfg.DATASETS.TRAIN, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE if cfg.MODEL.KEYPOINT_ON else 0, proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, cfg=cfg, ) dataset = DatasetFromList(dataset_dicts, copy=False) if mapper is None: mapper = DatasetMapper(cfg, True) dataset = MapDataset(dataset, mapper) sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger = logging.getLogger(__name__) logger.info("Using training sampler {}".format(sampler_name)) # TODO avoid if-else? if sampler_name == "TrainingSampler": sampler = TrainingSampler(len(dataset)) elif sampler_name == "RepeatFactorTrainingSampler": repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD ) sampler = RepeatFactorTrainingSampler(repeat_factors) else: raise ValueError("Unknown training sampler: {}".format(sampler_name)) return build_batch_data_loader( dataset, sampler, cfg.SOLVER.IMS_PER_BATCH, aspect_ratio_grouping=cfg.DATALOADER.ASPECT_RATIO_GROUPING, num_workers=cfg.DATALOADER.NUM_WORKERS, ) def build_detection_test_loader(cfg, dataset_name, mapper=None): """ Similar to `build_detection_train_loader`. But this function uses the given `dataset_name` argument (instead of the names in cfg), and uses batch size 1. Args: cfg: a detectron2 CfgNode dataset_name (str): a name of the dataset that's available in the DatasetCatalog mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. By default it will be `DatasetMapper(cfg, False)`. Returns: DataLoader: a torch DataLoader, that loads the given detection dataset, with test-time transformation and batching. """ dataset_dicts = get_detection_dataset_dicts( [dataset_name], filter_empty=False, proposal_files=[ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)] ] if cfg.MODEL.LOAD_PROPOSALS else None, cfg=cfg ) dataset = DatasetFromList(dataset_dicts) if mapper is None: mapper = DatasetMapper(cfg, False) dataset = MapDataset(dataset, mapper) sampler = InferenceSampler(len(dataset)) # Always use 1 image per worker during inference since this is the # standard when reporting inference time in papers. batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False) data_loader = torch.utils.data.DataLoader( dataset, num_workers=cfg.DATALOADER.NUM_WORKERS, batch_sampler=batch_sampler, collate_fn=trivial_batch_collator, ) return data_loader def trivial_batch_collator(batch): """ A batch collator that does nothing. """ return batch def worker_init_reset_seed(worker_id): seed_all_rng(np.random.randint(2 ** 31) + worker_id)