[Fix] Update the ClassBalancedDataset logic to keep len(repeat_factors) = len(dataset) (#1048)

This commit is contained in:
BigDong 2023-04-04 14:27:27 +08:00 committed by GitHub
parent 093068e4ff
commit fd84c210e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -429,12 +429,16 @@ class ClassBalancedDataset:
# r(I) = max_{c in L(I)} r(c)
repeat_factors = []
for idx in range(num_images):
# the length of `repeat_factors` need equal to the length of
# dataset. Hence, if the `cat_ids` is empty,
# the repeat_factor should be 1.
repeat_factor: float = 1.
cat_ids = set(self.dataset.get_cat_ids(idx))
if len(cat_ids) != 0:
repeat_factor = max(
{category_repeat[cat_id]
for cat_id in cat_ids})
repeat_factors.append(repeat_factor)
repeat_factors.append(repeat_factor)
return repeat_factors