MQ-Det/maskrcnn_benchmark/data/datasets/custom_distributed_sampler.py

186 lines
7.8 KiB
Python
Raw Normal View History

2023-10-07 23:02:26 +08:00
import math
from typing import TypeVar, Optional, Iterator
import torch
from torch.utils.data import Sampler, Dataset
import torch.distributed as dist
import random
import numpy as np
import torch
class DistributedSamplerChunkByNode(torch.utils.data.Sampler):
def __init__(self,
dataset,
all_datasets,
chunk_or_not,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
node_rank=0,
node_number=1, process_num_per_node=1,
rank_within_local_node=0) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.node_number = node_number
self.node_rank = node_rank
self.chunk_or_not = chunk_or_not
self.process_num_per_node = process_num_per_node
self.rank_within_local_node = rank_within_local_node
assert (self.process_num_per_node * self.node_number == self.num_replicas)
# 1. divide the datasets into two parts
normal_datasets = []
chunked_datasets = []
for dataset_i, chunk_i in zip(all_datasets, chunk_or_not):
if chunk_i:
chunked_datasets.append(dataset_i)
else:
normal_datasets.append(dataset_i)
# 2. calculate dataset sizes:
self.normal_dataset_size = sum(
[len(i) for i in normal_datasets]) # this part we follow the conventional distributed sampler
# 3. Divide
self.current_node_start_range = -1
self.current_node_end_range = -1
assert (len(chunked_datasets) >= self.node_number)
chunk_size = len(chunked_datasets) // self.node_number
current_example_num = self.normal_dataset_size
for index in range(len(chunked_datasets)):
if index == self.node_rank * chunk_size:
self.current_node_start_range = current_example_num
current_example_num += len(chunked_datasets[index])
if index == (self.node_rank + 1) * chunk_size - 1:
self.current_node_end_range = current_example_num
if self.current_node_end_range == -1: # boundary
self.current_node_end_range = current_example_num
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
# `type:ignore` is required because Dataset cannot provide a default __len__
# see NOTE in pytorch/torch/utils/data/sampler.py
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
def __iter__(self):
indices = self.generate_indices_within_range_with_rank(
seed=self.seed,
epoch=self.epoch,
# NOTE: Distribute among all processes
process_num=self.num_replicas,
rank=self.rank,
generate_length=-1,
valid_indices=list(range(self.normal_dataset_size)),
prefix="Normal "
)
addition_indices = self.generate_indices_within_range_with_rank(
seed=self.seed,
epoch=self.epoch,
# NOTE : very important arguments, distribute among local nodes
process_num=self.process_num_per_node,
rank=self.rank_within_local_node,
generate_length=self.num_samples - len(indices),
valid_indices=list(range(self.current_node_start_range, self.current_node_end_range)),
prefix="Distribute "
)
indices.extend(addition_indices)
random.seed(self.seed + self.epoch + 10 * self.rank) # Set the seed to maximize randomness
random.shuffle(indices) # Reshuffle
assert len(indices) == self.num_samples
return iter(indices)
def generate_indices_within_range_with_rank(self, seed, epoch, process_num, generate_length, valid_indices, rank=-1,
shuffle=True, prefix=""):
'''
Use scenario : we want to sample 2500 examples from 10000 examples, while not sampling overlapping examples with other three process.
Modified from DistributedSampler
'''
dataset_size = len(valid_indices)
if shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(seed + epoch)
indices = torch.randperm(dataset_size, generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(dataset_size)) # type: ignore[arg-type]
indices = [valid_indices[i] for i in indices]
num_samples_normal = math.ceil(
(dataset_size - process_num) / process_num # type: ignore[arg-type]
)
# remove tail of data to make it evenly divisible.
indices = indices[:num_samples_normal * process_num]
print("\n")
print(prefix,
"Global Rank {} Local Rank {} generate_length {} valid_indices {} process_num {} indices_before_subsample {} {}".format(
self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10]))
# subsample
indices = indices[rank:num_samples_normal * process_num: process_num]
print(prefix,
"Global Rank {} Local Rank {} generate_length {} valid_indices {} process_num {} indices_after_subsample {} {}".format(
self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10]))
print("\n")
if generate_length != -1:
if len(indices) > generate_length:
indices = indices[:generate_length]
else:
indices.extend(np.random.choice(valid_indices, generate_length - len(indices)).tolist())
return indices
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch