suanPan
Loading...
Searching...
No Matches
IterativeSolver.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2024 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
17
18#ifndef ITERATIVESOLVER_HPP
19#define ITERATIVESOLVER_HPP
20
21#include <Toolbox/utility.h>
22#include "SolverSetting.hpp"
23
24template<typename T, typename data_t> concept HasEvaluate = requires(T* t, const Col<data_t>& x) { { t->evaluate(x) } -> std::convertible_to<Col<data_t>>; };
25
26template<sp_d data_t, HasEvaluate<data_t> System> int GMRES(const System* system, Col<data_t>& x, const Col<data_t>& b, SolverSetting<data_t>& setting) {
27 constexpr sp_d auto ZERO = data_t(0);
28 constexpr sp_d auto ONE = data_t(1);
29
30 const auto& conditioner = setting.preconditioner;
31
32 auto generate_rotation = [](const data_t dx, const data_t dy, data_t& cs, data_t& sn) -> void {
34 cs = ONE;
35 sn = ZERO;
36 }
37 else if(std::fabs(dy) > std::fabs(dx)) {
38 const data_t fraction = dx / dy;
39 sn = ONE / std::sqrt(ONE + fraction * fraction);
40 cs = fraction * sn;
41 }
42 else {
43 const data_t fraction = dy / dx;
44 cs = ONE / std::sqrt(ONE + fraction * fraction);
45 sn = fraction * cs;
46 }
47 };
48
49 auto apply_rotation = [](data_t& dx, data_t& dy, const data_t cs, const data_t sn) -> void {
50 const data_t factor = cs * dx + sn * dy;
51 dy = cs * dy - sn * dx;
52 dx = factor;
53 };
54
55 if(x.empty()) x = conditioner->apply(b);
56 else x.zeros(arma::size(b));
57
58 const auto mp = setting.restart + 1;
59
60 Mat<data_t> hessenberg(mp, setting.restart, fill::zeros);
61
62 auto counter = 1;
63 data_t beta, residual;
64 Col<data_t> s(mp, fill::none), cs(mp, fill::none), sn(mp, fill::none), r;
65
66 auto norm_b = arma::norm(conditioner->apply(b));
67 if(suanpan::approx_equal(norm_b, ZERO)) norm_b = ONE;
68
69 auto stop_criterion = [&] {
70 residual = (beta = arma::norm(r = conditioner->apply(b - system->evaluate(x)))) / norm_b;
71 suanpan_debug("GMRES iterative solver residual: {:.5E}.\n", residual);
72 if(residual > setting.tolerance) return SUANPAN_FAIL;
73 setting.tolerance = residual;
74 setting.max_iteration = counter;
75 return SUANPAN_SUCCESS;
76 };
77
78 if(SUANPAN_SUCCESS == stop_criterion()) return SUANPAN_SUCCESS;
79
80 Mat<data_t> v(b.n_rows, mp, fill::none);
81
82 auto update = [&](const int k) -> Col<data_t> {
83 Col<data_t> y = s.head(k + 1llu);
84
85 for(auto i = k; i >= 0; --i) {
86 y(i) /= hessenberg(i, i);
87 y.head(i) -= hessenberg.col(i).head(i) * y(i);
88 }
89
90 return v.head_cols(k + 1llu) * y;
91 };
92
93 while(counter <= setting.max_iteration) {
94 v.col(0) = r / beta;
95 s.zeros();
96 s(0) = beta;
97
98 for(auto i = 0, j = 1; i < setting.restart && counter <= setting.max_iteration; ++i, ++j, ++counter) {
99 auto w = conditioner->apply(system->evaluate(v.col(i)));
100 for(auto k = 0; k <= i; ++k) w -= (hessenberg(k, i) = arma::dot(w, v.col(k))) * v.col(k);
101 v.col(j) = w / (hessenberg(j, i) = arma::norm(w));
102
103 for(auto k = 0; k < i; ++k) apply_rotation(hessenberg(k, i), hessenberg(k + 1llu, i), cs(k), sn(k));
104
105 generate_rotation(hessenberg(i, i), hessenberg(j, i), cs(i), sn(i));
106 apply_rotation(hessenberg(i, i), hessenberg(j, i), cs(i), sn(i));
107 apply_rotation(s(i), s(j), cs(i), sn(i));
108
109 residual = std::fabs(s(j)) / norm_b;
110 suanpan_debug("GMRES iterative solver residual: {:.5E}.\n", residual);
111 if(residual < setting.tolerance) {
112 x += update(i);
113 setting.tolerance = residual;
114 setting.max_iteration = counter;
115 return SUANPAN_SUCCESS;
116 }
117 }
118
119 x += update(setting.restart - 1);
120 if(SUANPAN_SUCCESS == stop_criterion()) return SUANPAN_SUCCESS;
121 }
122
123 setting.tolerance = residual;
124 return SUANPAN_FAIL;
125}
126
127template<sp_d data_t, HasEvaluate<data_t> System> int BiCGSTAB(const System* system, Col<data_t>& x, const Col<data_t>& b, SolverSetting<data_t>& setting) {
128 constexpr sp_d auto ZERO = data_t(0);
129 constexpr sp_d auto ONE = data_t(1);
130
131 const auto& conditioner = setting.preconditioner;
132
133 data_t norm_b = arma::norm(b);
134 if(suanpan::approx_equal(norm_b, ZERO)) norm_b = ONE;
135
136 if(x.empty()) x = conditioner->apply(b);
137 else x.zeros(arma::size(b));
138
139 Col<data_t> r = b - system->evaluate(x);
140 const auto initial_r = r;
141
142 data_t residual = arma::norm(r) / norm_b;
143 suanpan_debug("BiCGSTAB iterative solver residual: {:.5E}.\n", residual);
144 if(residual < setting.tolerance) {
145 setting.tolerance = residual;
146 setting.max_iteration = 0;
147 return 0;
148 }
149
150 sp_d auto pre_rho = ZERO, alpha = ZERO, omega = ZERO;
151 Col<data_t> v, p;
152
153 for(auto i = 1; i <= setting.max_iteration; ++i) {
154 const auto rho = arma::dot(initial_r, r);
155 if(suanpan::approx_equal(rho, ZERO)) {
156 setting.tolerance = residual;
157 setting.max_iteration = i;
158 return SUANPAN_FAIL;
159 }
160
161 if(1 == i) p = r;
162 else p = r + rho / pre_rho * alpha / omega * (p - omega * v);
163
164 const auto phat = conditioner->apply(p);
165 v = system->evaluate(phat);
166 alpha = rho / arma::dot(initial_r, v);
167 const Col<data_t> s = r - alpha * v;
168
169 suanpan_debug("BiCGSTAB iterative solver residual: {:.5E}.\n", residual = arma::norm(s) / norm_b);
170 if(residual < setting.tolerance) {
171 x += alpha * phat;
172 setting.tolerance = residual;
173 setting.max_iteration = i;
174 return SUANPAN_SUCCESS;
175 }
176
177 const auto shat = conditioner->apply(s);
178 const Col<data_t> t = system->evaluate(shat);
179 omega = arma::dot(t, s) / arma::dot(t, t);
180 x += alpha * phat + omega * shat;
181 r = s - omega * t;
182
183 pre_rho = rho;
184
185 suanpan_debug("BiCGSTAB iterative solver residual: {:.5E}.\n", residual = arma::norm(r) / norm_b);
186 if(residual < setting.tolerance) {
187 setting.tolerance = residual;
188 setting.max_iteration = i;
189 return SUANPAN_SUCCESS;
190 }
191
192 if(suanpan::approx_equal(omega, ZERO)) {
193 setting.tolerance = residual;
194 setting.max_iteration = i;
195 return SUANPAN_FAIL;
196 }
197 }
198
199 setting.tolerance = residual;
200 return SUANPAN_FAIL;
201}
202
203#endif
int BiCGSTAB(const System *system, Col< data_t > &x, const Col< data_t > &b, SolverSetting< data_t > &setting)
Definition IterativeSolver.hpp:127
Definition IterativeSolver.hpp:24
Definition suanPan.h:330
std::enable_if_t<!std::numeric_limits< T >::is_integer, bool > approx_equal(T x, T y, int ulp=2)
Definition utility.h:60
Definition TestSolver.h:6
#define suanpan_debug(...)
Definition suanPan.h:307
constexpr auto SUANPAN_SUCCESS
Definition suanPan.h:172
constexpr auto SUANPAN_FAIL
Definition suanPan.h:173