mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
various bugfixes from github issues kmean with some frozen centroids GPU better tiling for large flat datasets default AVX for vector ops
591 lines
69 KiB
HTML
591 lines
69 KiB
HTML
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
|
<html xmlns="http://www.w3.org/1999/xhtml">
|
|
<head>
|
|
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
|
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
|
<meta name="generator" content="Doxygen 1.8.5"/>
|
|
<title>Faiss: /data/users/matthijs/github_faiss/faiss/gpu/impl/Distance.cu Source File</title>
|
|
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
|
<script type="text/javascript" src="jquery.js"></script>
|
|
<script type="text/javascript" src="dynsections.js"></script>
|
|
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
|
<script type="text/javascript" src="search/search.js"></script>
|
|
<script type="text/javascript">
|
|
$(document).ready(function() { searchBox.OnSelectItem(0); });
|
|
</script>
|
|
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
|
</head>
|
|
<body>
|
|
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
|
<div id="titlearea">
|
|
<table cellspacing="0" cellpadding="0">
|
|
<tbody>
|
|
<tr style="height: 56px;">
|
|
<td style="padding-left: 0.5em;">
|
|
<div id="projectname">Faiss
|
|
</div>
|
|
</td>
|
|
</tr>
|
|
</tbody>
|
|
</table>
|
|
</div>
|
|
<!-- end header part -->
|
|
<!-- Generated by Doxygen 1.8.5 -->
|
|
<script type="text/javascript">
|
|
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
|
</script>
|
|
<div id="navrow1" class="tabs">
|
|
<ul class="tablist">
|
|
<li><a href="index.html"><span>Main Page</span></a></li>
|
|
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
|
|
<li><a href="annotated.html"><span>Classes</span></a></li>
|
|
<li class="current"><a href="files.html"><span>Files</span></a></li>
|
|
<li>
|
|
<div id="MSearchBox" class="MSearchBoxInactive">
|
|
<span class="left">
|
|
<img id="MSearchSelect" src="search/mag_sel.png"
|
|
onmouseover="return searchBox.OnSearchSelectShow()"
|
|
onmouseout="return searchBox.OnSearchSelectHide()"
|
|
alt=""/>
|
|
<input type="text" id="MSearchField" value="Search" accesskey="S"
|
|
onfocus="searchBox.OnSearchFieldFocus(true)"
|
|
onblur="searchBox.OnSearchFieldFocus(false)"
|
|
onkeyup="searchBox.OnSearchFieldChange(event)"/>
|
|
</span><span class="right">
|
|
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
|
|
</span>
|
|
</div>
|
|
</li>
|
|
</ul>
|
|
</div>
|
|
<div id="navrow2" class="tabs2">
|
|
<ul class="tablist">
|
|
<li><a href="files.html"><span>File List</span></a></li>
|
|
</ul>
|
|
</div>
|
|
<!-- window showing the filter options -->
|
|
<div id="MSearchSelectWindow"
|
|
onmouseover="return searchBox.OnSearchSelectShow()"
|
|
onmouseout="return searchBox.OnSearchSelectHide()"
|
|
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
|
<a class="SelectItem" href="javascript:void(0)" onclick="searchBox.OnSelectItem(0)"><span class="SelectionMark"> </span>All</a><a class="SelectItem" href="javascript:void(0)" onclick="searchBox.OnSelectItem(1)"><span class="SelectionMark"> </span>Classes</a><a class="SelectItem" href="javascript:void(0)" onclick="searchBox.OnSelectItem(2)"><span class="SelectionMark"> </span>Namespaces</a><a class="SelectItem" href="javascript:void(0)" onclick="searchBox.OnSelectItem(3)"><span class="SelectionMark"> </span>Functions</a><a class="SelectItem" href="javascript:void(0)" onclick="searchBox.OnSelectItem(4)"><span class="SelectionMark"> </span>Variables</a><a class="SelectItem" href="javascript:void(0)" onclick="searchBox.OnSelectItem(5)"><span class="SelectionMark"> </span>Typedefs</a><a class="SelectItem" href="javascript:void(0)" onclick="searchBox.OnSelectItem(6)"><span class="SelectionMark"> </span>Enumerations</a><a class="SelectItem" href="javascript:void(0)" onclick="searchBox.OnSelectItem(7)"><span class="SelectionMark"> </span>Enumerator</a><a class="SelectItem" href="javascript:void(0)" onclick="searchBox.OnSelectItem(8)"><span class="SelectionMark"> </span>Friends</a></div>
|
|
|
|
<!-- iframe showing the search results (closed by default) -->
|
|
<div id="MSearchResultsWindow">
|
|
<iframe src="javascript:void(0)" frameborder="0"
|
|
name="MSearchResults" id="MSearchResults">
|
|
</iframe>
|
|
</div>
|
|
|
|
<div id="nav-path" class="navpath">
|
|
<ul>
|
|
<li class="navelem"><a class="el" href="dir_6b3ae6988449b0834e9596fad5d75199.html">gpu</a></li><li class="navelem"><a class="el" href="dir_49d1182a3b8dfb62757c53ae905481ad.html">impl</a></li> </ul>
|
|
</div>
|
|
</div><!-- top -->
|
|
<div class="header">
|
|
<div class="headertitle">
|
|
<div class="title">Distance.cu</div> </div>
|
|
</div><!--header-->
|
|
<div class="contents">
|
|
<div class="fragment"><div class="line"><a name="l00001"></a><span class="lineno"> 1</span> <span class="comment">/**</span></div>
|
|
<div class="line"><a name="l00002"></a><span class="lineno"> 2</span> <span class="comment"> * Copyright (c) 2015-present, Facebook, Inc.</span></div>
|
|
<div class="line"><a name="l00003"></a><span class="lineno"> 3</span> <span class="comment"> * All rights reserved.</span></div>
|
|
<div class="line"><a name="l00004"></a><span class="lineno"> 4</span> <span class="comment"> *</span></div>
|
|
<div class="line"><a name="l00005"></a><span class="lineno"> 5</span> <span class="comment"> * This source code is licensed under the BSD+Patents license found in the</span></div>
|
|
<div class="line"><a name="l00006"></a><span class="lineno"> 6</span> <span class="comment"> * LICENSE file in the root directory of this source tree.</span></div>
|
|
<div class="line"><a name="l00007"></a><span class="lineno"> 7</span> <span class="comment"> */</span></div>
|
|
<div class="line"><a name="l00008"></a><span class="lineno"> 8</span> </div>
|
|
<div class="line"><a name="l00009"></a><span class="lineno"> 9</span> <span class="comment">// Copyright 2004-present Facebook. All Rights Reserved.</span></div>
|
|
<div class="line"><a name="l00010"></a><span class="lineno"> 10</span> </div>
|
|
<div class="line"><a name="l00011"></a><span class="lineno"> 11</span> <span class="preprocessor">#include "Distance.cuh"</span></div>
|
|
<div class="line"><a name="l00012"></a><span class="lineno"> 12</span> <span class="preprocessor">#include "BroadcastSum.cuh"</span></div>
|
|
<div class="line"><a name="l00013"></a><span class="lineno"> 13</span> <span class="preprocessor">#include "L2Norm.cuh"</span></div>
|
|
<div class="line"><a name="l00014"></a><span class="lineno"> 14</span> <span class="preprocessor">#include "L2Select.cuh"</span></div>
|
|
<div class="line"><a name="l00015"></a><span class="lineno"> 15</span> <span class="preprocessor">#include "../../FaissAssert.h"</span></div>
|
|
<div class="line"><a name="l00016"></a><span class="lineno"> 16</span> <span class="preprocessor">#include "../GpuResources.h"</span></div>
|
|
<div class="line"><a name="l00017"></a><span class="lineno"> 17</span> <span class="preprocessor">#include "../utils/DeviceUtils.h"</span></div>
|
|
<div class="line"><a name="l00018"></a><span class="lineno"> 18</span> <span class="preprocessor">#include "../utils/Limits.cuh"</span></div>
|
|
<div class="line"><a name="l00019"></a><span class="lineno"> 19</span> <span class="preprocessor">#include "../utils/MatrixMult.cuh"</span></div>
|
|
<div class="line"><a name="l00020"></a><span class="lineno"> 20</span> <span class="preprocessor">#include "../utils/BlockSelectKernel.cuh"</span></div>
|
|
<div class="line"><a name="l00021"></a><span class="lineno"> 21</span> </div>
|
|
<div class="line"><a name="l00022"></a><span class="lineno"> 22</span> <span class="preprocessor">#include <memory></span></div>
|
|
<div class="line"><a name="l00023"></a><span class="lineno"> 23</span> <span class="preprocessor">#include <thrust/fill.h></span></div>
|
|
<div class="line"><a name="l00024"></a><span class="lineno"> 24</span> <span class="preprocessor">#include <thrust/for_each.h></span></div>
|
|
<div class="line"><a name="l00025"></a><span class="lineno"> 25</span> <span class="preprocessor">#include <thrust/device_ptr.h></span></div>
|
|
<div class="line"><a name="l00026"></a><span class="lineno"> 26</span> <span class="preprocessor">#include <thrust/execution_policy.h></span></div>
|
|
<div class="line"><a name="l00027"></a><span class="lineno"> 27</span> </div>
|
|
<div class="line"><a name="l00028"></a><span class="lineno"> 28</span> <span class="keyword">namespace </span>faiss { <span class="keyword">namespace </span>gpu {</div>
|
|
<div class="line"><a name="l00029"></a><span class="lineno"> 29</span> </div>
|
|
<div class="line"><a name="l00030"></a><span class="lineno"> 30</span> <span class="keyword">namespace </span>{</div>
|
|
<div class="line"><a name="l00031"></a><span class="lineno"> 31</span> </div>
|
|
<div class="line"><a name="l00032"></a><span class="lineno"> 32</span> <span class="keyword">template</span> <<span class="keyword">typename</span> T></div>
|
|
<div class="line"><a name="l00033"></a><span class="lineno"> 33</span> Tensor<T, 2, true> sliceCentroids(Tensor<T, 2, true>& centroids,</div>
|
|
<div class="line"><a name="l00034"></a><span class="lineno"> 34</span>  Tensor<T, 2, true>* centroidsTransposed,</div>
|
|
<div class="line"><a name="l00035"></a><span class="lineno"> 35</span>  <span class="keywordtype">int</span> startCentroid,</div>
|
|
<div class="line"><a name="l00036"></a><span class="lineno"> 36</span>  <span class="keywordtype">int</span> num) {</div>
|
|
<div class="line"><a name="l00037"></a><span class="lineno"> 37</span>  <span class="keywordflow">if</span> (startCentroid == 0 && num == centroids.getSize(0)) {</div>
|
|
<div class="line"><a name="l00038"></a><span class="lineno"> 38</span>  <span class="keywordflow">if</span> (centroidsTransposed) {</div>
|
|
<div class="line"><a name="l00039"></a><span class="lineno"> 39</span>  <span class="keywordflow">return</span> *centroidsTransposed;</div>
|
|
<div class="line"><a name="l00040"></a><span class="lineno"> 40</span>  } <span class="keywordflow">else</span> {</div>
|
|
<div class="line"><a name="l00041"></a><span class="lineno"> 41</span>  <span class="keywordflow">return</span> centroids;</div>
|
|
<div class="line"><a name="l00042"></a><span class="lineno"> 42</span>  }</div>
|
|
<div class="line"><a name="l00043"></a><span class="lineno"> 43</span>  }</div>
|
|
<div class="line"><a name="l00044"></a><span class="lineno"> 44</span> </div>
|
|
<div class="line"><a name="l00045"></a><span class="lineno"> 45</span>  <span class="keywordflow">if</span> (centroidsTransposed) {</div>
|
|
<div class="line"><a name="l00046"></a><span class="lineno"> 46</span>  <span class="comment">// (dim, num)</span></div>
|
|
<div class="line"><a name="l00047"></a><span class="lineno"> 47</span>  <span class="keywordflow">return</span> centroidsTransposed->narrow(1, startCentroid, num);</div>
|
|
<div class="line"><a name="l00048"></a><span class="lineno"> 48</span>  } <span class="keywordflow">else</span> {</div>
|
|
<div class="line"><a name="l00049"></a><span class="lineno"> 49</span>  <span class="keywordflow">return</span> centroids.narrow(0, startCentroid, num);</div>
|
|
<div class="line"><a name="l00050"></a><span class="lineno"> 50</span>  }</div>
|
|
<div class="line"><a name="l00051"></a><span class="lineno"> 51</span> }</div>
|
|
<div class="line"><a name="l00052"></a><span class="lineno"> 52</span> </div>
|
|
<div class="line"><a name="l00053"></a><span class="lineno"> 53</span> <span class="comment">// For each chunk of k indices, increment the index by chunk * increment</span></div>
|
|
<div class="line"><a name="l00054"></a><span class="lineno"> 54</span> <span class="keyword">template</span> <<span class="keyword">typename</span> T></div>
|
|
<div class="line"><a name="l00055"></a><span class="lineno"> 55</span> __global__ <span class="keywordtype">void</span> incrementIndex(Tensor<T, 2, true> indices,</div>
|
|
<div class="line"><a name="l00056"></a><span class="lineno"> 56</span>  <span class="keywordtype">int</span> k,</div>
|
|
<div class="line"><a name="l00057"></a><span class="lineno"> 57</span>  <span class="keywordtype">int</span> increment) {</div>
|
|
<div class="line"><a name="l00058"></a><span class="lineno"> 58</span>  <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = threadIdx.x; i < k; i += blockDim.x) {</div>
|
|
<div class="line"><a name="l00059"></a><span class="lineno"> 59</span>  indices[blockIdx.y][blockIdx.x * k + i] += blockIdx.x * increment;</div>
|
|
<div class="line"><a name="l00060"></a><span class="lineno"> 60</span>  }</div>
|
|
<div class="line"><a name="l00061"></a><span class="lineno"> 61</span> }</div>
|
|
<div class="line"><a name="l00062"></a><span class="lineno"> 62</span> </div>
|
|
<div class="line"><a name="l00063"></a><span class="lineno"> 63</span> <span class="comment">// Used to update result indices in distance computation where the number of</span></div>
|
|
<div class="line"><a name="l00064"></a><span class="lineno"> 64</span> <span class="comment">// centroids is high, and is tiled</span></div>
|
|
<div class="line"><a name="l00065"></a><span class="lineno"> 65</span> <span class="keyword">template</span> <<span class="keyword">typename</span> T></div>
|
|
<div class="line"><a name="l00066"></a><span class="lineno"> 66</span> <span class="keywordtype">void</span> runIncrementIndex(Tensor<T, 2, true>& indices,</div>
|
|
<div class="line"><a name="l00067"></a><span class="lineno"> 67</span>  <span class="keywordtype">int</span> k,</div>
|
|
<div class="line"><a name="l00068"></a><span class="lineno"> 68</span>  <span class="keywordtype">int</span> increment,</div>
|
|
<div class="line"><a name="l00069"></a><span class="lineno"> 69</span>  cudaStream_t stream) {</div>
|
|
<div class="line"><a name="l00070"></a><span class="lineno"> 70</span>  dim3 grid(indices.getSize(1) / k, indices.getSize(0));</div>
|
|
<div class="line"><a name="l00071"></a><span class="lineno"> 71</span>  <span class="keywordtype">int</span> block = std::min(k, 512);</div>
|
|
<div class="line"><a name="l00072"></a><span class="lineno"> 72</span> </div>
|
|
<div class="line"><a name="l00073"></a><span class="lineno"> 73</span>  <span class="comment">// should be exact</span></div>
|
|
<div class="line"><a name="l00074"></a><span class="lineno"> 74</span>  FAISS_ASSERT(grid.x * k == indices.getSize(1));</div>
|
|
<div class="line"><a name="l00075"></a><span class="lineno"> 75</span> </div>
|
|
<div class="line"><a name="l00076"></a><span class="lineno"> 76</span>  incrementIndex<<<grid, block, 0, stream>>>(indices, k, increment);</div>
|
|
<div class="line"><a name="l00077"></a><span class="lineno"> 77</span> </div>
|
|
<div class="line"><a name="l00078"></a><span class="lineno"> 78</span>  cudaDeviceSynchronize();</div>
|
|
<div class="line"><a name="l00079"></a><span class="lineno"> 79</span> }</div>
|
|
<div class="line"><a name="l00080"></a><span class="lineno"> 80</span> </div>
|
|
<div class="line"><a name="l00081"></a><span class="lineno"> 81</span> <span class="comment">// If the inner size (dim) of the vectors is small, we want a larger query tile</span></div>
|
|
<div class="line"><a name="l00082"></a><span class="lineno"> 82</span> <span class="comment">// size, like 1024</span></div>
|
|
<div class="line"><a name="l00083"></a><span class="lineno"> 83</span> </div>
|
|
<div class="line"><a name="l00084"></a><span class="lineno"> 84</span> <span class="keywordtype">void</span> chooseTileSize(<span class="keywordtype">int</span> numQueries,</div>
|
|
<div class="line"><a name="l00085"></a><span class="lineno"> 85</span>  <span class="keywordtype">int</span> numCentroids,</div>
|
|
<div class="line"><a name="l00086"></a><span class="lineno"> 86</span>  <span class="keywordtype">int</span> dim,</div>
|
|
<div class="line"><a name="l00087"></a><span class="lineno"> 87</span>  <span class="keywordtype">int</span> elementSize,</div>
|
|
<div class="line"><a name="l00088"></a><span class="lineno"> 88</span>  <span class="keywordtype">size_t</span> tempMemAvailable,</div>
|
|
<div class="line"><a name="l00089"></a><span class="lineno"> 89</span>  <span class="keywordtype">int</span>& tileRows,</div>
|
|
<div class="line"><a name="l00090"></a><span class="lineno"> 90</span>  <span class="keywordtype">int</span>& tileCols) {</div>
|
|
<div class="line"><a name="l00091"></a><span class="lineno"> 91</span>  <span class="comment">// The matrix multiplication should be large enough to be efficient, but if it</span></div>
|
|
<div class="line"><a name="l00092"></a><span class="lineno"> 92</span>  <span class="comment">// is too large, we seem to lose efficiency as opposed to double-streaming.</span></div>
|
|
<div class="line"><a name="l00093"></a><span class="lineno"> 93</span>  <span class="comment">// Each tile size here defines 1/2 of the memory use due to double streaming.</span></div>
|
|
<div class="line"><a name="l00094"></a><span class="lineno"> 94</span>  <span class="comment">// We ignore available temporary memory, as that is adjusted independently by</span></div>
|
|
<div class="line"><a name="l00095"></a><span class="lineno"> 95</span>  <span class="comment">// the user and can thus meet these requirements (or not).</span></div>
|
|
<div class="line"><a name="l00096"></a><span class="lineno"> 96</span>  <span class="comment">// For <= 4 GB GPUs, prefer 512 MB of usage.</span></div>
|
|
<div class="line"><a name="l00097"></a><span class="lineno"> 97</span>  <span class="comment">// For <= 8 GB GPUs, prefer 768 MB of usage.</span></div>
|
|
<div class="line"><a name="l00098"></a><span class="lineno"> 98</span>  <span class="comment">// Otherwise, prefer 1 GB of usage.</span></div>
|
|
<div class="line"><a name="l00099"></a><span class="lineno"> 99</span>  <span class="keyword">auto</span> totalMem = getCurrentDeviceProperties().totalGlobalMem;</div>
|
|
<div class="line"><a name="l00100"></a><span class="lineno"> 100</span> </div>
|
|
<div class="line"><a name="l00101"></a><span class="lineno"> 101</span>  <span class="keywordtype">int</span> targetUsage = 0;</div>
|
|
<div class="line"><a name="l00102"></a><span class="lineno"> 102</span> </div>
|
|
<div class="line"><a name="l00103"></a><span class="lineno"> 103</span>  <span class="keywordflow">if</span> (totalMem <= ((<span class="keywordtype">size_t</span>) 4) * 1024 * 1024 * 1024) {</div>
|
|
<div class="line"><a name="l00104"></a><span class="lineno"> 104</span>  targetUsage = 512 * 1024 * 1024;</div>
|
|
<div class="line"><a name="l00105"></a><span class="lineno"> 105</span>  } <span class="keywordflow">else</span> <span class="keywordflow">if</span> (totalMem <= ((<span class="keywordtype">size_t</span>) 8) * 1024 * 1024 * 1024) {</div>
|
|
<div class="line"><a name="l00106"></a><span class="lineno"> 106</span>  targetUsage = 768 * 1024 * 1024;</div>
|
|
<div class="line"><a name="l00107"></a><span class="lineno"> 107</span>  } <span class="keywordflow">else</span> {</div>
|
|
<div class="line"><a name="l00108"></a><span class="lineno"> 108</span>  targetUsage = 1024 * 1024 * 1024;</div>
|
|
<div class="line"><a name="l00109"></a><span class="lineno"> 109</span>  }</div>
|
|
<div class="line"><a name="l00110"></a><span class="lineno"> 110</span> </div>
|
|
<div class="line"><a name="l00111"></a><span class="lineno"> 111</span>  targetUsage /= 2 * elementSize;</div>
|
|
<div class="line"><a name="l00112"></a><span class="lineno"> 112</span> </div>
|
|
<div class="line"><a name="l00113"></a><span class="lineno"> 113</span>  <span class="comment">// 512 seems to be a batch size sweetspot for float32.</span></div>
|
|
<div class="line"><a name="l00114"></a><span class="lineno"> 114</span>  <span class="comment">// If we are on float16, increase to 512.</span></div>
|
|
<div class="line"><a name="l00115"></a><span class="lineno"> 115</span>  <span class="comment">// If the k size (vec dim) of the matrix multiplication is small (<= 32),</span></div>
|
|
<div class="line"><a name="l00116"></a><span class="lineno"> 116</span>  <span class="comment">// increase to 1024.</span></div>
|
|
<div class="line"><a name="l00117"></a><span class="lineno"> 117</span>  <span class="keywordtype">int</span> preferredTileRows = 512;</div>
|
|
<div class="line"><a name="l00118"></a><span class="lineno"> 118</span>  <span class="keywordflow">if</span> (dim <= 32) {</div>
|
|
<div class="line"><a name="l00119"></a><span class="lineno"> 119</span>  preferredTileRows = 1024;</div>
|
|
<div class="line"><a name="l00120"></a><span class="lineno"> 120</span>  }</div>
|
|
<div class="line"><a name="l00121"></a><span class="lineno"> 121</span> </div>
|
|
<div class="line"><a name="l00122"></a><span class="lineno"> 122</span>  tileRows = std::min(preferredTileRows, numQueries);</div>
|
|
<div class="line"><a name="l00123"></a><span class="lineno"> 123</span> </div>
|
|
<div class="line"><a name="l00124"></a><span class="lineno"> 124</span>  <span class="comment">// tileCols is the remainder size</span></div>
|
|
<div class="line"><a name="l00125"></a><span class="lineno"> 125</span>  tileCols = std::min(targetUsage / preferredTileRows, numCentroids);</div>
|
|
<div class="line"><a name="l00126"></a><span class="lineno"> 126</span> }</div>
|
|
<div class="line"><a name="l00127"></a><span class="lineno"> 127</span> </div>
|
|
<div class="line"><a name="l00128"></a><span class="lineno"> 128</span> }</div>
|
|
<div class="line"><a name="l00129"></a><span class="lineno"> 129</span> </div>
|
|
<div class="line"><a name="l00130"></a><span class="lineno"> 130</span> <span class="keyword">template</span> <<span class="keyword">typename</span> T></div>
|
|
<div class="line"><a name="l00131"></a><span class="lineno"> 131</span> <span class="keywordtype">void</span> runDistance(<span class="keywordtype">bool</span> computeL2,</div>
|
|
<div class="line"><a name="l00132"></a><span class="lineno"> 132</span>  GpuResources* resources,</div>
|
|
<div class="line"><a name="l00133"></a><span class="lineno"> 133</span>  Tensor<T, 2, true>& centroids,</div>
|
|
<div class="line"><a name="l00134"></a><span class="lineno"> 134</span>  Tensor<T, 2, true>* centroidsTransposed,</div>
|
|
<div class="line"><a name="l00135"></a><span class="lineno"> 135</span>  Tensor<T, 1, true>* centroidNorms,</div>
|
|
<div class="line"><a name="l00136"></a><span class="lineno"> 136</span>  Tensor<T, 2, true>& queries,</div>
|
|
<div class="line"><a name="l00137"></a><span class="lineno"> 137</span>  <span class="keywordtype">int</span> k,</div>
|
|
<div class="line"><a name="l00138"></a><span class="lineno"> 138</span>  Tensor<T, 2, true>& outDistances,</div>
|
|
<div class="line"><a name="l00139"></a><span class="lineno"> 139</span>  Tensor<int, 2, true>& outIndices,</div>
|
|
<div class="line"><a name="l00140"></a><span class="lineno"> 140</span>  <span class="keywordtype">bool</span> useHgemm,</div>
|
|
<div class="line"><a name="l00141"></a><span class="lineno"> 141</span>  <span class="keywordtype">bool</span> ignoreOutDistances) {</div>
|
|
<div class="line"><a name="l00142"></a><span class="lineno"> 142</span>  FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));</div>
|
|
<div class="line"><a name="l00143"></a><span class="lineno"> 143</span>  FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));</div>
|
|
<div class="line"><a name="l00144"></a><span class="lineno"> 144</span>  FAISS_ASSERT(outDistances.getSize(1) == k);</div>
|
|
<div class="line"><a name="l00145"></a><span class="lineno"> 145</span>  FAISS_ASSERT(outIndices.getSize(1) == k);</div>
|
|
<div class="line"><a name="l00146"></a><span class="lineno"> 146</span> </div>
|
|
<div class="line"><a name="l00147"></a><span class="lineno"> 147</span>  <span class="keyword">auto</span>& mem = resources->getMemoryManagerCurrentDevice();</div>
|
|
<div class="line"><a name="l00148"></a><span class="lineno"> 148</span>  <span class="keyword">auto</span> defaultStream = resources->getDefaultStreamCurrentDevice();</div>
|
|
<div class="line"><a name="l00149"></a><span class="lineno"> 149</span> </div>
|
|
<div class="line"><a name="l00150"></a><span class="lineno"> 150</span>  <span class="comment">// If we're quering against a 0 sized set, just return empty results</span></div>
|
|
<div class="line"><a name="l00151"></a><span class="lineno"> 151</span>  <span class="keywordflow">if</span> (centroids.numElements() == 0) {</div>
|
|
<div class="line"><a name="l00152"></a><span class="lineno"> 152</span>  thrust::fill(thrust::cuda::par.on(defaultStream),</div>
|
|
<div class="line"><a name="l00153"></a><span class="lineno"> 153</span>  outDistances.data(), outDistances.end(),</div>
|
|
<div class="line"><a name="l00154"></a><span class="lineno"> 154</span>  Limits<T>::getMax());</div>
|
|
<div class="line"><a name="l00155"></a><span class="lineno"> 155</span> </div>
|
|
<div class="line"><a name="l00156"></a><span class="lineno"> 156</span>  thrust::fill(thrust::cuda::par.on(defaultStream),</div>
|
|
<div class="line"><a name="l00157"></a><span class="lineno"> 157</span>  outIndices.data(), outIndices.end(),</div>
|
|
<div class="line"><a name="l00158"></a><span class="lineno"> 158</span>  -1);</div>
|
|
<div class="line"><a name="l00159"></a><span class="lineno"> 159</span> </div>
|
|
<div class="line"><a name="l00160"></a><span class="lineno"> 160</span>  <span class="keywordflow">return</span>;</div>
|
|
<div class="line"><a name="l00161"></a><span class="lineno"> 161</span>  }</div>
|
|
<div class="line"><a name="l00162"></a><span class="lineno"> 162</span> </div>
|
|
<div class="line"><a name="l00163"></a><span class="lineno"> 163</span>  <span class="comment">// L2: If ||c||^2 is not pre-computed, calculate it</span></div>
|
|
<div class="line"><a name="l00164"></a><span class="lineno"> 164</span>  DeviceTensor<T, 1, true> cNorms;</div>
|
|
<div class="line"><a name="l00165"></a><span class="lineno"> 165</span>  <span class="keywordflow">if</span> (computeL2 && !centroidNorms) {</div>
|
|
<div class="line"><a name="l00166"></a><span class="lineno"> 166</span>  cNorms = std::move(DeviceTensor<T, 1, true>(</div>
|
|
<div class="line"><a name="l00167"></a><span class="lineno"> 167</span>  mem,</div>
|
|
<div class="line"><a name="l00168"></a><span class="lineno"> 168</span>  {centroids.getSize(0)}, defaultStream));</div>
|
|
<div class="line"><a name="l00169"></a><span class="lineno"> 169</span>  runL2Norm(centroids, cNorms, <span class="keyword">true</span>, defaultStream);</div>
|
|
<div class="line"><a name="l00170"></a><span class="lineno"> 170</span>  centroidNorms = &cNorms;</div>
|
|
<div class="line"><a name="l00171"></a><span class="lineno"> 171</span>  }</div>
|
|
<div class="line"><a name="l00172"></a><span class="lineno"> 172</span> </div>
|
|
<div class="line"><a name="l00173"></a><span class="lineno"> 173</span>  <span class="comment">//</span></div>
|
|
<div class="line"><a name="l00174"></a><span class="lineno"> 174</span>  <span class="comment">// Prepare norm vector ||q||^2; ||c||^2 is already pre-computed</span></div>
|
|
<div class="line"><a name="l00175"></a><span class="lineno"> 175</span>  <span class="comment">//</span></div>
|
|
<div class="line"><a name="l00176"></a><span class="lineno"> 176</span>  <span class="keywordtype">int</span> qNormSize[1] = {queries.getSize(0)};</div>
|
|
<div class="line"><a name="l00177"></a><span class="lineno"> 177</span>  DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);</div>
|
|
<div class="line"><a name="l00178"></a><span class="lineno"> 178</span> </div>
|
|
<div class="line"><a name="l00179"></a><span class="lineno"> 179</span>  <span class="comment">// ||q||^2</span></div>
|
|
<div class="line"><a name="l00180"></a><span class="lineno"> 180</span>  <span class="keywordflow">if</span> (computeL2) {</div>
|
|
<div class="line"><a name="l00181"></a><span class="lineno"> 181</span>  runL2Norm(queries, queryNorms, <span class="keyword">true</span>, defaultStream);</div>
|
|
<div class="line"><a name="l00182"></a><span class="lineno"> 182</span>  }</div>
|
|
<div class="line"><a name="l00183"></a><span class="lineno"> 183</span> </div>
|
|
<div class="line"><a name="l00184"></a><span class="lineno"> 184</span>  <span class="comment">// By default, aim to use up to 512 MB of memory for the processing, with both</span></div>
|
|
<div class="line"><a name="l00185"></a><span class="lineno"> 185</span>  <span class="comment">// number of queries and number of centroids being at least 512.</span></div>
|
|
<div class="line"><a name="l00186"></a><span class="lineno"> 186</span>  <span class="keywordtype">int</span> tileRows = 0;</div>
|
|
<div class="line"><a name="l00187"></a><span class="lineno"> 187</span>  <span class="keywordtype">int</span> tileCols = 0;</div>
|
|
<div class="line"><a name="l00188"></a><span class="lineno"> 188</span>  chooseTileSize(queries.getSize(0),</div>
|
|
<div class="line"><a name="l00189"></a><span class="lineno"> 189</span>  centroids.getSize(0),</div>
|
|
<div class="line"><a name="l00190"></a><span class="lineno"> 190</span>  queries.getSize(1),</div>
|
|
<div class="line"><a name="l00191"></a><span class="lineno"> 191</span>  <span class="keyword">sizeof</span>(T),</div>
|
|
<div class="line"><a name="l00192"></a><span class="lineno"> 192</span>  mem.getSizeAvailable(),</div>
|
|
<div class="line"><a name="l00193"></a><span class="lineno"> 193</span>  tileRows,</div>
|
|
<div class="line"><a name="l00194"></a><span class="lineno"> 194</span>  tileCols);</div>
|
|
<div class="line"><a name="l00195"></a><span class="lineno"> 195</span> </div>
|
|
<div class="line"><a name="l00196"></a><span class="lineno"> 196</span>  <span class="keywordtype">int</span> numColTiles = utils::divUp(centroids.getSize(0), tileCols);</div>
|
|
<div class="line"><a name="l00197"></a><span class="lineno"> 197</span> </div>
|
|
<div class="line"><a name="l00198"></a><span class="lineno"> 198</span>  FAISS_ASSERT(k <= centroids.getSize(0));</div>
|
|
<div class="line"><a name="l00199"></a><span class="lineno"> 199</span>  FAISS_ASSERT(k <= 1024); <span class="comment">// select limitation</span></div>
|
|
<div class="line"><a name="l00200"></a><span class="lineno"> 200</span> </div>
|
|
<div class="line"><a name="l00201"></a><span class="lineno"> 201</span>  <span class="comment">// Temporary output memory space we'll use</span></div>
|
|
<div class="line"><a name="l00202"></a><span class="lineno"> 202</span>  DeviceTensor<T, 2, true> distanceBuf1(</div>
|
|
<div class="line"><a name="l00203"></a><span class="lineno"> 203</span>  mem, {tileRows, tileCols}, defaultStream);</div>
|
|
<div class="line"><a name="l00204"></a><span class="lineno"> 204</span>  DeviceTensor<T, 2, true> distanceBuf2(</div>
|
|
<div class="line"><a name="l00205"></a><span class="lineno"> 205</span>  mem, {tileRows, tileCols}, defaultStream);</div>
|
|
<div class="line"><a name="l00206"></a><span class="lineno"> 206</span>  DeviceTensor<T, 2, true>* distanceBufs[2] =</div>
|
|
<div class="line"><a name="l00207"></a><span class="lineno"> 207</span>  {&distanceBuf1, &distanceBuf2};</div>
|
|
<div class="line"><a name="l00208"></a><span class="lineno"> 208</span> </div>
|
|
<div class="line"><a name="l00209"></a><span class="lineno"> 209</span>  DeviceTensor<T, 2, true> outDistanceBuf1(</div>
|
|
<div class="line"><a name="l00210"></a><span class="lineno"> 210</span>  mem, {tileRows, numColTiles * k}, defaultStream);</div>
|
|
<div class="line"><a name="l00211"></a><span class="lineno"> 211</span>  DeviceTensor<T, 2, true> outDistanceBuf2(</div>
|
|
<div class="line"><a name="l00212"></a><span class="lineno"> 212</span>  mem, {tileRows, numColTiles * k}, defaultStream);</div>
|
|
<div class="line"><a name="l00213"></a><span class="lineno"> 213</span>  DeviceTensor<T, 2, true>* outDistanceBufs[2] =</div>
|
|
<div class="line"><a name="l00214"></a><span class="lineno"> 214</span>  {&outDistanceBuf1, &outDistanceBuf2};</div>
|
|
<div class="line"><a name="l00215"></a><span class="lineno"> 215</span> </div>
|
|
<div class="line"><a name="l00216"></a><span class="lineno"> 216</span>  DeviceTensor<int, 2, true> outIndexBuf1(</div>
|
|
<div class="line"><a name="l00217"></a><span class="lineno"> 217</span>  mem, {tileRows, numColTiles * k}, defaultStream);</div>
|
|
<div class="line"><a name="l00218"></a><span class="lineno"> 218</span>  DeviceTensor<int, 2, true> outIndexBuf2(</div>
|
|
<div class="line"><a name="l00219"></a><span class="lineno"> 219</span>  mem, {tileRows, numColTiles * k}, defaultStream);</div>
|
|
<div class="line"><a name="l00220"></a><span class="lineno"> 220</span>  DeviceTensor<int, 2, true>* outIndexBufs[2] =</div>
|
|
<div class="line"><a name="l00221"></a><span class="lineno"> 221</span>  {&outIndexBuf1, &outIndexBuf2};</div>
|
|
<div class="line"><a name="l00222"></a><span class="lineno"> 222</span> </div>
|
|
<div class="line"><a name="l00223"></a><span class="lineno"> 223</span>  <span class="keyword">auto</span> streams = resources->getAlternateStreamsCurrentDevice();</div>
|
|
<div class="line"><a name="l00224"></a><span class="lineno"> 224</span>  streamWait(streams, {defaultStream});</div>
|
|
<div class="line"><a name="l00225"></a><span class="lineno"> 225</span> </div>
|
|
<div class="line"><a name="l00226"></a><span class="lineno"> 226</span>  <span class="keywordtype">int</span> curStream = 0;</div>
|
|
<div class="line"><a name="l00227"></a><span class="lineno"> 227</span> </div>
|
|
<div class="line"><a name="l00228"></a><span class="lineno"> 228</span>  <span class="comment">// Tile over the input queries</span></div>
|
|
<div class="line"><a name="l00229"></a><span class="lineno"> 229</span>  <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < queries.getSize(0); i += tileRows) {</div>
|
|
<div class="line"><a name="l00230"></a><span class="lineno"> 230</span>  <span class="keywordtype">int</span> curQuerySize = std::min(tileRows, queries.getSize(0) - i);</div>
|
|
<div class="line"><a name="l00231"></a><span class="lineno"> 231</span> </div>
|
|
<div class="line"><a name="l00232"></a><span class="lineno"> 232</span>  <span class="keyword">auto</span> outDistanceView =</div>
|
|
<div class="line"><a name="l00233"></a><span class="lineno"> 233</span>  outDistances.narrow(0, i, curQuerySize);</div>
|
|
<div class="line"><a name="l00234"></a><span class="lineno"> 234</span>  <span class="keyword">auto</span> outIndexView =</div>
|
|
<div class="line"><a name="l00235"></a><span class="lineno"> 235</span>  outIndices.narrow(0, i, curQuerySize);</div>
|
|
<div class="line"><a name="l00236"></a><span class="lineno"> 236</span> </div>
|
|
<div class="line"><a name="l00237"></a><span class="lineno"> 237</span>  <span class="keyword">auto</span> queryView =</div>
|
|
<div class="line"><a name="l00238"></a><span class="lineno"> 238</span>  queries.narrow(0, i, curQuerySize);</div>
|
|
<div class="line"><a name="l00239"></a><span class="lineno"> 239</span>  <span class="keyword">auto</span> queryNormNiew =</div>
|
|
<div class="line"><a name="l00240"></a><span class="lineno"> 240</span>  queryNorms.narrow(0, i, curQuerySize);</div>
|
|
<div class="line"><a name="l00241"></a><span class="lineno"> 241</span> </div>
|
|
<div class="line"><a name="l00242"></a><span class="lineno"> 242</span>  <span class="keyword">auto</span> outDistanceBufRowView =</div>
|
|
<div class="line"><a name="l00243"></a><span class="lineno"> 243</span>  outDistanceBufs[curStream]->narrow(0, 0, curQuerySize);</div>
|
|
<div class="line"><a name="l00244"></a><span class="lineno"> 244</span>  <span class="keyword">auto</span> outIndexBufRowView =</div>
|
|
<div class="line"><a name="l00245"></a><span class="lineno"> 245</span>  outIndexBufs[curStream]->narrow(0, 0, curQuerySize);</div>
|
|
<div class="line"><a name="l00246"></a><span class="lineno"> 246</span> </div>
|
|
<div class="line"><a name="l00247"></a><span class="lineno"> 247</span>  <span class="comment">// Tile over the centroids</span></div>
|
|
<div class="line"><a name="l00248"></a><span class="lineno"> 248</span>  <span class="keywordflow">for</span> (<span class="keywordtype">int</span> j = 0; j < centroids.getSize(0); j += tileCols) {</div>
|
|
<div class="line"><a name="l00249"></a><span class="lineno"> 249</span>  <span class="keywordtype">int</span> curCentroidSize = std::min(tileCols, centroids.getSize(0) - j);</div>
|
|
<div class="line"><a name="l00250"></a><span class="lineno"> 250</span> </div>
|
|
<div class="line"><a name="l00251"></a><span class="lineno"> 251</span>  <span class="keywordtype">int</span> curColTile = j / tileCols;</div>
|
|
<div class="line"><a name="l00252"></a><span class="lineno"> 252</span> </div>
|
|
<div class="line"><a name="l00253"></a><span class="lineno"> 253</span>  <span class="keyword">auto</span> centroidsView =</div>
|
|
<div class="line"><a name="l00254"></a><span class="lineno"> 254</span>  sliceCentroids(centroids, centroidsTransposed, j, curCentroidSize);</div>
|
|
<div class="line"><a name="l00255"></a><span class="lineno"> 255</span> </div>
|
|
<div class="line"><a name="l00256"></a><span class="lineno"> 256</span>  <span class="keyword">auto</span> distanceBufView = distanceBufs[curStream]-></div>
|
|
<div class="line"><a name="l00257"></a><span class="lineno"> 257</span>  narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize);</div>
|
|
<div class="line"><a name="l00258"></a><span class="lineno"> 258</span> </div>
|
|
<div class="line"><a name="l00259"></a><span class="lineno"> 259</span>  <span class="keyword">auto</span> outDistanceBufColView =</div>
|
|
<div class="line"><a name="l00260"></a><span class="lineno"> 260</span>  outDistanceBufRowView.narrow(1, k * curColTile, k);</div>
|
|
<div class="line"><a name="l00261"></a><span class="lineno"> 261</span>  <span class="keyword">auto</span> outIndexBufColView =</div>
|
|
<div class="line"><a name="l00262"></a><span class="lineno"> 262</span>  outIndexBufRowView.narrow(1, k * curColTile, k);</div>
|
|
<div class="line"><a name="l00263"></a><span class="lineno"> 263</span> </div>
|
|
<div class="line"><a name="l00264"></a><span class="lineno"> 264</span>  <span class="comment">// L2: distance is ||c||^2 - 2qc + ||q||^2, we compute -2qc</span></div>
|
|
<div class="line"><a name="l00265"></a><span class="lineno"> 265</span>  <span class="comment">// IP: just compute qc</span></div>
|
|
<div class="line"><a name="l00266"></a><span class="lineno"> 266</span>  <span class="comment">// (query id x dim) x (centroid id, dim)' = (query id, centroid id)</span></div>
|
|
<div class="line"><a name="l00267"></a><span class="lineno"> 267</span>  runMatrixMult(distanceBufView, <span class="keyword">false</span>,</div>
|
|
<div class="line"><a name="l00268"></a><span class="lineno"> 268</span>  queryView, <span class="keyword">false</span>,</div>
|
|
<div class="line"><a name="l00269"></a><span class="lineno"> 269</span>  centroidsView,</div>
|
|
<div class="line"><a name="l00270"></a><span class="lineno"> 270</span>  centroidsTransposed ? <span class="keyword">false</span> : <span class="keyword">true</span>,</div>
|
|
<div class="line"><a name="l00271"></a><span class="lineno"> 271</span>  computeL2 ? -2.0f : 1.0f, 0.0f, useHgemm,</div>
|
|
<div class="line"><a name="l00272"></a><span class="lineno"> 272</span>  resources->getBlasHandleCurrentDevice(),</div>
|
|
<div class="line"><a name="l00273"></a><span class="lineno"> 273</span>  streams[curStream]);</div>
|
|
<div class="line"><a name="l00274"></a><span class="lineno"> 274</span> </div>
|
|
<div class="line"><a name="l00275"></a><span class="lineno"> 275</span>  <span class="keywordflow">if</span> (computeL2) {</div>
|
|
<div class="line"><a name="l00276"></a><span class="lineno"> 276</span>  <span class="comment">// For L2 distance, we use this fused kernel that performs both</span></div>
|
|
<div class="line"><a name="l00277"></a><span class="lineno"> 277</span>  <span class="comment">// adding ||c||^2 to -2qc and k-selection, so we only need two</span></div>
|
|
<div class="line"><a name="l00278"></a><span class="lineno"> 278</span>  <span class="comment">// passes (one write by the gemm, one read here) over the huge</span></div>
|
|
<div class="line"><a name="l00279"></a><span class="lineno"> 279</span>  <span class="comment">// region of output memory</span></div>
|
|
<div class="line"><a name="l00280"></a><span class="lineno"> 280</span>  <span class="comment">//</span></div>
|
|
<div class="line"><a name="l00281"></a><span class="lineno"> 281</span>  <span class="comment">// If we aren't tiling along the number of centroids, we can perform the</span></div>
|
|
<div class="line"><a name="l00282"></a><span class="lineno"> 282</span>  <span class="comment">// output work directly</span></div>
|
|
<div class="line"><a name="l00283"></a><span class="lineno"> 283</span>  <span class="keywordflow">if</span> (tileCols == centroids.getSize(0)) {</div>
|
|
<div class="line"><a name="l00284"></a><span class="lineno"> 284</span>  <span class="comment">// Write into the final output</span></div>
|
|
<div class="line"><a name="l00285"></a><span class="lineno"> 285</span>  runL2SelectMin(distanceBufView,</div>
|
|
<div class="line"><a name="l00286"></a><span class="lineno"> 286</span>  *centroidNorms,</div>
|
|
<div class="line"><a name="l00287"></a><span class="lineno"> 287</span>  outDistanceView,</div>
|
|
<div class="line"><a name="l00288"></a><span class="lineno"> 288</span>  outIndexView,</div>
|
|
<div class="line"><a name="l00289"></a><span class="lineno"> 289</span>  k,</div>
|
|
<div class="line"><a name="l00290"></a><span class="lineno"> 290</span>  streams[curStream]);</div>
|
|
<div class="line"><a name="l00291"></a><span class="lineno"> 291</span> </div>
|
|
<div class="line"><a name="l00292"></a><span class="lineno"> 292</span>  <span class="keywordflow">if</span> (!ignoreOutDistances) {</div>
|
|
<div class="line"><a name="l00293"></a><span class="lineno"> 293</span>  <span class="comment">// expand (query id) to (query id, k) by duplicating along rows</span></div>
|
|
<div class="line"><a name="l00294"></a><span class="lineno"> 294</span>  <span class="comment">// top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)</span></div>
|
|
<div class="line"><a name="l00295"></a><span class="lineno"> 295</span>  runSumAlongRows(queryNormNiew, outDistanceView, streams[curStream]);</div>
|
|
<div class="line"><a name="l00296"></a><span class="lineno"> 296</span>  }</div>
|
|
<div class="line"><a name="l00297"></a><span class="lineno"> 297</span>  } <span class="keywordflow">else</span> {</div>
|
|
<div class="line"><a name="l00298"></a><span class="lineno"> 298</span>  <span class="keyword">auto</span> centroidNormsView =</div>
|
|
<div class="line"><a name="l00299"></a><span class="lineno"> 299</span>  centroidNorms->narrow(0, j, curCentroidSize);</div>
|
|
<div class="line"><a name="l00300"></a><span class="lineno"> 300</span> </div>
|
|
<div class="line"><a name="l00301"></a><span class="lineno"> 301</span>  <span class="comment">// Write into our intermediate output</span></div>
|
|
<div class="line"><a name="l00302"></a><span class="lineno"> 302</span>  runL2SelectMin(distanceBufView,</div>
|
|
<div class="line"><a name="l00303"></a><span class="lineno"> 303</span>  centroidNormsView,</div>
|
|
<div class="line"><a name="l00304"></a><span class="lineno"> 304</span>  outDistanceBufColView,</div>
|
|
<div class="line"><a name="l00305"></a><span class="lineno"> 305</span>  outIndexBufColView,</div>
|
|
<div class="line"><a name="l00306"></a><span class="lineno"> 306</span>  k,</div>
|
|
<div class="line"><a name="l00307"></a><span class="lineno"> 307</span>  streams[curStream]);</div>
|
|
<div class="line"><a name="l00308"></a><span class="lineno"> 308</span> </div>
|
|
<div class="line"><a name="l00309"></a><span class="lineno"> 309</span>  <span class="keywordflow">if</span> (!ignoreOutDistances) {</div>
|
|
<div class="line"><a name="l00310"></a><span class="lineno"> 310</span>  <span class="comment">// expand (query id) to (query id, k) by duplicating along rows</span></div>
|
|
<div class="line"><a name="l00311"></a><span class="lineno"> 311</span>  <span class="comment">// top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)</span></div>
|
|
<div class="line"><a name="l00312"></a><span class="lineno"> 312</span>  runSumAlongRows(queryNormNiew,</div>
|
|
<div class="line"><a name="l00313"></a><span class="lineno"> 313</span>  outDistanceBufColView,</div>
|
|
<div class="line"><a name="l00314"></a><span class="lineno"> 314</span>  streams[curStream]);</div>
|
|
<div class="line"><a name="l00315"></a><span class="lineno"> 315</span>  }</div>
|
|
<div class="line"><a name="l00316"></a><span class="lineno"> 316</span>  }</div>
|
|
<div class="line"><a name="l00317"></a><span class="lineno"> 317</span>  } <span class="keywordflow">else</span> {</div>
|
|
<div class="line"><a name="l00318"></a><span class="lineno"> 318</span>  <span class="comment">// For IP, just k-select the output for this tile</span></div>
|
|
<div class="line"><a name="l00319"></a><span class="lineno"> 319</span>  <span class="keywordflow">if</span> (tileCols == centroids.getSize(0)) {</div>
|
|
<div class="line"><a name="l00320"></a><span class="lineno"> 320</span>  <span class="comment">// Write into the final output</span></div>
|
|
<div class="line"><a name="l00321"></a><span class="lineno"> 321</span>  runBlockSelect(distanceBufView,</div>
|
|
<div class="line"><a name="l00322"></a><span class="lineno"> 322</span>  outDistanceView,</div>
|
|
<div class="line"><a name="l00323"></a><span class="lineno"> 323</span>  outIndexView,</div>
|
|
<div class="line"><a name="l00324"></a><span class="lineno"> 324</span>  <span class="keyword">true</span>, k, streams[curStream]);</div>
|
|
<div class="line"><a name="l00325"></a><span class="lineno"> 325</span>  } <span class="keywordflow">else</span> {</div>
|
|
<div class="line"><a name="l00326"></a><span class="lineno"> 326</span>  <span class="comment">// Write into the intermediate output</span></div>
|
|
<div class="line"><a name="l00327"></a><span class="lineno"> 327</span>  runBlockSelect(distanceBufView,</div>
|
|
<div class="line"><a name="l00328"></a><span class="lineno"> 328</span>  outDistanceBufColView,</div>
|
|
<div class="line"><a name="l00329"></a><span class="lineno"> 329</span>  outIndexBufColView,</div>
|
|
<div class="line"><a name="l00330"></a><span class="lineno"> 330</span>  <span class="keyword">true</span>, k, streams[curStream]);</div>
|
|
<div class="line"><a name="l00331"></a><span class="lineno"> 331</span>  }</div>
|
|
<div class="line"><a name="l00332"></a><span class="lineno"> 332</span>  }</div>
|
|
<div class="line"><a name="l00333"></a><span class="lineno"> 333</span>  }</div>
|
|
<div class="line"><a name="l00334"></a><span class="lineno"> 334</span> </div>
|
|
<div class="line"><a name="l00335"></a><span class="lineno"> 335</span>  <span class="comment">// As we're finished with processing a full set of centroids, perform the</span></div>
|
|
<div class="line"><a name="l00336"></a><span class="lineno"> 336</span>  <span class="comment">// final k-selection</span></div>
|
|
<div class="line"><a name="l00337"></a><span class="lineno"> 337</span>  <span class="keywordflow">if</span> (tileCols != centroids.getSize(0)) {</div>
|
|
<div class="line"><a name="l00338"></a><span class="lineno"> 338</span>  <span class="comment">// The indices are tile-relative; for each tile of k, we need to add</span></div>
|
|
<div class="line"><a name="l00339"></a><span class="lineno"> 339</span>  <span class="comment">// tileCols to the index</span></div>
|
|
<div class="line"><a name="l00340"></a><span class="lineno"> 340</span>  runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]);</div>
|
|
<div class="line"><a name="l00341"></a><span class="lineno"> 341</span> </div>
|
|
<div class="line"><a name="l00342"></a><span class="lineno"> 342</span>  runBlockSelectPair(outDistanceBufRowView,</div>
|
|
<div class="line"><a name="l00343"></a><span class="lineno"> 343</span>  outIndexBufRowView,</div>
|
|
<div class="line"><a name="l00344"></a><span class="lineno"> 344</span>  outDistanceView,</div>
|
|
<div class="line"><a name="l00345"></a><span class="lineno"> 345</span>  outIndexView,</div>
|
|
<div class="line"><a name="l00346"></a><span class="lineno"> 346</span>  computeL2 ? <span class="keyword">false</span> : <span class="keyword">true</span>, k, streams[curStream]);</div>
|
|
<div class="line"><a name="l00347"></a><span class="lineno"> 347</span>  }</div>
|
|
<div class="line"><a name="l00348"></a><span class="lineno"> 348</span> </div>
|
|
<div class="line"><a name="l00349"></a><span class="lineno"> 349</span>  curStream = (curStream + 1) % 2;</div>
|
|
<div class="line"><a name="l00350"></a><span class="lineno"> 350</span>  }</div>
|
|
<div class="line"><a name="l00351"></a><span class="lineno"> 351</span> </div>
|
|
<div class="line"><a name="l00352"></a><span class="lineno"> 352</span>  <span class="comment">// Have the desired ordering stream wait on the multi-stream</span></div>
|
|
<div class="line"><a name="l00353"></a><span class="lineno"> 353</span>  streamWait({defaultStream}, streams);</div>
|
|
<div class="line"><a name="l00354"></a><span class="lineno"> 354</span> }</div>
|
|
<div class="line"><a name="l00355"></a><span class="lineno"> 355</span> </div>
|
|
<div class="line"><a name="l00356"></a><span class="lineno"> 356</span> <span class="keyword">template</span> <<span class="keyword">typename</span> T></div>
|
|
<div class="line"><a name="l00357"></a><span class="lineno"> 357</span> <span class="keywordtype">void</span> runL2Distance(GpuResources* resources,</div>
|
|
<div class="line"><a name="l00358"></a><span class="lineno"> 358</span>  Tensor<T, 2, true>& centroids,</div>
|
|
<div class="line"><a name="l00359"></a><span class="lineno"> 359</span>  Tensor<T, 2, true>* centroidsTransposed,</div>
|
|
<div class="line"><a name="l00360"></a><span class="lineno"> 360</span>  Tensor<T, 1, true>* centroidNorms,</div>
|
|
<div class="line"><a name="l00361"></a><span class="lineno"> 361</span>  Tensor<T, 2, true>& queries,</div>
|
|
<div class="line"><a name="l00362"></a><span class="lineno"> 362</span>  <span class="keywordtype">int</span> k,</div>
|
|
<div class="line"><a name="l00363"></a><span class="lineno"> 363</span>  Tensor<T, 2, true>& outDistances,</div>
|
|
<div class="line"><a name="l00364"></a><span class="lineno"> 364</span>  Tensor<int, 2, true>& outIndices,</div>
|
|
<div class="line"><a name="l00365"></a><span class="lineno"> 365</span>  <span class="keywordtype">bool</span> useHgemm,</div>
|
|
<div class="line"><a name="l00366"></a><span class="lineno"> 366</span>  <span class="keywordtype">bool</span> ignoreOutDistances = <span class="keyword">false</span>) {</div>
|
|
<div class="line"><a name="l00367"></a><span class="lineno"> 367</span>  runDistance<T>(<span class="keyword">true</span>, <span class="comment">// L2</span></div>
|
|
<div class="line"><a name="l00368"></a><span class="lineno"> 368</span>  resources,</div>
|
|
<div class="line"><a name="l00369"></a><span class="lineno"> 369</span>  centroids,</div>
|
|
<div class="line"><a name="l00370"></a><span class="lineno"> 370</span>  centroidsTransposed,</div>
|
|
<div class="line"><a name="l00371"></a><span class="lineno"> 371</span>  centroidNorms,</div>
|
|
<div class="line"><a name="l00372"></a><span class="lineno"> 372</span>  queries,</div>
|
|
<div class="line"><a name="l00373"></a><span class="lineno"> 373</span>  k,</div>
|
|
<div class="line"><a name="l00374"></a><span class="lineno"> 374</span>  outDistances,</div>
|
|
<div class="line"><a name="l00375"></a><span class="lineno"> 375</span>  outIndices,</div>
|
|
<div class="line"><a name="l00376"></a><span class="lineno"> 376</span>  useHgemm,</div>
|
|
<div class="line"><a name="l00377"></a><span class="lineno"> 377</span>  ignoreOutDistances);</div>
|
|
<div class="line"><a name="l00378"></a><span class="lineno"> 378</span> }</div>
|
|
<div class="line"><a name="l00379"></a><span class="lineno"> 379</span> </div>
|
|
<div class="line"><a name="l00380"></a><span class="lineno"> 380</span> <span class="keyword">template</span> <<span class="keyword">typename</span> T></div>
|
|
<div class="line"><a name="l00381"></a><span class="lineno"> 381</span> <span class="keywordtype">void</span> runIPDistance(GpuResources* resources,</div>
|
|
<div class="line"><a name="l00382"></a><span class="lineno"> 382</span>  Tensor<T, 2, true>& centroids,</div>
|
|
<div class="line"><a name="l00383"></a><span class="lineno"> 383</span>  Tensor<T, 2, true>* centroidsTransposed,</div>
|
|
<div class="line"><a name="l00384"></a><span class="lineno"> 384</span>  Tensor<T, 2, true>& queries,</div>
|
|
<div class="line"><a name="l00385"></a><span class="lineno"> 385</span>  <span class="keywordtype">int</span> k,</div>
|
|
<div class="line"><a name="l00386"></a><span class="lineno"> 386</span>  Tensor<T, 2, true>& outDistances,</div>
|
|
<div class="line"><a name="l00387"></a><span class="lineno"> 387</span>  Tensor<int, 2, true>& outIndices,</div>
|
|
<div class="line"><a name="l00388"></a><span class="lineno"> 388</span>  <span class="keywordtype">bool</span> useHgemm) {</div>
|
|
<div class="line"><a name="l00389"></a><span class="lineno"> 389</span>  runDistance<T>(<span class="keyword">false</span>, <span class="comment">// IP</span></div>
|
|
<div class="line"><a name="l00390"></a><span class="lineno"> 390</span>  resources,</div>
|
|
<div class="line"><a name="l00391"></a><span class="lineno"> 391</span>  centroids,</div>
|
|
<div class="line"><a name="l00392"></a><span class="lineno"> 392</span>  centroidsTransposed,</div>
|
|
<div class="line"><a name="l00393"></a><span class="lineno"> 393</span>  <span class="keyword">nullptr</span>,</div>
|
|
<div class="line"><a name="l00394"></a><span class="lineno"> 394</span>  queries,</div>
|
|
<div class="line"><a name="l00395"></a><span class="lineno"> 395</span>  k,</div>
|
|
<div class="line"><a name="l00396"></a><span class="lineno"> 396</span>  outDistances,</div>
|
|
<div class="line"><a name="l00397"></a><span class="lineno"> 397</span>  outIndices,</div>
|
|
<div class="line"><a name="l00398"></a><span class="lineno"> 398</span>  useHgemm,</div>
|
|
<div class="line"><a name="l00399"></a><span class="lineno"> 399</span>  <span class="keyword">false</span>);</div>
|
|
<div class="line"><a name="l00400"></a><span class="lineno"> 400</span> }</div>
|
|
<div class="line"><a name="l00401"></a><span class="lineno"> 401</span> </div>
|
|
<div class="line"><a name="l00402"></a><span class="lineno"> 402</span> <span class="comment">//</span></div>
|
|
<div class="line"><a name="l00403"></a><span class="lineno"> 403</span> <span class="comment">// Instantiations of the distance templates</span></div>
|
|
<div class="line"><a name="l00404"></a><span class="lineno"> 404</span> <span class="comment">//</span></div>
|
|
<div class="line"><a name="l00405"></a><span class="lineno"> 405</span> </div>
|
|
<div class="line"><a name="l00406"></a><span class="lineno"> 406</span> <span class="keywordtype">void</span></div>
|
|
<div class="line"><a name="l00407"></a><span class="lineno"> 407</span> runIPDistance(GpuResources* resources,</div>
|
|
<div class="line"><a name="l00408"></a><span class="lineno"> 408</span>  Tensor<float, 2, true>& vectors,</div>
|
|
<div class="line"><a name="l00409"></a><span class="lineno"> 409</span>  Tensor<float, 2, true>* vectorsTransposed,</div>
|
|
<div class="line"><a name="l00410"></a><span class="lineno"> 410</span>  Tensor<float, 2, true>& queries,</div>
|
|
<div class="line"><a name="l00411"></a><span class="lineno"> 411</span>  <span class="keywordtype">int</span> k,</div>
|
|
<div class="line"><a name="l00412"></a><span class="lineno"> 412</span>  Tensor<float, 2, true>& outDistances,</div>
|
|
<div class="line"><a name="l00413"></a><span class="lineno"> 413</span>  Tensor<int, 2, true>& outIndices) {</div>
|
|
<div class="line"><a name="l00414"></a><span class="lineno"> 414</span>  runIPDistance<float>(resources,</div>
|
|
<div class="line"><a name="l00415"></a><span class="lineno"> 415</span>  vectors,</div>
|
|
<div class="line"><a name="l00416"></a><span class="lineno"> 416</span>  vectorsTransposed,</div>
|
|
<div class="line"><a name="l00417"></a><span class="lineno"> 417</span>  queries,</div>
|
|
<div class="line"><a name="l00418"></a><span class="lineno"> 418</span>  k,</div>
|
|
<div class="line"><a name="l00419"></a><span class="lineno"> 419</span>  outDistances,</div>
|
|
<div class="line"><a name="l00420"></a><span class="lineno"> 420</span>  outIndices,</div>
|
|
<div class="line"><a name="l00421"></a><span class="lineno"> 421</span>  <span class="keyword">false</span>);</div>
|
|
<div class="line"><a name="l00422"></a><span class="lineno"> 422</span> }</div>
|
|
<div class="line"><a name="l00423"></a><span class="lineno"> 423</span> </div>
|
|
<div class="line"><a name="l00424"></a><span class="lineno"> 424</span> <span class="preprocessor">#ifdef FAISS_USE_FLOAT16</span></div>
|
|
<div class="line"><a name="l00425"></a><span class="lineno"> 425</span> <span class="preprocessor"></span><span class="keywordtype">void</span></div>
|
|
<div class="line"><a name="l00426"></a><span class="lineno"> 426</span> runIPDistance(GpuResources* resources,</div>
|
|
<div class="line"><a name="l00427"></a><span class="lineno"> 427</span>  Tensor<half, 2, true>& vectors,</div>
|
|
<div class="line"><a name="l00428"></a><span class="lineno"> 428</span>  Tensor<half, 2, true>* vectorsTransposed,</div>
|
|
<div class="line"><a name="l00429"></a><span class="lineno"> 429</span>  Tensor<half, 2, true>& queries,</div>
|
|
<div class="line"><a name="l00430"></a><span class="lineno"> 430</span>  <span class="keywordtype">int</span> k,</div>
|
|
<div class="line"><a name="l00431"></a><span class="lineno"> 431</span>  Tensor<half, 2, true>& outDistances,</div>
|
|
<div class="line"><a name="l00432"></a><span class="lineno"> 432</span>  Tensor<int, 2, true>& outIndices,</div>
|
|
<div class="line"><a name="l00433"></a><span class="lineno"> 433</span>  <span class="keywordtype">bool</span> useHgemm) {</div>
|
|
<div class="line"><a name="l00434"></a><span class="lineno"> 434</span>  runIPDistance<half>(resources,</div>
|
|
<div class="line"><a name="l00435"></a><span class="lineno"> 435</span>  vectors,</div>
|
|
<div class="line"><a name="l00436"></a><span class="lineno"> 436</span>  vectorsTransposed,</div>
|
|
<div class="line"><a name="l00437"></a><span class="lineno"> 437</span>  queries,</div>
|
|
<div class="line"><a name="l00438"></a><span class="lineno"> 438</span>  k,</div>
|
|
<div class="line"><a name="l00439"></a><span class="lineno"> 439</span>  outDistances,</div>
|
|
<div class="line"><a name="l00440"></a><span class="lineno"> 440</span>  outIndices,</div>
|
|
<div class="line"><a name="l00441"></a><span class="lineno"> 441</span>  useHgemm);</div>
|
|
<div class="line"><a name="l00442"></a><span class="lineno"> 442</span> }</div>
|
|
<div class="line"><a name="l00443"></a><span class="lineno"> 443</span> <span class="preprocessor">#endif</span></div>
|
|
<div class="line"><a name="l00444"></a><span class="lineno"> 444</span> <span class="preprocessor"></span></div>
|
|
<div class="line"><a name="l00445"></a><span class="lineno"> 445</span> <span class="keywordtype">void</span></div>
|
|
<div class="line"><a name="l00446"></a><span class="lineno"> 446</span> runL2Distance(GpuResources* resources,</div>
|
|
<div class="line"><a name="l00447"></a><span class="lineno"> 447</span>  Tensor<float, 2, true>& vectors,</div>
|
|
<div class="line"><a name="l00448"></a><span class="lineno"> 448</span>  Tensor<float, 2, true>* vectorsTransposed,</div>
|
|
<div class="line"><a name="l00449"></a><span class="lineno"> 449</span>  Tensor<float, 1, true>* vectorNorms,</div>
|
|
<div class="line"><a name="l00450"></a><span class="lineno"> 450</span>  Tensor<float, 2, true>& queries,</div>
|
|
<div class="line"><a name="l00451"></a><span class="lineno"> 451</span>  <span class="keywordtype">int</span> k,</div>
|
|
<div class="line"><a name="l00452"></a><span class="lineno"> 452</span>  Tensor<float, 2, true>& outDistances,</div>
|
|
<div class="line"><a name="l00453"></a><span class="lineno"> 453</span>  Tensor<int, 2, true>& outIndices,</div>
|
|
<div class="line"><a name="l00454"></a><span class="lineno"> 454</span>  <span class="keywordtype">bool</span> ignoreOutDistances) {</div>
|
|
<div class="line"><a name="l00455"></a><span class="lineno"> 455</span>  runL2Distance<float>(resources,</div>
|
|
<div class="line"><a name="l00456"></a><span class="lineno"> 456</span>  vectors,</div>
|
|
<div class="line"><a name="l00457"></a><span class="lineno"> 457</span>  vectorsTransposed,</div>
|
|
<div class="line"><a name="l00458"></a><span class="lineno"> 458</span>  vectorNorms,</div>
|
|
<div class="line"><a name="l00459"></a><span class="lineno"> 459</span>  queries,</div>
|
|
<div class="line"><a name="l00460"></a><span class="lineno"> 460</span>  k,</div>
|
|
<div class="line"><a name="l00461"></a><span class="lineno"> 461</span>  outDistances,</div>
|
|
<div class="line"><a name="l00462"></a><span class="lineno"> 462</span>  outIndices,</div>
|
|
<div class="line"><a name="l00463"></a><span class="lineno"> 463</span>  <span class="keyword">false</span>,</div>
|
|
<div class="line"><a name="l00464"></a><span class="lineno"> 464</span>  ignoreOutDistances);</div>
|
|
<div class="line"><a name="l00465"></a><span class="lineno"> 465</span> }</div>
|
|
<div class="line"><a name="l00466"></a><span class="lineno"> 466</span> </div>
|
|
<div class="line"><a name="l00467"></a><span class="lineno"> 467</span> <span class="preprocessor">#ifdef FAISS_USE_FLOAT16</span></div>
|
|
<div class="line"><a name="l00468"></a><span class="lineno"> 468</span> <span class="preprocessor"></span><span class="keywordtype">void</span></div>
|
|
<div class="line"><a name="l00469"></a><span class="lineno"> 469</span> runL2Distance(GpuResources* resources,</div>
|
|
<div class="line"><a name="l00470"></a><span class="lineno"> 470</span>  Tensor<half, 2, true>& vectors,</div>
|
|
<div class="line"><a name="l00471"></a><span class="lineno"> 471</span>  Tensor<half, 2, true>* vectorsTransposed,</div>
|
|
<div class="line"><a name="l00472"></a><span class="lineno"> 472</span>  Tensor<half, 1, true>* vectorNorms,</div>
|
|
<div class="line"><a name="l00473"></a><span class="lineno"> 473</span>  Tensor<half, 2, true>& queries,</div>
|
|
<div class="line"><a name="l00474"></a><span class="lineno"> 474</span>  <span class="keywordtype">int</span> k,</div>
|
|
<div class="line"><a name="l00475"></a><span class="lineno"> 475</span>  Tensor<half, 2, true>& outDistances,</div>
|
|
<div class="line"><a name="l00476"></a><span class="lineno"> 476</span>  Tensor<int, 2, true>& outIndices,</div>
|
|
<div class="line"><a name="l00477"></a><span class="lineno"> 477</span>  <span class="keywordtype">bool</span> useHgemm,</div>
|
|
<div class="line"><a name="l00478"></a><span class="lineno"> 478</span>  <span class="keywordtype">bool</span> ignoreOutDistances) {</div>
|
|
<div class="line"><a name="l00479"></a><span class="lineno"> 479</span>  runL2Distance<half>(resources,</div>
|
|
<div class="line"><a name="l00480"></a><span class="lineno"> 480</span>  vectors,</div>
|
|
<div class="line"><a name="l00481"></a><span class="lineno"> 481</span>  vectorsTransposed,</div>
|
|
<div class="line"><a name="l00482"></a><span class="lineno"> 482</span>  vectorNorms,</div>
|
|
<div class="line"><a name="l00483"></a><span class="lineno"> 483</span>  queries,</div>
|
|
<div class="line"><a name="l00484"></a><span class="lineno"> 484</span>  k,</div>
|
|
<div class="line"><a name="l00485"></a><span class="lineno"> 485</span>  outDistances,</div>
|
|
<div class="line"><a name="l00486"></a><span class="lineno"> 486</span>  outIndices,</div>
|
|
<div class="line"><a name="l00487"></a><span class="lineno"> 487</span>  useHgemm,</div>
|
|
<div class="line"><a name="l00488"></a><span class="lineno"> 488</span>  ignoreOutDistances);</div>
|
|
<div class="line"><a name="l00489"></a><span class="lineno"> 489</span> }</div>
|
|
<div class="line"><a name="l00490"></a><span class="lineno"> 490</span> <span class="preprocessor">#endif</span></div>
|
|
<div class="line"><a name="l00491"></a><span class="lineno"> 491</span> <span class="preprocessor"></span></div>
|
|
<div class="line"><a name="l00492"></a><span class="lineno"> 492</span> } } <span class="comment">// namespace</span></div>
|
|
</div><!-- fragment --></div><!-- contents -->
|
|
<!-- start footer part -->
|
|
<hr class="footer"/><address class="footer"><small>
|
|
Generated by  <a href="http://www.doxygen.org/index.html">
|
|
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
|
</a> 1.8.5
|
|
</small></address>
|
|
</body>
|
|
</html>
|