32#ifndef SPARSEMATCUDA_HPP
33#define SPARSEMATCUDA_HPP
37#include <cusolverSp.h>
43 cusolverSpHandle_t handle =
nullptr;
44 cudaStream_t stream =
nullptr;
45 cusparseMatDescr_t descr =
nullptr;
47 void* d_val_idx =
nullptr;
48 void* d_col_idx =
nullptr;
49 void* d_row_ptr =
nullptr;
55 void device_dealloc()
const;
65 unique_ptr<
MetaMat<
T>> make_copy() override;
67 int direct_solve(Mat<
T>&, const Mat<
T>&) override;
71 cusolverSpCreate(&handle);
72 cudaStreamCreate(&stream);
73 cusolverSpSetStream(handle, stream);
74 cusparseCreateMatDescr(&descr);
75 cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
76 cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
77 cudaMalloc(&d_row_ptr,
sizeof(
int) * (this->n_rows + 1));
81 if(handle) cusolverSpDestroy(handle);
82 if(stream) cudaStreamDestroy(stream);
83 if(descr) cusparseDestroyMatDescr(descr);
84 if(d_row_ptr) cudaFree(d_row_ptr);
88 const size_t n_val =
sizeof(ET) * csr_mat.n_elem;
89 const size_t n_col =
sizeof(
int) * csr_mat.n_elem;
91 cudaMalloc(&d_val_idx, n_val);
92 cudaMalloc(&d_col_idx, n_col);
94 cudaMemcpyAsync(d_val_idx, csr_mat.val_mem(), n_val, cudaMemcpyHostToDevice, stream);
95 cudaMemcpyAsync(d_col_idx, csr_mat.col_mem(), n_col, cudaMemcpyHostToDevice, stream);
96 cudaMemcpyAsync(d_row_ptr, csr_mat.row_mem(),
sizeof(
int) * (csr_mat.n_rows + 1llu), cudaMemcpyHostToDevice, stream);
100 if(d_val_idx) cudaFree(d_val_idx);
101 if(d_col_idx) cudaFree(d_col_idx);
105 :
SparseMat<
T>(in_row, in_col, in_elem) { acquire(); }
118 if(!this->factored) {
124 this->factored =
true;
127 const size_t n_rhs = (std::is_same_v<T, float> ||
Precision::MIXED == this->setting.precision ?
sizeof(float) :
sizeof(
double)) * B.n_elem;
132 cudaMalloc(&d_b, n_rhs);
133 cudaMalloc(&d_x, n_rhs);
138 if(std::is_same_v<T, float>) {
139 cudaMemcpyAsync(d_b, B.memptr(), n_rhs, cudaMemcpyHostToDevice, stream);
141 for(
auto I = 0llu; I < B.n_elem; I += B.n_rows) code += cusolverSpScsrlsvqr(handle,
int(this->n_rows),
int(this->triplet_mat.n_elem), descr, (
float*)d_val_idx, (
int*)d_row_ptr, (
int*)d_col_idx, (
float*)d_b + I,
float(this->setting.tolerance), 3, (
float*)d_x + I, &singularity);
143 X.set_size(arma::size(B));
145 cudaMemcpyAsync(X.memptr(), d_x, n_rhs, cudaMemcpyDeviceToHost, stream);
147 cudaDeviceSynchronize();
150 cudaMemcpyAsync(d_b, B.memptr(), n_rhs, cudaMemcpyHostToDevice, stream);
152 for(
auto I = 0llu; I < B.n_elem; I += B.n_rows) code += cusolverSpDcsrlsvqr(handle,
int(this->n_rows),
int(this->triplet_mat.n_elem), descr, (
double*)d_val_idx, (
int*)d_row_ptr, (
int*)d_col_idx, (
double*)d_b + I, this->setting.tolerance, 3, (
double*)d_x + I, &singularity);
154 X.set_size(arma::size(B));
156 cudaMemcpyAsync(X.memptr(), d_x, n_rhs, cudaMemcpyDeviceToHost, stream);
158 cudaDeviceSynchronize();
161 X = arma::zeros(B.n_rows, B.n_cols);
163 mat full_residual = B;
165 auto multiplier =
norm(full_residual);
168 while(counter++ < this->setting.iterative_refinement) {
169 if(multiplier < this->setting.tolerance)
break;
171 auto residual = conv_to<fmat>::from(full_residual / multiplier);
173 cudaMemcpyAsync(d_b, residual.memptr(), n_rhs, cudaMemcpyHostToDevice, stream);
176 for(
auto I = 0llu; I < B.n_elem; I += B.n_rows) code += cusolverSpScsrlsvqr(handle,
int(this->n_rows),
int(this->triplet_mat.n_elem), descr, (
float*)d_val_idx, (
int*)d_row_ptr, (
int*)d_col_idx, (
float*)d_b + I,
float(this->setting.tolerance), 3, (
float*)d_x + I, &singularity);
179 cudaMemcpyAsync(residual.memptr(), d_x, n_rhs, cudaMemcpyDeviceToHost, stream);
181 cudaDeviceSynchronize();
183 const mat incre = multiplier * conv_to<mat>::from(residual);
187 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->
operator*(incre)));
191 if(d_b) cudaFree(d_b);
192 if(d_x) cudaFree(d_x);
A SparseMatCUDA class that holds matrices.
A SparseMat class that holds matrices.
Definition: SparseMat.hpp:34
Definition: suanPan.h:318
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
constexpr auto SUANPAN_SUCCESS
Definition: suanPan.h:162
constexpr auto SUANPAN_FAIL
Definition: suanPan.h:163