mmselfsup/openselfsup/hooks/deepcluster_hook.py

110 lines
3.8 KiB
Python

import numpy as np
from mmcv.runner import Hook
import torch
import torch.distributed as dist
from openselfsup.third_party import clustering as _clustering
from openselfsup.utils import print_log
from .registry import HOOKS
from .extractor import Extractor
@HOOKS.register_module
class DeepClusterHook(Hook):
def __init__(
self,
extractor,
clustering,
unif_sampling,
reweight,
reweight_pow,
init_memory=False, # for ODC
initial=True,
interval=1,
dist_mode=True,
data_loaders=None):
self.extractor = Extractor(dist_mode=dist_mode, **extractor)
self.clustering_type = clustering.pop('type')
self.clustering_cfg = clustering
self.unif_sampling = unif_sampling
self.reweight = reweight
self.reweight_pow = reweight_pow
self.init_memory = init_memory
self.initial = initial
self.interval = interval
self.dist_mode = dist_mode
self.data_loaders = data_loaders
def before_run(self, runner):
if self.initial:
self.deepcluster(runner)
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
return
self.deepcluster(runner)
def deepcluster(self, runner):
# step 1: get features
runner.model.eval()
features = self.extractor(runner)
runner.model.train()
# step 2: get labels
if not self.dist_mode or (self.dist_mode and runner.rank == 0):
clustering_algo = _clustering.__dict__[self.clustering_type](
**self.clustering_cfg)
# Features are normalized during clustering
clustering_algo.cluster(features, verbose=True)
assert isinstance(clustering_algo.labels, np.ndarray)
new_labels = clustering_algo.labels.astype(np.int64)
np.save(
"{}/cluster_epoch_{}.npy".format(runner.work_dir,
runner.epoch), new_labels)
self.evaluate(runner, new_labels)
else:
new_labels = np.zeros((len(self.data_loaders[0].dataset), ),
dtype=np.int64)
if self.dist_mode:
new_labels_tensor = torch.from_numpy(new_labels).cuda()
dist.broadcast(new_labels_tensor, 0)
new_labels = new_labels_tensor.cpu().numpy()
new_labels_list = list(new_labels)
# step 3: assign new labels
self.data_loaders[0].dataset.assign_labels(new_labels_list)
# step 4 (a): set uniform sampler
if self.unif_sampling:
self.data_loaders[0].sampler.set_uniform_indices(
new_labels_list, self.clustering_cfg.k)
# step 4 (b): set loss reweight
if self.reweight:
runner.model.module.set_reweight(new_labels, self.reweight_pow)
# step 5: randomize classifier
runner.model.module.head.init_weights(init_linear='normal')
if self.dist_mode:
for p in runner.model.module.head.state_dict().values():
dist.broadcast(p, 0)
# step 6: init memory for ODC
if self.init_memory:
runner.model.module.memory_bank.init_memory(features, new_labels)
def evaluate(self, runner, new_labels):
hist = np.bincount(new_labels, minlength=self.clustering_cfg.k)
empty_cls = (hist == 0).sum()
minimal_cls_size, maximal_cls_size = hist.min(), hist.max()
if runner.rank == 0:
print_log(
"empty_num: {}\tmin_cluster: {}\tmax_cluster:{}".format(
empty_cls.item(), minimal_cls_size.item(),
maximal_cls_size.item()),
logger='root')