add 'set_reweight' to hook
parent
99de0c1aeb
commit
1f69aa84f7
|
@ -105,7 +105,7 @@ class DeepClusterHook(Hook):
|
|||
|
||||
# step 4 (b): set loss reweight
|
||||
if self.reweight:
|
||||
runner.model.module.set_reweight(new_labels, self.reweight_pow)
|
||||
self.set_reweight(runner, new_labels, self.reweight_pow)
|
||||
|
||||
# step 5: randomize classifier
|
||||
runner.model.module.head._is_init = False
|
||||
|
@ -119,6 +119,7 @@ class DeepClusterHook(Hook):
|
|||
runner.model.module.memory_bank.init_memory(features, new_labels)
|
||||
|
||||
def evaluate(self, runner, new_labels: np.ndarray) -> None:
|
||||
"""Evaluate with labels histogram."""
|
||||
histogram = np.bincount(new_labels, minlength=self.clustering_cfg.k)
|
||||
empty_cls = (histogram == 0).sum()
|
||||
minimal_cls_size, maximal_cls_size = histogram.min(), histogram.max()
|
||||
|
@ -128,3 +129,23 @@ class DeepClusterHook(Hook):
|
|||
f'min_cluster: {minimal_cls_size.item()}\t'
|
||||
f'max_cluster:{maximal_cls_size.item()}',
|
||||
logger='root')
|
||||
|
||||
def set_reweight(self,
|
||||
runner,
|
||||
labels: np.ndarray,
|
||||
reweight_pow: float = 0.5):
|
||||
"""Loss re-weighting.
|
||||
|
||||
Re-weighting the loss according to the number of samples in each class.
|
||||
|
||||
Args:
|
||||
runner (mmengine.Runner): mmengine Runner.
|
||||
labels (numpy.ndarray): Label assignments.
|
||||
reweight_pow (float, optional): The power of re-weighting. Defaults
|
||||
to 0.5.
|
||||
"""
|
||||
histogram = np.bincount(
|
||||
labels, minlength=self.num_classes).astype(np.float32)
|
||||
inv_histogram = (1. / (histogram + 1e-10))**reweight_pow
|
||||
weight = inv_histogram / inv_histogram.sum()
|
||||
runner.model.module.loss_weight.copy_(torch.from_numpy(weight))
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import print_log
|
||||
|
||||
|
@ -50,7 +51,7 @@ class ODCHook(Hook):
|
|||
runner.model.module.memory_bank.deal_with_small_clusters()
|
||||
|
||||
# reweight
|
||||
runner.model.module.set_reweight()
|
||||
self.set_reweight(runner)
|
||||
|
||||
# evaluate
|
||||
if self.every_n_iters(runner, self.evaluate_interval):
|
||||
|
@ -79,3 +80,28 @@ class ODCHook(Hook):
|
|||
f'min_cluster: {minimal_cls_size.item()}\t'
|
||||
f'max_cluster:{maximal_cls_size.item()}',
|
||||
logger='root')
|
||||
|
||||
def set_reweight(self,
|
||||
runner,
|
||||
labels: Optional[np.ndarray] = None,
|
||||
reweight_pow: float = 0.5):
|
||||
"""Loss re-weighting.
|
||||
|
||||
Re-weighting the loss according to the number of samples in each class.
|
||||
|
||||
Args:
|
||||
runner (mmengine.Runner): mmengine Runner.
|
||||
labels (numpy.ndarray): Label assignments.
|
||||
reweight_pow (float, optional): The power of re-weighting. Defaults
|
||||
to 0.5.
|
||||
"""
|
||||
if labels is None:
|
||||
if self.memory_bank.label_bank.is_cuda:
|
||||
labels = self.memory_bank.label_bank.cpu().numpy()
|
||||
else:
|
||||
labels = self.memory_bank.label_bank.numpy()
|
||||
histogram = np.bincount(
|
||||
labels, minlength=self.num_classes).astype(np.float32)
|
||||
inv_histogram = (1. / (histogram + 1e-10))**reweight_pow
|
||||
weight = inv_histogram / inv_histogram.sum()
|
||||
runner.model.module.loss_weight.copy_(torch.from_numpy(weight))
|
||||
|
|
Loading…
Reference in New Issue