1 #ifndef GRADIENT_MATRICES_HPP
2 #define GRADIENT_MATRICES_HPP
6 extern "C" void get_grad_mat_data(
int ind,
int* m,
int* n,
int* w,
int* nnz,
int* is_csr_int,
int** csr_ridx_or_eindex,
int** csr_cidx_or_nelement,
double** values);
17 int m, n, w, nnz, is_csr_int;
18 int* csr_ridx_or_eindex;
19 int* csr_cidx_or_nelement;
24 get_grad_mat_data(X, &m, &n, &w, &nnz, &is_csr_int, &csr_ridx_or_eindex, &csr_cidx_or_nelement, &values);
25 gradx =
Matrix<Device>(m, n, w, nnz, (is_csr_int==1), csr_ridx_or_eindex, csr_cidx_or_nelement, values);
28 get_grad_mat_data(Y, &m, &n, &w, &nnz, &is_csr_int, &csr_ridx_or_eindex, &csr_cidx_or_nelement, &values);
29 grady =
Matrix<Device>(m, n, w, nnz, (is_csr_int==1), csr_ridx_or_eindex, csr_cidx_or_nelement, values);
33 template<
class Device2>
44 :
gradx(nnode, nnode, matrix_width),
45 grady(nnode, nnode, matrix_width){}
Definition: gradient_matrices.hpp:9
void get_grad_mat_data(int ind, int *m, int *n, int *w, int *nnz, int *is_csr_int, int **csr_ridx_or_eindex, int **csr_cidx_or_nelement, double **values)
GradientMatrices(int nnode, int matrix_width)
Definition: gradient_matrices.hpp:43
Matrix< Device > gradx
Definition: gradient_matrices.hpp:10
Definition: matrix.hpp:11
GradientMatrices(bool use)
Definition: gradient_matrices.hpp:16
GradientMatrices< Device2 > mirror() const
Definition: gradient_matrices.hpp:34
Matrix< Device > grady
Definition: gradient_matrices.hpp:11
GradientMatrices()
Definition: gradient_matrices.hpp:13