Improve performance of Hamming computer (#1661)
Summary: Signed-off-by: shengjun.li <shengjun.li@zilliz.com> Improve performance of Hamming computer Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1661 Reviewed By: wickedfoo Differential Revision: D26222892 Pulled By: mdouze fbshipit-source-id: 5c1228b9e6c0f196ebcdfb0227ecdf7a02610871pull/1669/head
parent
8894ba7488
commit
cf33102a7e
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
|
||||
#include <cstdio>
|
||||
#include <omp.h>
|
||||
#include <vector>
|
||||
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
#include <faiss/utils/hamming.h>
|
||||
#include <faiss/utils/random.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
|
||||
using namespace faiss;
|
||||
|
||||
template<class T>
|
||||
void hamming_cpt_test(int code_size, uint8_t* data1, uint8_t* data2, int n, int* rst) {
|
||||
T computer(data1, code_size);
|
||||
for (int i = 0; i < n; i++) {
|
||||
rst[i] = computer.hamming(data2);
|
||||
data2 += code_size;
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
size_t n = 4 * 1000 * 1000;
|
||||
|
||||
std::vector<size_t> code_size = {128, 256, 512, 1000};
|
||||
|
||||
std::vector<uint8_t> x(n * code_size.back());
|
||||
byte_rand(x.data(), n, 12345);
|
||||
|
||||
int nrun = 100;
|
||||
for(size_t cs: code_size) {
|
||||
printf("benchmark with code_size=%zd n=%zd nrun=%d\n", cs, n, nrun);
|
||||
|
||||
double tot_t1 = 0, tot_t2 = 0, tot_t3 = 0;
|
||||
#pragma omp parallel reduction(+: tot_t1, tot_t2, tot_t3)
|
||||
{
|
||||
std::vector<int> rst_m4(n);
|
||||
std::vector<int> rst_m8(n);
|
||||
std::vector<int> rst_default(n);
|
||||
|
||||
#pragma omp for
|
||||
for (int run = 0; run < nrun; run++) {
|
||||
double t0, t1, t2, t3;
|
||||
t0 = getmillisecs();
|
||||
|
||||
// new implem from Zilliz
|
||||
hamming_cpt_test<HammingComputerDefault>(cs, x.data(), x.data(), n, rst_default.data());
|
||||
t1 = getmillisecs();
|
||||
|
||||
// M8
|
||||
hamming_cpt_test<HammingComputerM8>(cs, x.data(), x.data(), n, rst_m8.data());
|
||||
t2 = getmillisecs();
|
||||
|
||||
// M4
|
||||
hamming_cpt_test<HammingComputerM4>(cs, x.data(), x.data(), n, rst_m4.data());
|
||||
t3= getmillisecs();
|
||||
|
||||
tot_t1 += t1 - t0;
|
||||
tot_t2 += t2 - t1;
|
||||
tot_t3 += t3 - t2;
|
||||
}
|
||||
|
||||
for (int i=0;i<n;i++){
|
||||
FAISS_THROW_IF_NOT_FMT(
|
||||
(rst_m4[i] == rst_m8[i] && rst_m4[i] == rst_default[i]),
|
||||
"wrong result i=%d, m4 %d m8 %d default %d",
|
||||
i, rst_m4[i], rst_m8[i], rst_default[i]);
|
||||
}
|
||||
}
|
||||
|
||||
printf("Hamming_Dft implem: %.3f ms\n", tot_t1 / nrun);
|
||||
printf("Hamming_M8 implem: %.3f ms\n", tot_t2 / nrun);
|
||||
printf("Hamming_M4 implem: %.3f ms\n", tot_t3 / nrun);
|
||||
|
||||
}
|
||||
return 0;
|
||||
}
|
|
@ -310,11 +310,7 @@ DistanceComputer *IndexBinaryHNSW::get_distance_computer() const {
|
|||
case 64:
|
||||
return new FlatHammingDis<HammingComputer64>(*flat_storage);
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
return new FlatHammingDis<HammingComputerM8>(*flat_storage);
|
||||
} else if (code_size % 4 == 0) {
|
||||
return new FlatHammingDis<HammingComputerM4>(*flat_storage);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
|
||||
|
|
|
@ -195,12 +195,7 @@ search_single_query(const IndexBinaryHash & index, const uint8_t *q,
|
|||
case 16: HC(HammingComputer16); break;
|
||||
case 20: HC(HammingComputer20); break;
|
||||
case 32: HC(HammingComputer32); break;
|
||||
default:
|
||||
if (index.code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault); break;
|
||||
}
|
||||
#undef HC
|
||||
}
|
||||
|
@ -413,12 +408,7 @@ search_1_query_multihash(const IndexBinaryMultiHash & index, const uint8_t *xi,
|
|||
case 16: HC(HammingComputer16); break;
|
||||
case 20: HC(HammingComputer20); break;
|
||||
case 32: HC(HammingComputer32); break;
|
||||
default:
|
||||
if (index.code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault); break;
|
||||
}
|
||||
#undef HC
|
||||
}
|
||||
|
|
|
@ -560,16 +560,8 @@ void search_knn_hamming_count_1 (
|
|||
HANDLE_CS(64);
|
||||
#undef HANDLE_CS
|
||||
default:
|
||||
if (ivf.code_size % 8 == 0) {
|
||||
search_knn_hamming_count<HammingComputerM8, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params);
|
||||
} else if (ivf.code_size % 4 == 0) {
|
||||
search_knn_hamming_count<HammingComputerM4, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params);
|
||||
} else {
|
||||
search_knn_hamming_count<HammingComputerDefault, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params);
|
||||
}
|
||||
search_knn_hamming_count<HammingComputerDefault, store_pairs>
|
||||
(ivf, nx, x, keys, k, distances, labels, params);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -589,14 +581,7 @@ BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScanner
|
|||
case 20: HC(HammingComputer20);
|
||||
case 32: HC(HammingComputer32);
|
||||
case 64: HC(HammingComputer64);
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else if (code_size % 4 == 0) {
|
||||
HC(HammingComputerM4);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault);
|
||||
}
|
||||
#undef HC
|
||||
|
||||
|
|
|
@ -1031,14 +1031,9 @@ struct IVFPQScannerT: QueryTables {
|
|||
HANDLE_CODE_SIZE(64);
|
||||
#undef HANDLE_CODE_SIZE
|
||||
default:
|
||||
if (pq.code_size % 8 == 0)
|
||||
scan_list_polysemous_hc
|
||||
<HammingComputerM8, SearchResultType>
|
||||
(ncode, codes, res);
|
||||
else
|
||||
scan_list_polysemous_hc
|
||||
<HammingComputerM4, SearchResultType>
|
||||
(ncode, codes, res);
|
||||
scan_list_polysemous_hc
|
||||
<HammingComputerDefault, SearchResultType>
|
||||
(ncode, codes, res);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -314,10 +314,8 @@ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner
|
|||
HANDLE_CODE_SIZE(64);
|
||||
#undef HANDLE_CODE_SIZE
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
return new IVFScanner<HammingComputerM8>(this, store_pairs);
|
||||
} else if (code_size % 4 == 0) {
|
||||
return new IVFScanner<HammingComputerM4>(this, store_pairs);
|
||||
if (code_size % 4 == 0) {
|
||||
return new IVFScanner<HammingComputerDefault>(this, store_pairs);
|
||||
} else {
|
||||
FAISS_THROW_MSG("not supported");
|
||||
}
|
||||
|
|
|
@ -416,11 +416,8 @@ void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
|
|||
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
||||
break;
|
||||
default:
|
||||
if (pq.code_size % 8 == 0) {
|
||||
n_pass += polysemous_inner_loop<HammingComputerM8>
|
||||
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
||||
} else if (pq.code_size % 4 == 0) {
|
||||
n_pass += polysemous_inner_loop<HammingComputerM4>
|
||||
if (pq.code_size % 4 == 0) {
|
||||
n_pass += polysemous_inner_loop<HammingComputerDefault>
|
||||
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
||||
} else {
|
||||
FAISS_THROW_FMT(
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
namespace faiss {
|
||||
|
||||
extern const uint8_t hamdis_tab_ham_bytes[256];
|
||||
|
||||
inline BitstringWriter::BitstringWriter(uint8_t *code, size_t code_size):
|
||||
code (code), code_size (code_size), i(0)
|
||||
|
@ -214,10 +215,10 @@ struct HammingComputer64 {
|
|||
|
||||
};
|
||||
|
||||
// very inefficient...
|
||||
struct HammingComputerDefault {
|
||||
const uint8_t *a;
|
||||
int n;
|
||||
const uint8_t *a8;
|
||||
int quotient8;
|
||||
int remainder8;
|
||||
|
||||
HammingComputerDefault () {}
|
||||
|
||||
|
@ -226,19 +227,52 @@ struct HammingComputerDefault {
|
|||
}
|
||||
|
||||
void set (const uint8_t *a8, int code_size) {
|
||||
a = a8;
|
||||
n = code_size;
|
||||
this->a8 = a8;
|
||||
quotient8 = code_size / 8;
|
||||
remainder8 = code_size % 8;
|
||||
}
|
||||
|
||||
int hamming (const uint8_t *b8) const {
|
||||
int accu = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
accu += popcount64 (a[i] ^ b8[i]);
|
||||
|
||||
const uint64_t *a64 = reinterpret_cast<const uint64_t *>(a8);
|
||||
const uint64_t *b64 = reinterpret_cast<const uint64_t *>(b8);
|
||||
int i = 0, len = quotient8;
|
||||
switch (len & 7) {
|
||||
default:
|
||||
while (len > 7) {
|
||||
len -= 8;
|
||||
accu += popcount64(a64[i] ^ b64[i]); i++;
|
||||
case 7: accu += popcount64(a64[i] ^ b64[i]); i++;
|
||||
case 6: accu += popcount64(a64[i] ^ b64[i]); i++;
|
||||
case 5: accu += popcount64(a64[i] ^ b64[i]); i++;
|
||||
case 4: accu += popcount64(a64[i] ^ b64[i]); i++;
|
||||
case 3: accu += popcount64(a64[i] ^ b64[i]); i++;
|
||||
case 2: accu += popcount64(a64[i] ^ b64[i]); i++;
|
||||
case 1: accu += popcount64(a64[i] ^ b64[i]); i++;
|
||||
}
|
||||
}
|
||||
if (remainder8) {
|
||||
const uint8_t *a = a8 + 8 * quotient8;
|
||||
const uint8_t *b = b8 + 8 * quotient8;
|
||||
switch (remainder8) {
|
||||
case 7: accu += hamdis_tab_ham_bytes[a[6] ^ b[6]];
|
||||
case 6: accu += hamdis_tab_ham_bytes[a[5] ^ b[5]];
|
||||
case 5: accu += hamdis_tab_ham_bytes[a[4] ^ b[4]];
|
||||
case 4: accu += hamdis_tab_ham_bytes[a[3] ^ b[3]];
|
||||
case 3: accu += hamdis_tab_ham_bytes[a[2] ^ b[2]];
|
||||
case 2: accu += hamdis_tab_ham_bytes[a[1] ^ b[1]];
|
||||
case 1: accu += hamdis_tab_ham_bytes[a[0] ^ b[0]];
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
return accu;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// more inefficient than HammingComputerDefault (obsolete)
|
||||
struct HammingComputerM8 {
|
||||
const uint64_t *a;
|
||||
int n;
|
||||
|
@ -265,7 +299,7 @@ struct HammingComputerM8 {
|
|||
|
||||
};
|
||||
|
||||
// even more inefficient!
|
||||
// more inefficient than HammingComputerDefault (obsolete)
|
||||
struct HammingComputerM4 {
|
||||
const uint32_t *a;
|
||||
int n;
|
||||
|
@ -298,9 +332,9 @@ struct HammingComputerM4 {
|
|||
|
||||
// default template
|
||||
template<int CODE_SIZE>
|
||||
struct HammingComputer: HammingComputerM8 {
|
||||
struct HammingComputer: HammingComputerDefault {
|
||||
HammingComputer (const uint8_t *a, int code_size):
|
||||
HammingComputerM8(a, code_size) {}
|
||||
HammingComputerDefault(a, code_size) {}
|
||||
};
|
||||
|
||||
#define SPECIALIZED_HC(CODE_SIZE) \
|
||||
|
|
|
@ -44,7 +44,7 @@ namespace faiss {
|
|||
|
||||
size_t hamming_batch_size = 65536;
|
||||
|
||||
static const uint8_t hamdis_tab_ham_bytes[256] = {
|
||||
const uint8_t hamdis_tab_ham_bytes[256] = {
|
||||
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
|
||||
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
|
||||
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
|
||||
|
@ -579,14 +579,9 @@ void hammings_knn_hc (
|
|||
(32, ha, a, b, nb, order, true);
|
||||
break;
|
||||
default:
|
||||
if(ncodes % 8 == 0) {
|
||||
hammings_knn_hc<faiss::HammingComputerM8>
|
||||
(ncodes, ha, a, b, nb, order, true);
|
||||
} else {
|
||||
hammings_knn_hc<faiss::HammingComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true);
|
||||
|
||||
}
|
||||
hammings_knn_hc<faiss::HammingComputerDefault>
|
||||
(ncodes, ha, a, b, nb, order, true);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -624,15 +619,10 @@ void hammings_knn_mc(
|
|||
);
|
||||
break;
|
||||
default:
|
||||
if(ncodes % 8 == 0) {
|
||||
hammings_knn_mc<faiss::HammingComputerM8>(
|
||||
ncodes, a, b, na, nb, k, distances, labels
|
||||
);
|
||||
} else {
|
||||
hammings_knn_mc<faiss::HammingComputerDefault>(
|
||||
ncodes, a, b, na, nb, k, distances, labels
|
||||
);
|
||||
}
|
||||
hammings_knn_mc<faiss::HammingComputerDefault>(
|
||||
ncodes, a, b, na, nb, k, distances, labels
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
template <class HammingComputer>
|
||||
|
@ -686,12 +676,7 @@ void hamming_range_search (
|
|||
case 8: HC(HammingComputer8); break;
|
||||
case 16: HC(HammingComputer16); break;
|
||||
case 32: HC(HammingComputer32); break;
|
||||
default:
|
||||
if (code_size % 8 == 0) {
|
||||
HC(HammingComputerM8);
|
||||
} else {
|
||||
HC(HammingComputerDefault);
|
||||
}
|
||||
default: HC(HammingComputerDefault); break;
|
||||
}
|
||||
#undef HC
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue