mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Update the ClassBalancedDataset logic to keep len(repeat_factors) = len(dataset) (#1048)
This commit is contained in:
parent
093068e4ff
commit
fd84c210e5
@ -429,12 +429,16 @@ class ClassBalancedDataset:
|
||||
# r(I) = max_{c in L(I)} r(c)
|
||||
repeat_factors = []
|
||||
for idx in range(num_images):
|
||||
# the length of `repeat_factors` need equal to the length of
|
||||
# dataset. Hence, if the `cat_ids` is empty,
|
||||
# the repeat_factor should be 1.
|
||||
repeat_factor: float = 1.
|
||||
cat_ids = set(self.dataset.get_cat_ids(idx))
|
||||
if len(cat_ids) != 0:
|
||||
repeat_factor = max(
|
||||
{category_repeat[cat_id]
|
||||
for cat_id in cat_ids})
|
||||
repeat_factors.append(repeat_factor)
|
||||
repeat_factors.append(repeat_factor)
|
||||
|
||||
return repeat_factors
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user