diff --git a/benchs/bench_fw/benchmark_io.py b/benchs/bench_fw/benchmark_io.py index 5ee3eb3a6..379fa608b 100644 --- a/benchs/bench_fw/benchmark_io.py +++ b/benchs/bench_fw/benchmark_io.py @@ -10,7 +10,7 @@ import logging import os import pickle from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from zipfile import ZipFile import faiss # @manual=//faiss/python:pyfaiss_gpu @@ -46,13 +46,11 @@ def merge_rcq_itq( @dataclass class BenchmarkIO: path: str + cached_ds: Dict[Any, Any] = {} def clone(self): return BenchmarkIO(path=self.path) - def __post_init__(self): - self.cached_ds = {} - # TODO(kuarora): rename it as get_local_file def get_local_filename(self, filename): if len(filename) > 184: diff --git a/benchs/bench_fw/descriptors.py b/benchs/bench_fw/descriptors.py index e76278ced..a553da589 100644 --- a/benchs/bench_fw/descriptors.py +++ b/benchs/bench_fw/descriptors.py @@ -78,6 +78,8 @@ class DatasetDescriptor: # number of vectors to load from the dataset num_vectors: Optional[int] = None + embedding_column: Optional[str] = None + def __hash__(self): return hash(self.get_filename())