30#ifndef BANDMATCUDA_HPP
31#define BANDMATCUDA_HPP
36#include <cusolverSp.h>
41 cusolverSpHandle_t handle =
nullptr;
42 cudaStream_t stream =
nullptr;
43 cusparseMatDescr_t descr =
nullptr;
45 void* d_val_idx =
nullptr;
46 void* d_col_idx =
nullptr;
47 void* d_row_ptr =
nullptr;
49 triplet_form<float, int> s_mat{
static_cast<int>(this->n_rows),
static_cast<int>(this->n_cols),
static_cast<int>(this->n_elem)};
52 cusolverSpCreate(&handle);
53 cudaStreamCreate(&stream);
54 cusolverSpSetStream(handle, stream);
55 cusparseCreateMatDescr(&descr);
56 cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
57 cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
58 cudaMalloc(&d_row_ptr,
sizeof(
int) * (this->n_rows + 1));
61 void release()
const {
62 if(handle) cusolverSpDestroy(handle);
63 if(stream) cudaStreamDestroy(stream);
64 if(descr) cusparseDestroyMatDescr(descr);
65 if(d_row_ptr) cudaFree(d_row_ptr);
69 const size_t n_val =
sizeof(float) * csr_mat.n_elem;
70 const size_t n_col =
sizeof(
int) * csr_mat.n_elem;
72 cudaMalloc(&d_val_idx, n_val);
73 cudaMalloc(&d_col_idx, n_col);
75 cudaMemcpyAsync(d_val_idx, csr_mat.val_mem(), n_val, cudaMemcpyHostToDevice, stream);
76 cudaMemcpyAsync(d_col_idx, csr_mat.col_mem(), n_col, cudaMemcpyHostToDevice, stream);
77 cudaMemcpyAsync(d_row_ptr, csr_mat.row_mem(),
sizeof(
int) * (csr_mat.n_rows + 1llu), cudaMemcpyHostToDevice, stream);
80 void device_dealloc()
const {
81 if(d_val_idx) cudaFree(d_val_idx);
82 if(d_col_idx) cudaFree(d_col_idx);
88 int direct_solve(Mat<T>&, Mat<T>&&)
override;
91 BandMatCUDA(
const uword in_size,
const uword in_l,
const uword in_u)
92 :
BandMat<
T>(in_size, in_l, in_u) { acquire(); }
106 unique_ptr<MetaMat<T>>
make_copy()
override {
return std::make_unique<BandMatCUDA>(*
this); }
110 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw invalid_argument(
"requires a square matrix"); });
112 if(!this->factored) {
113 this->factored =
true;
118 for(
auto I = 0; I < static_cast<int>(this->n_rows); ++I)
for(
auto J = std::max(0, I -
static_cast<int>(this->u_band)); J <= std::min(static_cast<int>(this->n_rows) - 1, I +
static_cast<int>(this->l_band)); ++J) s_mat.at(J, I) =
static_cast<float>(this->at(J, I));
123 const size_t n_rhs =
sizeof(float) * B.n_elem;
128 cudaMalloc(&d_b, n_rhs);
129 cudaMalloc(&d_x, n_rhs);
131 auto INFO = this->mixed_trs(X, std::move(B), [&](fmat& residual) {
132 cudaMemcpyAsync(d_b, residual.memptr(), n_rhs, cudaMemcpyHostToDevice, stream);
137 for(
auto I = 0llu; I < residual.n_elem; I += residual.n_rows) code += cusolverSpScsrlsvqr(handle,
static_cast<int>(this->n_rows),
static_cast<int>(this->s_mat.n_elem), descr, (
float*)d_val_idx, (
int*)d_row_ptr, (
int*)d_col_idx, (
float*)d_b + I,
static_cast<float>(this->setting.tolerance), 3, (
float*)d_x + I, &singularity);
139 cudaMemcpyAsync(residual.memptr(), d_x, n_rhs, cudaMemcpyDeviceToHost, stream);
141 cudaDeviceSynchronize();
147 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
A BandMatCUDA class that holds matrices.
A BandMat class that holds matrices.
Definition: BandMat.hpp:35
unique_ptr< Material > make_copy(const shared_ptr< Material > &)
Definition: Material.cpp:370
void suanpan_assert(const std::function< void()> &F)
Definition: suanPan.h:296
#define suanpan_error(...)
Definition: suanPan.h:309