PyRetri/pyretri/datasets/collate_fn/collate_fn_base.py
2020-04-15 14:44:22 +08:00

28 lines
569 B
Python

# -*- coding: utf-8 -*-
from abc import abstractmethod
import torch
from ...utils import ModuleBase
from typing import Dict, List
class CollateFnBase(ModuleBase):
"""
The base class of collate function.
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps: default hyper parameters in a dict (keys, values).
"""
super(CollateFnBase, self).__init__(hps)
@abstractmethod
def __call__(self, batch: List[Dict]) -> Dict[str, torch.tensor]:
pass