mirror of https://github.com/YifanXu74/MQ-Det.git
507 lines
24 KiB
Python
507 lines
24 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import bisect
|
|
import copy
|
|
import logging
|
|
import os
|
|
|
|
import torch.utils.data
|
|
import torch.distributed as dist
|
|
from maskrcnn_benchmark.utils.comm import get_world_size
|
|
from maskrcnn_benchmark.utils.imports import import_file
|
|
|
|
from . import datasets as D
|
|
from . import samplers
|
|
|
|
from .collate_batch import BatchCollator, BBoxAugCollator
|
|
from .transforms import build_transforms
|
|
|
|
from transformers import AutoTokenizer
|
|
from .datasets.duplicate_dataset import create_duplicate_dataset
|
|
|
|
def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True, class_concat=False, extra_args={}):
|
|
"""
|
|
Arguments:
|
|
dataset_list (list[str]): Contains the names of the datasets, i.e.,
|
|
coco_2014_trian, coco_2014_val, etc
|
|
transforms (callable): transforms to apply to each (image, target) sample
|
|
dataset_catalog (DatasetCatalog): contains the information on how to
|
|
construct a dataset.
|
|
is_train (bool): whether to setup the dataset for training or testing
|
|
"""
|
|
if not isinstance(dataset_list, (list, tuple)):
|
|
raise RuntimeError(
|
|
"dataset_list should be a list of strings, got {}".format(dataset_list)
|
|
)
|
|
datasets = []
|
|
num_category = 1
|
|
for dataset_id, dataset_name in enumerate(dataset_list, 1):
|
|
if is_train:
|
|
dataset_name = dataset_name + cfg.DATASETS.TRAIN_DATASETNAME_SUFFIX
|
|
else:
|
|
dataset_name = dataset_name + cfg.DATASETS.TEST_DATASETNAME_SUFFIX
|
|
data = dataset_catalog.get(dataset_name)
|
|
factory = getattr(D, data["factory"])
|
|
args = data["args"]
|
|
# for COCODataset, we want to remove images without annotations
|
|
# during training
|
|
if data["factory"] == "COCODataset":
|
|
args["remove_images_without_annotations"] = is_train
|
|
|
|
if data["factory"] == "PascalVOCDataset":
|
|
args["use_difficult"] = not is_train
|
|
if data["factory"] in ["VGTSVDataset", "CocoDetectionTSV", "ODTSVDataset"]:
|
|
args["extra_fields"] = ["class"]
|
|
if cfg.MODEL.MASK_ON:
|
|
args["extra_fields"].append("mask")
|
|
|
|
if data["factory"] in ["CocoGrounding", "CocoGrounding_New", "CocoDetectionTSV", "CaptionTSV", "MixedDataset", "FlickrDataset", "RefExpDataset", "GQADataset", "PseudoData", "PhrasecutDetection"]:
|
|
# args["return_masks"] = False
|
|
args["return_masks"] = cfg.MODEL.MASK_ON
|
|
args["return_tokens"] = True
|
|
args["max_num_labels"] = cfg.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM
|
|
args["max_query_len"] = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN
|
|
|
|
args["transforms"] = transforms
|
|
args.update(extra_args)
|
|
|
|
if dataset_name == "flickr30k_train":
|
|
copy = cfg.DATASETS.FLICKR_COPY
|
|
elif dataset_name in ["mixed_train", "mixed_train_no_coco"]:
|
|
copy = cfg.DATASETS.MIXED_COPY
|
|
elif dataset_name == "COCO_odinw_train_8copy_dt_train":
|
|
copy = cfg.DATASETS.COCO_COPY
|
|
elif dataset_name == "LVIS_odinw_train_8copy_dt_train":
|
|
copy = cfg.DATASETS.LVIS_COPY
|
|
elif dataset_name == "object365_odinw_2copy_dt_train":
|
|
copy = cfg.DATASETS.OBJECT365_COPY
|
|
elif dataset_name == "vg_odinw_clipped_8copy_dt_train":
|
|
copy = cfg.DATASETS.VG_COPY
|
|
elif dataset_name == "vg_vgoi6_clipped_8copy_dt_train":
|
|
copy = cfg.DATASETS.VG_COPY
|
|
elif dataset_name == "imagenetod_train_odinw_2copy_dt":
|
|
copy = cfg.DATASETS.IN_COPY
|
|
elif dataset_name == "oi_train_odinw_dt":
|
|
copy = cfg.DATASETS.OI_COPY
|
|
elif is_train:
|
|
copy = cfg.DATASETS.GENERAL_COPY
|
|
elif not is_train:
|
|
copy = cfg.DATASETS.GENERAL_COPY_TEST
|
|
else:
|
|
copy = -1 # do not ever copy test
|
|
|
|
if (copy != -1) and (not args['is_cache']):
|
|
new_factory = create_duplicate_dataset(factory)
|
|
dataset = new_factory(copy=copy, **args)
|
|
else:
|
|
# make dataset from factory
|
|
dataset = factory(**args)
|
|
|
|
print(dataset_name, 'has the {} data points'.format(len(dataset)), data["factory"])
|
|
|
|
if class_concat:
|
|
category = list(dataset.contiguous_category_id_to_json_id.values())
|
|
dataset.contiguous_category_id_to_json_id = {}
|
|
dataset.json_category_id_to_contiguous_id = {}
|
|
for id, cat in enumerate(category, start=num_category):
|
|
dataset.json_category_id_to_contiguous_id[cat] = id
|
|
dataset.contiguous_category_id_to_json_id[id] = cat
|
|
num_category += len(category)
|
|
print("Found {} #category after group {}, concating ...".format(num_category, dataset_id))
|
|
datasets.append(dataset)
|
|
|
|
# for testing, return a list of datasets
|
|
if not is_train:
|
|
return datasets
|
|
|
|
# for training, concatenate all datasets into a single one
|
|
dataset = datasets[0]
|
|
if len(datasets) > 1:
|
|
dataset = D.ConcatDataset(datasets)
|
|
|
|
return [dataset]
|
|
|
|
|
|
def build_dataset_by_group(dataset_list, transforms, dataset_catalog, is_train=True, class_by_group=True,
|
|
class_concat=False, extra_args={}):
|
|
"""
|
|
Arguments:
|
|
dataset_list (list[str]): Contains the names of the datasets, i.e.,
|
|
coco_2014_trian, coco_2014_val, etc
|
|
transforms (callable): transforms to apply to each (image, target) sample
|
|
dataset_catalog (DatasetCatalog): contains the information on how to
|
|
construct a dataset.
|
|
is_train (bool): whether to setup the dataset for training or testing
|
|
"""
|
|
if not isinstance(dataset_list, (list, tuple)):
|
|
raise RuntimeError(
|
|
"dataset_list should be a list of strings, got {}".format(dataset_list)
|
|
)
|
|
|
|
num_category = 1
|
|
grouped_datasets = []
|
|
for group_id, group in enumerate(dataset_list, 1):
|
|
datasets = []
|
|
for dataset_name in group:
|
|
data = dataset_catalog.get(dataset_name)
|
|
factory = getattr(D, data["factory"])
|
|
args = data["args"]
|
|
# for COCODataset, we want to remove images without annotations
|
|
# during training
|
|
if data["factory"] == "COCODataset":
|
|
args["remove_images_without_annotations"] = is_train
|
|
if data["factory"] == "PascalVOCDataset":
|
|
args["use_difficult"] = not is_train
|
|
args["transforms"] = transforms
|
|
args.update(extra_args)
|
|
# make dataset from factory
|
|
dataset = factory(**args)
|
|
|
|
# check if dataset is grouped by task, assume one class per task
|
|
if class_by_group and data["factory"] != "Background":
|
|
category = dataset.contiguous_category_id_to_json_id[1]
|
|
del dataset.contiguous_category_id_to_json_id[1]
|
|
dataset.json_category_id_to_contiguous_id[category] = group_id
|
|
dataset.contiguous_category_id_to_json_id[group_id] = category
|
|
|
|
datasets.append(dataset)
|
|
|
|
if class_concat:
|
|
for dataset in datasets:
|
|
category = list(dataset.contiguous_category_id_to_json_id.values())
|
|
dataset.contiguous_category_id_to_json_id = {}
|
|
dataset.json_category_id_to_contiguous_id = {}
|
|
for id, cat in enumerate(category, start=num_category):
|
|
dataset.json_category_id_to_contiguous_id[cat] = id
|
|
dataset.contiguous_category_id_to_json_id[id] = cat
|
|
num_category += len(category)
|
|
print("Found {} #category after group {}, concating ...".format(num_category, group_id))
|
|
|
|
if is_train:
|
|
datasets = D.ConcatDataset(datasets)
|
|
|
|
grouped_datasets.append(datasets)
|
|
|
|
# for testing, return a list of datasets
|
|
if not is_train:
|
|
datasets = [dataset for group in grouped_datasets for dataset in group]
|
|
return datasets
|
|
if class_concat:
|
|
grouped_datasets = D.ConcatDataset(grouped_datasets)
|
|
return [grouped_datasets]
|
|
|
|
# for training, concatenate all datasets into a single one
|
|
return grouped_datasets
|
|
|
|
|
|
def make_data_sampler(dataset, shuffle, distributed, num_replicas=None, rank=None, use_random_seed=True):
|
|
if distributed:
|
|
return samplers.DistributedSampler(dataset, shuffle=shuffle, num_replicas=num_replicas, rank=rank,
|
|
use_random=use_random_seed)
|
|
if shuffle:
|
|
sampler = torch.utils.data.sampler.RandomSampler(dataset)
|
|
else:
|
|
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
|
|
return sampler
|
|
|
|
|
|
def _quantize(x, bins):
|
|
bins = copy.copy(bins)
|
|
bins = sorted(bins)
|
|
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
|
|
return quantized
|
|
|
|
|
|
def _compute_aspect_ratios(dataset):
|
|
aspect_ratios = []
|
|
for i in range(len(dataset)):
|
|
img_info = dataset.get_img_info(i)
|
|
aspect_ratio = float(img_info["height"]) / float(img_info["width"])
|
|
aspect_ratios.append(aspect_ratio)
|
|
return aspect_ratios
|
|
|
|
|
|
def make_batch_data_sampler(
|
|
dataset, sampler, aspect_grouping, images_per_batch, num_iters=None, start_iter=0, drop_last=False
|
|
):
|
|
if aspect_grouping:
|
|
if not isinstance(aspect_grouping, (list, tuple)):
|
|
aspect_grouping = [aspect_grouping]
|
|
aspect_ratios = _compute_aspect_ratios(dataset)
|
|
group_ids = _quantize(aspect_ratios, aspect_grouping)
|
|
batch_sampler = samplers.GroupedBatchSampler(
|
|
sampler, group_ids, images_per_batch, drop_uneven=drop_last
|
|
)
|
|
else:
|
|
batch_sampler = torch.utils.data.sampler.BatchSampler(
|
|
sampler, images_per_batch, drop_last=drop_last
|
|
)
|
|
if num_iters is not None:
|
|
batch_sampler = samplers.IterationBasedBatchSampler(
|
|
batch_sampler, num_iters, start_iter
|
|
)
|
|
return batch_sampler
|
|
|
|
def make_data_loader(cfg, is_train=True, is_distributed=False, num_replicas=None, rank=None, start_iter=0, is_cache=False):
|
|
num_gpus = num_replicas or get_world_size()
|
|
|
|
if is_train:
|
|
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
|
|
assert (
|
|
images_per_batch % num_gpus == 0
|
|
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number "
|
|
"of GPUs ({}) used.".format(images_per_batch, num_gpus)
|
|
images_per_gpu = images_per_batch // num_gpus
|
|
shuffle = True
|
|
num_iters = cfg.SOLVER.MAX_ITER
|
|
elif is_cache:
|
|
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
|
|
images_per_gpu = images_per_batch
|
|
# shuffle = False if not is_distributed else True
|
|
shuffle = False
|
|
num_iters = None
|
|
start_iter = 0
|
|
else:
|
|
images_per_batch = cfg.TEST.IMS_PER_BATCH
|
|
assert (
|
|
images_per_batch % num_gpus == 0
|
|
), "TEST.IMS_PER_BATCH ({}) must be divisible by the number "
|
|
"of GPUs ({}) used.".format(images_per_batch, num_gpus)
|
|
images_per_gpu = images_per_batch // num_gpus
|
|
shuffle = False if not is_distributed else True
|
|
num_iters = None
|
|
start_iter = 0
|
|
|
|
if images_per_gpu > 1:
|
|
logger = logging.getLogger(__name__)
|
|
logger.warning(
|
|
"When using more than one image per GPU you may encounter "
|
|
"an out-of-memory (OOM) error if your GPU does not have "
|
|
"sufficient memory. If this happens, you can reduce "
|
|
"SOLVER.IMS_PER_BATCH (for training) or "
|
|
"TEST.IMS_PER_BATCH (for inference). For training, you must "
|
|
"also adjust the learning rate and schedule length according "
|
|
"to the linear scaling rule. See for example: "
|
|
"https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14"
|
|
)
|
|
|
|
# group images which have similar aspect ratio. In this case, we only
|
|
# group in two cases: those with width / height > 1, and the other way around,
|
|
# but the code supports more general grouping strategy
|
|
aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else []
|
|
|
|
paths_catalog = import_file(
|
|
"maskrcnn_benchmark.config.paths_catalog", cfg.PATHS_CATALOG, True
|
|
)
|
|
|
|
DatasetCatalog = paths_catalog.DatasetCatalog
|
|
if len(cfg.DATASETS.REGISTER) > 0:
|
|
for new_dataset in cfg.DATASETS.REGISTER:
|
|
# img_dir = cfg.DATASETS.REGISTER[new_dataset]["img_dir"]
|
|
# if "ann_file" in cfg.DATASETS.REGISTER[new_dataset]:
|
|
# ann_file = cfg.DATASETS.REGISTER[new_dataset]["ann_file"]
|
|
# else:
|
|
# ann_file = None
|
|
attrs = dict(cfg.DATASETS.REGISTER[new_dataset])
|
|
if is_train or is_cache:
|
|
new_dataset = new_dataset + cfg.DATASETS.TRAIN_DATASETNAME_SUFFIX
|
|
else:
|
|
new_dataset = new_dataset + cfg.DATASETS.TEST_DATASETNAME_SUFFIX
|
|
DatasetCatalog.set(new_dataset, attrs)
|
|
|
|
|
|
dataset_list = cfg.DATASETS.TRAIN if is_train or is_cache else cfg.DATASETS.TEST
|
|
|
|
# Haotian: expand bing dataset
|
|
if "bing_caption_train" in dataset_list and len(cfg.DATASETS.BING_INDEX_LIST) > 0:
|
|
dataset_list = list(dataset_list)
|
|
dataset_list.remove("bing_caption_train")
|
|
for bing_index in cfg.DATASETS.BING_INDEX_LIST:
|
|
dataset_list.insert(len(dataset_list), "bing_caption_{}_train".format(bing_index))
|
|
dataset_list = tuple(dataset_list)
|
|
|
|
if "bing_caption_train_no_coco" in dataset_list and len(cfg.DATASETS.BING_INDEX_LIST) > 0:
|
|
dataset_list = list(dataset_list)
|
|
dataset_list.remove("bing_caption_train_no_coco")
|
|
for bing_index in cfg.DATASETS.BING_INDEX_LIST:
|
|
dataset_list.insert(len(dataset_list), "bing_caption_{}_train_no_coco".format(bing_index))
|
|
dataset_list = tuple(dataset_list)
|
|
|
|
print("The combined datasets are: {}.".format(dataset_list))
|
|
|
|
transforms = None if not is_train and cfg.TEST.USE_MULTISCALE else build_transforms(cfg, is_train)
|
|
|
|
extra_args = {}
|
|
if is_train and cfg.DATASETS.USE_CROWD:
|
|
extra_args['ignore_crowd'] = False
|
|
if is_train and cfg.DATASETS.MAX_BOX > 0:
|
|
extra_args['max_box'] = cfg.DATASETS.MAX_BOX
|
|
if (is_train or is_cache) and cfg.DATASETS.FEW_SHOT>0:
|
|
extra_args['few_shot'] = cfg.DATASETS.FEW_SHOT
|
|
if is_train and cfg.DATASETS.SHUFFLE_SEED != 0:
|
|
extra_args['shuffle_seed'] = cfg.DATASETS.SHUFFLE_SEED
|
|
|
|
# od to grounding
|
|
if (is_train or is_cache) and cfg.DATASETS.RANDOM_SAMPLE_NEG > 0:
|
|
extra_args['random_sample_negative'] = cfg.DATASETS.RANDOM_SAMPLE_NEG
|
|
if is_train and cfg.DATASETS.ADD_DET_PROMPT:
|
|
extra_args["add_detection_prompt"] = True
|
|
if is_train and cfg.DATASETS.USE_OD_AUG:
|
|
extra_args["use_od_data_aug"] = True
|
|
if (is_train or is_cache) and cfg.DATASETS.DISABLE_SHUFFLE:
|
|
extra_args["disable_shuffle"] = True
|
|
if cfg.DATASETS.ONE_HOT:
|
|
extra_args["one_hot"] = True
|
|
if is_train and len(cfg.DATASETS.PROMPT_VERSION) > 0:
|
|
extra_args["prompt_engineer_version"] = cfg.DATASETS.PROMPT_VERSION
|
|
if (is_train or is_cache) and len(cfg.DATASETS.CONTROL_PROB) == 4:
|
|
extra_args["control_probabilities"] = cfg.DATASETS.CONTROL_PROB
|
|
if is_train and cfg.DATASETS.DISABLE_CLIP_TO_IMAGE:
|
|
extra_args["disable_clip_to_image"] = cfg.DATASETS.DISABLE_CLIP_TO_IMAGE
|
|
if is_train and cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT:
|
|
extra_args["no_minus_one_for_one_hot"] = cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT
|
|
if (is_train or is_cache):
|
|
extra_args["separation_tokens"] = cfg.DATASETS.SEPARATION_TOKENS
|
|
# caption
|
|
if (is_train or is_cache) and cfg.DATASETS.CAPTION_MIN_BOX > 0:
|
|
extra_args["caption_min_box"] = cfg.DATASETS.CAPTION_MIN_BOX
|
|
if is_train and cfg.DATASETS.REPLACE_CLEAN_LABEL:
|
|
extra_args["replace_clean_label"] = True
|
|
if is_train and cfg.DATASETS.FURTHER_SCREEN:
|
|
extra_args["further_screen"] = True
|
|
if is_train and cfg.DATASETS.CAPTION_CONF > 0.0:
|
|
extra_args["caption_conf"] = cfg.DATASETS.CAPTION_CONF
|
|
if (is_train or is_cache):
|
|
extra_args["caption_nms"] = cfg.DATASETS.CAPTION_NMS
|
|
if is_train and cfg.DATASETS.PACK_RANDOM_CAPTION_NUMBER > 0:
|
|
extra_args["pack_random_caption_number"] = cfg.DATASETS.PACK_RANDOM_CAPTION_NUMBER
|
|
if is_train and cfg.DATASETS.INFERENCE_CAPTION:
|
|
extra_args["inference_caption"] = True
|
|
if is_train and cfg.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA > 0:
|
|
extra_args["sample_negative_for_grounding_data"] = cfg.DATASETS.SAMPLE_NEGATIVE_FOR_GROUNDING_DATA
|
|
if is_train and cfg.DATASETS.RANDOM_PACK_PROB > 0:
|
|
extra_args["random_pack_prob"] = cfg.DATASETS.RANDOM_PACK_PROB
|
|
if is_train and cfg.DATASETS.NO_RANDOM_PACK_PROBABILITY > 0:
|
|
extra_args["no_random_pack_probability"] = cfg.DATASETS.NO_RANDOM_PACK_PROBABILITY
|
|
if (is_train or is_cache):
|
|
extra_args["safeguard_positive_caption"] = cfg.DATASETS.SAFEGUARD_POSITIVE_CAPTION
|
|
if (is_train or is_cache):
|
|
extra_args["local_debug"] = cfg.DATASETS.LOCAL_DEBUG
|
|
if (is_train or is_cache):
|
|
extra_args["no_mask_for_od"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_OD
|
|
if (is_train or is_cache):
|
|
extra_args["no_mask_for_gold"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.NO_MASK_FOR_GOLD
|
|
if (is_train or is_cache):
|
|
extra_args["mlm_obj_for_only_positive"] = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_OBJ_FOR_ONLY_POSITIVE
|
|
if cfg.DATASETS.OVERRIDE_CATEGORY and cfg.DATASETS.USE_OVERRIDE_CATEGORY:
|
|
extra_args["override_category"] = cfg.DATASETS.OVERRIDE_CATEGORY
|
|
if (is_train or is_cache):
|
|
extra_args["caption_format_version"] = cfg.DATASETS.CAPTION_FORMAT_VERSION
|
|
if (is_train or is_cache):
|
|
extra_args["special_safeguard_for_coco_grounding"] = cfg.DATASETS.SPECIAL_SAFEGUARD_FOR_COCO_GROUNDING
|
|
if (is_train or is_cache):
|
|
extra_args["diver_box_for_vqa"] = cfg.DATASETS.DIVER_BOX_FOR_VQA
|
|
if (is_train or is_cache):
|
|
extra_args["exclude_crowd"] = cfg.DATASETS.EXCLUDE_CROWD
|
|
if is_cache:
|
|
extra_args["is_cache"] = True
|
|
else:
|
|
extra_args["is_cache"] = False
|
|
extra_args["caption_prompt"] = cfg.DATASETS.CAPTION_PROMPT
|
|
extra_args["use_caption_prompt"] = cfg.DATASETS.USE_CAPTION_PROMPT
|
|
extra_args['sep_at_last'] = cfg.DATASETS.SEP_AT_LAST
|
|
extra_args['add_normed_cxcy'] = cfg.DATASETS.ADD_NORMED_CXCY
|
|
extra_args["custom_category_ids"] = cfg.VISION_QUERY.CUSTOM_CAT_IDS
|
|
extra_args["cumtom_ids"] = cfg.VISION_QUERY.CUSTOM_DATA_IDS
|
|
|
|
# extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
|
|
if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
|
|
# extra_args['tokenizer'] = build_tokenizer("clip")
|
|
from transformers import CLIPTokenizerFast
|
|
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
|
|
extra_args["tokenizer"] = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True, mask_token='ðŁĴij</w>')
|
|
else:
|
|
extra_args["tokenizer"] = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True)
|
|
else:
|
|
extra_args['tokenizer'] = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
|
|
|
|
if isinstance(dataset_list[0], (tuple, list)):
|
|
datasets = build_dataset_by_group(dataset_list, transforms, DatasetCatalog, (is_train or is_cache),
|
|
class_by_group=cfg.DATASETS.ALTERNATIVE_TRAINING,
|
|
class_concat=cfg.DATASETS.CLASS_CONCAT,
|
|
extra_args=extra_args)
|
|
else:
|
|
datasets = build_dataset(cfg, dataset_list, transforms, DatasetCatalog, (is_train or is_cache),
|
|
class_concat=cfg.DATASETS.CLASS_CONCAT,
|
|
extra_args=extra_args)
|
|
|
|
data_loaders = []
|
|
for di, dataset in enumerate(datasets):
|
|
if is_train and cfg.SOLVER.MAX_EPOCH > 0:
|
|
num_iters = cfg.SOLVER.MAX_EPOCH * len(dataset) // cfg.SOLVER.IMS_PER_BATCH
|
|
print("Number of iterations are {}".format(num_iters))
|
|
cfg.defrost()
|
|
cfg.SOLVER.MAX_ITER = num_iters
|
|
cfg.SOLVER.DATASET_LENGTH = len(dataset)
|
|
cfg.freeze()
|
|
if is_train and cfg.SOLVER.MULTI_MAX_EPOCH:
|
|
num_iters = None
|
|
cfg.defrost()
|
|
cfg.SOLVER.MULTI_MAX_ITER += (cfg.SOLVER.MULTI_MAX_EPOCH[di] * len(dataset) // cfg.SOLVER.IMS_PER_BATCH,)
|
|
cfg.freeze()
|
|
|
|
if is_train and cfg.DATALOADER.DISTRIBUTE_CHUNK_AMONG_NODE:
|
|
from .datasets.custom_distributed_sampler import DistributedSamplerChunkByNode
|
|
chunk_or_not = []
|
|
for i in dataset_list:
|
|
if "bing_caption" in i:
|
|
chunk_or_not.append(True)
|
|
else:
|
|
chunk_or_not.append(False)
|
|
assert(len(chunk_or_not) == len(dataset.datasets))
|
|
'''
|
|
If we are training on 4 nodes, each with 8 GPUs
|
|
'''
|
|
num_nodes = int(os.getenv('NODE_COUNT', os.getenv('OMPI_COMM_WORLD_SIZE', 1)))
|
|
local_size = cfg.num_gpus//num_nodes
|
|
node_rank = int(os.getenv('NODE_RANK', os.getenv('OMPI_COMM_WORLD_RANK', 0)))
|
|
local_rank = cfg.local_rank
|
|
sampler = DistributedSamplerChunkByNode(
|
|
dataset = dataset,
|
|
all_datasets = dataset.datasets, # Assumming dataset is a ConcateDataset instance,
|
|
chunk_or_not = chunk_or_not,
|
|
num_replicas = cfg.num_gpus, # total GPU number, e.g., 32
|
|
rank = dist.get_rank(), # Global Rank, e.g., 0~31
|
|
node_rank = node_rank, # Node Rank, e.g., 0~3
|
|
node_number = num_nodes, # how many node e.g., 4
|
|
process_num_per_node = local_size, # e.g., 8
|
|
rank_within_local_node = local_rank, # e.g., 0~7
|
|
)
|
|
else:
|
|
sampler = make_data_sampler(dataset, shuffle, is_distributed, num_replicas=num_replicas, rank=rank,
|
|
use_random_seed=cfg.DATALOADER.USE_RANDOM_SEED)
|
|
batch_sampler = make_batch_data_sampler(
|
|
dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter, drop_last=is_train
|
|
)
|
|
collator = BBoxAugCollator() if not is_train and cfg.TEST.USE_MULTISCALE else BatchCollator(
|
|
cfg.DATALOADER.SIZE_DIVISIBILITY)
|
|
num_workers = cfg.DATALOADER.NUM_WORKERS
|
|
data_loader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=num_workers,
|
|
batch_sampler=batch_sampler,
|
|
collate_fn=collator,
|
|
)
|
|
data_loaders.append(data_loader)
|
|
if is_train and cfg.SOLVER.MULTI_MAX_EPOCH:
|
|
cfg.defrost()
|
|
cfg.SOLVER.MULTI_MAX_ITER += (
|
|
cfg.SOLVER.MULTI_MAX_EPOCH[-1] * min([len(dataset) // cfg.SOLVER.IMS_PER_BATCH for dataset in datasets]),)
|
|
cfg.freeze()
|
|
|
|
if is_train and not cfg.DATASETS.ALTERNATIVE_TRAINING and not cfg.DATASETS.MULTISTAGE_TRAINING:
|
|
# during training, a single (possibly concatenated) data_loader is returned
|
|
assert len(data_loaders) == 1
|
|
return data_loaders[0]
|
|
|
|
return data_loaders
|