diff --git a/mmselfsup/core/hooks/deepcluster_hook.py b/mmselfsup/core/hooks/deepcluster_hook.py index 2e608149..13ad708d 100644 --- a/mmselfsup/core/hooks/deepcluster_hook.py +++ b/mmselfsup/core/hooks/deepcluster_hook.py @@ -105,7 +105,7 @@ class DeepClusterHook(Hook): # step 4 (b): set loss reweight if self.reweight: - runner.model.module.set_reweight(new_labels, self.reweight_pow) + self.set_reweight(runner, new_labels, self.reweight_pow) # step 5: randomize classifier runner.model.module.head._is_init = False @@ -119,6 +119,7 @@ class DeepClusterHook(Hook): runner.model.module.memory_bank.init_memory(features, new_labels) def evaluate(self, runner, new_labels: np.ndarray) -> None: + """Evaluate with labels histogram.""" histogram = np.bincount(new_labels, minlength=self.clustering_cfg.k) empty_cls = (histogram == 0).sum() minimal_cls_size, maximal_cls_size = histogram.min(), histogram.max() @@ -128,3 +129,23 @@ class DeepClusterHook(Hook): f'min_cluster: {minimal_cls_size.item()}\t' f'max_cluster:{maximal_cls_size.item()}', logger='root') + + def set_reweight(self, + runner, + labels: np.ndarray, + reweight_pow: float = 0.5): + """Loss re-weighting. + + Re-weighting the loss according to the number of samples in each class. + + Args: + runner (mmengine.Runner): mmengine Runner. + labels (numpy.ndarray): Label assignments. + reweight_pow (float, optional): The power of re-weighting. Defaults + to 0.5. + """ + histogram = np.bincount( + labels, minlength=self.num_classes).astype(np.float32) + inv_histogram = (1. / (histogram + 1e-10))**reweight_pow + weight = inv_histogram / inv_histogram.sum() + runner.model.module.loss_weight.copy_(torch.from_numpy(weight)) diff --git a/mmselfsup/core/hooks/odc_hook.py b/mmselfsup/core/hooks/odc_hook.py index a756bc95..36232251 100644 --- a/mmselfsup/core/hooks/odc_hook.py +++ b/mmselfsup/core/hooks/odc_hook.py @@ -2,6 +2,7 @@ from typing import Optional import numpy as np +import torch from mmengine.hooks import Hook from mmengine.logging import print_log @@ -50,7 +51,7 @@ class ODCHook(Hook): runner.model.module.memory_bank.deal_with_small_clusters() # reweight - runner.model.module.set_reweight() + self.set_reweight(runner) # evaluate if self.every_n_iters(runner, self.evaluate_interval): @@ -79,3 +80,28 @@ class ODCHook(Hook): f'min_cluster: {minimal_cls_size.item()}\t' f'max_cluster:{maximal_cls_size.item()}', logger='root') + + def set_reweight(self, + runner, + labels: Optional[np.ndarray] = None, + reweight_pow: float = 0.5): + """Loss re-weighting. + + Re-weighting the loss according to the number of samples in each class. + + Args: + runner (mmengine.Runner): mmengine Runner. + labels (numpy.ndarray): Label assignments. + reweight_pow (float, optional): The power of re-weighting. Defaults + to 0.5. + """ + if labels is None: + if self.memory_bank.label_bank.is_cuda: + labels = self.memory_bank.label_bank.cpu().numpy() + else: + labels = self.memory_bank.label_bank.numpy() + histogram = np.bincount( + labels, minlength=self.num_classes).astype(np.float32) + inv_histogram = (1. / (histogram + 1e-10))**reweight_pow + weight = inv_histogram / inv_histogram.sum() + runner.model.module.loss_weight.copy_(torch.from_numpy(weight))