suanPan
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages Concepts
SparseMatCUDA.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
31// ReSharper disable CppCStyleCast
32#ifndef SPARSEMATCUDA_HPP
33#define SPARSEMATCUDA_HPP
34
35#include "../SparseMat.hpp"
36#include "../csr_form.hpp"
37
38#include <cusolverSp.h>
39#include <cusparse.h>
40
41template<sp_d T> class SparseMatCUDA final : public SparseMat<T> {
42 cusolverSpHandle_t handle = nullptr;
43 cudaStream_t stream = nullptr;
44 cusparseMatDescr_t descr = nullptr;
45
46 void* d_val_idx = nullptr;
47 void* d_col_idx = nullptr;
48 void* d_row_ptr = nullptr;
49
50 void acquire() {
51 cusolverSpCreate(&handle);
52 cudaStreamCreate(&stream);
53 cusolverSpSetStream(handle, stream);
54 cusparseCreateMatDescr(&descr);
55 cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
56 cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
57 cudaMalloc(&d_row_ptr, sizeof(int) * (this->n_rows + 1));
58 }
59
60 void release() const {
61 if(handle) cusolverSpDestroy(handle);
62 if(stream) cudaStreamDestroy(stream);
63 if(descr) cusparseDestroyMatDescr(descr);
64 if(d_row_ptr) cudaFree(d_row_ptr);
65 }
66
67 template<sp_d ET> void device_alloc(csr_form<ET, int>&& csr_mat) {
68 const size_t n_val = sizeof(ET) * csr_mat.n_elem;
69 const size_t n_col = sizeof(int) * csr_mat.n_elem;
70
71 cudaMalloc(&d_val_idx, n_val);
72 cudaMalloc(&d_col_idx, n_col);
73
74 cudaMemcpyAsync(d_val_idx, csr_mat.val_mem(), n_val, cudaMemcpyHostToDevice, stream);
75 cudaMemcpyAsync(d_col_idx, csr_mat.col_mem(), n_col, cudaMemcpyHostToDevice, stream);
76 cudaMemcpyAsync(d_row_ptr, csr_mat.row_mem(), sizeof(int) * (csr_mat.n_rows + 1llu), cudaMemcpyHostToDevice, stream);
77 }
78
79 void device_dealloc() const {
80 if(d_val_idx) cudaFree(d_val_idx);
81 if(d_col_idx) cudaFree(d_col_idx);
82 }
83
84protected:
86
87 int direct_solve(Mat<T>&, const Mat<T>&) override;
88
89public:
90 SparseMatCUDA(const uword in_row, const uword in_col, const uword in_elem = 0)
91 : SparseMat<T>(in_row, in_col, in_elem) { acquire(); }
92
94 : SparseMat<T>(other) { acquire(); }
95
96 SparseMatCUDA(SparseMatCUDA&&) noexcept = delete;
97 SparseMatCUDA& operator=(const SparseMatCUDA&) = delete;
98 SparseMatCUDA& operator=(SparseMatCUDA&&) noexcept = delete;
99
100 ~SparseMatCUDA() override {
101 release();
102 device_dealloc();
103 }
104
105 unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<SparseMatCUDA>(*this); }
106};
107
108template<sp_d T> int SparseMatCUDA<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
109 if(!this->factored) {
110 // deallocate memory previously allocated for csr matrix
111 device_dealloc();
112
113 std::is_same_v<T, float> || Precision::MIXED == this->setting.precision ? device_alloc(csr_form<float, int>(this->triplet_mat)) : device_alloc(csr_form<double, int>(this->triplet_mat));
114
115 this->factored = true;
116 }
117
118 const size_t n_rhs = (std::is_same_v<T, float> || Precision::MIXED == this->setting.precision ? sizeof(float) : sizeof(double)) * B.n_elem;
119
120 void* d_b = nullptr;
121 void* d_x = nullptr;
122
123 cudaMalloc(&d_b, n_rhs);
124 cudaMalloc(&d_x, n_rhs);
125
126 int singularity;
127 auto code = 0;
128
129 if constexpr(std::is_same_v<T, float>) {
130 cudaMemcpyAsync(d_b, B.memptr(), n_rhs, cudaMemcpyHostToDevice, stream);
131
132 for(auto I = 0llu; I < B.n_elem; I += B.n_rows) code += cusolverSpScsrlsvqr(handle, int(this->n_rows), int(this->triplet_mat.n_elem), descr, (float*)d_val_idx, (int*)d_row_ptr, (int*)d_col_idx, (float*)d_b + I, float(this->setting.tolerance), 3, (float*)d_x + I, &singularity);
133
134 X.set_size(arma::size(B));
135
136 cudaMemcpyAsync(X.memptr(), d_x, n_rhs, cudaMemcpyDeviceToHost, stream);
137
138 cudaDeviceSynchronize();
139 }
140 else if(Precision::FULL == this->setting.precision) {
141 cudaMemcpyAsync(d_b, B.memptr(), n_rhs, cudaMemcpyHostToDevice, stream);
142
143 for(auto I = 0llu; I < B.n_elem; I += B.n_rows) code += cusolverSpDcsrlsvqr(handle, int(this->n_rows), int(this->triplet_mat.n_elem), descr, (double*)d_val_idx, (int*)d_row_ptr, (int*)d_col_idx, (double*)d_b + I, this->setting.tolerance, 3, (double*)d_x + I, &singularity);
144
145 X.set_size(arma::size(B));
146
147 cudaMemcpyAsync(X.memptr(), d_x, n_rhs, cudaMemcpyDeviceToHost, stream);
148
149 cudaDeviceSynchronize();
150 }
151 else {
152 X = arma::zeros(arma::size(B));
153
154 mat full_residual = B;
155
156 auto multiplier = norm(full_residual);
157
158 auto counter = std::uint8_t{0};
159 while(counter++ < this->setting.iterative_refinement) {
160 if(multiplier < this->setting.tolerance) break;
161
162 auto residual = conv_to<fmat>::from(full_residual / multiplier);
163
164 cudaMemcpyAsync(d_b, residual.memptr(), n_rhs, cudaMemcpyHostToDevice, stream);
165
166 code = 0;
167 for(auto I = 0llu; I < B.n_elem; I += B.n_rows) code += cusolverSpScsrlsvqr(handle, int(this->n_rows), int(this->triplet_mat.n_elem), descr, (float*)d_val_idx, (int*)d_row_ptr, (int*)d_col_idx, (float*)d_b + I, float(this->setting.tolerance), 3, (float*)d_x + I, &singularity);
168 if(0 != code) break;
169
170 cudaMemcpyAsync(residual.memptr(), d_x, n_rhs, cudaMemcpyDeviceToHost, stream);
171
172 cudaDeviceSynchronize();
173
174 const mat incre = multiplier * conv_to<mat>::from(residual);
175
176 X += incre;
177
178 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
179 }
180 }
181
182 if(d_b) cudaFree(d_b);
183 if(d_x) cudaFree(d_x);
184
185 return 0 == code ? SUANPAN_SUCCESS : SUANPAN_FAIL;
186}
187
188#endif
189
const uword n_rows
Definition MetaMat.hpp:116
A SparseMatCUDA class that holds matrices.
Definition SparseMatCUDA.hpp:41
SparseMatCUDA(SparseMatCUDA &&) noexcept=delete
unique_ptr< MetaMat< T > > make_copy() override
Definition SparseMatCUDA.hpp:105
SparseMatCUDA(const uword in_row, const uword in_col, const uword in_elem=0)
Definition SparseMatCUDA.hpp:90
SparseMatCUDA(const SparseMatCUDA &other)
Definition SparseMatCUDA.hpp:93
A SparseMat class that holds matrices.
Definition SparseMat.hpp:34
Definition csr_form.hpp:25
int direct_solve(Mat< T > &, const Mat< T > &) override
Definition SparseMatCUDA.hpp:108
#define suanpan_debug(...)
Definition suanPan.h:371
constexpr auto SUANPAN_SUCCESS
Definition suanPan.h:180
constexpr auto SUANPAN_FAIL
Definition suanPan.h:181