mirror of https://github.com/PyRetri/PyRetri.git
82 lines
2.2 KiB
Python
82 lines
2.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
from .registry import COLLATEFNS, FOLDERS, TRANSFORMERS
|
|
from .collate_fn import CollateFnBase
|
|
from .folder import FolderBase
|
|
from .transformer import TransformerBase
|
|
|
|
from ..utils import simple_build
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torchvision.transforms import Compose
|
|
|
|
|
|
def build_collate(cfg: CfgNode) -> CollateFnBase:
|
|
"""
|
|
Instantiate a collate class with the given configuration tree.
|
|
|
|
Args:
|
|
cfg (CfgNode): the configuration tree.
|
|
|
|
Returns:
|
|
collate (CollateFnBase): a collate class.
|
|
"""
|
|
name = cfg["name"]
|
|
collate = simple_build(name, cfg, COLLATEFNS)
|
|
return collate
|
|
|
|
|
|
def build_transformers(cfg: CfgNode) -> Compose:
|
|
"""
|
|
Instantiate a compose class containing several transforms with the given configuration tree.
|
|
|
|
Args:
|
|
cfg (CfgNode): the configuration tree.
|
|
|
|
Returns:
|
|
transformers (Compose): a compose class.
|
|
"""
|
|
names = cfg["names"]
|
|
transformers = list()
|
|
for name in names:
|
|
transformers.append(simple_build(name, cfg, TRANSFORMERS))
|
|
transformers = Compose(transformers)
|
|
return transformers
|
|
|
|
|
|
def build_folder(data_json_path: str, cfg: CfgNode) -> FolderBase:
|
|
"""
|
|
Instantiate a folder class with the given configuration tree.
|
|
|
|
Args:
|
|
data_json_path (str): the path of the data json file.
|
|
cfg (CfgNode): the configuration tree.
|
|
|
|
Returns:
|
|
folder (FolderBase): a folder class.
|
|
"""
|
|
trans = build_transformers(cfg.transformers)
|
|
folder = simple_build(cfg.folder["name"], cfg.folder, FOLDERS, data_json_path=data_json_path, transformer=trans)
|
|
return folder
|
|
|
|
|
|
def build_loader(folder: FolderBase, cfg: CfgNode) -> DataLoader:
|
|
"""
|
|
Instantiate a data loader class with the given configuration tree.
|
|
|
|
Args:
|
|
folder (FolderBase): the folder function.
|
|
cfg (CfgNode): the configuration tree.
|
|
|
|
Returns:
|
|
data_loader (DataLoader): a data loader class.
|
|
"""
|
|
co_fn = build_collate(cfg.collate_fn)
|
|
|
|
data_loader = DataLoader(folder, cfg["batch_size"], collate_fn=co_fn, num_workers=8, pin_memory=True)
|
|
|
|
return data_loader
|