[Fix] Fix retrieval multi gpu bug (#1319)
* fix mlti-gpu bug in retrevel * fix bugs(cannot load vecs in dist and diff test-val recall\) * load weight each processpull/1386/head
parent
7ec6062415
commit
705ed2be49
|
@ -2,6 +2,7 @@
|
|||
import warnings
|
||||
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.model import is_model_wrapper
|
||||
|
||||
from mmcls.models import BaseRetriever
|
||||
from mmcls.registry import HOOKS
|
||||
|
@ -17,10 +18,14 @@ class PrepareProtoBeforeValLoopHook(Hook):
|
|||
"""
|
||||
|
||||
def before_val(self, runner) -> None:
|
||||
if isinstance(runner.model, BaseRetriever):
|
||||
if hasattr(runner.model, 'prepare_prototype'):
|
||||
runner.model.prepare_prototype()
|
||||
model = runner.model
|
||||
if is_model_wrapper(model):
|
||||
model = model.module
|
||||
|
||||
if isinstance(model, BaseRetriever):
|
||||
if hasattr(model, 'prepare_prototype'):
|
||||
model.prepare_prototype()
|
||||
else:
|
||||
warnings.warn(
|
||||
'Only the retrievers can execute PrepareRetrieverPrototypeHook'
|
||||
f', but got {type(runner.model)}')
|
||||
'Only the `mmcls.models.retrievers.BaseRetriever` can execute '
|
||||
f'`PrepareRetrieverPrototypeHook`, but got `{type(model)}`')
|
||||
|
|
|
@ -243,9 +243,9 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
score).set_pred_label(label))
|
||||
return data_samples
|
||||
|
||||
def _get_prototype_vecs_from_dataloader(self):
|
||||
def _get_prototype_vecs_from_dataloader(self, data_loader):
|
||||
"""get prototype_vecs from dataloader."""
|
||||
data_loader = self.prototype
|
||||
self.eval()
|
||||
num = len(data_loader.dataset)
|
||||
|
||||
prototype_vecs = None
|
||||
|
@ -282,11 +282,9 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
prototype_vecs = self.prototype
|
||||
elif isinstance(self.prototype, str):
|
||||
prototype_vecs = torch.load(self.prototype)
|
||||
elif isinstance(self.prototype, dict):
|
||||
self.prototype = Runner.build_dataloader(self.prototype)
|
||||
|
||||
if isinstance(self.prototype, DataLoader):
|
||||
prototype_vecs = self._get_prototype_vecs_from_dataloader()
|
||||
elif isinstance(self.prototype, (dict, DataLoader)):
|
||||
loader = Runner.build_dataloader(self.prototype)
|
||||
prototype_vecs = self._get_prototype_vecs_from_dataloader(loader)
|
||||
|
||||
self.register_buffer(
|
||||
'prototype_vecs', prototype_vecs.to(device), persistent=False)
|
||||
|
|
|
@ -5,6 +5,6 @@ import rich.progress as progress
|
|||
|
||||
def track_on_main_process(sequence, *args, **kwargs):
|
||||
if not dist.is_main_process():
|
||||
return sequence
|
||||
|
||||
yield from progress.track(sequence, *args, **kwargs)
|
||||
yield from sequence
|
||||
else:
|
||||
yield from progress.track(sequence, *args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue