From 705ed2be4965c91a096c904a965ea83358b928a6 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Thu, 9 Feb 2023 15:55:47 +0800 Subject: [PATCH] [Fix] Fix retrieval multi gpu bug (#1319) * fix mlti-gpu bug in retrevel * fix bugs(cannot load vecs in dist and diff test-val recall\) * load weight each process --- mmcls/engine/hooks/retriever_hooks.py | 15 ++++++++++----- mmcls/models/retrievers/image2image.py | 12 +++++------- mmcls/utils/progress.py | 6 +++--- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/mmcls/engine/hooks/retriever_hooks.py b/mmcls/engine/hooks/retriever_hooks.py index ed9b6f99..f09c5ff0 100644 --- a/mmcls/engine/hooks/retriever_hooks.py +++ b/mmcls/engine/hooks/retriever_hooks.py @@ -2,6 +2,7 @@ import warnings from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper from mmcls.models import BaseRetriever from mmcls.registry import HOOKS @@ -17,10 +18,14 @@ class PrepareProtoBeforeValLoopHook(Hook): """ def before_val(self, runner) -> None: - if isinstance(runner.model, BaseRetriever): - if hasattr(runner.model, 'prepare_prototype'): - runner.model.prepare_prototype() + model = runner.model + if is_model_wrapper(model): + model = model.module + + if isinstance(model, BaseRetriever): + if hasattr(model, 'prepare_prototype'): + model.prepare_prototype() else: warnings.warn( - 'Only the retrievers can execute PrepareRetrieverPrototypeHook' - f', but got {type(runner.model)}') + 'Only the `mmcls.models.retrievers.BaseRetriever` can execute ' + f'`PrepareRetrieverPrototypeHook`, but got `{type(model)}`') diff --git a/mmcls/models/retrievers/image2image.py b/mmcls/models/retrievers/image2image.py index 8038120f..2e1807f8 100644 --- a/mmcls/models/retrievers/image2image.py +++ b/mmcls/models/retrievers/image2image.py @@ -243,9 +243,9 @@ class ImageToImageRetriever(BaseRetriever): score).set_pred_label(label)) return data_samples - def _get_prototype_vecs_from_dataloader(self): + def _get_prototype_vecs_from_dataloader(self, data_loader): """get prototype_vecs from dataloader.""" - data_loader = self.prototype + self.eval() num = len(data_loader.dataset) prototype_vecs = None @@ -282,11 +282,9 @@ class ImageToImageRetriever(BaseRetriever): prototype_vecs = self.prototype elif isinstance(self.prototype, str): prototype_vecs = torch.load(self.prototype) - elif isinstance(self.prototype, dict): - self.prototype = Runner.build_dataloader(self.prototype) - - if isinstance(self.prototype, DataLoader): - prototype_vecs = self._get_prototype_vecs_from_dataloader() + elif isinstance(self.prototype, (dict, DataLoader)): + loader = Runner.build_dataloader(self.prototype) + prototype_vecs = self._get_prototype_vecs_from_dataloader(loader) self.register_buffer( 'prototype_vecs', prototype_vecs.to(device), persistent=False) diff --git a/mmcls/utils/progress.py b/mmcls/utils/progress.py index c200944f..66c6c32d 100644 --- a/mmcls/utils/progress.py +++ b/mmcls/utils/progress.py @@ -5,6 +5,6 @@ import rich.progress as progress def track_on_main_process(sequence, *args, **kwargs): if not dist.is_main_process(): - return sequence - - yield from progress.track(sequence, *args, **kwargs) + yield from sequence + else: + yield from progress.track(sequence, *args, **kwargs)