add 'set_reweight' to hook

pull/352/head
fangyixiao18 2022-05-24 15:47:00 +08:00
parent 99de0c1aeb
commit 1f69aa84f7
2 changed files with 49 additions and 2 deletions

View File

@ -105,7 +105,7 @@ class DeepClusterHook(Hook):
# step 4 (b): set loss reweight
if self.reweight:
runner.model.module.set_reweight(new_labels, self.reweight_pow)
self.set_reweight(runner, new_labels, self.reweight_pow)
# step 5: randomize classifier
runner.model.module.head._is_init = False
@ -119,6 +119,7 @@ class DeepClusterHook(Hook):
runner.model.module.memory_bank.init_memory(features, new_labels)
def evaluate(self, runner, new_labels: np.ndarray) -> None:
"""Evaluate with labels histogram."""
histogram = np.bincount(new_labels, minlength=self.clustering_cfg.k)
empty_cls = (histogram == 0).sum()
minimal_cls_size, maximal_cls_size = histogram.min(), histogram.max()
@ -128,3 +129,23 @@ class DeepClusterHook(Hook):
f'min_cluster: {minimal_cls_size.item()}\t'
f'max_cluster:{maximal_cls_size.item()}',
logger='root')
def set_reweight(self,
runner,
labels: np.ndarray,
reweight_pow: float = 0.5):
"""Loss re-weighting.
Re-weighting the loss according to the number of samples in each class.
Args:
runner (mmengine.Runner): mmengine Runner.
labels (numpy.ndarray): Label assignments.
reweight_pow (float, optional): The power of re-weighting. Defaults
to 0.5.
"""
histogram = np.bincount(
labels, minlength=self.num_classes).astype(np.float32)
inv_histogram = (1. / (histogram + 1e-10))**reweight_pow
weight = inv_histogram / inv_histogram.sum()
runner.model.module.loss_weight.copy_(torch.from_numpy(weight))

View File

@ -2,6 +2,7 @@
from typing import Optional
import numpy as np
import torch
from mmengine.hooks import Hook
from mmengine.logging import print_log
@ -50,7 +51,7 @@ class ODCHook(Hook):
runner.model.module.memory_bank.deal_with_small_clusters()
# reweight
runner.model.module.set_reweight()
self.set_reweight(runner)
# evaluate
if self.every_n_iters(runner, self.evaluate_interval):
@ -79,3 +80,28 @@ class ODCHook(Hook):
f'min_cluster: {minimal_cls_size.item()}\t'
f'max_cluster:{maximal_cls_size.item()}',
logger='root')
def set_reweight(self,
runner,
labels: Optional[np.ndarray] = None,
reweight_pow: float = 0.5):
"""Loss re-weighting.
Re-weighting the loss according to the number of samples in each class.
Args:
runner (mmengine.Runner): mmengine Runner.
labels (numpy.ndarray): Label assignments.
reweight_pow (float, optional): The power of re-weighting. Defaults
to 0.5.
"""
if labels is None:
if self.memory_bank.label_bank.is_cuda:
labels = self.memory_bank.label_bank.cpu().numpy()
else:
labels = self.memory_bank.label_bank.numpy()
histogram = np.bincount(
labels, minlength=self.num_classes).astype(np.float32)
inv_histogram = (1. / (histogram + 1e-10))**reweight_pow
weight = inv_histogram / inv_histogram.sum()
runner.model.module.loss_weight.copy_(torch.from_numpy(weight))