68 lines
2.4 KiB
Python
68 lines
2.4 KiB
Python
import numpy as np
|
|
|
|
from mmcv.runner import Hook
|
|
|
|
from openselfsup.utils import print_log
|
|
from .registry import HOOKS
|
|
|
|
|
|
@HOOKS.register_module
|
|
class ODCHook(Hook):
|
|
|
|
def __init__(self,
|
|
centroids_update_interval,
|
|
deal_with_small_clusters_interval,
|
|
evaluate_interval,
|
|
reweight,
|
|
reweight_pow,
|
|
dist_mode=True):
|
|
assert dist_mode, "non-dist mode is not implemented"
|
|
self.centroids_update_interval = centroids_update_interval
|
|
self.deal_with_small_clusters_interval = \
|
|
deal_with_small_clusters_interval
|
|
self.evaluate_interval = evaluate_interval
|
|
self.reweight = reweight
|
|
self.reweight_pow = reweight_pow
|
|
|
|
def after_train_iter(self, runner):
|
|
# centroids update
|
|
if self.every_n_iters(runner, self.centroids_update_interval):
|
|
runner.model.module.memory_bank.update_centroids_memory()
|
|
|
|
# deal with small clusters
|
|
if self.every_n_iters(runner, self.deal_with_small_clusters_interval):
|
|
runner.model.module.memory_bank.deal_with_small_clusters()
|
|
|
|
# reweight
|
|
runner.model.module.set_reweight()
|
|
|
|
# evaluate
|
|
if self.every_n_iters(runner, self.evaluate_interval):
|
|
new_labels = runner.model.module.memory_bank.label_bank
|
|
if new_labels.is_cuda:
|
|
new_labels = new_labels.cpu()
|
|
self.evaluate(runner, new_labels.numpy())
|
|
|
|
def after_train_epoch(self, runner):
|
|
# save cluster
|
|
if self.every_n_epochs(10) and runner.rank == 0:
|
|
new_labels = runner.model.module.memory_bank.label_bank
|
|
if new_labels.is_cuda:
|
|
new_labels = new_labels.cpu()
|
|
np.save(
|
|
"{}/cluster_epoch_{}.npy".format(runner.work_dir,
|
|
runner.epoch),
|
|
new_labels.numpy())
|
|
|
|
def evaluate(self, runner, new_labels):
|
|
hist = np.bincount(
|
|
new_labels, minlength=runner.model.module.memory_bank.num_classes)
|
|
empty_cls = (hist == 0).sum()
|
|
minimal_cls_size, maximal_cls_size = hist.min(), hist.max()
|
|
if runner.rank == 0:
|
|
print_log(
|
|
"empty_num: {}\tmin_cluster: {}\tmax_cluster:{}".format(
|
|
empty_cls.item(), minimal_cls_size.item(),
|
|
maximal_cls_size.item()),
|
|
logger='root')
|