Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_blas.cpp
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 #include <cstdio>
11 #include <cstdlib>
12 
13 #undef FINTEGER
14 #define FINTEGER long
15 
16 
17 extern "C" {
18 
19 /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
20 
21 int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
22  n, FINTEGER *k, const float *alpha, const float *a,
23  FINTEGER *lda, const float *b, FINTEGER *
24  ldb, float *beta, float *c, FINTEGER *ldc);
25 
26 /* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
27 
28 int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
29  float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
30 
31 }
32 
33 float *new_random_vec(int size)
34 {
35  float *x = new float[size];
36  for (int i = 0; i < size; i++)
37  x[i] = drand48();
38  return x;
39 }
40 
41 
42 int main() {
43 
44  FINTEGER m = 10, n = 20, k = 30;
45  float *a = new_random_vec(m * k), *b = new_random_vec(n * k), *c = new float[n * m];
46  float one = 1.0, zero = 0.0;
47 
48  printf("BLAS test\n");
49 
50  sgemm_("Not transposed", "Not transposed",
51  &m, &n, &k, &one, a, &m, b, &k, &zero, c, &m);
52 
53  printf("errors=\n");
54 
55  for (int i = 0; i < m; i++) {
56  for (int j = 0; j < n; j++) {
57  float accu = 0;
58  for (int l = 0; l < k; l++)
59  accu += a[i + l * m] * b[l + j * k];
60  printf ("%6.3f ", accu - c[i + j * m]);
61  }
62  printf("\n");
63  }
64 
65  long info = 0x64bL << 32;
66  long mi = 0x64bL << 32 | m;
67  float *tau = new float[m];
68  FINTEGER lwork = -1;
69 
70  float work1;
71 
72  printf("Intentional Lapack error (appears only for 64-bit INTEGER):\n");
73  sgeqrf_ (&mi, &n, c, &m, tau, &work1, &lwork, (FINTEGER*)&info);
74 
75  // sgeqrf_ (&m, &n, c, &zeroi, tau, &work1, &lwork, (FINTEGER*)&info);
76  printf("info=%016lx\n", info);
77 
78  if(info >> 32 == 0x64b) {
79  printf("Lapack uses 32-bit integers\n");
80  } else {
81  printf("Lapack uses 64-bit integers\n");
82  }
83 
84 
85  return 0;
86 }