mirror of https://github.com/UX-Decoder/DINOv.git
341 lines
12 KiB
Python
341 lines
12 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License)
|
|
import copy
|
|
import logging
|
|
import numpy as np
|
|
import operator
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
import torch
|
|
import torch.utils.data as torchdata
|
|
import json
|
|
from detectron2.utils.comm import get_world_size
|
|
from detectron2.utils.logger import _log_api_usage, log_first_n
|
|
|
|
from detectron2.config import configurable
|
|
from detectron2.data import samplers
|
|
from torch.utils.data.sampler import BatchSampler, Sampler
|
|
from detectron2.data.common import DatasetFromList, MapDataset
|
|
from detectron2.data.dataset_mapper import DatasetMapper
|
|
from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader
|
|
from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler, InferenceSampler
|
|
from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram
|
|
from detectron2.data.build import filter_images_with_only_crowd_annotations
|
|
from detectron2.data.build import filter_images_with_few_keypoints
|
|
from detectron2.data.build import check_metadata_consistency
|
|
from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
|
|
from detectron2.utils import comm
|
|
import itertools
|
|
import math
|
|
from collections import defaultdict
|
|
from typing import Optional
|
|
|
|
logger = logging.getLogger('detectron2.vlpart.data.custom_dataset_dataloader')
|
|
|
|
|
|
def trivial_batch_collator(batch):
|
|
return batch
|
|
|
|
|
|
def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
|
|
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
|
if 'MultiDataset' in sampler_name:
|
|
dataset_dicts = get_detection_dataset_dicts_with_source(
|
|
cfg.DATASETS.TRAIN,
|
|
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
|
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
|
if cfg.MODEL.KEYPOINT_ON else 0,
|
|
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
|
)
|
|
else:
|
|
dataset_dicts = get_detection_dataset_dicts(
|
|
cfg.DATASETS.TRAIN,
|
|
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
|
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
|
if cfg.MODEL.KEYPOINT_ON else 0,
|
|
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
|
)
|
|
|
|
if mapper is None:
|
|
mapper = DatasetMapper(cfg, True)
|
|
|
|
if sampler is not None:
|
|
pass
|
|
elif sampler_name == "TrainingSampler":
|
|
sampler = TrainingSampler(len(dataset))
|
|
elif sampler_name == "MultiDatasetSampler":
|
|
sampler = MultiDatasetSampler(
|
|
dataset_dicts,
|
|
dataset_ratio = cfg.DATALOADER.DATASET_RATIO,
|
|
use_rfs = cfg.DATALOADER.USE_RFS,
|
|
dataset_ann = cfg.DATALOADER.DATASET_ANN,
|
|
repeat_threshold = cfg.DATALOADER.REPEAT_THRESHOLD,
|
|
)
|
|
elif sampler_name == "RepeatFactorTrainingSampler":
|
|
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
|
|
dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
|
|
)
|
|
sampler = RepeatFactorTrainingSampler(repeat_factors)
|
|
else:
|
|
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
|
|
|
return {
|
|
"dataset": dataset_dicts,
|
|
"sampler": sampler,
|
|
"mapper": mapper,
|
|
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
|
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
|
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
|
'multi_dataset_grouping': cfg.DATALOADER.MULTI_DATASET_GROUPING,
|
|
'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE,
|
|
'dataset_bs': cfg.DATALOADER.DATASET_BS,
|
|
'num_datasets': len(cfg.DATASETS.TRAIN)
|
|
}
|
|
|
|
|
|
@configurable(from_config=_custom_train_loader_from_config)
|
|
def build_custom_train_loader(
|
|
dataset, *, mapper, sampler,
|
|
total_batch_size=16,
|
|
aspect_ratio_grouping=True,
|
|
num_workers=0,
|
|
num_datasets=1,
|
|
multi_dataset_grouping=False,
|
|
use_diff_bs_size=False,
|
|
dataset_bs=[]
|
|
):
|
|
"""
|
|
Modified from detectron2.data.build.build_custom_train_loader, but supports
|
|
different samplers
|
|
"""
|
|
if isinstance(dataset, list):
|
|
dataset = DatasetFromList(dataset, copy=False)
|
|
if mapper is not None:
|
|
dataset = MapDataset(dataset, mapper)
|
|
if sampler is None:
|
|
sampler = TrainingSampler(len(dataset))
|
|
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
|
if multi_dataset_grouping:
|
|
return build_multi_dataset_batch_data_loader(
|
|
use_diff_bs_size,
|
|
dataset_bs,
|
|
dataset,
|
|
sampler,
|
|
total_batch_size,
|
|
num_datasets=num_datasets,
|
|
num_workers=num_workers,
|
|
)
|
|
else:
|
|
return build_batch_data_loader(
|
|
dataset,
|
|
sampler,
|
|
total_batch_size,
|
|
aspect_ratio_grouping=aspect_ratio_grouping,
|
|
num_workers=num_workers,
|
|
)
|
|
|
|
|
|
def build_multi_dataset_batch_data_loader(
|
|
use_diff_bs_size, dataset_bs,
|
|
dataset, sampler, total_batch_size, num_datasets, num_workers=0
|
|
):
|
|
"""
|
|
"""
|
|
world_size = get_world_size()
|
|
assert (
|
|
total_batch_size > 0 and total_batch_size % world_size == 0
|
|
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
|
|
total_batch_size, world_size
|
|
)
|
|
|
|
batch_size = total_batch_size // world_size
|
|
data_loader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
sampler=sampler,
|
|
num_workers=num_workers,
|
|
batch_sampler=None,
|
|
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
|
|
worker_init_fn=worker_init_reset_seed,
|
|
) # yield individual mapped dict
|
|
if use_diff_bs_size:
|
|
return DIFFMDAspectRatioGroupedDataset(
|
|
data_loader, dataset_bs, num_datasets)
|
|
else:
|
|
return MDAspectRatioGroupedDataset(
|
|
data_loader, batch_size, num_datasets)
|
|
|
|
|
|
def get_detection_dataset_dicts_with_source(
|
|
dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None
|
|
):
|
|
assert len(dataset_names)
|
|
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
|
|
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
|
|
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
|
|
|
for source_id, (dataset_name, dicts) in \
|
|
enumerate(zip(dataset_names, dataset_dicts)):
|
|
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
|
for d in dicts:
|
|
d['dataset_source'] = source_id
|
|
|
|
if "annotations" in dicts[0]:
|
|
try:
|
|
class_names = MetadataCatalog.get(dataset_name).thing_classes
|
|
check_metadata_consistency("thing_classes", dataset_name)
|
|
print_instances_class_histogram(dicts, class_names)
|
|
except AttributeError: # class names are not available for this dataset
|
|
pass
|
|
|
|
assert proposal_files is None
|
|
|
|
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
|
|
|
has_instances = "annotations" in dataset_dicts[0]
|
|
if filter_empty and has_instances:
|
|
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
|
|
if min_keypoints > 0 and has_instances:
|
|
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
|
|
|
return dataset_dicts
|
|
|
|
|
|
class MultiDatasetSampler(Sampler):
|
|
def __init__(
|
|
self,
|
|
dataset_dicts,
|
|
dataset_ratio,
|
|
use_rfs,
|
|
dataset_ann,
|
|
repeat_threshold=0.001,
|
|
seed: Optional[int] = None,
|
|
):
|
|
"""
|
|
"""
|
|
sizes = [0 for _ in range(len(dataset_ratio))]
|
|
for d in dataset_dicts:
|
|
sizes[d['dataset_source']] += 1
|
|
logger.info('dataset sizes {}'.format(sizes))
|
|
|
|
self.sizes = sizes
|
|
assert len(dataset_ratio) == len(sizes), \
|
|
'length of dataset ratio {} should be equal to number if dataset {}'.format(
|
|
len(dataset_ratio), len(sizes)
|
|
)
|
|
if seed is None:
|
|
seed = comm.shared_random_seed()
|
|
self._seed = int(seed)
|
|
self._rank = comm.get_rank()
|
|
self._world_size = comm.get_world_size()
|
|
|
|
self.dataset_ids = torch.tensor(
|
|
[d['dataset_source'] for d in dataset_dicts], dtype=torch.long)
|
|
|
|
dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \
|
|
for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]
|
|
dataset_weight = torch.cat(dataset_weight)
|
|
|
|
rfs_factors = []
|
|
st = 0
|
|
for i, s in enumerate(sizes):
|
|
if use_rfs[i]:
|
|
if dataset_ann[i] == 'box':
|
|
rfs_func = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency
|
|
else:
|
|
rfs_func = repeat_factors_from_tag_frequency
|
|
rfs_factor = rfs_func(
|
|
dataset_dicts[st: st + s],
|
|
repeat_thresh=repeat_threshold)
|
|
rfs_factor = rfs_factor * (s / rfs_factor.sum())
|
|
else:
|
|
rfs_factor = torch.ones(s)
|
|
rfs_factors.append(rfs_factor)
|
|
st = st + s
|
|
rfs_factors = torch.cat(rfs_factors)
|
|
|
|
self.weights = dataset_weight * rfs_factors
|
|
self.sample_epoch_size = len(self.weights)
|
|
|
|
|
|
def __iter__(self):
|
|
start = self._rank
|
|
yield from itertools.islice(
|
|
self._infinite_indices(), start, None, self._world_size)
|
|
|
|
|
|
def _infinite_indices(self):
|
|
g = torch.Generator()
|
|
g.manual_seed(self._seed)
|
|
while True:
|
|
ids = torch.multinomial(
|
|
self.weights, self.sample_epoch_size, generator=g,
|
|
replacement=True)
|
|
nums = [(self.dataset_ids[ids] == i).sum().int().item() \
|
|
for i in range(len(self.sizes))]
|
|
yield from ids
|
|
|
|
|
|
class MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):
|
|
def __init__(self, dataset, batch_size, num_datasets):
|
|
"""
|
|
"""
|
|
self.dataset = dataset
|
|
self.batch_size = batch_size
|
|
self._buckets = [[] for _ in range(2 * num_datasets)]
|
|
|
|
def __iter__(self):
|
|
for d in self.dataset:
|
|
w, h = d["width"], d["height"]
|
|
aspect_ratio_bucket_id = 0 if w > h else 1
|
|
bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
|
|
bucket = self._buckets[bucket_id]
|
|
bucket.append(d)
|
|
if len(bucket) == self.batch_size:
|
|
yield bucket[:]
|
|
del bucket[:]
|
|
|
|
|
|
class DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):
|
|
def __init__(self, dataset, batch_sizes, num_datasets):
|
|
"""
|
|
"""
|
|
self.dataset = dataset
|
|
self.batch_sizes = batch_sizes
|
|
self._buckets = [[] for _ in range(2 * num_datasets)]
|
|
|
|
def __iter__(self):
|
|
for d in self.dataset:
|
|
w, h = d["width"], d["height"]
|
|
aspect_ratio_bucket_id = 0 if w > h else 1
|
|
bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
|
|
bucket = self._buckets[bucket_id]
|
|
bucket.append(d)
|
|
if len(bucket) == self.batch_sizes[d['dataset_source']]:
|
|
yield bucket[:]
|
|
del bucket[:]
|
|
|
|
|
|
def repeat_factors_from_tag_frequency(dataset_dicts, repeat_thresh):
|
|
"""
|
|
"""
|
|
category_freq = defaultdict(int)
|
|
for dataset_dict in dataset_dicts:
|
|
cat_ids = dataset_dict['pos_category_ids']
|
|
for cat_id in cat_ids:
|
|
category_freq[cat_id] += 1
|
|
num_images = len(dataset_dicts)
|
|
for k, v in category_freq.items():
|
|
category_freq[k] = v / num_images
|
|
|
|
category_rep = {
|
|
cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))
|
|
for cat_id, cat_freq in category_freq.items()
|
|
}
|
|
|
|
rep_factors = []
|
|
for dataset_dict in dataset_dicts:
|
|
cat_ids = dataset_dict['pos_category_ids']
|
|
rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
|
|
rep_factors.append(rep_factor)
|
|
|
|
return torch.tensor(rep_factors, dtype=torch.float32)
|