mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* fix mlti-gpu bug in retrevel * fix bugs(cannot load vecs in dist and diff test-val recall\) * load weight each process
32 lines
1007 B
Python
32 lines
1007 B
Python
# Copyright (c) OpenMMLab. All rights reserved
|
|
import warnings
|
|
|
|
from mmengine.hooks import Hook
|
|
from mmengine.model import is_model_wrapper
|
|
|
|
from mmcls.models import BaseRetriever
|
|
from mmcls.registry import HOOKS
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class PrepareProtoBeforeValLoopHook(Hook):
|
|
"""The hook to prepare the prototype in retrievers.
|
|
|
|
Since the encoders of the retriever changes during training, the prototype
|
|
changes accordingly. So the `prototype_vecs` needs to be regenerated before
|
|
validation loop.
|
|
"""
|
|
|
|
def before_val(self, runner) -> None:
|
|
model = runner.model
|
|
if is_model_wrapper(model):
|
|
model = model.module
|
|
|
|
if isinstance(model, BaseRetriever):
|
|
if hasattr(model, 'prepare_prototype'):
|
|
model.prepare_prototype()
|
|
else:
|
|
warnings.warn(
|
|
'Only the `mmcls.models.retrievers.BaseRetriever` can execute '
|
|
f'`PrepareRetrieverPrototypeHook`, but got `{type(model)}`')
|