1
0
mirror of https://github.com/open-mmlab/mmpretrain.git synced 2025-06-03 14:59:18 +08:00

35 lines
1.0 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import MagicMock
import torch
from mmcls.engine import PrepareProtoBeforeValLoopHook
from mmcls.models.retrievers import BaseRetriever
class ToyRetriever(BaseRetriever):
def forward(self, inputs, data_samples=None, mode: str = 'loss'):
self.prototype_inited is False
def prepare_prototype(self):
"""Preprocessing the prototype before predict."""
self.prototype_vecs = torch.tensor([0])
self.prototype_inited = True
class TestPrepareProtBeforeValLoopHook(TestCase):
def setUp(self):
self.hook = PrepareProtoBeforeValLoopHook
self.runner = MagicMock()
self.runner.model = ToyRetriever()
def test_before_val(self):
self.runner.model.prepare_prototype()
self.assertTrue(self.runner.model.prototype_inited)
self.hook.before_val(self, self.runner)
self.assertIsNotNone(self.runner.model.prototype_vecs)
self.assertTrue(self.runner.model.prototype_inited)