mmpretrain/tests/test_engine/test_hooks/test_retrievers_hooks.py
zzc98 693596bc2f
[Feature] Add Base Retriever and Image2Image Retriever for retrieval tasks. (#1055)
* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* feat: add image retriever

* update retriever

* fix lint

* add hook unit test

* Use `register_buffer` to save prototype vectors and add a progress bar
during preparing prototype.

* update UTs

* update UTs

* fix typo

* modify the hook

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
2022-11-02 17:43:56 +08:00

35 lines
1.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock
import torch
from mmcls.engine import PrepareProtoBeforeValLoopHook
from mmcls.models.retrievers import BaseRetriever
class ToyRetriever(BaseRetriever):
def forward(self, inputs, data_samples=None, mode: str = 'loss'):
self.prototype_inited is False
def prepare_prototype(self):
"""Preprocessing the prototype before predict."""
self.prototype_vecs = torch.tensor([0])
self.prototype_inited = True
class TestPrepareProtBeforeValLoopHook(TestCase):
def setUp(self):
self.hook = PrepareProtoBeforeValLoopHook
self.runner = MagicMock()
self.runner.model = ToyRetriever()
def test_before_val(self):
self.runner.model.prepare_prototype()
self.assertTrue(self.runner.model.prototype_inited)
self.hook.before_val(self, self.runner)
self.assertIsNotNone(self.runner.model.prototype_vecs)
self.assertTrue(self.runner.model.prototype_inited)