mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* feat: add image retriever * feat: add image retriever * feat: add image retriever * feat: add image retriever * feat: add image retriever * feat: add image retriever * feat: add image retriever * feat: add image retriever * feat: add image retriever * update retriever * fix lint * add hook unit test * Use `register_buffer` to save prototype vectors and add a progress bar during preparing prototype. * update UTs * update UTs * fix typo * modify the hook Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Co-authored-by: mzr1996 <mzr1996@163.com>
27 lines
859 B
Python
27 lines
859 B
Python
# Copyright (c) OpenMMLab. All rights reserved
|
|
import warnings
|
|
|
|
from mmengine.hooks import Hook
|
|
|
|
from mmcls.models import BaseRetriever
|
|
from mmcls.registry import HOOKS
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class PrepareProtoBeforeValLoopHook(Hook):
|
|
"""The hook to prepare the prototype in retrievers.
|
|
|
|
Since the encoders of the retriever changes during training, the prototype
|
|
changes accordingly. So the `prototype_vecs` needs to be regenerated before
|
|
validation loop.
|
|
"""
|
|
|
|
def before_val(self, runner) -> None:
|
|
if isinstance(runner.model, BaseRetriever):
|
|
if hasattr(runner.model, 'prepare_prototype'):
|
|
runner.model.prepare_prototype()
|
|
else:
|
|
warnings.warn(
|
|
'Only the retrievers can execute PrepareRetrieverPrototypeHook'
|
|
f', but got {type(runner.model)}')
|