Improve performance of Hamming computer (#1661)

Summary:
Signed-off-by: shengjun.li <shengjun.li@zilliz.com>

Improve performance of Hamming computer

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1661

Reviewed By: wickedfoo

Differential Revision: D26222892

Pulled By: mdouze

fbshipit-source-id: 5c1228b9e6c0f196ebcdfb0227ecdf7a02610871
pull/1669/head
shengjun.li 2021-02-03 10:30:44 -08:00 committed by Facebook GitHub Bot
parent 8894ba7488
commit cf33102a7e
9 changed files with 150 additions and 86 deletions

View File

@ -0,0 +1,84 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstdio>
#include <omp.h>
#include <vector>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/random.h>
#include <faiss/utils/utils.h>
using namespace faiss;
template<class T>
void hamming_cpt_test(int code_size, uint8_t* data1, uint8_t* data2, int n, int* rst) {
T computer(data1, code_size);
for (int i = 0; i < n; i++) {
rst[i] = computer.hamming(data2);
data2 += code_size;
}
}
int main() {
size_t n = 4 * 1000 * 1000;
std::vector<size_t> code_size = {128, 256, 512, 1000};
std::vector<uint8_t> x(n * code_size.back());
byte_rand(x.data(), n, 12345);
int nrun = 100;
for(size_t cs: code_size) {
printf("benchmark with code_size=%zd n=%zd nrun=%d\n", cs, n, nrun);
double tot_t1 = 0, tot_t2 = 0, tot_t3 = 0;
#pragma omp parallel reduction(+: tot_t1, tot_t2, tot_t3)
{
std::vector<int> rst_m4(n);
std::vector<int> rst_m8(n);
std::vector<int> rst_default(n);
#pragma omp for
for (int run = 0; run < nrun; run++) {
double t0, t1, t2, t3;
t0 = getmillisecs();
// new implem from Zilliz
hamming_cpt_test<HammingComputerDefault>(cs, x.data(), x.data(), n, rst_default.data());
t1 = getmillisecs();
// M8
hamming_cpt_test<HammingComputerM8>(cs, x.data(), x.data(), n, rst_m8.data());
t2 = getmillisecs();
// M4
hamming_cpt_test<HammingComputerM4>(cs, x.data(), x.data(), n, rst_m4.data());
t3= getmillisecs();
tot_t1 += t1 - t0;
tot_t2 += t2 - t1;
tot_t3 += t3 - t2;
}
for (int i=0;i<n;i++){
FAISS_THROW_IF_NOT_FMT(
(rst_m4[i] == rst_m8[i] && rst_m4[i] == rst_default[i]),
"wrong result i=%d, m4 %d m8 %d default %d",
i, rst_m4[i], rst_m8[i], rst_default[i]);
}
}
printf("Hamming_Dft implem: %.3f ms\n", tot_t1 / nrun);
printf("Hamming_M8 implem: %.3f ms\n", tot_t2 / nrun);
printf("Hamming_M4 implem: %.3f ms\n", tot_t3 / nrun);
}
return 0;
}

View File

@ -310,11 +310,7 @@ DistanceComputer *IndexBinaryHNSW::get_distance_computer() const {
case 64:
return new FlatHammingDis<HammingComputer64>(*flat_storage);
default:
if (code_size % 8 == 0) {
return new FlatHammingDis<HammingComputerM8>(*flat_storage);
} else if (code_size % 4 == 0) {
return new FlatHammingDis<HammingComputerM4>(*flat_storage);
}
break;
}
return new FlatHammingDis<HammingComputerDefault>(*flat_storage);

View File

@ -195,12 +195,7 @@ search_single_query(const IndexBinaryHash & index, const uint8_t *q,
case 16: HC(HammingComputer16); break;
case 20: HC(HammingComputer20); break;
case 32: HC(HammingComputer32); break;
default:
if (index.code_size % 8 == 0) {
HC(HammingComputerM8);
} else {
HC(HammingComputerDefault);
}
default: HC(HammingComputerDefault); break;
}
#undef HC
}
@ -413,12 +408,7 @@ search_1_query_multihash(const IndexBinaryMultiHash & index, const uint8_t *xi,
case 16: HC(HammingComputer16); break;
case 20: HC(HammingComputer20); break;
case 32: HC(HammingComputer32); break;
default:
if (index.code_size % 8 == 0) {
HC(HammingComputerM8);
} else {
HC(HammingComputerDefault);
}
default: HC(HammingComputerDefault); break;
}
#undef HC
}

