suanPan
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages Concepts
SparseMatSuperLU.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef SPARSEMATSUPERLU_HPP
31#define SPARSEMATSUPERLU_HPP
32
33#include "../SparseMat.hpp"
34#include "../csc_form.hpp"
35
36#include <superlu-mt/superlu-mt.h>
37
38template<sp_d T> class SparseMatSuperLU final : public SparseMat<T> {
39 SuperMatrix A{}, L{}, U{}, B{};
40
41#ifndef SUANPAN_SUPERLUMT
42 superlu_options_t options{};
43
44 SuperLUStat_t stat{};
45#else
46 const int ordering_num = 1;
47
48 Gstat_t stat{};
49#endif
50
51 std::vector<T> t_val;
52 std::vector<int> t_row, t_col, perm_r, perm_c;
53
54 bool allocated = false;
55
56 auto init_config();
57
58 template<sp_d ET> void alloc(csc_form<ET, int>&&);
59 void dealloc();
60
61 template<sp_d ET> void wrap_b(const Mat<ET>&);
62 template<sp_d ET> void tri_solve(int&);
63 template<sp_d ET> void full_solve(int&);
64
65 int solve_trs(Mat<T>&, Mat<T>&&);
66
67protected:
68 int direct_solve(Mat<T>& out_mat, const Mat<T>& in_mat) override { return this->direct_solve(out_mat, Mat<T>(in_mat)); }
69
70 int direct_solve(Mat<T>&, Mat<T>&&) override;
71
72public:
73 SparseMatSuperLU(uword, uword, uword = 0);
75 SparseMatSuperLU(SparseMatSuperLU&&) noexcept = delete;
76 SparseMatSuperLU& operator=(const SparseMatSuperLU&) = delete;
77 SparseMatSuperLU& operator=(SparseMatSuperLU&&) noexcept = delete;
78 ~SparseMatSuperLU() override;
79
80 unique_ptr<MetaMat<T>> make_copy() override;
81};
82
83template<sp_d T> auto SparseMatSuperLU<T>::init_config() {
84#ifndef SUANPAN_SUPERLUMT
85 set_default_options(&options);
86 options.IterRefine = std::is_same_v<T, float> ? superlu::IterRefine_t::SLU_SINGLE : superlu::IterRefine_t::SLU_DOUBLE;
87 options.Equil = superlu::yes_no_t::NO;
88
89 StatInit(&stat);
90#else
91 StatAlloc(static_cast<int>(this->n_cols), SUANPAN_NUM_THREADS, sp_ienv(1), sp_ienv(2), &stat);
92 StatInit(static_cast<int>(this->n_cols), SUANPAN_NUM_THREADS, &stat);
93#endif
94}
95
96template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::alloc(csc_form<ET, int>&& in) {
97 dealloc();
98
99 t_row = std::vector<int>(in.row_mem(), in.row_mem() + in.n_elem);
100 t_col = std::vector<int>(in.col_mem(), in.col_mem() + in.n_cols + 1);
101 t_val = std::vector<ET>(in.val_mem(), in.val_mem() + in.n_elem);
102
103 if constexpr(std::is_same_v<ET, double>) {
104 using E = double;
105 dCreate_CompCol_Matrix(&A, in.n_rows, in.n_cols, in.n_elem, (E*)t_val.data(), t_row.data(), t_col.data(), Stype_t::SLU_NC, Dtype_t::SLU_D, Mtype_t::SLU_GE);
106 }
107 else {
108 using E = float;
109 sCreate_CompCol_Matrix(&A, in.n_rows, in.n_cols, in.n_elem, (E*)t_val.data(), t_row.data(), t_col.data(), Stype_t::SLU_NC, Dtype_t::SLU_S, Mtype_t::SLU_GE);
110 }
111
112 perm_r = std::vector<int>(this->n_rows + 1);
113 perm_c = std::vector<int>(this->n_cols + 1);
114
115 allocated = true;
116}
117
118template<sp_d T> void SparseMatSuperLU<T>::dealloc() {
119 if(!allocated) return;
120
121 Destroy_SuperMatrix_Store(&A);
122#ifdef SUANPAN_SUPERLUMT
123 Destroy_SuperNode_SCP(&L);
124 Destroy_CompCol_NCP(&U);
125#else
126 Destroy_SuperNode_Matrix(&L);
127 Destroy_CompCol_Matrix(&U);
128#endif
129
130 allocated = false;
131}
132
133template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::wrap_b(const Mat<ET>& in_mat) {
134 if constexpr(std::is_same_v<ET, float>) {
135 using E = float;
136 sCreate_Dense_Matrix(&B, (int)in_mat.n_rows, (int)in_mat.n_cols, (E*)in_mat.memptr(), (int)in_mat.n_rows, Stype_t::SLU_DN, Dtype_t::SLU_S, Mtype_t::SLU_GE);
137 }
138 else {
139 using E = double;
140 dCreate_Dense_Matrix(&B, (int)in_mat.n_rows, (int)in_mat.n_cols, (E*)in_mat.memptr(), (int)in_mat.n_rows, Stype_t::SLU_DN, Dtype_t::SLU_D, Mtype_t::SLU_GE);
141 }
142}
143
144template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::tri_solve(int& flag) {
145#ifdef SUANPAN_SUPERLUMT
146 if(std::is_same_v<ET, float>) sgstrs(NOTRANS, &L, &U, perm_c.data(), perm_r.data(), &B, &stat, &flag);
147 else dgstrs(NOTRANS, &L, &U, perm_c.data(), perm_r.data(), &B, &stat, &flag);
148#else
149 superlu::gstrs<ET>(options.Trans, &L, &U, perm_c.data(), perm_r.data(), &B, &stat, &flag);
150#endif
151
152 Destroy_SuperMatrix_Store(&B);
153}
154
155template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::full_solve(int& flag) {
156#ifdef SUANPAN_SUPERLUMT
157 get_perm_c(ordering_num, &A, perm_c.data());
158 if(std::is_same_v<ET, float>) psgssv(SUANPAN_NUM_THREADS, &A, perm_c.data(), perm_r.data(), &L, &U, &B, &flag);
159 else pdgssv(SUANPAN_NUM_THREADS, &A, perm_c.data(), perm_r.data(), &L, &U, &B, &flag);
160#else
161 superlu::gssv<ET>(&options, &A, perm_c.data(), perm_r.data(), &L, &U, &B, &stat, &flag);
162#endif
163
164 Destroy_SuperMatrix_Store(&B);
165}
166
167template<sp_d T> SparseMatSuperLU<T>::SparseMatSuperLU(const uword in_row, const uword in_col, const uword in_elem)
168 : SparseMat<T>(in_row, in_col, in_elem) { init_config(); }
169
171 : SparseMat<T>(other) {
172 init_config();
173 this->factored = false;
174}
175
177 dealloc();
178 StatFree(&stat);
179}
180
181template<sp_d T> unique_ptr<MetaMat<T>> SparseMatSuperLU<T>::make_copy() { return std::make_unique<SparseMatSuperLU>(*this); }
182
183template<sp_d T> int SparseMatSuperLU<T>::direct_solve(Mat<T>& out_mat, Mat<T>&& in_mat) {
184 if(this->factored) return solve_trs(out_mat, std::move(in_mat));
185
186 this->factored = true;
187
188 alloc(csc_form<T, int>(this->triplet_mat));
189
190 wrap_b(in_mat);
191
192 auto flag = 0;
193
194 full_solve<T>(flag);
195
196 out_mat = std::move(in_mat);
197
198 return flag;
199}
200
201template<sp_d T> int SparseMatSuperLU<T>::solve_trs(Mat<T>& out_mat, Mat<T>&& in_mat) {
202 wrap_b(in_mat);
203
204 auto flag = 0;
205
206 tri_solve<T>(flag);
207
208 out_mat = std::move(in_mat);
209
210 return flag;
211}
212#endif
213
A MetaMat class that holds matrices.
Definition MetaMat.hpp:74
const uword n_cols
Definition MetaMat.hpp:117
const uword n_rows
Definition MetaMat.hpp:116
bool factored
Definition MetaMat.hpp:76
A SparseMat class that holds matrices.
Definition SparseMat.hpp:34
A SparseMatSuperLU class that holds matrices.
Definition SparseMatSuperLU.hpp:38
SparseMatSuperLU(SparseMatSuperLU &&) noexcept=delete
int direct_solve(Mat< T > &out_mat, const Mat< T > &in_mat) override
Definition SparseMatSuperLU.hpp:68
Definition csc_form.hpp:25
int SUANPAN_NUM_THREADS
Definition command.cpp:70
Definition suanPan.h:394
~SparseMatSuperLU() override
Definition SparseMatSuperLU.hpp:176
unique_ptr< MetaMat< T > > make_copy() override
Definition SparseMatSuperLU.hpp:181
SparseMatSuperLU(uword, uword, uword=0)
Definition SparseMatSuperLU.hpp:167