[Fix] Fix retrieval multi gpu bug (#1319)

* fix mlti-gpu bug in retrevel

* fix bugs(cannot load vecs in dist and diff test-val recall\)

* load weight each process
pull/1386/head
Ezra-Yu 2023-02-09 15:55:47 +08:00 committed by GitHub
parent 7ec6062415
commit 705ed2be49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 15 deletions

View File

@ -2,6 +2,7 @@
import warnings
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmcls.models import BaseRetriever
from mmcls.registry import HOOKS
@ -17,10 +18,14 @@ class PrepareProtoBeforeValLoopHook(Hook):
"""
def before_val(self, runner) -> None:
if isinstance(runner.model, BaseRetriever):
if hasattr(runner.model, 'prepare_prototype'):
runner.model.prepare_prototype()
model = runner.model
if is_model_wrapper(model):
model = model.module
if isinstance(model, BaseRetriever):
if hasattr(model, 'prepare_prototype'):
model.prepare_prototype()
else:
warnings.warn(
'Only the retrievers can execute PrepareRetrieverPrototypeHook'
f', but got {type(runner.model)}')
'Only the `mmcls.models.retrievers.BaseRetriever` can execute '
f'`PrepareRetrieverPrototypeHook`, but got `{type(model)}`')

View File

@ -243,9 +243,9 @@ class ImageToImageRetriever(BaseRetriever):
score).set_pred_label(label))
return data_samples
def _get_prototype_vecs_from_dataloader(self):
def _get_prototype_vecs_from_dataloader(self, data_loader):
"""get prototype_vecs from dataloader."""
data_loader = self.prototype
self.eval()
num = len(data_loader.dataset)
prototype_vecs = None
@ -282,11 +282,9 @@ class ImageToImageRetriever(BaseRetriever):
prototype_vecs = self.prototype
elif isinstance(self.prototype, str):
prototype_vecs = torch.load(self.prototype)
elif isinstance(self.prototype, dict):
self.prototype = Runner.build_dataloader(self.prototype)
if isinstance(self.prototype, DataLoader):
prototype_vecs = self._get_prototype_vecs_from_dataloader()
elif isinstance(self.prototype, (dict, DataLoader)):
loader = Runner.build_dataloader(self.prototype)
prototype_vecs = self._get_prototype_vecs_from_dataloader(loader)
self.register_buffer(
'prototype_vecs', prototype_vecs.to(device), persistent=False)

View File

@ -5,6 +5,6 @@ import rich.progress as progress
def track_on_main_process(sequence, *args, **kwargs):
if not dist.is_main_process():
return sequence
yield from progress.track(sequence, *args, **kwargs)
yield from sequence
else:
yield from progress.track(sequence, *args, **kwargs)