mmclassification/mmcls/engine/hooks/retriever_hooks.py
Ezra-Yu 705ed2be49
[Fix] Fix retrieval multi gpu bug (#1319)
* fix mlti-gpu bug in retrevel

* fix bugs(cannot load vecs in dist and diff test-val recall\)

* load weight each process
2023-02-09 15:55:47 +08:00

32 lines
1007 B
Python

# Copyright (c) OpenMMLab. All rights reserved
import warnings
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmcls.models import BaseRetriever
from mmcls.registry import HOOKS
@HOOKS.register_module()
class PrepareProtoBeforeValLoopHook(Hook):
"""The hook to prepare the prototype in retrievers.
Since the encoders of the retriever changes during training, the prototype
changes accordingly. So the `prototype_vecs` needs to be regenerated before
validation loop.
"""
def before_val(self, runner) -> None:
model = runner.model
if is_model_wrapper(model):
model = model.module
if isinstance(model, BaseRetriever):
if hasattr(model, 'prepare_prototype'):
model.prepare_prototype()
else:
warnings.warn(
'Only the `mmcls.models.retrievers.BaseRetriever` can execute '
f'`PrepareRetrieverPrototypeHook`, but got `{type(model)}`')