27 lines
859 B
Python
27 lines
859 B
Python
# Copyright (c) OpenMMLab. All rights reserved
|
|
import warnings
|
|
|
|
from mmengine.hooks import Hook
|
|
|
|
from mmcls.models import BaseRetriever
|
|
from mmcls.registry import HOOKS
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class PrepareProtoBeforeValLoopHook(Hook):
|
|
"""The hook to prepare the prototype in retrievers.
|
|
|
|
Since the encoders of the retriever changes during training, the prototype
|
|
changes accordingly. So the `prototype_vecs` needs to be regenerated before
|
|
validation loop.
|
|
"""
|
|
|
|
def before_val(self, runner) -> None:
|
|
if isinstance(runner.model, BaseRetriever):
|
|
if hasattr(runner.model, 'prepare_prototype'):
|
|
runner.model.prepare_prototype()
|
|
else:
|
|
warnings.warn(
|
|
'Only the retrievers can execute PrepareRetrieverPrototypeHook'
|
|
f', but got {type(runner.model)}')
|