View File

@ -560,16 +560,8 @@ void search_knn_hamming_count_1 (
HANDLE_CS(64);
#undef HANDLE_CS
default:
if (ivf.code_size % 8 == 0) {
search_knn_hamming_count<HammingComputerM8, store_pairs>
(ivf, nx, x, keys, k, distances, labels, params);
} else if (ivf.code_size % 4 == 0) {
search_knn_hamming_count<HammingComputerM4, store_pairs>
(ivf, nx, x, keys, k, distances, labels, params);
} else {
search_knn_hamming_count<HammingComputerDefault, store_pairs>
(ivf, nx, x, keys, k, distances, labels, params);
}
search_knn_hamming_count<HammingComputerDefault, store_pairs>
(ivf, nx, x, keys, k, distances, labels, params);
break;
}
@ -589,14 +581,7 @@ BinaryInvertedListScanner *IndexBinaryIVF::get_InvertedListScanner
case 20: HC(HammingComputer20);
case 32: HC(HammingComputer32);
case 64: HC(HammingComputer64);
default:
if (code_size % 8 == 0) {
HC(HammingComputerM8);
} else if (code_size % 4 == 0) {
HC(HammingComputerM4);
} else {
HC(HammingComputerDefault);
}
default: HC(HammingComputerDefault);
}
#undef HC

View File

@ -1031,14 +1031,9 @@ struct IVFPQScannerT: QueryTables {
HANDLE_CODE_SIZE(64);
#undef HANDLE_CODE_SIZE
default:
if (pq.code_size % 8 == 0)
scan_list_polysemous_hc
<HammingComputerM8, SearchResultType>
(ncode, codes, res);
else
scan_list_polysemous_hc
<HammingComputerM4, SearchResultType>
(ncode, codes, res);
scan_list_polysemous_hc
<HammingComputerDefault, SearchResultType>
(ncode, codes, res);
break;
}
}

View File

