code formatting

pull/1753/head
bobo0810 2023-08-09 11:12:29 +08:00
parent cd71c12a94
commit ec3a59d23e
2 changed files with 27 additions and 22 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .batch_balance import BatchBalanceSampler
from .repeat_aug import RepeatAugSampler
from .sequential import SequentialSampler
from .batch_balance import BatchBalanceSampler
__all__ = ['RepeatAugSampler', 'SequentialSampler', 'BatchBalanceSampler']

View File

@ -1,16 +1,21 @@
from typing import Iterator
from mmengine.dataset import DefaultSampler
from mmpretrain.registry import DATA_SAMPLERS
import numpy as np
import math
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import math
from typing import Iterator
import numpy as np
import torch
from mmengine.dataset import DefaultSampler
from mmpretrain.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class BatchBalanceSampler(DefaultSampler):
"""
refer: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/v2.3.0/src/pytorch_metric_learning/samplers/num_per_class_sampler.py
refer: https://github.com/KevinMusgrave/pytorch-metric-learning/
blob/v2.3.0/src/pytorch_metric_learning/samplers/num_per_class_sampler.py
At every iteration, this will return m samples per class. For example,
if dataloader's batchsize is 100, and m = 5, then 20 classes with 5 samples
@ -31,9 +36,8 @@ class BatchBalanceSampler(DefaultSampler):
super().__init__(**kwargs)
self.num_per_class = int(num_per_class)
self.labels_to_indices = self.get_labels_to_indices(
self.dataset.get_gt_labels()
)
self.labels = list(self.labels_to_indices.keys()) # labels index list
self.dataset.get_gt_labels())
self.labels = list(self.labels_to_indices.keys()) # labels index list
self.length_of_single_pass = self.num_per_class * len(self.labels)
self.total_size = len(self.dataset)
@ -41,7 +45,8 @@ class BatchBalanceSampler(DefaultSampler):
if self.length_of_single_pass < self.total_size:
self.total_size -= (self.total_size) % (self.length_of_single_pass)
# The number of samples in this rank
self.num_samples = math.ceil((self.total_size - self.rank) / self.world_size)
self.num_samples = math.ceil(
(self.total_size - self.rank) / self.world_size)
def __len__(self) -> int:
"""The number of samples in this rank."""
@ -55,13 +60,13 @@ class BatchBalanceSampler(DefaultSampler):
np.random.shuffle(self.labels)
curr_label_set = self.labels
for label in curr_label_set:
t = self.labels_to_indices[label] # List of all sample indexes corresponding to the current label
indices[i : i + self.num_per_class] = self.safe_random_choice(
t, size=self.num_per_class
)
# List of all sample indexes corresponding to the current label
t = self.labels_to_indices[label]
indices[i:i + self.num_per_class] = self.safe_random_choice(
t, size=self.num_per_class)
i += self.num_per_class
# subsample
indices = indices[self.rank : self.total_size : self.world_size]
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
@ -70,8 +75,9 @@ class BatchBalanceSampler(DefaultSampler):
return self.total_size // divisor if divisor < self.total_size else 1
def safe_random_choice(self, input_data, size):
"""
Randomly samples without replacement from a sequence. It is "safe" because
"""Randomly samples without replacement from a sequence.
It is "safe" because
if len(input_data) < size, it will randomly sample WITH replacement
Args:
input_data is a sequence, like a torch tensor, numpy array,
@ -84,9 +90,9 @@ class BatchBalanceSampler(DefaultSampler):
return np.random.choice(input_data, size=size, replace=replace)
def get_labels_to_indices(self, labels):
"""
Creates labels_to_indices, which is a dictionary mapping each label
to a numpy array of indices that will be used to index into self.dataset
"""Creates labels_to_indices, which is a dictionary mapping each label
to a numpy array of indices that will be used to index into
self.dataset.
{labels_indexIndex of samples belonging to the category}
@ -95,7 +101,6 @@ class BatchBalanceSampler(DefaultSampler):
"1":[2,4,5,7],
"2":[0,9,10]
}
"""
if torch.is_tensor(labels):
labels = labels.cpu().numpy()