ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
Loading...
Searching...
No Matches
pgesvx.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2025-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
41#ifndef PGESVX_HPP
42#define PGESVX_HPP
43
44#include "abstract/full_solver.hpp"
45
46#include <numeric>
47
48namespace ezp {
49 template<data_t DT, index_t IT, char FCT = 'E', char ODER = 'R'> class pgesvx final : public detail::full_solver<DT, IT, ODER> {
50 static constexpr char TRANS = 'N';
51 static constexpr char FACT = FCT;
52
54
55 static auto ceil(const IT a, const IT b) { return (a + b - 1) / b; }
56
57 auto compute_lwork() {
58 const auto ceil_a = std::max(IT{1}, ceil(this->ctx.n_rows - 1, this->ctx.n_cols));
59 const auto ceil_b = std::max(IT{1}, ceil(this->ctx.n_cols - 1, this->ctx.n_rows));
60 const auto pgecon_lwork = 2 * (this->loc.rows + this->loc.cols) + std::max(IT{2}, std::max(this->loc.block * ceil_a, this->loc.cols + this->loc.block * ceil_b));
61
62 const auto lcmq = std::lcm(this->ctx.n_rows, this->ctx.n_cols) / this->ctx.n_cols;
63 const auto nqb = ceil(this->loc.n, this->loc.block * this->ctx.n_cols);
64 const auto pgerfs_lwork = 4 * this->loc.rows + this->loc.cols + this->loc.block * ceil(nqb, lcmq);
65
66 return this->loc.rows + std::max(pgecon_lwork, pgerfs_lwork);
67 }
68
69 struct expert_system {
70 IT lwork;
71 std::vector<DT> af, work;
72 std::vector<work_t<DT>> r, c;
73 };
74
75 expert_system exp;
76
77 auto init_expert_storage() {
78 exp.lwork = compute_lwork();
79 exp.af.resize(this->loc.a.size());
80 exp.work.resize(exp.lwork);
81 exp.r.resize(this->loc.rows);
82 exp.c.resize(this->loc.cols);
83 }
84
85 public:
86 pgesvx()
87 : base_t() {}
88
89 pgesvx(const IT rows, const IT cols)
90 : base_t(rows, cols) {}
91
92 using base_t::solve;
93
94 IT solve(full_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
95 if(!this->ctx.is_valid()) return 0;
96
97 if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
98
99 this->init_storage(A.n_rows);
100 init_expert_storage();
101
102 this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), exp.af, this->loc.desc_a);
103
104 const auto loc_cols_b = this->ctx.cols(B.n_cols, this->loc.block);
105 this->loc.b.resize(this->loc.rows * loc_cols_b);
106
107 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
108 const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
109
110 this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
111
112 std::vector<work_t<DT>> ferr(loc_cols_b), berr(loc_cols_b);
113 work_t<DT> rcond;
114
115 std::vector<DT> x(this->loc.b.size());
116
117 auto equed = 'X';
118
119 IT info{-1};
120 // ReSharper disable CppCStyleCast
121 if constexpr(std::is_same_v<DT, double>) {
122 using E = double;
123
124 const auto liwork = this->loc.rows;
125 std::vector<IT> iwork(liwork);
126
127 pdgesvx(&FACT, &TRANS, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &equed, (E*)exp.r.data(), (E*)exp.c.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
128 }
129 else if constexpr(std::is_same_v<DT, float>) {
130 using E = float;
131
132 const auto liwork = this->loc.rows;
133 std::vector<IT> iwork(liwork);
134
135 psgesvx(&FACT, &TRANS, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &equed, (E*)exp.r.data(), (E*)exp.c.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
136 }
137 else if constexpr(std::is_same_v<DT, complex16>) {
138 using E = complex16;
139 using BE = work_t<complex16>;
140
141 const auto lrwork = std::max(this->loc.rows, 2 * this->loc.cols);
142 std::vector<BE> rwork(lrwork);
143
144 pzgesvx(&FACT, &TRANS, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &equed, (BE*)exp.r.data(), (BE*)exp.c.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
145 }
146 else if constexpr(std::is_same_v<DT, complex8>) {
147 using E = complex8;
148 using BE = work_t<complex8>;
149
150 const auto lrwork = std::max(this->loc.rows, 2 * this->loc.cols);
151 std::vector<BE> rwork(lrwork);
152
153 pcgesvx(&FACT, &TRANS, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &equed, (BE*)exp.r.data(), (BE*)exp.c.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
154 }
155 // ReSharper restore CppCStyleCast
156
157 if((info = this->ctx.amx(info)) != 0) return info;
158
159 this->ctx.gather(x, loc_desc_b, B, full_desc_b);
160
161 return info;
162 }
163
164 IT solve(full_mat<DT, IT>&&) override { throw std::runtime_error("not implemented"); }
165 };
166
167 template<index_t IT, char FCT = 'E', char ODER = 'R'> using par_dgesvx = pgesvx<double, IT, FCT, ODER>;
168 template<index_t IT, char FCT = 'E', char ODER = 'R'> using par_sgesvx = pgesvx<float, IT, FCT, ODER>;
169 template<index_t IT, char FCT = 'E', char ODER = 'R'> using par_zgesvx = pgesvx<complex16, IT, FCT, ODER>;
170 template<index_t IT, char FCT = 'E', char ODER = 'R'> using par_cgesvx = pgesvx<complex8, IT, FCT, ODER>;
171} // namespace ezp
172
173#endif // PGESVX_HPP
174
Definition full_solver.hpp:24
Definition pgesvx.hpp:49
Definition traits.hpp:85