@ -314,10 +314,8 @@ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner
HANDLE_CODE_SIZE(64);
#undef HANDLE_CODE_SIZE
default:
if (code_size % 8 == 0) {
return new IVFScanner<HammingComputerM8>(this, store_pairs);
} else if (code_size % 4 == 0) {
return new IVFScanner<HammingComputerM4>(this, store_pairs);
if (code_size % 4 == 0) {
return new IVFScanner<HammingComputerDefault>(this, store_pairs);
} else {
FAISS_THROW_MSG("not supported");
}

View File

@ -416,11 +416,8 @@ void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
default:
if (pq.code_size % 8 == 0) {
n_pass += polysemous_inner_loop<HammingComputerM8>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
} else if (pq.code_size % 4 == 0) {
n_pass += polysemous_inner_loop<HammingComputerM4>
if (pq.code_size % 4 == 0) {
n_pass += polysemous_inner_loop<HammingComputerDefault>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
} else {
FAISS_THROW_FMT(

View File

@ -7,6 +7,7 @@
namespace faiss {
extern const uint8_t hamdis_tab_ham_bytes[256];
inline BitstringWriter::BitstringWriter(uint8_t *code, size_t code_size):
code (code), code_size (code_size), i(0)
@ -214,10 +215,10 @@ struct HammingComputer64 {
};
// very inefficient...
struct HammingComputerDefault {
const uint8_t *a;
int n;
const uint8_t *a8;
int quotient8;
int remainder8;
HammingComputerDefault () {}
@ -226,19 +227,52 @@ struct HammingComputerDefault {
}
void set (const uint8_t *a8, int code_size) {
a = a8;
n = code_size;
this->a8 = a8;
quotient8 = code_size / 8;
remainder8 = code_size % 8;
}
int hamming (const uint8_t *b8) const {
int accu = 0;
for (int i = 0; i < n; i++)
accu += popcount64 (a[i] ^ b8[i]);
const uint64_t *a64 = reinterpret_cast<const uint64_t *>(a8);
const uint64_t *b64 = reinterpret_cast<const uint64_t *>(b8);
int i = 0, len = quotient8;
switch (len & 7) {
default:
while (len > 7) {
len -= 8;
accu += popcount64(a64[i] ^ b64[i]); i++;
case 7: accu += popcount64(a64[i] ^ b64[i]); i++;
case 6: accu += popcount64(a64[i] ^ b64[i]); i++;
case 5: accu += popcount64(a64[i] ^ b64[i]); i++;
case 4: accu += popcount64(a64[i] ^ b64[i]); i++;
case 3: accu += popcount64(a64[i] ^ b64[i]); i++;
case 2: accu += popcount64(a64[i] ^ b64[i]); i++;
case 1: accu += popcount64(a64[i] ^ b64[i]); i++;
}
}
if (remainder8) {
const uint8_t *a = a8 + 8 * quotient8;
const uint8_t *b = b8 + 8 * quotient8;
switch (remainder8) {
case 7: accu += hamdis_tab_ham_bytes[a[6] ^ b[6]];
case 6: accu += hamdis_tab_ham_bytes[a[5] ^ b[5]];
case 5: accu += hamdis_tab_ham_bytes[a[4] ^ b[4]];
case 4: accu += hamdis_tab_ham_bytes[a[3] ^ b[3]];
case 3: accu += hamdis_tab_ham_bytes[a[2] ^ b[2]];
case 2: accu += hamdis_tab_ham_bytes[a[1] ^ b[1]];
case 1: accu += hamdis_tab_ham_bytes[a[0] ^ b[0]];
default: break;
}
}
return accu;
}
};
// more inefficient than HammingComputerDefault (obsolete)
struct HammingComputerM8 {
const uint64_t *a;
int n;
@ -265,7 +299,7 @@ struct HammingComputerM8 {
};
// even more inefficient!
// more inefficient than HammingComputerDefault (obsolete)
struct HammingComputerM4 {
const uint32_t *a;
int n;
@ -298,9 +332,9 @@ struct HammingComputerM4 {
// default template
template<int CODE_SIZE>
struct HammingComputer: HammingComputerM8 {
struct HammingComputer: HammingComputerDefault {
HammingComputer (const uint8_t *a, int code_size):
HammingComputerM8(a, code_size) {}
HammingComputerDefault(a, code_size) {}
};
#define SPECIALIZED_HC(CODE_SIZE) \

View File

@ -44,7 +44,7 @@ namespace faiss {
size_t hamming_batch_size = 65536;
static const uint8_t hamdis_tab_ham_bytes[256] = {
const uint8_t hamdis_tab_ham_bytes[256] = {
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
@ -579,14 +579,9 @@ void hammings_knn_hc (
(32, ha, a, b, nb, order, true);
break;
default:
if(ncodes % 8 == 0) {
hammings_knn_hc<faiss::HammingComputerM8>
(ncodes, ha, a, b, nb, order, true);
} else {
hammings_knn_hc<faiss::HammingComputerDefault>
(ncodes, ha, a, b, nb, order, true);
}
hammings_knn_hc<faiss::HammingComputerDefault>
(ncodes, ha, a, b, nb, order, true);
break;
}
}
@ -624,15 +619,10 @@ void hammings_knn_mc(
);
break;
default:
if(ncodes % 8 == 0) {
hammings_knn_mc<faiss::HammingComputerM8>(
ncodes, a, b, na, nb, k, distances, labels
);
} else {
hammings_knn_mc<faiss::HammingComputerDefault>(
ncodes, a, b, na, nb, k, distances, labels
);
}
hammings_knn_mc<faiss::HammingComputerDefault>(
ncodes, a, b, na, nb, k, distances, labels
);
break;
}
}
template <class HammingComputer>
@ -686,12 +676,7 @@ void hamming_range_search (
case 8: HC(HammingComputer8); break;
case 16: HC(HammingComputer16); break;
case 32: HC(HammingComputer32); break;
default:
if (code_size % 8 == 0) {
HC(HammingComputerM8);
} else {
HC(HammingComputerDefault);
}
default: HC(HammingComputerDefault); break;
}
#undef HC
}