parent
3c2fe162b0
commit
4f2bfecc4d
|
@ -115,4 +115,5 @@ class ODCHook(Hook):
|
|||
inv_histogram = (1. / (histogram + 1e-10))**reweight_pow
|
||||
weight = inv_histogram / inv_histogram.sum()
|
||||
runner.model.module.loss_weight.copy_(torch.from_numpy(weight))
|
||||
runner.model.module.head.loss.class_weight = self.loss_weight
|
||||
runner.model.module.head.loss.class_weight = \
|
||||
runner.model.module.loss_weight
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
|
@ -95,7 +97,8 @@ class ODC(BaseModel):
|
|||
if self.with_neck:
|
||||
feature = self.neck(feature)
|
||||
|
||||
loss_inputs = (feature, self.memory_bank.label_bank[idx])
|
||||
loss_inputs = (feature,
|
||||
self.memory_bank.label_bank[idx].to(get_device()))
|
||||
loss = self.head(*loss_inputs)
|
||||
losses = dict(loss=loss)
|
||||
|
||||
|
|
|
@ -32,12 +32,10 @@ class ODCMemory(BaseModule):
|
|||
super().__init__()
|
||||
self.rank, self.num_replicas = get_dist_info()
|
||||
if self.rank == 0:
|
||||
self.register_buffer(
|
||||
'feature_bank',
|
||||
torch.zeros((length, feat_dim), dtype=torch.float32))
|
||||
self.feature_bank = torch.zeros((length, feat_dim),
|
||||
dtype=torch.float32)
|
||||
|
||||
self.register_buffer('label_bank',
|
||||
torch.zeros((length, ), dtype=torch.long))
|
||||
self.label_bank = torch.zeros((length, ), dtype=torch.long)
|
||||
self.register_buffer(
|
||||
'centroids',
|
||||
torch.zeros((num_classes, feat_dim), dtype=torch.float32))
|
||||
|
|
|
@ -116,7 +116,9 @@ def run_kmeans(x: np.ndarray,
|
|||
# perform the training
|
||||
clus.train(x, index)
|
||||
_, I = index.search(x, 1) # noqa E741
|
||||
losses = faiss.vector_to_array(clus.obj)
|
||||
|
||||
stats = clus.iteration_stats
|
||||
losses = np.array([stats.at(i).obj for i in range(stats.size())])
|
||||
if verbose:
|
||||
print(f'k-means loss evolution: {losses}')
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
faiss-gpu==1.6.1
|
||||
faiss-gpu==1.7.2
|
||||
|
|
Loading…
Reference in New Issue