Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_blas.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdio>
10 #include <cstdlib>
11 
12 #undef FINTEGER
13 #define FINTEGER long
14 
15 
16 extern "C" {
17 
18 /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
19 
20 int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
21  n, FINTEGER *k, const float *alpha, const float *a,
22  FINTEGER *lda, const float *b, FINTEGER *
23  ldb, float *beta, float *c, FINTEGER *ldc);
24 
25 /* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
26 
27 int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
28  float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
29 
30 }
31 
32 float *new_random_vec(int size)
33 {
34  float *x = new float[size];
35  for (int i = 0; i < size; i++)
36  x[i] = drand48();
37  return x;
38 }
39 
40 
41 int main() {
42 
43  FINTEGER m = 10, n = 20, k = 30;
44  float *a = new_random_vec(m * k), *b = new_random_vec(n * k), *c = new float[n * m];
45  float one = 1.0, zero = 0.0;
46 
47  printf("BLAS test\n");
48 
49  sgemm_("Not transposed", "Not transposed",
50  &m, &n, &k, &one, a, &m, b, &k, &zero, c, &m);
51 
52  printf("errors=\n");
53 
54  for (int i = 0; i < m; i++) {
55  for (int j = 0; j < n; j++) {
56  float accu = 0;
57  for (int l = 0; l < k; l++)
58  accu += a[i + l * m] * b[l + j * k];
59  printf ("%6.3f ", accu - c[i + j * m]);
60  }
61  printf("\n");
62  }
63 
64  long info = 0x64bL << 32;
65  long mi = 0x64bL << 32 | m;
66  float *tau = new float[m];
67  FINTEGER lwork = -1;
68 
69  float work1;
70 
71  printf("Intentional Lapack error (appears only for 64-bit INTEGER):\n");
72  sgeqrf_ (&mi, &n, c, &m, tau, &work1, &lwork, (FINTEGER*)&info);
73 
74  // sgeqrf_ (&m, &n, c, &zeroi, tau, &work1, &lwork, (FINTEGER*)&info);
75  printf("info=%016lx\n", info);
76 
77  if(info >> 32 == 0x64b) {
78  printf("Lapack uses 32-bit integers\n");
79  } else {
80  printf("Lapack uses 64-bit integers\n");
81  }
82 
83 
84  return 0;
85 }