[Fix] Fix odc pipeline (#444)

* fix odc pipeline

* update faiss requirement
pull/466/head
Yixiao Fang 2022-08-31 19:38:42 +08:00 committed by GitHub
parent 3c2fe162b0
commit 4f2bfecc4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 13 additions and 9 deletions

View File

@ -115,4 +115,5 @@ class ODCHook(Hook):
inv_histogram = (1. / (histogram + 1e-10))**reweight_pow
weight = inv_histogram / inv_histogram.sum()
runner.model.module.loss_weight.copy_(torch.from_numpy(weight))
runner.model.module.head.loss.class_weight = self.loss_weight
runner.model.module.head.loss.class_weight = \
runner.model.module.loss_weight

View File

@ -2,6 +2,8 @@
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.device import get_device
from mmengine.structures import LabelData
from mmselfsup.registry import MODELS
@ -95,7 +97,8 @@ class ODC(BaseModel):
if self.with_neck:
feature = self.neck(feature)
loss_inputs = (feature, self.memory_bank.label_bank[idx])
loss_inputs = (feature,
self.memory_bank.label_bank[idx].to(get_device()))
loss = self.head(*loss_inputs)
losses = dict(loss=loss)

View File

@ -32,12 +32,10 @@ class ODCMemory(BaseModule):
super().__init__()
self.rank, self.num_replicas = get_dist_info()
if self.rank == 0:
self.register_buffer(
'feature_bank',
torch.zeros((length, feat_dim), dtype=torch.float32))
self.feature_bank = torch.zeros((length, feat_dim),
dtype=torch.float32)
self.register_buffer('label_bank',
torch.zeros((length, ), dtype=torch.long))
self.label_bank = torch.zeros((length, ), dtype=torch.long)
self.register_buffer(
'centroids',
torch.zeros((num_classes, feat_dim), dtype=torch.float32))

View File

@ -116,7 +116,9 @@ def run_kmeans(x: np.ndarray,
# perform the training
clus.train(x, index)
_, I = index.search(x, 1) # noqa E741
losses = faiss.vector_to_array(clus.obj)
stats = clus.iteration_stats
losses = np.array([stats.at(i).obj for i in range(stats.size())])
if verbose:
print(f'k-means loss evolution: {losses}')

View File

@ -1 +1 @@
faiss-gpu==1.6.1
faiss-gpu==1.7.2