use dispatcher function to call HammingComputer (#2918)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2918 The HammingComputer class is optimized for several vector sizes. So far it's been the caller's responsiblity to instanciate the relevant optimized version. This diff introduces a `dispatch_HammingComputer` function that can be called with a template class that is instanciated for all existing optimized HammingComputer's. Reviewed By: algoriddle Differential Revision: D46858553 fbshipit-source-id: 32c31689bba7c0b406b309fc8574c95fa24022bapull/2922/head
parent
a27036aa72
commit
a91a2887fe
|
@ -18,6 +18,66 @@
|
|||
|
||||
using namespace faiss;
|
||||
|
||||
// These implementations are currently slower than HammingComputerDefault so
|
||||
// they are not in the main faiss anymore.
|
||||
struct HammingComputerM8 {
|
||||
const uint64_t* a;
|
||||
int n;
|
||||
|
||||
HammingComputerM8() {}
|
||||
|
||||
HammingComputerM8(const uint8_t* a8, int code_size) {
|
||||
set(a8, code_size);
|
||||
}
|
||||
|
||||
void set(const uint8_t* a8, int code_size) {
|
||||
assert(code_size % 8 == 0);
|
||||
a = (uint64_t*)a8;
|
||||
n = code_size / 8;
|
||||
}
|
||||
|
||||
int hamming(const uint8_t* b8) const {
|
||||
const uint64_t* b = (uint64_t*)b8;
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
accu += popcount64(a[i] ^ b[i]);
|
||||
return accu;
|
||||
}
|
||||
|
||||
inline int get_code_size() const {
|
||||
return n * 8;
|
||||
}
|
||||
};
|
||||
|
||||
struct HammingComputerM4 {
|
||||
const uint32_t* a;
|
||||
int n;
|
||||
|
||||
HammingComputerM4() {}
|
||||
|
||||
HammingComputerM4(const uint8_t* a4, int code_size) {
|
||||
set(a4, code_size);
|
||||
}
|
||||
|
||||
void set(const uint8_t* a4, int code_size) {
|
||||
assert(code_size % 4 == 0);
|
||||
a = (uint32_t*)a4;
|
||||
n = code_size / 4;
|
||||
}
|
||||
|
||||
int hamming(const uint8_t* b8) const {
|
||||
const uint32_t* b = (uint32_t*)b8;
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
accu += popcount64(a[i] ^ b[i]);
|
||||
return accu;
|
||||
}
|
||||
|
||||
inline int get_code_size() const {
|
||||
return n * 4;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
void hamming_cpt_test(
|
||||
int code_size,
|
||||
|
|
|
@ -281,31 +281,21 @@ struct FlatHammingDis : DistanceComputer {
|
|||
}
|
||||
};
|
||||
|
||||
struct BuildDistanceComputer {
|
||||
using T = DistanceComputer*;
|
||||
template <class HammingComputer>
|
||||
DistanceComputer* f(IndexBinaryFlat* flat_storage) {
|
||||
return new FlatHammingDis<HammingComputer>(*flat_storage);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
|
||||
IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage);
|
||||
|
||||
FAISS_ASSERT(flat_storage != nullptr);
|
||||
|
||||
switch (code_size) {
|
||||
case 4:
|
||||
return new FlatHammingDis<HammingComputer4>(*flat_storage);
|
||||
case 8:
|
||||
return new FlatHammingDis<HammingComputer8>(*flat_storage);
|
||||
case 16:
|
||||
return new FlatHammingDis<HammingComputer16>(*flat_storage);
|
||||
case 20:
|
||||
return new FlatHammingDis<HammingComputer20>(*flat_storage);
|
||||
case 32:
|
||||
return new FlatHammingDis<HammingComputer32>(*flat_storage);
|
||||
case 64:
|
||||
return new FlatHammingDis<HammingComputer64>(*flat_storage);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
|
||||
BuildDistanceComputer bd;
|
||||
return dispatch_HammingComputer(code_size, bd, flat_storage);
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -176,6 +176,14 @@ void search_single_query_template(
|
|||
} while (fe.next());
|
||||
}
|
||||
|
||||
struct Run_search_single_query {
|
||||
using T = void;
|
||||
template <class HammingComputer, class... Types>
|
||||
T f(Types... args) {
|
||||
search_single_query_template<HammingComputer>(args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <class SearchResults>
|
||||
void search_single_query(
|
||||
const IndexBinaryHash& index,
|
||||
|
@ -184,29 +192,9 @@ void search_single_query(
|
|||
size_t& n0,
|
||||
size_t& nlist,
|
||||
size_t& ndis) {
|
||||
#define HC(name) \
|
||||
search_single_query_template<name>(index, q, res, n0, nlist, ndis);
|
||||
switch (index.code_size) {
|
||||
case 4:
|
||||
HC(HammingComputer4);
|
||||
break;
|
||||
case 8:
|
||||
HC(HammingComputer8);
|
||||
break;
|
||||
case 16:
|
||||
HC(HammingComputer16);
|
||||
break;
|
||||
case 20:
|
||||
HC(HammingComputer20);
|
||||
break;
|
||||
case 32:
|
||||
HC(HammingComputer32);
|
||||
break;
|
||||
default:
|
||||
HC(HammingComputerDefault);
|
||||
break;
|
||||
}
|
||||
#undef HC
|
||||
Run_search_single_query r;
|
||||
dispatch_HammingComputer(
|
||||
index.code_size, r, index, q, res, n0, nlist, ndis);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
@ -349,15 +337,15 @@ namespace {
|
|||
|
||||
template <class HammingComputer, class SearchResults>
|
||||
static void verify_shortlist(
|
||||
const IndexBinaryFlat& index,
|
||||
const IndexBinaryFlat* index,
|
||||
const uint8_t* q,
|
||||
const std::unordered_set<idx_t>& shortlist,
|
||||
SearchResults& res) {
|
||||
size_t code_size = index.code_size;
|
||||
size_t code_size = index->code_size;
|
||||
size_t nlist = 0, ndis = 0, n0 = 0;
|
||||
|
||||
HammingComputer hc(q, code_size);
|
||||
const uint8_t* codes = index.xb.data();
|
||||
const uint8_t* codes = index->xb.data();
|
||||
|
||||
for (auto i : shortlist) {
|
||||
int dis = hc.hamming(codes + i * code_size);
|
||||
|
@ -365,6 +353,14 @@ static void verify_shortlist(
|
|||
}
|
||||
}
|
||||
|
||||
struct Run_verify_shortlist {
|
||||
using T = void;
|
||||
template <class HammingComputer, class... Types>
|
||||
void f(Types... args) {
|
||||
verify_shortlist<HammingComputer>(args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <class SearchResults>
|
||||
void search_1_query_multihash(
|
||||
const IndexBinaryMultiHash& index,
|
||||
|
@ -405,29 +401,9 @@ void search_1_query_multihash(
|
|||
ndis += shortlist.size();
|
||||
|
||||
// verify shortlist
|
||||
|
||||
#define HC(name) verify_shortlist<name>(*index.storage, xi, shortlist, res)
|
||||
switch (index.code_size) {
|
||||
case 4:
|
||||
HC(HammingComputer4);
|
||||
break;
|
||||
case 8:
|
||||
HC(HammingComputer8);
|
||||
break;
|
||||
case 16:
|
||||
HC(HammingComputer16);
|
||||
break;
|
||||
case 20:
|
||||
HC(HammingComputer20);
|
||||
break;
|
||||
case 32:
|
||||
HC(HammingComputer32);
|
||||
break;
|
||||
default:
|
||||
HC(HammingComputerDefault);
|
||||
break;
|
||||
}
|
||||
#undef HC
|
||||
Run_verify_shortlist r;
|
||||
dispatch_HammingComputer(
|
||||
index.code_size, r, index.storage, xi, shortlist, res);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
|
|
@ -370,7 +370,7 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
|
|||
};
|
||||
|
||||
void search_knn_hamming_heap(
|
||||
const IndexBinaryIVF& ivf,
|
||||
const IndexBinaryIVF* ivf,
|
||||
size_t n,
|
||||
const uint8_t* __restrict x,
|
||||
idx_t k,
|
||||
|
@ -380,10 +380,10 @@ void search_knn_hamming_heap(
|
|||
idx_t* __restrict labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters* params) {
|
||||
idx_t nprobe = params ? params->nprobe : ivf.nprobe;
|
||||
nprobe = std::min((idx_t)ivf.nlist, nprobe);
|
||||
idx_t max_codes = params ? params->max_codes : ivf.max_codes;
|
||||
MetricType metric_type = ivf.metric_type;
|
||||
idx_t nprobe = params ? params->nprobe : ivf->nprobe;
|
||||
nprobe = std::min((idx_t)ivf->nlist, nprobe);
|
||||
idx_t max_codes = params ? params->max_codes : ivf->max_codes;
|
||||
MetricType metric_type = ivf->metric_type;
|
||||
|
||||
// almost verbatim copy from IndexIVF::search_preassigned
|
||||
|
||||
|
@ -394,11 +394,11 @@ void search_knn_hamming_heap(
|
|||
#pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap)
|
||||
{
|
||||
std::unique_ptr<BinaryInvertedListScanner> scanner(
|
||||
ivf.get_InvertedListScanner(store_pairs));
|
||||
ivf->get_InvertedListScanner(store_pairs));
|
||||
|
||||
#pragma omp for
|
||||
for (idx_t i = 0; i < n; i++) {
|
||||
const uint8_t* xi = x + i * ivf.code_size;
|
||||
const uint8_t* xi = x + i * ivf->code_size;
|
||||
scanner->set_query(xi);
|
||||
|
||||
const idx_t* keysi = keys + i * nprobe;
|
||||
|
@ -420,23 +420,24 @@ void search_knn_hamming_heap(
|
|||
continue;
|
||||
}
|
||||
FAISS_THROW_IF_NOT_FMT(
|
||||
key < (idx_t)ivf.nlist,
|
||||
key < (idx_t)ivf->nlist,
|
||||
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
|
||||
key,
|
||||
ik,
|
||||
ivf.nlist);
|
||||
ivf->nlist);
|
||||
|
||||
scanner->set_list(key, coarse_dis[i * nprobe + ik]);
|
||||
|
||||
nlistv++;
|
||||
|
||||
size_t list_size = ivf.invlists->list_size(key);
|
||||
InvertedLists::ScopedCodes scodes(ivf.invlists, key);
|
||||
size_t list_size = ivf->invlists->list_size(key);
|
||||
InvertedLists::ScopedCodes scodes(ivf->invlists, key);
|
||||
std::unique_ptr<InvertedLists::ScopedIds> sids;
|
||||
const idx_t* ids = nullptr;
|
||||
|
||||
if (!store_pairs) {
|
||||
sids.reset(new InvertedLists::ScopedIds(ivf.invlists, key));
|
||||
sids.reset(
|
||||
new InvertedLists::ScopedIds(ivf->invlists, key));
|
||||
ids = sids->get();
|
||||
}
|
||||
|
||||
|
@ -466,7 +467,7 @@ void search_knn_hamming_heap(
|
|||
|
||||
template <class HammingComputer, bool store_pairs>
|
||||
void search_knn_hamming_count(
|
||||
const IndexBinaryIVF& ivf,
|
||||
const IndexBinaryIVF* ivf,
|
||||
size_t nx,
|
||||
const uint8_t* __restrict x,
|
||||
const idx_t* __restrict keys,
|
||||
|
@ -474,21 +475,21 @@ void search_knn_hamming_count(
|
|||
int32_t* __restrict distances,
|
||||
idx_t* __restrict labels,
|
||||
const IVFSearchParameters* params) {
|
||||
const int nBuckets = ivf.d + 1;
|
||||
const int nBuckets = ivf->d + 1;
|
||||
std::vector<int> all_counters(nx * nBuckets, 0);
|
||||
std::unique_ptr<idx_t[]> all_ids_per_dis(new idx_t[nx * nBuckets * k]);
|
||||
|
||||
idx_t nprobe = params ? params->nprobe : ivf.nprobe;
|
||||
nprobe = std::min((idx_t)ivf.nlist, nprobe);
|
||||
idx_t max_codes = params ? params->max_codes : ivf.max_codes;
|
||||
idx_t nprobe = params ? params->nprobe : ivf->nprobe;
|
||||
nprobe = std::min((idx_t)ivf->nlist, nprobe);
|
||||
idx_t max_codes = params ? params->max_codes : ivf->max_codes;
|
||||
|
||||
std::vector<HCounterState<HammingComputer>> cs;
|
||||
for (size_t i = 0; i < nx; ++i) {
|
||||
cs.push_back(HCounterState<HammingComputer>(
|
||||
all_counters.data() + i * nBuckets,
|
||||
all_ids_per_dis.get() + i * nBuckets * k,
|
||||
x + i * ivf.code_size,
|
||||
ivf.d,
|
||||
x + i * ivf->code_size,
|
||||
ivf->d,
|
||||
k));
|
||||
}
|
||||
|
||||
|
@ -508,27 +509,28 @@ void search_knn_hamming_count(
|
|||
continue;
|
||||
}
|
||||
FAISS_THROW_IF_NOT_FMT(
|
||||
key < (idx_t)ivf.nlist,
|
||||
key < (idx_t)ivf->nlist,
|
||||
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
|
||||
key,
|
||||
ik,
|
||||
ivf.nlist);
|
||||
ivf->nlist);
|
||||
|
||||
nlistv++;
|
||||
size_t list_size = ivf.invlists->list_size(key);
|
||||
InvertedLists::ScopedCodes scodes(ivf.invlists, key);
|
||||
size_t list_size = ivf->invlists->list_size(key);
|
||||
InvertedLists::ScopedCodes scodes(ivf->invlists, key);
|
||||
const uint8_t* list_vecs = scodes.get();
|
||||
const idx_t* ids =
|
||||
store_pairs ? nullptr : ivf.invlists->get_ids(key);
|
||||
store_pairs ? nullptr : ivf->invlists->get_ids(key);
|
||||
|
||||
for (size_t j = 0; j < list_size; j++) {
|
||||
const uint8_t* yj = list_vecs + ivf.code_size * j;
|
||||
const uint8_t* yj = list_vecs + ivf->code_size * j;
|
||||
|
||||
idx_t id = store_pairs ? (key << 32 | j) : ids[j];
|
||||
csi.update_counter(yj, id);
|
||||
}
|
||||
if (ids)
|
||||
ivf.invlists->release_ids(key, ids);
|
||||
if (ids) {
|
||||
ivf->invlists->release_ids(key, ids);
|
||||
}
|
||||
|
||||
nscan += list_size;
|
||||
if (max_codes && nscan >= max_codes)
|
||||
|
@ -634,7 +636,7 @@ struct BlockSearchVariableK {
|
|||
|
||||
template <class HammingComputer>
|
||||
void search_knn_hamming_per_invlist(
|
||||
const IndexBinaryIVF& ivf,
|
||||
const IndexBinaryIVF* ivf,
|
||||
size_t n,
|
||||
const uint8_t* __restrict x,
|
||||
idx_t k,
|
||||
|
@ -644,12 +646,12 @@ void search_knn_hamming_per_invlist(
|
|||
idx_t* __restrict labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters* params) {
|
||||
idx_t nprobe = params ? params->nprobe : ivf.nprobe;
|
||||
nprobe = std::min((idx_t)ivf.nlist, nprobe);
|
||||
idx_t max_codes = params ? params->max_codes : ivf.max_codes;
|
||||
idx_t nprobe = params ? params->nprobe : ivf->nprobe;
|
||||
nprobe = std::min((idx_t)ivf->nlist, nprobe);
|
||||
idx_t max_codes = params ? params->max_codes : ivf->max_codes;
|
||||
FAISS_THROW_IF_NOT(max_codes == 0);
|
||||
FAISS_THROW_IF_NOT(!store_pairs);
|
||||
MetricType metric_type = ivf.metric_type;
|
||||
MetricType metric_type = ivf->metric_type;
|
||||
|
||||
// reorder buckets
|
||||
std::vector<int64_t> lims(n + 1);
|
||||
|
@ -658,18 +660,18 @@ void search_knn_hamming_per_invlist(
|
|||
for (idx_t i = 0; i < n * nprobe; i++) {
|
||||
keys[i] = keys_in[i];
|
||||
}
|
||||
matrix_bucket_sort_inplace(n, nprobe, keys, ivf.nlist, lims.data(), 0);
|
||||
matrix_bucket_sort_inplace(n, nprobe, keys, ivf->nlist, lims.data(), 0);
|
||||
|
||||
using C = CMax<int32_t, idx_t>;
|
||||
heap_heapify<C>(n * k, distances, labels);
|
||||
const size_t code_size = ivf.code_size;
|
||||
const size_t code_size = ivf->code_size;
|
||||
|
||||
for (idx_t l = 0; l < ivf.nlist; l++) {
|
||||
for (idx_t l = 0; l < ivf->nlist; l++) {
|
||||
idx_t l0 = lims[l], nq = lims[l + 1] - l0;
|
||||
|
||||
InvertedLists::ScopedCodes scodes(ivf.invlists, l);
|
||||
InvertedLists::ScopedIds sidx(ivf.invlists, l);
|
||||
idx_t nb = ivf.invlists->list_size(l);
|
||||
InvertedLists::ScopedCodes scodes(ivf->invlists, l);
|
||||
InvertedLists::ScopedIds sidx(ivf->invlists, l);
|
||||
idx_t nb = ivf->invlists->list_size(l);
|
||||
const uint8_t* bcodes = scodes.get();
|
||||
const idx_t* ids = sidx.get();
|
||||
|
||||
|
@ -735,151 +737,70 @@ void search_knn_hamming_per_invlist(
|
|||
}
|
||||
}
|
||||
|
||||
template <bool store_pairs>
|
||||
void search_knn_hamming_count_1(
|
||||
const IndexBinaryIVF& ivf,
|
||||
size_t nx,
|
||||
const uint8_t* x,
|
||||
const idx_t* keys,
|
||||
int k,
|
||||
int32_t* distances,
|
||||
idx_t* labels,
|
||||
const IVFSearchParameters* params) {
|
||||
switch (ivf.code_size) {
|
||||
#define HANDLE_CS(cs) \
|
||||
case cs: \
|
||||
search_knn_hamming_count<HammingComputer##cs, store_pairs>( \
|
||||
ivf, nx, x, keys, k, distances, labels, params); \
|
||||
break;
|
||||
HANDLE_CS(4);
|
||||
HANDLE_CS(8);
|
||||
HANDLE_CS(16);
|
||||
HANDLE_CS(20);
|
||||
HANDLE_CS(32);
|
||||
HANDLE_CS(64);
|
||||
#undef HANDLE_CS
|
||||
default:
|
||||
search_knn_hamming_count<HammingComputerDefault, store_pairs>(
|
||||
ivf, nx, x, keys, k, distances, labels, params);
|
||||
break;
|
||||
}
|
||||
}
|
||||
struct Run_search_knn_hamming_per_invlist {
|
||||
using T = void;
|
||||
|
||||
void search_knn_hamming_per_invlist_1(
|
||||
const IndexBinaryIVF& ivf,
|
||||
size_t n,
|
||||
const uint8_t* x,
|
||||
idx_t k,
|
||||
const idx_t* keys,
|
||||
const int32_t* coarse_dis,
|
||||
int32_t* distances,
|
||||
idx_t* labels,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters* params) {
|
||||
switch (ivf.code_size) {
|
||||
#define HANDLE_CS(cs) \
|
||||
case cs: \
|
||||
search_knn_hamming_per_invlist<HammingComputer##cs>( \
|
||||
ivf, \
|
||||
n, \
|
||||
x, \
|
||||
k, \
|
||||
keys, \
|
||||
coarse_dis, \
|
||||
distances, \
|
||||
labels, \
|
||||
store_pairs, \
|
||||
params); \
|
||||
break;
|
||||
HANDLE_CS(4);
|
||||
HANDLE_CS(8);
|
||||
HANDLE_CS(16);
|
||||
HANDLE_CS(20);
|
||||
HANDLE_CS(32);
|
||||
HANDLE_CS(64);
|
||||
#undef HANDLE_CS
|
||||
default:
|
||||
search_knn_hamming_per_invlist<HammingComputerDefault>(
|
||||
ivf,
|
||||
n,
|
||||
x,
|
||||
k,
|
||||
keys,
|
||||
coarse_dis,
|
||||
distances,
|
||||
labels,
|
||||
store_pairs,
|
||||
params);
|
||||
break;
|
||||
template <class HammingComputer, class... Types>
|
||||
void f(Types... args) {
|
||||
search_knn_hamming_per_invlist<HammingComputer>(args...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool store_pairs>
|
||||
struct Run_search_knn_hamming_count {
|
||||
using T = void;
|
||||
|
||||
template <class HammingComputer, class... Types>
|
||||
void f(Types... args) {
|
||||
search_knn_hamming_count<HammingComputer, store_pairs>(args...);
|
||||
}
|
||||
};
|
||||
|
||||
struct BuildScanner {
|
||||
using T = BinaryInvertedListScanner*;
|
||||
|
||||
template <class HammingComputer>
|
||||
T f(size_t code_size, bool store_pairs) {
|
||||
return new IVFBinaryScannerL2<HammingComputer>(code_size, store_pairs);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
|
||||
bool store_pairs) const {
|
||||
#define HC(name) return new IVFBinaryScannerL2<name>(code_size, store_pairs)
|
||||
switch (code_size) {
|
||||
case 4:
|
||||
HC(HammingComputer4);
|
||||
case 8:
|
||||
HC(HammingComputer8);
|
||||
case 16:
|
||||
HC(HammingComputer16);
|
||||
case 20:
|
||||
HC(HammingComputer20);
|
||||
case 32:
|
||||
HC(HammingComputer32);
|
||||
case 64:
|
||||
HC(HammingComputer64);
|
||||
default:
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
#undef HC
|
||||
BuildScanner bs;
|
||||
return dispatch_HammingComputer(code_size, bs, code_size, store_pairs);
|
||||
}
|
||||
|
||||
void IndexBinaryIVF::search_preassigned(
|
||||
idx_t n,
|
||||
const uint8_t* x,
|
||||
idx_t k,
|
||||
const idx_t* idx,
|
||||
const int32_t* coarse_dis,
|
||||
int32_t* distances,
|
||||
idx_t* labels,
|
||||
const idx_t* cidx,
|
||||
const int32_t* cdis,
|
||||
int32_t* dis,
|
||||
idx_t* idx,
|
||||
bool store_pairs,
|
||||
const IVFSearchParameters* params) const {
|
||||
if (per_invlist_search) {
|
||||
search_knn_hamming_per_invlist_1(
|
||||
*this,
|
||||
n,
|
||||
x,
|
||||
k,
|
||||
idx,
|
||||
coarse_dis,
|
||||
distances,
|
||||
labels,
|
||||
store_pairs,
|
||||
params);
|
||||
Run_search_knn_hamming_per_invlist r;
|
||||
// clang-format off
|
||||
dispatch_HammingComputer(
|
||||
code_size, r, this, n, x, k,
|
||||
cidx, cdis, dis, idx, store_pairs, params);
|
||||
// clang-format on
|
||||
} else if (use_heap) {
|
||||
search_knn_hamming_heap(
|
||||
*this,
|
||||
n,
|
||||
x,
|
||||
k,
|
||||
idx,
|
||||
coarse_dis,
|
||||
distances,
|
||||
labels,
|
||||
store_pairs,
|
||||
params);
|
||||
} else {
|
||||
if (store_pairs) {
|
||||
search_knn_hamming_count_1<true>(
|
||||
*this, n, x, idx, k, distances, labels, params);
|
||||
} else {
|
||||
search_knn_hamming_count_1<false>(
|
||||
*this, n, x, idx, k, distances, labels, params);
|
||||
}
|
||||
this, n, x, k, cidx, cdis, dis, idx, store_pairs, params);
|
||||
} else if (store_pairs) { // !use_heap && store_pairs
|
||||
Run_search_knn_hamming_count<true> r;
|
||||
dispatch_HammingComputer(
|
||||
code_size, r, this, n, x, cidx, k, dis, idx, params);
|
||||
} else { // !use_heap && !store_pairs
|
||||
Run_search_knn_hamming_count<false> r;
|
||||
dispatch_HammingComputer(
|
||||
code_size, r, this, n, x, cidx, k, dis, idx, params);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1154,30 +1154,23 @@ struct IVFPQScannerT : QueryTables {
|
|||
{ indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; }
|
||||
}
|
||||
|
||||
template <class SearchResultType>
|
||||
struct Run_scan_list_polysemous_hc {
|
||||
using T = void;
|
||||
template <class HammingComputer, class... Types>
|
||||
void f(const IVFPQScannerT* scanner, Types... args) {
|
||||
scanner->scan_list_polysemous_hc<HammingComputer, SearchResultType>(
|
||||
args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <class SearchResultType>
|
||||
void scan_list_polysemous(
|
||||
size_t ncode,
|
||||
const uint8_t* codes,
|
||||
SearchResultType& res) const {
|
||||
switch (pq.code_size) {
|
||||
#define HANDLE_CODE_SIZE(cs) \
|
||||
case cs: \
|
||||
scan_list_polysemous_hc<HammingComputer##cs, SearchResultType>( \
|
||||
ncode, codes, res); \
|
||||
break
|
||||
HANDLE_CODE_SIZE(4);
|
||||
HANDLE_CODE_SIZE(8);
|
||||
HANDLE_CODE_SIZE(16);
|
||||
HANDLE_CODE_SIZE(20);
|
||||
HANDLE_CODE_SIZE(32);
|
||||
HANDLE_CODE_SIZE(64);
|
||||
#undef HANDLE_CODE_SIZE
|
||||
default:
|
||||
scan_list_polysemous_hc<
|
||||
HammingComputerDefault,
|
||||
SearchResultType>(ncode, codes, res);
|
||||
break;
|
||||
}
|
||||
Run_scan_list_polysemous_hc<SearchResultType> r;
|
||||
dispatch_HammingComputer(pq.code_size, r, this, ncode, codes, res);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -288,26 +288,23 @@ struct IVFScanner : InvertedListScanner {
|
|||
}
|
||||
};
|
||||
|
||||
struct BuildScanner {
|
||||
using T = InvertedListScanner*;
|
||||
|
||||
template <class HammingComputer>
|
||||
static T f(const IndexIVFSpectralHash* index, bool store_pairs) {
|
||||
return new IVFScanner<HammingComputer>(index, store_pairs);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
|
||||
bool store_pairs,
|
||||
const IDSelector* sel) const {
|
||||
FAISS_THROW_IF_NOT(!sel);
|
||||
switch (code_size) {
|
||||
#define HANDLE_CODE_SIZE(cs) \
|
||||
case cs: \
|
||||
return new IVFScanner<HammingComputer##cs>(this, store_pairs)
|
||||
HANDLE_CODE_SIZE(4);
|
||||
HANDLE_CODE_SIZE(8);
|
||||
HANDLE_CODE_SIZE(16);
|
||||
HANDLE_CODE_SIZE(20);
|
||||
HANDLE_CODE_SIZE(32);
|
||||
HANDLE_CODE_SIZE(64);
|
||||
#undef HANDLE_CODE_SIZE
|
||||
default:
|
||||
return new IVFScanner<HammingComputerDefault>(this, store_pairs);
|
||||
}
|
||||
BuildScanner bs;
|
||||
return dispatch_HammingComputer(code_size, bs, this, store_pairs);
|
||||
}
|
||||
|
||||
void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) {
|
||||
|
|
|
@ -263,21 +263,23 @@ void IndexPQStats::reset() {
|
|||
|
||||
IndexPQStats indexPQ_stats;
|
||||
|
||||
namespace {
|
||||
|
||||
template <class HammingComputer>
|
||||
static size_t polysemous_inner_loop(
|
||||
const IndexPQ& index,
|
||||
size_t polysemous_inner_loop(
|
||||
const IndexPQ* index,
|
||||
const float* dis_table_qi,
|
||||
const uint8_t* q_code,
|
||||
size_t k,
|
||||
float* heap_dis,
|
||||
int64_t* heap_ids,
|
||||
int ht) {
|
||||
int M = index.pq.M;
|
||||
int code_size = index.pq.code_size;
|
||||
int ksub = index.pq.ksub;
|
||||
size_t ntotal = index.ntotal;
|
||||
int M = index->pq.M;
|
||||
int code_size = index->pq.code_size;
|
||||
int ksub = index->pq.ksub;
|
||||
size_t ntotal = index->ntotal;
|
||||
|
||||
const uint8_t* b_code = index.codes.data();
|
||||
const uint8_t* b_code = index->codes.data();
|
||||
|
||||
size_t n_pass_i = 0;
|
||||
|
||||
|
@ -305,6 +307,16 @@ static size_t polysemous_inner_loop(
|
|||
return n_pass_i;
|
||||
}
|
||||
|
||||
struct Run_polysemous_inner_loop {
|
||||
using T = size_t;
|
||||
template <class HammingComputer, class... Types>
|
||||
size_t f(Types... args) {
|
||||
return polysemous_inner_loop<HammingComputer>(args...);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void IndexPQ::search_core_polysemous(
|
||||
idx_t n,
|
||||
const float* x,
|
||||
|
@ -355,45 +367,24 @@ void IndexPQ::search_core_polysemous(
|
|||
maxheap_heapify(k, heap_dis, heap_ids);
|
||||
|
||||
if (!generalized_hamming) {
|
||||
switch (pq.code_size) {
|
||||
#define DISPATCH(cs) \
|
||||
case cs: \
|
||||
n_pass += polysemous_inner_loop<HammingComputer##cs>( \
|
||||
*this, \
|
||||
dis_table_qi, \
|
||||
q_code, \
|
||||
k, \
|
||||
heap_dis, \
|
||||
heap_ids, \
|
||||
polysemous_ht); \
|
||||
break;
|
||||
DISPATCH(4)
|
||||
DISPATCH(8)
|
||||
DISPATCH(16)
|
||||
DISPATCH(32)
|
||||
DISPATCH(20)
|
||||
default:
|
||||
if (pq.code_size % 4 == 0) {
|
||||
n_pass += polysemous_inner_loop<HammingComputerDefault>(
|
||||
*this,
|
||||
dis_table_qi,
|
||||
q_code,
|
||||
k,
|
||||
heap_dis,
|
||||
heap_ids,
|
||||
polysemous_ht);
|
||||
} else {
|
||||
bad_code_size++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
#undef DISPATCH
|
||||
Run_polysemous_inner_loop r;
|
||||
n_pass += dispatch_HammingComputer(
|
||||
pq.code_size,
|
||||
r,
|
||||
this,
|
||||
dis_table_qi,
|
||||
q_code,
|
||||
k,
|
||||
heap_dis,
|
||||
heap_ids,
|
||||
polysemous_ht);
|
||||
|
||||
} else { // generalized hamming
|
||||
switch (pq.code_size) {
|
||||
#define DISPATCH(cs) \
|
||||
case cs: \
|
||||
n_pass += polysemous_inner_loop<GenHammingComputer##cs>( \
|
||||
*this, \
|
||||
this, \
|
||||
dis_table_qi, \
|
||||
q_code, \
|
||||
k, \
|
||||
|
@ -407,7 +398,7 @@ void IndexPQ::search_core_polysemous(
|
|||
default:
|
||||
if (pq.code_size % 8 == 0) {
|
||||
n_pass += polysemous_inner_loop<GenHammingComputerM8>(
|
||||
*this,
|
||||
this,
|
||||
dis_table_qi,
|
||||
q_code,
|
||||
k,
|
||||
|
|
|
@ -5,14 +5,13 @@
|
|||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// -*- c++ -*-
|
||||
|
||||
/*
|
||||
* Implementation of Hamming related functions (distances, smallest distance
|
||||
* selection with regular heap|radix and probabilistic heap|radix.
|
||||
*
|
||||
* IMPLEMENTATION NOTES
|
||||
* Bitvectors are generally assumed to be multiples of 64 bits.
|
||||
* Optimal speed is typically obtained for vector sizes of multiples of 64
|
||||
* bits.
|
||||
*
|
||||
* hamdis_t is used for distances because at this time
|
||||
* it is not clear how we will need to balance
|
||||
|
@ -20,8 +19,6 @@
|
|||
* - memory usage
|
||||
* - cache-misses when dealing with large volumes of data (lower bits is better)
|
||||
*
|
||||
* The hamdis_t should optimally be compatibe with one of the Torch Storage
|
||||
* (Byte,Short,Long) and therefore should be signed for 2-bytes and 4-bytes
|
||||
*/
|
||||
|
||||
#include <faiss/utils/hamming.h>
|
||||
|
@ -165,9 +162,11 @@ size_t match_hamming_thres(
|
|||
return posm;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/* Return closest neighbors w.r.t Hamming distance, using a heap. */
|
||||
template <class HammingComputer>
|
||||
static void hammings_knn_hc(
|
||||
void hammings_knn_hc(
|
||||
int bytes_per_code,
|
||||
int_maxheap_array_t* __restrict ha,
|
||||
const uint8_t* __restrict bs1,
|
||||
|
@ -234,7 +233,7 @@ static void hammings_knn_hc(
|
|||
|
||||
/* Return closest neighbors w.r.t Hamming distance, using max count. */
|
||||
template <class HammingComputer>
|
||||
static void hammings_knn_mc(
|
||||
void hammings_knn_mc(
|
||||
int bytes_per_code,
|
||||
const uint8_t* __restrict a,
|
||||
const uint8_t* __restrict b,
|
||||
|
@ -287,6 +286,63 @@ static void hammings_knn_mc(
|
|||
}
|
||||
}
|
||||
|
||||
template <class HammingComputer>
|
||||
void hamming_range_search(
|
||||
const uint8_t* a,
|
||||
const uint8_t* b,
|
||||
size_t na,
|
||||
size_t nb,
|
||||
int radius,
|
||||
size_t code_size,
|
||||
RangeSearchResult* res) {
|
||||
#pragma omp parallel
|
||||
{
|
||||
RangeSearchPartialResult pres(res);
|
||||
|
||||
#pragma omp for
|
||||
for (int64_t i = 0; i < na; i++) {
|
||||
HammingComputer hc(a + i * code_size, code_size);
|
||||
const uint8_t* yi = b;
|
||||
RangeQueryResult& qres = pres.new_result(i);
|
||||
|
||||
for (size_t j = 0; j < nb; j++) {
|
||||
int dis = hc.hamming(yi);
|
||||
if (dis < radius) {
|
||||
qres.add(dis, j);
|
||||
}
|
||||
yi += code_size;
|
||||
}
|
||||
}
|
||||
pres.finalize();
|
||||
}
|
||||
}
|
||||
|
||||
struct Run_hammings_knn_hc {
|
||||
using T = void;
|
||||
template <class HammingComputer, class... Types>
|
||||
void f(Types... args) {
|
||||
hammings_knn_hc<HammingComputer>(args...);
|
||||
}
|
||||
};
|
||||
|
||||
struct Run_hammings_knn_mc {
|
||||
using T = void;
|
||||
template <class HammingComputer, class... Types>
|
||||
void f(Types... args) {
|
||||
hammings_knn_mc<HammingComputer>(args...);
|
||||
}
|
||||
};
|
||||
|
||||
struct Run_hamming_range_search {
|
||||
using T = void;
|
||||
template <class HammingComputer, class... Types>
|
||||
void f(Types... args) {
|
||||
hamming_range_search<HammingComputer>(args...);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/* Functions to maps vectors to bits. Assume proper allocation done beforehand,
|
||||
meaning that b should be be able to receive as many bits as x may produce. */
|
||||
|
||||
|
@ -437,28 +493,9 @@ void hammings_knn_hc(
|
|||
size_t ncodes,
|
||||
int order,
|
||||
ApproxTopK_mode_t approx_topk_mode) {
|
||||
switch (ncodes) {
|
||||
case 4:
|
||||
hammings_knn_hc<faiss::HammingComputer4>(
|
||||
4, ha, a, b, nb, order, true, approx_topk_mode);
|
||||
break;
|
||||
case 8:
|
||||
hammings_knn_hc<faiss::HammingComputer8>(
|
||||
8, ha, a, b, nb, order, true, approx_topk_mode);
|
||||
break;
|
||||
case 16:
|
||||
hammings_knn_hc<faiss::HammingComputer16>(
|
||||
16, ha, a, b, nb, order, true, approx_topk_mode);
|
||||
break;
|
||||
case 32:
|
||||
hammings_knn_hc<faiss::HammingComputer32>(
|
||||
32, ha, a, b, nb, order, true, approx_topk_mode);
|
||||
break;
|
||||
default:
|
||||
hammings_knn_hc<faiss::HammingComputerDefault>(
|
||||
ncodes, ha, a, b, nb, order, true, approx_topk_mode);
|
||||
break;
|
||||
}
|
||||
Run_hammings_knn_hc r;
|
||||
dispatch_HammingComputer(
|
||||
ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode);
|
||||
}
|
||||
|
||||
void hammings_knn_mc(
|
||||
|
@ -470,58 +507,9 @@ void hammings_knn_mc(
|
|||
size_t ncodes,
|
||||
int32_t* __restrict distances,
|
||||
int64_t* __restrict labels) {
|
||||
switch (ncodes) {
|
||||
case 4:
|
||||
hammings_knn_mc<faiss::HammingComputer4>(
|
||||
4, a, b, na, nb, k, distances, labels);
|
||||
break;
|
||||
case 8:
|
||||
hammings_knn_mc<faiss::HammingComputer8>(
|
||||
8, a, b, na, nb, k, distances, labels);
|
||||
break;
|
||||
case 16:
|
||||
hammings_knn_mc<faiss::HammingComputer16>(
|
||||
16, a, b, na, nb, k, distances, labels);
|
||||
break;
|
||||
case 32:
|
||||
hammings_knn_mc<faiss::HammingComputer32>(
|
||||
32, a, b, na, nb, k, distances, labels);
|
||||
break;
|
||||
default:
|
||||
hammings_knn_mc<faiss::HammingComputerDefault>(
|
||||
ncodes, a, b, na, nb, k, distances, labels);
|
||||
break;
|
||||
}
|
||||
}
|
||||
template <class HammingComputer>
|
||||
static void hamming_range_search_template(
|
||||
const uint8_t* a,
|
||||
const uint8_t* b,
|
||||
size_t na,
|
||||
size_t nb,
|
||||
int radius,
|
||||
size_t code_size,
|
||||
RangeSearchResult* res) {
|
||||
#pragma omp parallel
|
||||
{
|
||||
RangeSearchPartialResult pres(res);
|
||||
|
||||
#pragma omp for
|
||||
for (int64_t i = 0; i < na; i++) {
|
||||
HammingComputer hc(a + i * code_size, code_size);
|
||||
const uint8_t* yi = b;
|
||||
RangeQueryResult& qres = pres.new_result(i);
|
||||
|
||||
for (size_t j = 0; j < nb; j++) {
|
||||
int dis = hc.hamming(yi);
|
||||
if (dis < radius) {
|
||||
qres.add(dis, j);
|
||||
}
|
||||
yi += code_size;
|
||||
}
|
||||
}
|
||||
pres.finalize();
|
||||
}
|
||||
Run_hammings_knn_mc r;
|
||||
dispatch_HammingComputer(
|
||||
ncodes, r, ncodes, a, b, na, nb, k, distances, labels);
|
||||
}
|
||||
|
||||
void hamming_range_search(
|
||||
|
@ -532,27 +520,9 @@ void hamming_range_search(
|
|||
int radius,
|
||||
size_t code_size,
|
||||
RangeSearchResult* result) {
|
||||
#define HC(name) \
|
||||
hamming_range_search_template<name>(a, b, na, nb, radius, code_size, result)
|
||||
|
||||
switch (code_size) {
|
||||
case 4:
|
||||
HC(HammingComputer4);
|
||||
break;
|
||||
case 8:
|
||||
HC(HammingComputer8);
|
||||
break;
|
||||
case 16:
|
||||
HC(HammingComputer16);
|
||||
break;
|
||||
case 32:
|
||||
HC(HammingComputer32);
|
||||
break;
|
||||
default:
|
||||
HC(HammingComputerDefault);
|
||||
break;
|
||||
}
|
||||
#undef HC
|
||||
Run_hamming_range_search r;
|
||||
dispatch_HammingComputer(
|
||||
code_size, r, a, b, na, nb, radius, code_size, result);
|
||||
}
|
||||
|
||||
/* Count number of matches given a max threshold */
|
||||
|
|
|
@ -345,93 +345,6 @@ struct HammingComputerDefault {
|
|||
}
|
||||
};
|
||||
|
||||
// more inefficient than HammingComputerDefault (obsolete)
|
||||
struct HammingComputerM8 {
|
||||
const uint64_t* a;
|
||||
int n;
|
||||
|
||||
HammingComputerM8() {}
|
||||
|
||||
HammingComputerM8(const uint8_t* a8, int code_size) {
|
||||
set(a8, code_size);
|
||||
}
|
||||
|
||||
void set(const uint8_t* a8, int code_size) {
|
||||
assert(code_size % 8 == 0);
|
||||
a = (uint64_t*)a8;
|
||||
n = code_size / 8;
|
||||
}
|
||||
|
||||
int hamming(const uint8_t* b8) const {
|
||||
const uint64_t* b = (uint64_t*)b8;
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
accu += popcount64(a[i] ^ b[i]);
|
||||
return accu;
|
||||
}
|
||||
|
||||
inline int get_code_size() const {
|
||||
return n * 8;
|
||||
}
|
||||
};
|
||||
|
||||
// more inefficient than HammingComputerDefault (obsolete)
|
||||
struct HammingComputerM4 {
|
||||
const uint32_t* a;
|
||||
int n;
|
||||
|
||||
HammingComputerM4() {}
|
||||
|
||||
HammingComputerM4(const uint8_t* a4, int code_size) {
|
||||
set(a4, code_size);
|
||||
}
|
||||
|
||||
void set(const uint8_t* a4, int code_size) {
|
||||
assert(code_size % 4 == 0);
|
||||
a = (uint32_t*)a4;
|
||||
n = code_size / 4;
|
||||
}
|
||||
|
||||
int hamming(const uint8_t* b8) const {
|
||||
const uint32_t* b = (uint32_t*)b8;
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
accu += popcount64(a[i] ^ b[i]);
|
||||
return accu;
|
||||
}
|
||||
|
||||
inline int get_code_size() const {
|
||||
return n * 4;
|
||||
}
|
||||
};
|
||||
|
||||
/***************************************************************************
|
||||
* Equivalence with a template class when code size is known at compile time
|
||||
**************************************************************************/
|
||||
|
||||
// default template
|
||||
template <int CODE_SIZE>
|
||||
struct HammingComputer : HammingComputerDefault {
|
||||
HammingComputer(const uint8_t* a, int code_size)
|
||||
: HammingComputerDefault(a, code_size) {}
|
||||
};
|
||||
|
||||
#define SPECIALIZED_HC(CODE_SIZE) \
|
||||
template <> \
|
||||
struct HammingComputer<CODE_SIZE> : HammingComputer##CODE_SIZE { \
|
||||
HammingComputer(const uint8_t* a) \
|
||||
: HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \
|
||||
}
|
||||
|
||||
SPECIALIZED_HC(4);
|
||||
SPECIALIZED_HC(8);
|
||||
SPECIALIZED_HC(16);
|
||||
SPECIALIZED_HC(20);
|
||||
SPECIALIZED_HC(32);
|
||||
SPECIALIZED_HC(64);
|
||||
|
||||
#undef SPECIALIZED_HC
|
||||
|
||||
/***************************************************************************
|
||||
* generalized Hamming = number of bytes that are different between
|
||||
* two codes.
|
||||
|
|
|
@ -17,6 +17,7 @@ using hamdis_t = int32_t;
|
|||
|
||||
namespace faiss {
|
||||
|
||||
// trust the compiler to provide efficient popcount implementations
|
||||
inline int popcount32(uint32_t x) {
|
||||
return __builtin_popcount(x);
|
||||
}
|
||||
|
|
|
@ -329,93 +329,6 @@ struct HammingComputerDefault {
|
|||
}
|
||||
};
|
||||
|
||||
// more inefficient than HammingComputerDefault (obsolete)
|
||||
struct HammingComputerM8 {
|
||||
const uint64_t* a;
|
||||
int n;
|
||||
|
||||
HammingComputerM8() {}
|
||||
|
||||
HammingComputerM8(const uint8_t* a8, int code_size) {
|
||||
set(a8, code_size);
|
||||
}
|
||||
|
||||
void set(const uint8_t* a8, int code_size) {
|
||||
assert(code_size % 8 == 0);
|
||||
a = (uint64_t*)a8;
|
||||
n = code_size / 8;
|
||||
}
|
||||
|
||||
int hamming(const uint8_t* b8) const {
|
||||
const uint64_t* b = (uint64_t*)b8;
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
accu += popcount64(a[i] ^ b[i]);
|
||||
return accu;
|
||||
}
|
||||
|
||||
inline int get_code_size() const {
|
||||
return n * 8;
|
||||
}
|
||||
};
|
||||
|
||||
// more inefficient than HammingComputerDefault (obsolete)
|
||||
struct HammingComputerM4 {
|
||||
const uint32_t* a;
|
||||
int n;
|
||||
|
||||
HammingComputerM4() {}
|
||||
|
||||
HammingComputerM4(const uint8_t* a4, int code_size) {
|
||||
set(a4, code_size);
|
||||
}
|
||||
|
||||
void set(const uint8_t* a4, int code_size) {
|
||||
assert(code_size % 4 == 0);
|
||||
a = (uint32_t*)a4;
|
||||
n = code_size / 4;
|
||||
}
|
||||
|
||||
int hamming(const uint8_t* b8) const {
|
||||
const uint32_t* b = (uint32_t*)b8;
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
accu += popcount64(a[i] ^ b[i]);
|
||||
return accu;
|
||||
}
|
||||
|
||||
inline int get_code_size() const {
|
||||
return n * 4;
|
||||
}
|
||||
};
|
||||
|
||||
/***************************************************************************
|
||||
* Equivalence with a template class when code size is known at compile time
|
||||
**************************************************************************/
|
||||
|
||||
// default template
|
||||
template <int CODE_SIZE>
|
||||
struct HammingComputer : HammingComputerDefault {
|
||||
HammingComputer(const uint8_t* a, int code_size)
|
||||
: HammingComputerDefault(a, code_size) {}
|
||||
};
|
||||
|
||||
#define SPECIALIZED_HC(CODE_SIZE) \
|
||||
template <> \
|
||||
struct HammingComputer<CODE_SIZE> : HammingComputer##CODE_SIZE { \
|
||||
HammingComputer(const uint8_t* a) \
|
||||
: HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \
|
||||
}
|
||||
|
||||
SPECIALIZED_HC(4);
|
||||
SPECIALIZED_HC(8);
|
||||
SPECIALIZED_HC(16);
|
||||
SPECIALIZED_HC(20);
|
||||
SPECIALIZED_HC(32);
|
||||
SPECIALIZED_HC(64);
|
||||
|
||||
#undef SPECIALIZED_HC
|
||||
|
||||
/***************************************************************************
|
||||
* generalized Hamming = number of bytes that are different between
|
||||
* two codes.
|
||||
|
|
|
@ -23,4 +23,61 @@
|
|||
#include <faiss/utils/hamming_distance/generic-inl.h>
|
||||
#endif
|
||||
|
||||
namespace faiss {
|
||||
|
||||
/***************************************************************************
|
||||
* Equivalence with a template class when code size is known at compile time
|
||||
**************************************************************************/
|
||||
|
||||
// default template
|
||||
template <int CODE_SIZE>
|
||||
struct HammingComputer : HammingComputerDefault {
|
||||
HammingComputer(const uint8_t* a, int code_size)
|
||||
: HammingComputerDefault(a, code_size) {}
|
||||
};
|
||||
|
||||
#define SPECIALIZED_HC(CODE_SIZE) \
|
||||
template <> \
|
||||
struct HammingComputer<CODE_SIZE> : HammingComputer##CODE_SIZE { \
|
||||
HammingComputer(const uint8_t* a) \
|
||||
: HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \
|
||||
}
|
||||
|
||||
SPECIALIZED_HC(4);
|
||||
SPECIALIZED_HC(8);
|
||||
SPECIALIZED_HC(16);
|
||||
SPECIALIZED_HC(20);
|
||||
SPECIALIZED_HC(32);
|
||||
SPECIALIZED_HC(64);
|
||||
|
||||
#undef SPECIALIZED_HC
|
||||
|
||||
/***************************************************************************
|
||||
* Dispatching function that takes a code size and a consumer object
|
||||
* the consumer object should contain a retun type t and a operation template
|
||||
* function f() that to be called to perform the operation.
|
||||
**************************************************************************/
|
||||
|
||||
template <class Consumer, class... Types>
|
||||
typename Consumer::T dispatch_HammingComputer(
|
||||
int code_size,
|
||||
Consumer& consumer,
|
||||
Types... args) {
|
||||
switch (code_size) {
|
||||
#define DISPATCH_HC(CODE_SIZE) \
|
||||
case CODE_SIZE: \
|
||||
return consumer.template f<HammingComputer##CODE_SIZE>(args...);
|
||||
DISPATCH_HC(4);
|
||||
DISPATCH_HC(8);
|
||||
DISPATCH_HC(16);
|
||||
DISPATCH_HC(20);
|
||||
DISPATCH_HC(32);
|
||||
DISPATCH_HC(64);
|
||||
default:
|
||||
return consumer.template f<HammingComputerDefault>(args...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
||||
#endif
|
||||
|
|
|
@ -392,109 +392,6 @@ struct HammingComputerDefault {
|
|||
}
|
||||
};
|
||||
|
||||
// more inefficient than HammingComputerDefault (obsolete)
|
||||
struct HammingComputerM8 {
|
||||
const uint64_t* a;
|
||||
int n;
|
||||
|
||||
HammingComputerM8() {}
|
||||
|
||||
HammingComputerM8(const uint8_t* a8, int code_size) {
|
||||
set(a8, code_size);
|
||||
}
|
||||
|
||||
void set(const uint8_t* a8, int code_size) {
|
||||
assert(code_size % 8 == 0);
|
||||
a = (uint64_t*)a8;
|
||||
n = code_size / 8;
|
||||
}
|
||||
|
||||
int hamming(const uint8_t* b8) const {
|
||||
const uint64_t* b = (uint64_t*)b8;
|
||||
int n4 = (n / 4) * 4;
|
||||
int accu = 0;
|
||||
|
||||
int i = 0;
|
||||
for (; i < n4; i += 4) {
|
||||
accu += ::faiss::hamming<256>(a + i, b + i);
|
||||
}
|
||||
for (; i < n; i++) {
|
||||
accu += popcount64(a[i] ^ b[i]);
|
||||
}
|
||||
return accu;
|
||||
}
|
||||
|
||||
inline int get_code_size() const {
|
||||
return n * 8;
|
||||
}
|
||||
};
|
||||
|
||||
// more inefficient than HammingComputerDefault (obsolete)
|
||||
struct HammingComputerM4 {
|
||||
const uint32_t* a;
|
||||
int n;
|
||||
|
||||
HammingComputerM4() {}
|
||||
|
||||
HammingComputerM4(const uint8_t* a4, int code_size) {
|
||||
set(a4, code_size);
|
||||
}
|
||||
|
||||
void set(const uint8_t* a4, int code_size) {
|
||||
assert(code_size % 4 == 0);
|
||||
a = (uint32_t*)a4;
|
||||
n = code_size / 4;
|
||||
}
|
||||
|
||||
int hamming(const uint8_t* b8) const {
|
||||
const uint32_t* b = (uint32_t*)b8;
|
||||
|
||||
int n8 = (n / 8) * 8;
|
||||
int accu = 0;
|
||||
|
||||
int i = 0;
|
||||
for (; i < n8; i += 8) {
|
||||
accu += ::faiss::hamming<256>(
|
||||
(const uint64_t*)(a + i), (const uint64_t*)(b + i));
|
||||
}
|
||||
for (; i < n; i++) {
|
||||
accu += popcount64(a[i] ^ b[i]);
|
||||
}
|
||||
return accu;
|
||||
}
|
||||
|
||||
inline int get_code_size() const {
|
||||
return n * 4;
|
||||
}
|
||||
};
|
||||
|
||||
/***************************************************************************
|
||||
* Equivalence with a template class when code size is known at compile time
|
||||
**************************************************************************/
|
||||
|
||||
// default template
|
||||
template <int CODE_SIZE>
|
||||
struct HammingComputer : HammingComputerDefault {
|
||||
HammingComputer(const uint8_t* a, int code_size)
|
||||
: HammingComputerDefault(a, code_size) {}
|
||||
};
|
||||
|
||||
#define SPECIALIZED_HC(CODE_SIZE) \
|
||||
template <> \
|
||||
struct HammingComputer<CODE_SIZE> : HammingComputer##CODE_SIZE { \
|
||||
HammingComputer(const uint8_t* a) \
|
||||
: HammingComputer##CODE_SIZE(a, CODE_SIZE) {} \
|
||||
}
|
||||
|
||||
SPECIALIZED_HC(4);
|
||||
SPECIALIZED_HC(8);
|
||||
SPECIALIZED_HC(16);
|
||||
SPECIALIZED_HC(20);
|
||||
SPECIALIZED_HC(32);
|
||||
SPECIALIZED_HC(64);
|
||||
|
||||
#undef SPECIALIZED_HC
|
||||
|
||||
/***************************************************************************
|
||||
* generalized Hamming = number of bytes that are different between
|
||||
* two codes.
|
||||
|
|
Loading…
Reference in New Issue