mmclassification/mmcls/engine/hooks/retriever_hooks.py
zzc98 693596bc2f
[Feature] Add Base Retriever and Image2Image Retriever for retrieval tasks. (#1055)
* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* update retriever

* fix lint

* add hook unit test

* Use `register_buffer` to save prototype vectors and add a progress bar
during preparing prototype.

* update UTs

* update UTs

* fix typo

* modify the hook

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-02 17:43:56 +08:00

27 lines
859 B
Python

# Copyright (c) OpenMMLab. All rights reserved
import warnings
from mmengine.hooks import Hook
from mmcls.models import BaseRetriever
from mmcls.registry import HOOKS
@HOOKS.register_module()
class PrepareProtoBeforeValLoopHook(Hook):
"""The hook to prepare the prototype in retrievers.
Since the encoders of the retriever changes during training, the prototype
changes accordingly. So the `prototype_vecs` needs to be regenerated before
validation loop.
"""
def before_val(self, runner) -> None:
if isinstance(runner.model, BaseRetriever):
if hasattr(runner.model, 'prepare_prototype'):
runner.model.prepare_prototype()
else:
warnings.warn(
'Only the retrievers can execute PrepareRetrieverPrototypeHook'
f', but got {type(runner.model)}')