XGCa
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
gradient_matrices.hpp
Go to the documentation of this file.
1 #ifndef GRADIENT_MATRICES_HPP
2 #define GRADIENT_MATRICES_HPP
3 
4 #include "matrix.hpp"
5 
6 extern "C" void get_grad_mat_data(int ind, int* m, int* n, int* w, int* nnz, int* is_csr_int, int** csr_ridx_or_eindex, int** csr_cidx_or_nelement, double** values);
7 
8 template<class Device>
12 
14 
15  /* Set up gradx and grady on GPU by (for now) copying from the fortran matrix */
16  GradientMatrices(bool use){
17  int m, n, w, nnz, is_csr_int;
18  int* csr_ridx_or_eindex;
19  int* csr_cidx_or_nelement;
20  double* values;
21  const int X=0;
22  const int Y=1;
23  // gradx
24  get_grad_mat_data(X, &m, &n, &w, &nnz, &is_csr_int, &csr_ridx_or_eindex, &csr_cidx_or_nelement, &values);
25  gradx = Matrix<Device>(m, n, w, nnz, (is_csr_int==1), csr_ridx_or_eindex, csr_cidx_or_nelement, values);
26 
27  // grady
28  get_grad_mat_data(Y, &m, &n, &w, &nnz, &is_csr_int, &csr_ridx_or_eindex, &csr_cidx_or_nelement, &values);
29  grady = Matrix<Device>(m, n, w, nnz, (is_csr_int==1), csr_ridx_or_eindex, csr_cidx_or_nelement, values);
30  }
31 
32  // Create a mirror with a different device type
33  template<class Device2>
36 
37  m.gradx = gradx.template mirror<Device2>();
38  m.grady = grady.template mirror<Device2>();
39 
40  return m;
41  }
42 
43  GradientMatrices(int nnode, int matrix_width)
44  : gradx(nnode, nnode, matrix_width),
45  grady(nnode, nnode, matrix_width){}
46 
47 };
48 #endif
Definition: gradient_matrices.hpp:9
void get_grad_mat_data(int ind, int *m, int *n, int *w, int *nnz, int *is_csr_int, int **csr_ridx_or_eindex, int **csr_cidx_or_nelement, double **values)
GradientMatrices(int nnode, int matrix_width)
Definition: gradient_matrices.hpp:43
Matrix< Device > gradx
Definition: gradient_matrices.hpp:10
Definition: matrix.hpp:13
GradientMatrices(bool use)
Definition: gradient_matrices.hpp:16
GradientMatrices< Device2 > mirror() const
Definition: gradient_matrices.hpp:34
Matrix< Device > grady
Definition: gradient_matrices.hpp:11
GradientMatrices()
Definition: gradient_matrices.hpp:13