PyRetri/retrieval_tool_box/datasets/builder.py

82 lines
2.2 KiB
Python

# -*- coding: utf-8 -*-
from yacs.config import CfgNode
from .registry import COLLATEFNS, FOLDERS, TRANSFORMERS
from .collate_fn import CollateFnBase
from .folder import FolderBase
from .transformer import TransformerBase
from ..utils import simple_build
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
def build_collate(cfg: CfgNode) -> CollateFnBase:
"""
Instantiate a collate class with the given configuration tree.
Args:
cfg (CfgNode): the configuration tree.
Returns:
collate (CollateFnBase): a collate class.
"""
name = cfg["name"]
collate = simple_build(name, cfg, COLLATEFNS)
return collate
def build_transformers(cfg: CfgNode) -> Compose:
"""
Instantiate a compose class containing several transforms with the given configuration tree.
Args:
cfg (CfgNode): the configuration tree.
Returns:
transformers (Compose): a compose class.
"""
names = cfg["names"]
transformers = list()
for name in names:
transformers.append(simple_build(name, cfg, TRANSFORMERS))
transformers = Compose(transformers)
return transformers
def build_folder(data_json_path: str, cfg: CfgNode) -> FolderBase:
"""
Instantiate a folder class with the given configuration tree.
Args:
data_json_path (str): the path of the data json file.
cfg (CfgNode): the configuration tree.
Returns:
folder (FolderBase): a folder class.
"""
trans = build_transformers(cfg.transformers)
folder = simple_build(cfg.folder["name"], cfg.folder, FOLDERS, data_json_path=data_json_path, transformer=trans)
return folder
def build_loader(folder: FolderBase, cfg: CfgNode) -> DataLoader:
"""
Instantiate a data loader class with the given configuration tree.
Args:
folder (FolderBase): the folder function.
cfg (CfgNode): the configuration tree.
Returns:
data_loader (DataLoader): a data loader class.
"""
co_fn = build_collate(cfg.collate_fn)
data_loader = DataLoader(folder, cfg["batch_size"], collate_fn=co_fn, num_workers=8, pin_memory=True)
return data_loader