suanPan
Loading...
Searching...
No Matches
MetaMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29#ifndef METAMAT_HPP
30#define METAMAT_HPP
31
32#include "triplet_form.hpp"
33#include "IterativeSolver.hpp"
34#include "ILU.hpp"
35#include "Jacobi.hpp"
36
37template<typename T, typename U> concept ArmaContainer = std::is_floating_point_v<U> && (std::is_convertible_v<T, Mat<U>> || std::is_convertible_v<T, SpMat<U>>) ;
38
39template<sp_d T> class MetaMat;
40
41template<sp_d T> class op_add {
42 friend MetaMat<T>;
43
44 shared_ptr<MetaMat<T>> mat_a, mat_b;
45
46public:
47 explicit op_add(const shared_ptr<MetaMat<T>>& A)
48 : mat_a(A)
49 , mat_b(nullptr) {}
50
51 op_add(const shared_ptr<MetaMat<T>>& A, const shared_ptr<MetaMat<T>>& B)
52 : mat_a(A)
53 , mat_b(B) {}
54};
55
56template<sp_d T> class op_scale {
57 friend MetaMat<T>;
58
59 T scalar;
60 op_add<T> bracket;
61
62public:
63 op_scale(const T A, const shared_ptr<MetaMat<T>>& B)
64 : scalar(A)
65 , bracket(B) {}
66
67 op_scale(const T A, op_add<T>&& B)
68 : scalar(A)
69 , bracket(std::forward<op_add<T>>(B)) {}
70};
71
72template<sp_d T> class MetaMat {
73protected:
74 bool factored = false;
75
77
78 virtual int direct_solve(Mat<T>&, const Mat<T>&) = 0;
79
80 virtual int direct_solve(Mat<T>&, Mat<T>&&) = 0;
81
82 int direct_solve(Mat<T>& X, const SpMat<T>& B) { return this->direct_solve(X, Mat<T>(B)); }
83
84 int direct_solve(Mat<T>& X, SpMat<T>&& B) { return this->direct_solve(X, B); }
85
86 int iterative_solve(Mat<T>&, const Mat<T>&);
87
88 int iterative_solve(Mat<T>& X, const SpMat<T>& B) { return this->iterative_solve(X, Mat<T>(B)); }
89
90 template<std::invocable<fmat&> F> int mixed_trs(mat& X, mat&& B, F trs) {
91 auto INFO = 0;
92
93 X = arma::zeros(size(B));
94
95 auto multiplier = norm(B);
96
97 auto counter = 0u;
98 while(counter++ < this->setting.iterative_refinement) {
99 if(multiplier < this->setting.tolerance) break;
100
101 auto residual = conv_to<fmat>::from(B / multiplier);
102
103 if(0 != (INFO = trs(residual))) break;
104
105 const mat incre = multiplier * conv_to<mat>::from(residual);
106
107 X += incre;
108
109 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
110 }
111
112 return INFO;
113 }
114
115public:
117
118 const uword n_rows;
119 const uword n_cols;
120 const uword n_elem;
121
122 MetaMat(const uword in_rows, const uword in_cols, const uword in_elem)
123 : triplet_mat(in_rows, in_cols)
124 , n_rows(in_rows)
125 , n_cols(in_cols)
126 , n_elem(in_elem) {}
127
128 MetaMat(const MetaMat&) = default;
129 MetaMat(MetaMat&&) noexcept = delete;
130 MetaMat& operator=(const MetaMat&) = delete;
131 MetaMat& operator=(MetaMat&&) noexcept = delete;
132 virtual ~MetaMat() = default;
133
135
136 [[nodiscard]] SolverSetting<T>& get_solver_setting() { return setting; }
137
138 void set_factored(const bool F) { factored = F; }
139
140 [[nodiscard]] virtual bool is_empty() const = 0;
141 virtual void zeros() = 0;
142
143 virtual unique_ptr<MetaMat> make_copy() = 0;
144
145 void unify(const uword K) {
146 this->nullify(K);
147 this->at(K, K) = T(1);
148 }
149
150 virtual void nullify(uword) = 0;
151
152 [[nodiscard]] virtual T max() const = 0;
153 [[nodiscard]] virtual Col<T> diag() const = 0;
154
159 virtual T operator()(uword, uword) const = 0;
164 virtual T& unsafe_at(const uword I, const uword J) { return this->at(I, J); }
165
170 virtual T& at(uword, uword) = 0;
171
172 [[nodiscard]] virtual const T* memptr() const = 0;
173 virtual T* memptr() = 0;
174
175 virtual void scale_accu(T, const shared_ptr<MetaMat>&) = 0;
176 virtual void scale_accu(T, const triplet_form<T, uword>&) = 0;
177
178 void operator+=(const shared_ptr<MetaMat>& M) { return this->scale_accu(1., M); }
179
180 void operator-=(const shared_ptr<MetaMat>& M) { return this->scale_accu(-1., M); }
181
182 void operator+=(const op_scale<T>& M) {
183 const auto& bracket = M.bracket;
184 if(nullptr != bracket.mat_a) this->scale_accu(M.scalar, bracket.mat_a);
185 if(nullptr != bracket.mat_b) this->scale_accu(M.scalar, bracket.mat_b);
186 }
187
188 void operator-=(const op_scale<T>& M) {
189 const auto& bracket = M.bracket;
190 if(nullptr != bracket.mat_a) this->scale_accu(-M.scalar, bracket.mat_a);
191 if(nullptr != bracket.mat_b) this->scale_accu(-M.scalar, bracket.mat_b);
192 }
193
194 void operator+=(const triplet_form<T, uword>& M) { return this->scale_accu(1., M); }
195
196 void operator-=(const triplet_form<T, uword>& M) { return this->scale_accu(-1., M); }
197
198 virtual Mat<T> operator*(const Mat<T>&) const = 0;
199
200 virtual void operator*=(T) = 0;
201
202 template<ArmaContainer<T> C> int solve(Mat<T>& X, C&& B) { return IterativeSolver::NONE == this->setting.iterative_solver ? this->direct_solve(X, std::forward<C>(B)) : this->iterative_solve(X, std::forward<C>(B)); }
203
204 template<ArmaContainer<T> C> Mat<T> solve(C&& B) {
205 Mat<T> X;
206
207 if(SUANPAN_SUCCESS != this->solve(X, std::forward<C>(B))) throw std::runtime_error("fail to solve the system");
208
209 return X;
210 }
211
212 [[nodiscard]] virtual int sign_det() const = 0;
213
214 void save(const char* name) {
215 if(!to_mat(*this).save(name, raw_ascii))
216 suanpan_error("Cannot save to file \"{}\".\n", name);
217 }
218
219 virtual void csc_condense() {}
220
221 virtual void csr_condense() {}
222
223 [[nodiscard]] Col<T> evaluate(const Col<T>& X) const { return this->operator*(X); }
224};
225
226template<sp_d T> int MetaMat<T>::iterative_solve(Mat<T>& X, const Mat<T>& B) {
227 this->csc_condense();
228
229 X.zeros(arma::size(B));
230
231 unique_ptr<Preconditioner<T>> preconditioner;
232 if(PreconditionerType::JACOBI == this->setting.preconditioner_type) preconditioner = std::make_unique<Jacobi<T>>(this->diag());
233#ifndef SUANPAN_SUPERLUMT
234 else if(PreconditionerType::ILU == this->setting.preconditioner_type) {
235 if(this->triplet_mat.is_empty()) preconditioner = std::make_unique<ILU<T>>(to_triplet_form<T, int>(this));
236 else preconditioner = std::make_unique<ILU<T>>(this->triplet_mat);
237 }
238#endif
239 else if(PreconditionerType::NONE == this->setting.preconditioner_type) preconditioner = std::make_unique<UnityPreconditioner<T>>();
240
241 if(SUANPAN_SUCCESS != preconditioner->init()) return SUANPAN_FAIL;
242
243 this->setting.preconditioner = preconditioner.get();
244
245 std::atomic_int code = 0;
246
247 if(IterativeSolver::GMRES == setting.iterative_solver)
248 suanpan::for_each(B.n_cols, [&](const uword I) {
249 Col<T> sub_x(X.colptr(I), X.n_rows, false, true);
250 const Col<T> sub_b(B.colptr(I), B.n_rows);
251 auto col_setting = setting;
252 code += GMRES(this, sub_x, sub_b, col_setting);
253 });
254 else if(IterativeSolver::BICGSTAB == setting.iterative_solver)
255 suanpan::for_each(B.n_cols, [&](const uword I) {
256 Col<T> sub_x(X.colptr(I), X.n_rows, false, true);
257 const Col<T> sub_b(B.colptr(I), B.n_rows);
258 auto col_setting = setting;
259 code += BiCGSTAB(this, sub_x, sub_b, col_setting);
260 });
261 else throw invalid_argument("no proper iterative solver assigned but somehow iterative solving is called");
262
263 return 0 == code ? SUANPAN_SUCCESS : SUANPAN_FAIL;
264}
265
266template<sp_d T> Mat<T> to_mat(const MetaMat<T>& in_mat) {
267 Mat<T> out_mat(in_mat.n_rows, in_mat.n_cols);
268 for(uword J = 0; J < in_mat.n_cols; ++J) for(uword I = 0; I < in_mat.n_rows; ++I) out_mat(I, J) = in_mat(I, J);
269 return out_mat;
270}
271
272template<sp_d T> Mat<T> to_mat(const shared_ptr<MetaMat<T>>& in_mat) { return to_mat(*in_mat); }
273
274template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const triplet_form<data_t, index_t>& in_mat) {
275 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
276 for(index_t I = 0; I < in_mat.n_elem; ++I) out_mat(in_mat.row(I), in_mat.col(I)) += in_mat.val(I);
277 return out_mat;
278}
279
280template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const csr_form<data_t, index_t>& in_mat) {
281 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
282
283 index_t c_idx = 1;
284 for(index_t I = 0; I < in_mat.n_elem; ++I) {
285 if(I >= in_mat.row_mem()[c_idx]) ++c_idx;
286 out_mat(c_idx - 1, in_mat.col_mem()[I]) += in_mat.val_mem()[I];
287 }
288
289 return out_mat;
290}
291
292template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const csc_form<data_t, index_t>& in_mat) {
293 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
294
295 index_t c_idx = 1;
296 for(index_t I = 0; I < in_mat.n_elem; ++I) {
297 if(I >= in_mat.col_mem()[c_idx]) ++c_idx;
298 out_mat(in_mat.row_mem()[I], c_idx - 1) += in_mat.val_mem()[I];
299 }
300
301 return out_mat;
302}
303
304template<sp_d data_t, sp_i index_t> triplet_form<data_t, index_t> to_triplet_form(MetaMat<data_t>* in_mat) {
305 if(!in_mat->triplet_mat.is_empty()) return triplet_form<data_t, index_t>(in_mat->triplet_mat);
306
307 const sp_i auto n_rows = index_t(in_mat->n_rows);
308 const sp_i auto n_cols = index_t(in_mat->n_cols);
309 const sp_i auto n_elem = index_t(in_mat->n_elem);
310
311 triplet_form<data_t, index_t> out_mat(n_rows, n_cols, n_elem);
312 for(index_t J = 0; J < n_cols; ++J) for(index_t I = 0; I < n_rows; ++I) out_mat.at(I, J) = in_mat->operator()(I, J);
313
314 return out_mat;
315}
316
317template<sp_d data_t, sp_i index_t> triplet_form<data_t, index_t> to_triplet_form(const shared_ptr<MetaMat<data_t>>& in_mat) { return to_triplet_form<data_t, index_t>(in_mat.get()); }
318
319#endif
320
A ILU class.
Definition ILU.hpp:40
A MetaMat class that holds matrices.
Definition MetaMat.hpp:72
triplet_form< T, uword > triplet_mat
Definition MetaMat.hpp:116
MetaMat(const MetaMat &)=default
int direct_solve(Mat< T > &X, SpMat< T > &&B)
Definition MetaMat.hpp:84
virtual T max() const =0
virtual int sign_det() const =0
const uword n_cols
Definition MetaMat.hpp:119
void unify(const uword K)
Definition MetaMat.hpp:145
int solve(Mat< T > &X, C &&B)
Definition MetaMat.hpp:202
MetaMat(const uword in_rows, const uword in_cols, const uword in_elem)
Definition MetaMat.hpp:122
virtual const T * memptr() const =0
void operator-=(const shared_ptr< MetaMat > &M)
Definition MetaMat.hpp:180
Mat< T > solve(C &&B)
Definition MetaMat.hpp:204
virtual bool is_empty() const =0
void set_factored(const bool F)
Definition MetaMat.hpp:138
virtual unique_ptr< MetaMat > make_copy()=0
const uword n_rows
Definition MetaMat.hpp:118
void save(const char *name)
Definition MetaMat.hpp:214
virtual void scale_accu(T, const shared_ptr< MetaMat > &)=0
virtual T & unsafe_at(const uword I, const uword J)
Access element without bound check.
Definition MetaMat.hpp:164
SolverSetting< T > & get_solver_setting()
Definition MetaMat.hpp:136
virtual T * memptr()=0
virtual void scale_accu(T, const triplet_form< T, uword > &)=0
virtual void nullify(uword)=0
void operator-=(const triplet_form< T, uword > &M)
Definition MetaMat.hpp:196
MetaMat(MetaMat &&) noexcept=delete
void operator+=(const shared_ptr< MetaMat > &M)
Definition MetaMat.hpp:178
virtual int direct_solve(Mat< T > &, Mat< T > &&)=0
virtual void csc_condense()
Definition MetaMat.hpp:219
void operator-=(const op_scale< T > &M)
Definition MetaMat.hpp:188
virtual void csr_condense()
Definition MetaMat.hpp:221
int iterative_solve(Mat< T > &X, const SpMat< T > &B)
Definition MetaMat.hpp:88
bool factored
Definition MetaMat.hpp:74
virtual int direct_solve(Mat< T > &, const Mat< T > &)=0
void set_solver_setting(const SolverSetting< T > &SS)
Definition MetaMat.hpp:134
virtual T operator()(uword, uword) const =0
Access element (read-only), returns zero if out-of-bound.
void operator+=(const triplet_form< T, uword > &M)
Definition MetaMat.hpp:194
Col< T > evaluate(const Col< T > &X) const
Definition MetaMat.hpp:223
virtual Mat< T > operator*(const Mat< T > &) const =0
const uword n_elem
Definition MetaMat.hpp:120
int mixed_trs(mat &X, mat &&B, F trs)
Definition MetaMat.hpp:90
void operator+=(const op_scale< T > &M)
Definition MetaMat.hpp:182
SolverSetting< T > setting
Definition MetaMat.hpp:76
virtual void operator*=(T)=0
virtual Col< T > diag() const =0
virtual T & at(uword, uword)=0
Access element with bound check.
virtual void zeros()=0
int direct_solve(Mat< T > &X, const SpMat< T > &B)
Definition MetaMat.hpp:82
Definition csc_form.hpp:25
const index_t n_rows
Definition csc_form.hpp:50
const index_t n_cols
Definition csc_form.hpp:51
const index_t * col_mem() const
Definition csc_form.hpp:63
const index_t * row_mem() const
Definition csc_form.hpp:61
const data_t * val_mem() const
Definition csc_form.hpp:65
const index_t n_elem
Definition csc_form.hpp:52
Definition csr_form.hpp:25
const index_t * row_mem() const
Definition csr_form.hpp:61
const index_t n_rows
Definition csr_form.hpp:50
const index_t n_cols
Definition csr_form.hpp:51
const data_t * val_mem() const
Definition csr_form.hpp:65
const index_t * col_mem() const
Definition csr_form.hpp:63
const index_t n_elem
Definition csr_form.hpp:52
Definition MetaMat.hpp:41
op_add(const shared_ptr< MetaMat< T > > &A, const shared_ptr< MetaMat< T > > &B)
Definition MetaMat.hpp:51
op_add(const shared_ptr< MetaMat< T > > &A)
Definition MetaMat.hpp:47
Definition MetaMat.hpp:56
op_scale(const T A, const shared_ptr< MetaMat< T > > &B)
Definition MetaMat.hpp:63
op_scale(const T A, op_add< T > &&B)
Definition MetaMat.hpp:67
Definition triplet_form.hpp:62
const index_t n_rows
Definition triplet_form.hpp:128
bool is_empty() const
Definition triplet_form.hpp:169
data_t & at(index_t, index_t)
Definition triplet_form.hpp:384
const index_t n_cols
Definition triplet_form.hpp:129
const index_t n_elem
Definition triplet_form.hpp:130
index_t col(const index_t I) const
Definition triplet_form.hpp:161
data_t val(const index_t I) const
Definition triplet_form.hpp:163
index_t row(const index_t I) const
Definition triplet_form.hpp:159
Definition MetaMat.hpp:37
Definition suanPan.h:331
Mat< T > to_mat(const MetaMat< T > &in_mat)
Definition MetaMat.hpp:266
triplet_form< data_t, index_t > to_triplet_form(MetaMat< data_t > *in_mat)
Definition MetaMat.hpp:304
int iterative_solve(Mat< T > &, const Mat< T > &)
Definition MetaMat.hpp:226
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:28
Definition SolverSetting.hpp:40
unsigned iterative_refinement
Definition SolverSetting.hpp:44
data_t tolerance
Definition SolverSetting.hpp:43
IterativeSolver iterative_solver
Definition SolverSetting.hpp:46
#define suanpan_debug(...)
Definition suanPan.h:307
constexpr auto SUANPAN_SUCCESS
Definition suanPan.h:172
constexpr auto SUANPAN_FAIL
Definition suanPan.h:173
#define suanpan_error(...)
Definition suanPan.h:309