[Docs] further detail for the doc for `ClassBalancedDataset`. (#901)

* futher detail for the doc for datasets/dataset_wrappers/ClassBalancedDataset

* fix
pull/1171/head
JayChen 2022-11-02 17:52:58 +08:00 committed by GitHub
parent 8c63bb55a5
commit aacaa7316c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 8 deletions

View File

@ -175,17 +175,20 @@ class ClassBalancedDataset(object):
1. For each category c, compute the fraction :math:`f(c)` of images that
contain it.
2. For each category c, compute the category-level repeat factor
2. For each category c, compute the category-level repeat factor.
.. math::
r(c) = \max(1, \sqrt{\frac{t}{f(c)}})
where :math:`t` is `oversample_thr`.
3. For each image I and its labels :math:`L(I)`, compute the image-level
repeat factor
repeat factor.
.. math::
r(I) = \max_{c \in L(I)} r(c)
Each image repeats :math:`\lceil r(I) \rceil` times.
Args:
dataset (:obj:`BaseDataset`): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is
@ -214,8 +217,8 @@ class ClassBalancedDataset(object):
self.flag = np.asarray(flags, dtype=np.uint8)
def _get_repeat_factors(self, dataset, repeat_thr):
# 1. For each category c, compute the fraction # of images
# that contain it: f(c)
# 1. For each category c, compute the fraction of images
# that contain it: f(c)
category_freq = defaultdict(int)
num_images = len(dataset)
for idx in range(num_images):
@ -227,15 +230,15 @@ class ClassBalancedDataset(object):
category_freq[k] = v / num_images
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t/f(c)))
# r(c) = max(1, sqrt(t/f(c)))
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
# 3. For each image I and its labels L(I), compute the image-level
# repeat factor:
# r(I) = max_{c in L(I)} r(c)
# repeat factor:
# r(I) = max_{c in L(I)} r(c)
repeat_factors = []
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))