mirror of https://github.com/PyRetri/PyRetri.git
84 lines
2.1 KiB
Python
84 lines
2.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
from .registry import AGGREGATORS, SPLITTERS, EXTRACTORS
|
|
from .extractor import ExtractorBase
|
|
from .splitter import SplitterBase
|
|
from .aggregator import AggregatorBase
|
|
from .helper import ExtractHelper
|
|
|
|
from ..utils import simple_build
|
|
|
|
import torch.nn as nn
|
|
|
|
from typing import List
|
|
|
|
|
|
def build_aggregators(cfg: CfgNode) -> List[AggregatorBase]:
|
|
"""
|
|
Instantiate a list of aggregator classes.
|
|
|
|
Args:
|
|
cfg (CfgNode): the configuration tree.
|
|
|
|
Returns:
|
|
aggregators (list): a list of instances of aggregator class.
|
|
"""
|
|
names = cfg["names"]
|
|
aggregators = list()
|
|
for name in names:
|
|
aggregators.append(simple_build(name, cfg, AGGREGATORS))
|
|
return aggregators
|
|
|
|
|
|
def build_extractor(model: nn.Module, cfg: CfgNode) -> ExtractorBase:
|
|
"""
|
|
Instantiate a extractor class.
|
|
|
|
Args:
|
|
model (nn.Module): the model for extracting features.
|
|
cfg (CfgNode): the configuration tree.
|
|
|
|
Returns:
|
|
extractor (ExtractorBase): an instance of extractor class.
|
|
"""
|
|
name = cfg["name"]
|
|
extractor = simple_build(name, cfg, EXTRACTORS, model=model)
|
|
return extractor
|
|
|
|
|
|
def build_splitter(cfg: CfgNode) -> SplitterBase:
|
|
"""
|
|
Instantiate a splitter class.
|
|
|
|
Args:
|
|
cfg (CfgNode): the configuration tree.
|
|
|
|
Returns:
|
|
splitter (SplitterBase): an instance of splitter class.
|
|
"""
|
|
name = cfg["name"]
|
|
splitter = simple_build(name, cfg, SPLITTERS)
|
|
return splitter
|
|
|
|
|
|
def build_extract_helper(model: nn.Module, cfg: CfgNode) -> ExtractHelper:
|
|
"""
|
|
Instantiate a extract helper class.
|
|
|
|
Args:
|
|
model (nn.Module): the model for extracting features.
|
|
cfg (CfgNode): the configuration tree.
|
|
|
|
Returns:
|
|
helper (ExtractHelper): an instance of extract helper class.
|
|
"""
|
|
assemble = cfg.assemble
|
|
extractor = build_extractor(model, cfg.extractor)
|
|
splitter = build_splitter(cfg.splitter)
|
|
aggregators = build_aggregators(cfg.aggregators)
|
|
helper = ExtractHelper(assemble, extractor, splitter, aggregators)
|
|
return helper
|
|
|