code formatting
parent
cd71c12a94
commit
ec3a59d23e
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .batch_balance import BatchBalanceSampler
|
||||
from .repeat_aug import RepeatAugSampler
|
||||
from .sequential import SequentialSampler
|
||||
from .batch_balance import BatchBalanceSampler
|
||||
|
||||
__all__ = ['RepeatAugSampler', 'SequentialSampler', 'BatchBalanceSampler']
|
||||
|
|
|
@ -1,16 +1,21 @@
|
|||
from typing import Iterator
|
||||
from mmengine.dataset import DefaultSampler
|
||||
from mmpretrain.registry import DATA_SAMPLERS
|
||||
import numpy as np
|
||||
import math
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import collections
|
||||
import math
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dataset import DefaultSampler
|
||||
|
||||
from mmpretrain.registry import DATA_SAMPLERS
|
||||
|
||||
|
||||
@DATA_SAMPLERS.register_module()
|
||||
class BatchBalanceSampler(DefaultSampler):
|
||||
"""
|
||||
refer: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/v2.3.0/src/pytorch_metric_learning/samplers/num_per_class_sampler.py
|
||||
refer: https://github.com/KevinMusgrave/pytorch-metric-learning/
|
||||
blob/v2.3.0/src/pytorch_metric_learning/samplers/num_per_class_sampler.py
|
||||
|
||||
|
||||
At every iteration, this will return m samples per class. For example,
|
||||
if dataloader's batchsize is 100, and m = 5, then 20 classes with 5 samples
|
||||
|
@ -31,9 +36,8 @@ class BatchBalanceSampler(DefaultSampler):
|
|||
super().__init__(**kwargs)
|
||||
self.num_per_class = int(num_per_class)
|
||||
self.labels_to_indices = self.get_labels_to_indices(
|
||||
self.dataset.get_gt_labels()
|
||||
)
|
||||
self.labels = list(self.labels_to_indices.keys()) # labels index list
|
||||
self.dataset.get_gt_labels())
|
||||
self.labels = list(self.labels_to_indices.keys()) # labels index list
|
||||
self.length_of_single_pass = self.num_per_class * len(self.labels)
|
||||
|
||||
self.total_size = len(self.dataset)
|
||||
|
@ -41,7 +45,8 @@ class BatchBalanceSampler(DefaultSampler):
|
|||
if self.length_of_single_pass < self.total_size:
|
||||
self.total_size -= (self.total_size) % (self.length_of_single_pass)
|
||||
# The number of samples in this rank
|
||||
self.num_samples = math.ceil((self.total_size - self.rank) / self.world_size)
|
||||
self.num_samples = math.ceil(
|
||||
(self.total_size - self.rank) / self.world_size)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of samples in this rank."""
|
||||
|
@ -55,13 +60,13 @@ class BatchBalanceSampler(DefaultSampler):
|
|||
np.random.shuffle(self.labels)
|
||||
curr_label_set = self.labels
|
||||
for label in curr_label_set:
|
||||
t = self.labels_to_indices[label] # List of all sample indexes corresponding to the current label
|
||||
indices[i : i + self.num_per_class] = self.safe_random_choice(
|
||||
t, size=self.num_per_class
|
||||
)
|
||||
# List of all sample indexes corresponding to the current label
|
||||
t = self.labels_to_indices[label]
|
||||
indices[i:i + self.num_per_class] = self.safe_random_choice(
|
||||
t, size=self.num_per_class)
|
||||
i += self.num_per_class
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.world_size]
|
||||
indices = indices[self.rank:self.total_size:self.world_size]
|
||||
|
||||
return iter(indices)
|
||||
|
||||
|
@ -70,8 +75,9 @@ class BatchBalanceSampler(DefaultSampler):
|
|||
return self.total_size // divisor if divisor < self.total_size else 1
|
||||
|
||||
def safe_random_choice(self, input_data, size):
|
||||
"""
|
||||
Randomly samples without replacement from a sequence. It is "safe" because
|
||||
"""Randomly samples without replacement from a sequence.
|
||||
|
||||
It is "safe" because
|
||||
if len(input_data) < size, it will randomly sample WITH replacement
|
||||
Args:
|
||||
input_data is a sequence, like a torch tensor, numpy array,
|
||||
|
@ -84,9 +90,9 @@ class BatchBalanceSampler(DefaultSampler):
|
|||
return np.random.choice(input_data, size=size, replace=replace)
|
||||
|
||||
def get_labels_to_indices(self, labels):
|
||||
"""
|
||||
Creates labels_to_indices, which is a dictionary mapping each label
|
||||
to a numpy array of indices that will be used to index into self.dataset
|
||||
"""Creates labels_to_indices, which is a dictionary mapping each label
|
||||
to a numpy array of indices that will be used to index into
|
||||
self.dataset.
|
||||
|
||||
{labels_index:Index of samples belonging to the category}
|
||||
|
||||
|
@ -95,7 +101,6 @@ class BatchBalanceSampler(DefaultSampler):
|
|||
"1":[2,4,5,7],
|
||||
"2":[0,9,10]
|
||||
}
|
||||
|
||||
"""
|
||||
if torch.is_tensor(labels):
|
||||
labels = labels.cpu().numpy()
|
||||
|
|
Loading…
Reference in New Issue