suanPan
Loading...
Searching...
No Matches
SparseMatSuperLU.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2024 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef SPARSEMATSUPERLU_HPP
31#define SPARSEMATSUPERLU_HPP
32
33#include <superlu-mt/superlu-mt.h>
34#include "SparseMat.hpp"
35#include "csc_form.hpp"
36
37template<sp_d T> class SparseMatSuperLU final : public SparseMat<T> {
38 SuperMatrix A{}, L{}, U{}, B{};
39
40#ifndef SUANPAN_SUPERLUMT
41 superlu_options_t options{};
42
43 SuperLUStat_t stat{};
44#else
45 const int ordering_num = 1;
46
47 Gstat_t stat{};
48#endif
49
50 void* t_val = nullptr;
51 int* t_row = nullptr;
52 int* t_col = nullptr;
53
54 int* perm_r = nullptr;
55 int* perm_c = nullptr;
56
57 bool allocated = false;
58
59 template<sp_d ET> void alloc(csc_form<ET, int>&&);
60 void dealloc();
61
62 template<sp_d ET> void wrap_b(const Mat<ET>&);
63 template<sp_d ET> void tri_solve(int&);
64 template<sp_d ET> void full_solve(int&);
65
66 int solve_trs(Mat<T>&, Mat<T>&&);
67
68protected:
69 int direct_solve(Mat<T>& out_mat, const Mat<T>& in_mat) override { return this->direct_solve(out_mat, Mat<T>(in_mat)); }
70
71 int direct_solve(Mat<T>&, Mat<T>&&) override;
72
73public:
74 SparseMatSuperLU(uword, uword, uword = 0);
76 SparseMatSuperLU(SparseMatSuperLU&&) noexcept = delete;
77 SparseMatSuperLU& operator=(const SparseMatSuperLU&) = delete;
78 SparseMatSuperLU& operator=(SparseMatSuperLU&&) noexcept = delete;
79 ~SparseMatSuperLU() override;
80
81 void zeros() override;
82
83 unique_ptr<MetaMat<T>> make_copy() override;
84};
85
86template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::alloc(csc_form<ET, int>&& in) {
87 dealloc();
88
89 auto t_size = sizeof(ET) * in.n_elem;
90 t_val = superlu_malloc(t_size);
91 memcpy(t_val, (void*)in.val_mem(), t_size);
92
93 t_size = sizeof(int) * in.n_elem;
94 t_row = (int*)superlu_malloc(t_size);
95 memcpy(t_row, (void*)in.row_mem(), t_size);
96
97 t_size = sizeof(int) * (in.n_cols + 1llu);
98 t_col = (int*)superlu_malloc(t_size);
99 memcpy(t_col, (void*)in.col_mem(), t_size);
100
101 if constexpr(std::is_same_v<ET, double>) {
102 using E = double;
103 dCreate_CompCol_Matrix(&A, in.n_rows, in.n_cols, in.n_elem, (E*)t_val, t_row, t_col, Stype_t::SLU_NC, Dtype_t::SLU_D, Mtype_t::SLU_GE);
104 }
105 else {
106 using E = float;
107 sCreate_CompCol_Matrix(&A, in.n_rows, in.n_cols, in.n_elem, (E*)t_val, t_row, t_col, Stype_t::SLU_NC, Dtype_t::SLU_S, Mtype_t::SLU_GE);
108 }
109
110 perm_r = (int*)superlu_malloc(sizeof(int) * (this->n_rows + 1));
111 perm_c = (int*)superlu_malloc(sizeof(int) * (this->n_cols + 1));
112
113 allocated = true;
114}
115
116template<sp_d T> void SparseMatSuperLU<T>::dealloc() {
117 if(!allocated) return;
118
119 Destroy_SuperMatrix_Store(&A);
120#ifdef SUANPAN_SUPERLUMT
121 Destroy_SuperNode_SCP(&L);
122 Destroy_CompCol_NCP(&U);
123#else
124 Destroy_SuperNode_Matrix(&L);
125 Destroy_CompCol_Matrix(&U);
126#endif
127
128 if(t_val) superlu_free(t_val);
129 if(t_row) superlu_free(t_row);
130 if(t_col) superlu_free(t_col);
131 if(perm_r) superlu_free(perm_r);
132 if(perm_c) superlu_free(perm_c);
133
134 allocated = false;
135}
136
137template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::wrap_b(const Mat<ET>& in_mat) {
138 if constexpr(std::is_same_v<ET, float>) {
139 using E = float;
140 sCreate_Dense_Matrix(&B, (int)in_mat.n_rows, (int)in_mat.n_cols, (E*)in_mat.memptr(), (int)in_mat.n_rows, Stype_t::SLU_DN, Dtype_t::SLU_S, Mtype_t::SLU_GE);
141 }
142 else {
143 using E = double;
144 dCreate_Dense_Matrix(&B, (int)in_mat.n_rows, (int)in_mat.n_cols, (E*)in_mat.memptr(), (int)in_mat.n_rows, Stype_t::SLU_DN, Dtype_t::SLU_D, Mtype_t::SLU_GE);
145 }
146}
147
148template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::tri_solve(int& flag) {
149#ifdef SUANPAN_SUPERLUMT
150 if(std::is_same_v<ET, float>) sgstrs(NOTRANS, &L, &U, perm_c, perm_r, &B, &stat, &flag);
151 else dgstrs(NOTRANS, &L, &U, perm_c, perm_r, &B, &stat, &flag);
152#else
153 superlu::gstrs<ET>(options.Trans, &L, &U, perm_c, perm_r, &B, &stat, &flag);
154#endif
155
156 Destroy_SuperMatrix_Store(&B);
157}
158
159template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::full_solve(int& flag) {
160#ifdef SUANPAN_SUPERLUMT
161 get_perm_c(ordering_num, &A, perm_c);
162 if(std::is_same_v<ET, float>) psgssv(SUANPAN_NUM_THREADS, &A, perm_c, perm_r, &L, &U, &B, &flag);
163 else pdgssv(SUANPAN_NUM_THREADS, &A, perm_c, perm_r, &L, &U, &B, &flag);
164#else
165 superlu::gssv<ET>(&options, &A, perm_c, perm_r, &L, &U, &B, &stat, &flag);
166#endif
167
168 Destroy_SuperMatrix_Store(&B);
169}
170
171template<sp_d T> SparseMatSuperLU<T>::SparseMatSuperLU(const uword in_row, const uword in_col, const uword in_elem)
172 : SparseMat<T>(in_row, in_col, in_elem) {
173#ifndef SUANPAN_SUPERLUMT
174 set_default_options(&options);
175 options.IterRefine = std::is_same_v<T, float> ? superlu::IterRefine_t::SLU_SINGLE : superlu::IterRefine_t::SLU_DOUBLE;
176 options.Equil = superlu::yes_no_t::NO;
177
178 arrayops::fill_zeros(reinterpret_cast<char*>(&stat), sizeof(SuperLUStat_t));
179
180 StatInit(&stat);
181#else
182 StatAlloc(static_cast<int>(in_col), SUANPAN_NUM_THREADS, sp_ienv(1), sp_ienv(2), &stat);
183 StatInit(static_cast<int>(in_col), SUANPAN_NUM_THREADS, &stat);
184#endif
185}
186
188 : SparseMat<T>(other) {
189#ifndef SUANPAN_SUPERLUMT
190 set_default_options(&options);
191 options.IterRefine = std::is_same_v<T, float> ? superlu::IterRefine_t::SLU_SINGLE : superlu::IterRefine_t::SLU_DOUBLE;
192 options.Equil = superlu::yes_no_t::NO;
193
194 arrayops::fill_zeros(reinterpret_cast<char*>(&stat), sizeof(SuperLUStat_t));
195
196 StatInit(&stat);
197#else
198 StatAlloc(static_cast<int>(other.n_cols), SUANPAN_NUM_THREADS, sp_ienv(1), sp_ienv(2), &stat);
199 StatInit(static_cast<int>(other.n_cols), SUANPAN_NUM_THREADS, &stat);
200#endif
201}
202
204 dealloc();
205 StatFree(&stat);
206}
207
208template<sp_d T> void SparseMatSuperLU<T>::zeros() {
210 dealloc();
211}
212
213template<sp_d T> unique_ptr<MetaMat<T>> SparseMatSuperLU<T>::make_copy() { return std::make_unique<SparseMatSuperLU>(*this); }
214
215template<sp_d T> int SparseMatSuperLU<T>::direct_solve(Mat<T>& out_mat, Mat<T>&& in_mat) {
216 if(this->factored) return solve_trs(out_mat, std::forward<Mat<T>>(in_mat));
217
218 this->factored = true;
219
220 auto flag = 0;
221
222 if constexpr(std::is_same_v<T, float>) {
223 alloc(csc_form<float, int>(this->triplet_mat));
224
225 wrap_b(in_mat);
226
227 full_solve<float>(flag);
228
229 out_mat = std::move(in_mat);
230 }
231 else if(Precision::FULL == this->setting.precision) {
232 alloc(csc_form<double, int>(this->triplet_mat));
233
234 wrap_b(in_mat);
235
236 full_solve<double>(flag);
237
238 out_mat = std::move(in_mat);
239 }
240 else {
241 alloc(csc_form<float, int>(this->triplet_mat));
242
243 const fmat f_mat(arma::size(in_mat), fill::none);
244
245 wrap_b(f_mat);
246
247 full_solve<float>(flag);
248
249 if(0 == flag) flag = solve_trs(out_mat, std::forward<Mat<T>>(in_mat));
250 }
251
252 return flag;
253}
254
255template<sp_d T> int SparseMatSuperLU<T>::solve_trs(Mat<T>& out_mat, Mat<T>&& in_mat) {
256 auto flag = 0;
257
258 if constexpr(std::is_same_v<T, float>) {
259 wrap_b(in_mat);
260
261 tri_solve<float>(flag);
262
263 out_mat = std::move(in_mat);
264 }
265 else if(Precision::FULL == this->setting.precision) {
266 wrap_b(in_mat);
267
268 tri_solve<double>(flag);
269
270 out_mat = std::move(in_mat);
271 }
272 else {
273 out_mat.zeros(arma::size(in_mat));
274
275 auto multiplier = arma::norm(in_mat);
276
277 auto counter = 0u;
278 while(counter++ < this->setting.iterative_refinement) {
279 if(multiplier < this->setting.tolerance) break;
280
281 auto residual = conv_to<fmat>::from(in_mat / multiplier);
282
283 wrap_b(residual);
284
285 tri_solve<float>(flag);
286
287 if(0 != flag) break;
288
289 const mat incre = multiplier * conv_to<mat>::from(residual);
290
291 out_mat += incre;
292
293 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(in_mat -= this->operator*(incre)));
294 }
295 }
296
297 return flag;
298}
299#endif
300
A MetaMat class that holds matrices.
Definition MetaMat.hpp:72
const uword n_cols
Definition MetaMat.hpp:119
const uword n_rows
Definition MetaMat.hpp:118
A SparseMat class that holds matrices.
Definition SparseMat.hpp:34
void zeros() override
Definition SparseMat.hpp:46
A SparseMatSuperLU class that holds matrices.
Definition SparseMatSuperLU.hpp:37
SparseMatSuperLU(SparseMatSuperLU &&) noexcept=delete
int direct_solve(Mat< T > &out_mat, const Mat< T > &in_mat) override
Definition SparseMatSuperLU.hpp:69
Definition csc_form.hpp:25
int SUANPAN_NUM_THREADS
Definition command.cpp:71
Definition suanPan.h:330
~SparseMatSuperLU() override
Definition SparseMatSuperLU.hpp:203
unique_ptr< MetaMat< T > > make_copy() override
Definition SparseMatSuperLU.hpp:213
void zeros() override
Definition SparseMatSuperLU.hpp:208
SparseMatSuperLU(uword, uword, uword=0)
Definition SparseMatSuperLU.hpp:171
#define suanpan_debug(...)
Definition suanPan.h:307