ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
Loading...
Searching...
No Matches
pposvx.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2025-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
37#ifndef PPOSVX_HPP
38#define PPOSVX_HPP
39
40#include "abstract/full_solver.hpp"
41
42#include <numeric>
43
44namespace ezp {
45 template<data_t DT, index_t IT, char UL = 'L', char ODER = 'R'> class pposvx final : public detail::full_solver<DT, IT, ODER> {
46 static constexpr char FACT = 'E';
47 static constexpr char UPLO = UL;
48
50
51 static auto ceil(const IT a, const IT b) { return (a + b - 1) / b; }
52
53 auto compute_lwork() {
54 const auto ceil_a = std::max(IT{1}, ceil(this->ctx.n_rows - 1, this->ctx.n_cols));
55 const auto ceil_b = std::max(IT{1}, ceil(this->ctx.n_cols - 1, this->ctx.n_rows));
56 const auto ppocon_lwork = 2 * (this->loc.rows + this->loc.cols) + std::max(IT{2}, std::max(this->loc.block * ceil_a, this->loc.cols + this->loc.block * ceil_b));
57
58 const auto pporfs_lwork = 3 * this->loc.rows;
59
60 return std::max(ppocon_lwork, pporfs_lwork);
61 }
62
63 auto compute_lrwork() {
64 const auto lcmp = std::lcm(this->ctx.n_rows, this->ctx.n_cols) / this->ctx.n_rows;
65
66 return this->loc.rows + 2 * this->loc.cols + this->loc.block * ceil(ceil(this->loc.rows, this->loc.block), lcmp);
67 }
68
69 struct expert_system {
70 IT lwork;
71 std::vector<DT> af, work;
72 std::vector<work_t<DT>> sr, sc;
73 };
74
75 expert_system exp;
76
77 auto init_expert_storage() {
78 exp.lwork = compute_lwork();
79 exp.af.resize(this->loc.a.size());
80 exp.work.resize(exp.lwork);
81 exp.sr.resize(this->loc.rows);
82 exp.sc.resize(this->loc.cols);
83 }
84
85 public:
86 pposvx()
87 : base_t() {}
88
89 pposvx(const IT rows, const IT cols)
90 : base_t(rows, cols) {}
91
92 using base_t::solve;
93
94 IT solve(full_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
95 if(!this->ctx.is_valid()) return 0;
96
97 if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
98
99 this->init_storage(A.n_rows);
100 init_expert_storage();
101
102 this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), exp.af, this->loc.desc_a);
103
104 const auto loc_cols_b = this->ctx.cols(B.n_cols, this->loc.block);
105 this->loc.b.resize(this->loc.rows * loc_cols_b);
106
107 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
108 const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
109
110 this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
111
112 std::vector<work_t<DT>> ferr(loc_cols_b), berr(loc_cols_b);
113 work_t<DT> rcond;
114
115 std::vector<DT> x(this->loc.b.size());
116
117 auto equed = 'X';
118
119 IT info{-1};
120 // ReSharper disable CppCStyleCast
121 if constexpr(std::is_same_v<DT, double>) {
122 using E = double;
123
124 const auto liwork = this->loc.rows;
125 std::vector<IT> iwork(liwork);
126
127 pdposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (E*)exp.sr.data(), (E*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
128 }
129 else if constexpr(std::is_same_v<DT, float>) {
130 using E = float;
131
132 const auto liwork = this->loc.rows;
133 std::vector<IT> iwork(liwork);
134
135 psposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (E*)exp.sr.data(), (E*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
136 }
137 else if constexpr(std::is_same_v<DT, complex16>) {
138 using E = complex16;
139 using BE = work_t<complex16>;
140
141 const auto lrwork = compute_lrwork();
142 std::vector<BE> rwork(lrwork);
143
144 pzposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (BE*)exp.sr.data(), (BE*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
145 }
146 else if constexpr(std::is_same_v<DT, complex8>) {
147 using E = complex8;
148 using BE = work_t<complex8>;
149
150 const auto lrwork = compute_lrwork();
151 std::vector<BE> rwork(lrwork);
152
153 pcposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (BE*)exp.sr.data(), (BE*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
154 }
155 // ReSharper restore CppCStyleCast
156
157 if((info = this->ctx.amx(info)) != 0) return info;
158
159 this->ctx.gather(x, loc_desc_b, B, full_desc_b);
160
161 return info;
162 }
163
164 IT solve(full_mat<DT, IT>&&) override { throw std::runtime_error("not implemented"); }
165 };
166
167 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_dposvx = pposvx<double, IT, UL, ODER>;
168 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_sposvx = pposvx<float, IT, UL, ODER>;
169 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_zposvx = pposvx<complex16, IT, UL, ODER>;
170 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_cposvx = pposvx<complex8, IT, UL, ODER>;
171} // namespace ezp
172
173#endif // PPOSVX_HPP
174
Definition full_solver.hpp:24
Definition pposvx.hpp:45
Definition traits.hpp:85