suanPan
Loading...
Searching...
No Matches
MetaMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2023 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29#ifndef METAMAT_HPP
30#define METAMAT_HPP
31
32#include "triplet_form.hpp"
33#include "IterativeSolver.hpp"
34#include "ILU.hpp"
35#include "Jacobi.hpp"
36
37template<typename T, typename U> concept ArmaContainer = std::is_floating_point_v<U> && (std::is_convertible_v<T, Mat<U>> || std::is_convertible_v<T, SpMat<U>>) ;
38
39template<sp_d T> class MetaMat {
40protected:
41 bool factored = false;
42
44
45 virtual int direct_solve(Mat<T>&, const Mat<T>&) = 0;
46
47 virtual int direct_solve(Mat<T>&, Mat<T>&&) = 0;
48
49 int direct_solve(Mat<T>& X, const SpMat<T>& B) { return this->direct_solve(X, Mat<T>(B)); }
50
51 int direct_solve(Mat<T>& X, SpMat<T>&& B) { return this->direct_solve(X, B); }
52
53 int iterative_solve(Mat<T>&, const Mat<T>&);
54
55 int iterative_solve(Mat<T>& X, const SpMat<T>& B) { return this->iterative_solve(X, Mat<T>(B)); }
56
57 template<std::invocable<fmat&> F> int mixed_trs(mat& X, mat&& B, F trs) {
58 auto INFO = 0;
59
60 X = arma::zeros(size(B));
61
62 auto multiplier = norm(B);
63
64 auto counter = 0u;
65 while(counter++ < this->setting.iterative_refinement) {
66 if(multiplier < this->setting.tolerance) break;
67
68 auto residual = conv_to<fmat>::from(B / multiplier);
69
70 if(0 != (INFO = trs(residual))) break;
71
72 const mat incre = multiplier * conv_to<mat>::from(residual);
73
74 X += incre;
75
76 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
77 }
78
79 return INFO;
80 }
81
82public:
84
85 const uword n_rows;
86 const uword n_cols;
87 const uword n_elem;
88
89 MetaMat(const uword in_rows, const uword in_cols, const uword in_elem)
90 : triplet_mat(in_rows, in_cols)
91 , n_rows(in_rows)
92 , n_cols(in_cols)
93 , n_elem(in_elem) {}
94
95 MetaMat(const MetaMat&) = default;
96 MetaMat(MetaMat&&) noexcept = delete;
97 MetaMat& operator=(const MetaMat&) = delete;
98 MetaMat& operator=(MetaMat&&) noexcept = delete;
99 virtual ~MetaMat() = default;
100
102
103 [[nodiscard]] SolverSetting<T>& get_solver_setting() { return setting; }
104
105 void set_factored(const bool F) { factored = F; }
106
107 [[nodiscard]] virtual bool is_empty() const = 0;
108 virtual void zeros() = 0;
109
110 virtual unique_ptr<MetaMat> make_copy() = 0;
111
112 void unify(const uword K) {
113 this->nullify(K);
114 this->at(K, K) = T(1);
115 }
116
117 virtual void nullify(uword) = 0;
118
119 [[nodiscard]] virtual T max() const = 0;
120 [[nodiscard]] virtual Col<T> diag() const = 0;
121
126 virtual T operator()(uword, uword) const = 0;
131 virtual T& unsafe_at(const uword I, const uword J) { return this->at(I, J); }
132
137 virtual T& at(uword, uword) = 0;
138
139 [[nodiscard]] virtual const T* memptr() const = 0;
140 virtual T* memptr() = 0;
141
142 virtual void operator+=(const shared_ptr<MetaMat>&) = 0;
143 virtual void operator-=(const shared_ptr<MetaMat>&) = 0;
144
145 virtual void operator+=(const triplet_form<T, uword>&) = 0;
146 virtual void operator-=(const triplet_form<T, uword>&) = 0;
147
148 virtual Mat<T> operator*(const Mat<T>&) const = 0;
149
150 virtual void operator*=(T) = 0;
151
152 template<ArmaContainer<T> C> int solve(Mat<T>& X, const C& B) { return IterativeSolver::NONE == this->setting.iterative_solver ? this->direct_solve(X, B) : this->iterative_solve(X, B); }
153
154 template<ArmaContainer<T> C> int solve(Mat<T>& X, C&& B) { return IterativeSolver::NONE == this->setting.iterative_solver ? this->direct_solve(X, std::forward<C>(B)) : this->iterative_solve(X, std::forward<C>(B)); }
155
156 [[nodiscard]] virtual int sign_det() const = 0;
157
158 void save(const char* name) {
159 if(!to_mat(*this).save(name))
160 suanpan_error("Cannot save to file \"{}\".\n", name);
161 }
162
163 virtual void csc_condense() {}
164
165 virtual void csr_condense() {}
166
167 [[nodiscard]] Col<T> evaluate(const Col<T>& X) const { return this->operator*(X); }
168};
169
170template<sp_d T> int MetaMat<T>::iterative_solve(Mat<T>& X, const Mat<T>& B) {
171 this->csc_condense();
172
173 X.zeros(arma::size(B));
174
175 unique_ptr<Preconditioner<T>> preconditioner;
176 if(PreconditionerType::JACOBI == this->setting.preconditioner_type) preconditioner = std::make_unique<Jacobi<T>>(this->diag());
177#ifndef SUANPAN_SUPERLUMT
178 else if(PreconditionerType::ILU == this->setting.preconditioner_type) {
179 if(this->triplet_mat.is_empty()) preconditioner = std::make_unique<ILU<T>>(to_triplet_form<T, int>(this));
180 else preconditioner = std::make_unique<ILU<T>>(this->triplet_mat);
181 }
182#endif
183 else if(PreconditionerType::NONE == this->setting.preconditioner_type) preconditioner = std::make_unique<UnityPreconditioner<T>>();
184
185 if(SUANPAN_SUCCESS != preconditioner->init()) return SUANPAN_FAIL;
186
187 this->setting.preconditioner = preconditioner.get();
188
189 std::atomic_int code = 0;
190
191 if(IterativeSolver::GMRES == setting.iterative_solver)
192 suanpan_for(0llu, B.n_cols, [&](const uword I) {
193 Col<T> sub_x(X.colptr(I), X.n_rows, false, true);
194 const Col<T> sub_b(B.colptr(I), B.n_rows);
195 auto col_setting = setting;
196 code += GMRES(this, sub_x, sub_b, col_setting);
197 });
198 else if(IterativeSolver::BICGSTAB == setting.iterative_solver)
199 suanpan_for(0llu, B.n_cols, [&](const uword I) {
200 Col<T> sub_x(X.colptr(I), X.n_rows, false, true);
201 const Col<T> sub_b(B.colptr(I), B.n_rows);
202 auto col_setting = setting;
203 code += BiCGSTAB(this, sub_x, sub_b, col_setting);
204 });
205 else throw invalid_argument("no proper iterative solver assigned but somehow iterative solving is called");
206
207 return 0 == code ? SUANPAN_SUCCESS : SUANPAN_FAIL;
208}
209
210template<sp_d T> Mat<T> to_mat(const MetaMat<T>& in_mat) {
211 Mat<T> out_mat(in_mat.n_rows, in_mat.n_cols);
212 for(uword J = 0; J < in_mat.n_cols; ++J) for(uword I = 0; I < in_mat.n_rows; ++I) out_mat(I, J) = in_mat(I, J);
213 return out_mat;
214}
215
216template<sp_d T> Mat<T> to_mat(const shared_ptr<MetaMat<T>>& in_mat) { return to_mat(*in_mat); }
217
218template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const triplet_form<data_t, index_t>& in_mat) {
219 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
220 for(index_t I = 0; I < in_mat.n_elem; ++I) out_mat(in_mat.row(I), in_mat.col(I)) += in_mat.val(I);
221 return out_mat;
222}
223
224template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const csr_form<data_t, index_t>& in_mat) {
225 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
226
227 index_t c_idx = 1;
228 for(index_t I = 0; I < in_mat.n_elem; ++I) {
229 if(I >= in_mat.row_mem()[c_idx]) ++c_idx;
230 out_mat(c_idx - 1, in_mat.col_mem()[I]) += in_mat.val_mem()[I];
231 }
232
233 return out_mat;
234}
235
236template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const csc_form<data_t, index_t>& in_mat) {
237 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
238
239 index_t c_idx = 1;
240 for(index_t I = 0; I < in_mat.n_elem; ++I) {
241 if(I >= in_mat.col_mem()[c_idx]) ++c_idx;
242 out_mat(in_mat.row_mem()[I], c_idx - 1) += in_mat.val_mem()[I];
243 }
244
245 return out_mat;
246}
247
248template<sp_d data_t, sp_i index_t> triplet_form<data_t, index_t> to_triplet_form(MetaMat<data_t>* in_mat) {
249 if(!in_mat->triplet_mat.is_empty()) return triplet_form<data_t, index_t>(in_mat->triplet_mat);
250
251 const sp_i auto n_rows = index_t(in_mat->n_rows);
252 const sp_i auto n_cols = index_t(in_mat->n_cols);
253 const sp_i auto n_elem = index_t(in_mat->n_elem);
254
255 triplet_form<data_t, index_t> out_mat(n_rows, n_cols, n_elem);
256 for(index_t J = 0; J < n_cols; ++J) for(index_t I = 0; I < n_rows; ++I) out_mat.at(I, J) = in_mat->operator()(I, J);
257
258 return out_mat;
259}
260
261template<sp_d data_t, sp_i index_t> triplet_form<data_t, index_t> to_triplet_form(const shared_ptr<MetaMat<data_t>>& in_mat) { return to_triplet_form<data_t, index_t>(in_mat.get()); }
262
263#endif
264
A ILU class.
Definition: ILU.hpp:40
A MetaMat class that holds matrices.
Definition: MetaMat.hpp:39
triplet_form< T, uword > triplet_mat
Definition: MetaMat.hpp:83
MetaMat(const MetaMat &)=default
int direct_solve(Mat< T > &X, SpMat< T > &&B)
Definition: MetaMat.hpp:51
virtual T max() const =0
virtual int sign_det() const =0
const uword n_cols
Definition: MetaMat.hpp:86
void unify(const uword K)
Definition: MetaMat.hpp:112
int solve(Mat< T > &X, C &&B)
Definition: MetaMat.hpp:154
MetaMat(const uword in_rows, const uword in_cols, const uword in_elem)
Definition: MetaMat.hpp:89
virtual const T * memptr() const =0
virtual bool is_empty() const =0
void set_factored(const bool F)
Definition: MetaMat.hpp:105
virtual unique_ptr< MetaMat > make_copy()=0
virtual void operator-=(const shared_ptr< MetaMat > &)=0
const uword n_rows
Definition: MetaMat.hpp:85
void save(const char *name)
Definition: MetaMat.hpp:158
virtual T & unsafe_at(const uword I, const uword J)
Access element without bound check.
Definition: MetaMat.hpp:131
SolverSetting< T > & get_solver_setting()
Definition: MetaMat.hpp:103
virtual T * memptr()=0
virtual void nullify(uword)=0
MetaMat(MetaMat &&) noexcept=delete
virtual int direct_solve(Mat< T > &, Mat< T > &&)=0
virtual void csc_condense()
Definition: MetaMat.hpp:163
virtual void csr_condense()
Definition: MetaMat.hpp:165
int iterative_solve(Mat< T > &X, const SpMat< T > &B)
Definition: MetaMat.hpp:55
bool factored
Definition: MetaMat.hpp:41
virtual void operator+=(const shared_ptr< MetaMat > &)=0
virtual int direct_solve(Mat< T > &, const Mat< T > &)=0
void set_solver_setting(const SolverSetting< T > &SS)
Definition: MetaMat.hpp:101
virtual void operator+=(const triplet_form< T, uword > &)=0
virtual T operator()(uword, uword) const =0
Access element (read-only), returns zero if out-of-bound.
Col< T > evaluate(const Col< T > &X) const
Definition: MetaMat.hpp:167
int solve(Mat< T > &X, const C &B)
Definition: MetaMat.hpp:152
virtual Mat< T > operator*(const Mat< T > &) const =0
const uword n_elem
Definition: MetaMat.hpp:87
int mixed_trs(mat &X, mat &&B, F trs)
Definition: MetaMat.hpp:57
SolverSetting< T > setting
Definition: MetaMat.hpp:43
virtual void operator*=(T)=0
virtual Col< T > diag() const =0
virtual T & at(uword, uword)=0
Access element with bound check.
virtual void zeros()=0
int direct_solve(Mat< T > &X, const SpMat< T > &B)
Definition: MetaMat.hpp:49
virtual void operator-=(const triplet_form< T, uword > &)=0
Definition: csc_form.hpp:25
const index_t n_rows
Definition: csc_form.hpp:50
const index_t n_cols
Definition: csc_form.hpp:51
const index_t * col_mem() const
Definition: csc_form.hpp:63
const index_t * row_mem() const
Definition: csc_form.hpp:61
const data_t * val_mem() const
Definition: csc_form.hpp:65
const index_t n_elem
Definition: csc_form.hpp:52
Definition: csr_form.hpp:25
const index_t * row_mem() const
Definition: csr_form.hpp:61
const index_t n_rows
Definition: csr_form.hpp:50
const index_t n_cols
Definition: csr_form.hpp:51
const data_t * val_mem() const
Definition: csr_form.hpp:65
const index_t * col_mem() const
Definition: csr_form.hpp:63
const index_t n_elem
Definition: csr_form.hpp:52
Definition: triplet_form.hpp:62
const index_t n_rows
Definition: triplet_form.hpp:128
bool is_empty() const
Definition: triplet_form.hpp:169
data_t & at(index_t, index_t)
Definition: triplet_form.hpp:384
const index_t n_cols
Definition: triplet_form.hpp:129
const index_t n_elem
Definition: triplet_form.hpp:130
index_t col(const index_t I) const
Definition: triplet_form.hpp:161
data_t val(const index_t I) const
Definition: triplet_form.hpp:163
index_t row(const index_t I) const
Definition: triplet_form.hpp:159
Definition: MetaMat.hpp:37
Definition: suanPan.h:319
Mat< T > to_mat(const MetaMat< T > &in_mat)
Definition: MetaMat.hpp:210
triplet_form< data_t, index_t > to_triplet_form(MetaMat< data_t > *in_mat)
Definition: MetaMat.hpp:248
int iterative_solve(Mat< T > &, const Mat< T > &)
Definition: MetaMat.hpp:170
Definition: SolverSetting.hpp:40
unsigned iterative_refinement
Definition: SolverSetting.hpp:44
data_t tolerance
Definition: SolverSetting.hpp:43
IterativeSolver iterative_solver
Definition: SolverSetting.hpp:46
#define suanpan_debug(...)
Definition: suanPan.h:295
constexpr auto SUANPAN_SUCCESS
Definition: suanPan.h:162
constexpr auto SUANPAN_FAIL
Definition: suanPan.h:163
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27