[Docs] further detail for the doc for `ClassBalancedDataset`. (#901)
* futher detail for the doc for datasets/dataset_wrappers/ClassBalancedDataset * fixpull/1171/head
parent
8c63bb55a5
commit
aacaa7316c
|
@ -175,17 +175,20 @@ class ClassBalancedDataset(object):
|
|||
|
||||
1. For each category c, compute the fraction :math:`f(c)` of images that
|
||||
contain it.
|
||||
2. For each category c, compute the category-level repeat factor
|
||||
2. For each category c, compute the category-level repeat factor.
|
||||
|
||||
.. math::
|
||||
r(c) = \max(1, \sqrt{\frac{t}{f(c)}})
|
||||
|
||||
|
||||
where :math:`t` is `oversample_thr`.
|
||||
3. For each image I and its labels :math:`L(I)`, compute the image-level
|
||||
repeat factor
|
||||
repeat factor.
|
||||
|
||||
.. math::
|
||||
r(I) = \max_{c \in L(I)} r(c)
|
||||
|
||||
Each image repeats :math:`\lceil r(I) \rceil` times.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`BaseDataset`): The dataset to be repeated.
|
||||
oversample_thr (float): frequency threshold below which data is
|
||||
|
@ -214,8 +217,8 @@ class ClassBalancedDataset(object):
|
|||
self.flag = np.asarray(flags, dtype=np.uint8)
|
||||
|
||||
def _get_repeat_factors(self, dataset, repeat_thr):
|
||||
# 1. For each category c, compute the fraction # of images
|
||||
# that contain it: f(c)
|
||||
# 1. For each category c, compute the fraction of images
|
||||
# that contain it: f(c)
|
||||
category_freq = defaultdict(int)
|
||||
num_images = len(dataset)
|
||||
for idx in range(num_images):
|
||||
|
@ -227,15 +230,15 @@ class ClassBalancedDataset(object):
|
|||
category_freq[k] = v / num_images
|
||||
|
||||
# 2. For each category c, compute the category-level repeat factor:
|
||||
# r(c) = max(1, sqrt(t/f(c)))
|
||||
# r(c) = max(1, sqrt(t/f(c)))
|
||||
category_repeat = {
|
||||
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
|
||||
for cat_id, cat_freq in category_freq.items()
|
||||
}
|
||||
|
||||
# 3. For each image I and its labels L(I), compute the image-level
|
||||
# repeat factor:
|
||||
# r(I) = max_{c in L(I)} r(c)
|
||||
# repeat factor:
|
||||
# r(I) = max_{c in L(I)} r(c)
|
||||
repeat_factors = []
|
||||
for idx in range(num_images):
|
||||
cat_ids = set(self.dataset.get_cat_ids(idx))
|
||||
|
|
Loading…
Reference in New Issue