Move encode_fp16 and decode_fp16 into a separate entity (#2405)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2405

Move encode_fp16 and decode_fp16 out of impl/ScalarQuantizer.cpp into utils/fp_16.h. This is needed because fp16 functions might be needed elsewhere, not only in SQ code.

Reviewed By: mdouze

Differential Revision: D38428096

fbshipit-source-id: 73c9f32919b7b450827cc2394d4d083e0fff1aea
This commit is contained in:
Alexandr Guzhva 2022-08-08 08:32:33 -07:00 committed by Facebook GitHub Bot
parent 838f85cb52
commit dbc3d1d54b
5 changed files with 134 additions and 108 deletions

View File

@ -167,6 +167,9 @@ set(FAISS_HEADERS
utils/distances.h
utils/extra_distances-inl.h
utils/extra_distances.h
utils/fp16-fp16c.h
utils/fp16-inl.h
utils/fp16.h
utils/hamming-inl.h
utils/hamming.h
utils/ordered_key_value.h

View File

@ -22,6 +22,7 @@
#include <faiss/IndexIVF.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/fp16.h>
#include <faiss/utils/utils.h>
namespace faiss {
@ -202,114 +203,6 @@ struct Codec6bit {
#endif
};
#ifdef USE_F16C
uint16_t encode_fp16(float x) {
__m128 xf = _mm_set1_ps(x);
__m128i xi =
_mm_cvtps_ph(xf, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
return _mm_cvtsi128_si32(xi) & 0xffff;
}
float decode_fp16(uint16_t x) {
__m128i xi = _mm_set1_epi16(x);
__m128 xf = _mm_cvtph_ps(xi);
return _mm_cvtss_f32(xf);
}
#else
// non-intrinsic FP16 <-> FP32 code adapted from
// https://github.com/ispc/ispc/blob/master/stdlib.ispc
float floatbits(uint32_t x) {
void* xptr = &x;
return *(float*)xptr;
}
uint32_t intbits(float f) {
void* fptr = &f;
return *(uint32_t*)fptr;
}
uint16_t encode_fp16(float f) {
// via Fabian "ryg" Giesen.
// https://gist.github.com/2156668
uint32_t sign_mask = 0x80000000u;
int32_t o;
uint32_t fint = intbits(f);
uint32_t sign = fint & sign_mask;
fint ^= sign;
// NOTE all the integer compares in this function can be safely
// compiled into signed compares since all operands are below
// 0x80000000. Important if you want fast straight SSE2 code (since
// there's no unsigned PCMPGTD).
// Inf or NaN (all exponent bits set)
// NaN->qNaN and Inf->Inf
// unconditional assignment here, will override with right value for
// the regular case below.
uint32_t f32infty = 255u << 23;
o = (fint > f32infty) ? 0x7e00u : 0x7c00u;
// (De)normalized number or zero
// update fint unconditionally to save the blending; we don't need it
// anymore for the Inf/NaN case anyway.
const uint32_t round_mask = ~0xfffu;
const uint32_t magic = 15u << 23;
// Shift exponent down, denormalize if necessary.
// NOTE This represents half-float denormals using single
// precision denormals. The main reason to do this is that
// there's no shift with per-lane variable shifts in SSE*, which
// we'd otherwise need. It has some funky side effects though:
// - This conversion will actually respect the FTZ (Flush To Zero)
// flag in MXCSR - if it's set, no half-float denormals will be
// generated. I'm honestly not sure whether this is good or
// bad. It's definitely interesting.
// - If the underlying HW doesn't support denormals (not an issue
// with Intel CPUs, but might be a problem on GPUs or PS3 SPUs),
// you will always get flush-to-zero behavior. This is bad,
// unless you're on a CPU where you don't care.
// - Denormals tend to be slow. FP32 denormals are rare in
// practice outside of things like recursive filters in DSP -
// not a typical half-float application. Whether FP16 denormals
// are rare in practice, I don't know. Whatever slow path your
// HW may or may not have for denormals, this may well hit it.
float fscale = floatbits(fint & round_mask) * floatbits(magic);
fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u));
int32_t fint2 = intbits(fscale) - round_mask;
if (fint < f32infty)
o = fint2 >> 13; // Take the bits!
return (o | (sign >> 16));
}
float decode_fp16(uint16_t h) {
// https://gist.github.com/2144712
// Fabian "ryg" Giesen.
const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
int32_t exp = shifted_exp & o; // just the exponent
o += (int32_t)(127 - 15) << 23; // exponent adjust
int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
int32_t zerodenorm_val =
intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23));
int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit);
}
#endif
/*******************************************************************
* Quantizer: normalizes scalar vector components, then passes them
* through a codec

21
faiss/utils/fp16-fp16c.h Normal file
View File

@ -0,0 +1,21 @@
#pragma once
#include <immintrin.h>
#include <cstdint>
namespace faiss {
inline uint16_t encode_fp16(float x) {
__m128 xf = _mm_set1_ps(x);
__m128i xi =
_mm_cvtps_ph(xf, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
return _mm_cvtsi128_si32(xi) & 0xffff;
}
inline float decode_fp16(uint16_t x) {
__m128i xi = _mm_set1_epi16(x);
__m128 xf = _mm_cvtph_ps(xi);
return _mm_cvtss_f32(xf);
}
} // namespace faiss

100
faiss/utils/fp16-inl.h Normal file
View File

@ -0,0 +1,100 @@
#pragma once
#include <cstdint>
namespace faiss {
// non-intrinsic FP16 <-> FP32 code adapted from
// https://github.com/ispc/ispc/blob/master/stdlib.ispc
namespace {
inline float floatbits(uint32_t x) {
void* xptr = &x;
return *(float*)xptr;
}
inline uint32_t intbits(float f) {
void* fptr = &f;
return *(uint32_t*)fptr;
}
} // namespace
inline uint16_t encode_fp16(float f) {
// via Fabian "ryg" Giesen.
// https://gist.github.com/2156668
uint32_t sign_mask = 0x80000000u;
int32_t o;
uint32_t fint = intbits(f);
uint32_t sign = fint & sign_mask;
fint ^= sign;
// NOTE all the integer compares in this function can be safely
// compiled into signed compares since all operands are below
// 0x80000000. Important if you want fast straight SSE2 code (since
// there's no unsigned PCMPGTD).
// Inf or NaN (all exponent bits set)
// NaN->qNaN and Inf->Inf
// unconditional assignment here, will override with right value for
// the regular case below.
uint32_t f32infty = 255u << 23;
o = (fint > f32infty) ? 0x7e00u : 0x7c00u;
// (De)normalized number or zero
// update fint unconditionally to save the blending; we don't need it
// anymore for the Inf/NaN case anyway.
const uint32_t round_mask = ~0xfffu;
const uint32_t magic = 15u << 23;
// Shift exponent down, denormalize if necessary.
// NOTE This represents half-float denormals using single
// precision denormals. The main reason to do this is that
// there's no shift with per-lane variable shifts in SSE*, which
// we'd otherwise need. It has some funky side effects though:
// - This conversion will actually respect the FTZ (Flush To Zero)
// flag in MXCSR - if it's set, no half-float denormals will be
// generated. I'm honestly not sure whether this is good or
// bad. It's definitely interesting.
// - If the underlying HW doesn't support denormals (not an issue
// with Intel CPUs, but might be a problem on GPUs or PS3 SPUs),
// you will always get flush-to-zero behavior. This is bad,
// unless you're on a CPU where you don't care.
// - Denormals tend to be slow. FP32 denormals are rare in
// practice outside of things like recursive filters in DSP -
// not a typical half-float application. Whether FP16 denormals
// are rare in practice, I don't know. Whatever slow path your
// HW may or may not have for denormals, this may well hit it.
float fscale = floatbits(fint & round_mask) * floatbits(magic);
fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u));
int32_t fint2 = intbits(fscale) - round_mask;
if (fint < f32infty)
o = fint2 >> 13; // Take the bits!
return (o | (sign >> 16));
}
inline float decode_fp16(uint16_t h) {
// https://gist.github.com/2144712
// Fabian "ryg" Giesen.
const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
int32_t exp = shifted_exp & o; // just the exponent
o += (int32_t)(127 - 15) << 23; // exponent adjust
int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
int32_t zerodenorm_val =
intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23));
int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit);
}
} // namespace faiss

9
faiss/utils/fp16.h Normal file
View File

@ -0,0 +1,9 @@
#pragma once
#include <cstdint>
#if defined(__SSE__) && defined(USE_F16C)
#include <faiss/utils/fp16-fp16c.h>
#else
#include <faiss/utils/fp16-inl.h>
#endif