suanPan
Loading...
Searching...
No Matches
BandMatCUDA.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2023 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef BANDMATCUDA_HPP
31#define BANDMATCUDA_HPP
32
33#ifdef SUANPAN_CUDA
34
35#include "BandMat.hpp"
36#include <cusolverSp.h>
37#include <cusparse.h>
38#include "csr_form.hpp"
39
40template<sp_d T> class BandMatCUDA final : public BandMat<T> {
41 cusolverSpHandle_t handle = nullptr;
42 cudaStream_t stream = nullptr;
43 cusparseMatDescr_t descr = nullptr;
44
45 void* d_val_idx = nullptr;
46 void* d_col_idx = nullptr;
47 void* d_row_ptr = nullptr;
48
49 triplet_form<float, int> s_mat{static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), static_cast<int>(this->n_elem)};
50
51 void acquire() {
52 cusolverSpCreate(&handle);
53 cudaStreamCreate(&stream);
54 cusolverSpSetStream(handle, stream);
55 cusparseCreateMatDescr(&descr);
56 cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
57 cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
58 cudaMalloc(&d_row_ptr, sizeof(int) * (this->n_rows + 1));
59 }
60
61 void release() const {
62 if(handle) cusolverSpDestroy(handle);
63 if(stream) cudaStreamDestroy(stream);
64 if(descr) cusparseDestroyMatDescr(descr);
65 if(d_row_ptr) cudaFree(d_row_ptr);
66 }
67
68 void device_alloc(csr_form<float, int>&& csr_mat) {
69 const size_t n_val = sizeof(float) * csr_mat.n_elem;
70 const size_t n_col = sizeof(int) * csr_mat.n_elem;
71
72 cudaMalloc(&d_val_idx, n_val);
73 cudaMalloc(&d_col_idx, n_col);
74
75 cudaMemcpyAsync(d_val_idx, csr_mat.val_mem(), n_val, cudaMemcpyHostToDevice, stream);
76 cudaMemcpyAsync(d_col_idx, csr_mat.col_mem(), n_col, cudaMemcpyHostToDevice, stream);
77 cudaMemcpyAsync(d_row_ptr, csr_mat.row_mem(), sizeof(int) * (csr_mat.n_rows + 1llu), cudaMemcpyHostToDevice, stream);
78 }
79
80 void device_dealloc() const {
81 if(d_val_idx) cudaFree(d_val_idx);
82 if(d_col_idx) cudaFree(d_col_idx);
83 }
84
85protected:
86 using BandMat<T>::direct_solve;
87
88 int direct_solve(Mat<T>&, Mat<T>&&) override;
89
90public:
91 BandMatCUDA(const uword in_size, const uword in_l, const uword in_u)
92 : BandMat<T>(in_size, in_l, in_u) { acquire(); }
93
94 BandMatCUDA(const BandMatCUDA& other)
95 : BandMat<T>(other) { acquire(); }
96
97 BandMatCUDA(BandMatCUDA&&) noexcept = delete;
98 BandMatCUDA& operator=(const BandMatCUDA&) = delete;
99 BandMatCUDA& operator=(BandMatCUDA&&) noexcept = delete;
100
101 ~BandMatCUDA() override {
102 release();
103 device_dealloc();
104 }
105
106 unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<BandMatCUDA>(*this); }
107};
108
109template<sp_d T> int BandMatCUDA<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
110 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
111
112 if(!this->factored) {
113 this->factored = true;
114
115 device_dealloc();
116
117 s_mat.zeros();
118 for(auto I = 0; I < static_cast<int>(this->n_rows); ++I) for(auto J = std::max(0, I - static_cast<int>(this->u_band)); J <= std::min(static_cast<int>(this->n_rows) - 1, I + static_cast<int>(this->l_band)); ++J) s_mat.at(J, I) = static_cast<float>(this->at(J, I));
119
120 device_alloc(csr_form<float, int>(s_mat));
121 }
122
123 const size_t n_rhs = sizeof(float) * B.n_elem;
124
125 void* d_b = nullptr;
126 void* d_x = nullptr;
127
128 cudaMalloc(&d_b, n_rhs);
129 cudaMalloc(&d_x, n_rhs);
130
131 auto INFO = this->mixed_trs(X, std::move(B), [&](fmat& residual) {
132 cudaMemcpyAsync(d_b, residual.memptr(), n_rhs, cudaMemcpyHostToDevice, stream);
133
134 int singularity;
135
136 auto code = 0;
137 for(auto I = 0llu; I < residual.n_elem; I += residual.n_rows) code += cusolverSpScsrlsvqr(handle, static_cast<int>(this->n_rows), static_cast<int>(this->s_mat.n_elem), descr, (float*)d_val_idx, (int*)d_row_ptr, (int*)d_col_idx, (float*)d_b + I, static_cast<float>(this->setting.tolerance), 3, (float*)d_x + I, &singularity);
138
139 cudaMemcpyAsync(residual.memptr(), d_x, n_rhs, cudaMemcpyDeviceToHost, stream);
140
141 cudaDeviceSynchronize();
142
143 return code;
144 });
145
146 if(0 != INFO)
147 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
148
149 return INFO;
150}
151
152#endif
153
154#endif
155
A BandMatCUDA class that holds matrices.
A BandMat class that holds matrices.
Definition BandMat.hpp:35
Definition csr_form.hpp:25
Definition triplet_form.hpp:62
unique_ptr< Material > make_copy(const shared_ptr< Material > &)
Definition Material.cpp:356
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:296
#define suanpan_error(...)
Definition suanPan.h:309