ezp
Loading...
Searching...
No Matches
pposvx.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
27#ifndef PPOSVX_HPP
28#define PPOSVX_HPP
29
30#include "abstract/full_solver.hpp"
31
32#include <numeric>
33
34namespace ezp {
35 template<data_t DT, index_t IT, char UL = 'L', char ODER = 'R'> class pposvx final : public detail::full_solver<DT, IT, ODER> {
36 static constexpr char FACT = 'E';
37 static constexpr char UPLO = UL;
38
40
41 static auto ceil(const IT a, const IT b) { return (a + b - 1) / b; }
42
43 auto compute_lwork() {
44 const auto ceil_a = std::max(IT{1}, ceil(this->ctx.n_rows - 1, this->ctx.n_cols));
45 const auto ceil_b = std::max(IT{1}, ceil(this->ctx.n_cols - 1, this->ctx.n_rows));
46 const auto ppocon_lwork = 2 * (this->loc.rows + this->loc.cols) + std::max(IT{2}, std::max(this->loc.block * ceil_a, this->loc.cols + this->loc.block * ceil_b));
47
48 const auto pporfs_lwork = 3 * this->loc.rows;
49
50 return std::max(ppocon_lwork, pporfs_lwork);
51 }
52
53 auto compute_lrwork() {
54 const auto lcmp = std::lcm(this->ctx.n_rows, this->ctx.n_cols) / this->ctx.n_rows;
55
56 return this->loc.rows + 2 * this->loc.cols + this->loc.block * ceil(ceil(this->loc.rows, this->loc.block), lcmp);
57 }
58
59 struct expert_system {
60 IT lwork;
61 std::vector<DT> af, work;
62 std::vector<work_t<DT>> sr, sc;
63 };
64
65 expert_system exp;
66
67 auto init_expert_storage() {
68 exp.lwork = compute_lwork();
69 exp.af.resize(this->loc.a.size());
70 exp.work.resize(exp.lwork);
71 exp.sr.resize(this->loc.rows);
72 exp.sc.resize(this->loc.cols);
73 }
74
75 public:
76 pposvx()
77 : base_t() {}
78
79 pposvx(const IT rows, const IT cols)
80 : base_t(rows, cols) {}
81
82 using base_t::solve;
83
84 IT solve(full_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
85 if(!this->ctx.is_valid()) return 0;
86
87 if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
88
89 this->init_storage(A.n_rows);
90 init_expert_storage();
91
92 this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), exp.af, this->loc.desc_a);
93
94 const auto loc_cols_b = this->ctx.cols(B.n_cols, this->loc.block);
95 this->loc.b.resize(this->loc.rows * loc_cols_b);
96
97 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
98 const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
99
100 this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
101
102 std::vector<work_t<DT>> ferr(loc_cols_b), berr(loc_cols_b);
103 work_t<DT> rcond;
104
105 std::vector<DT> x(this->loc.b.size());
106
107 auto equed = 'X';
108
109 IT info{-1};
110 // ReSharper disable CppCStyleCast
111 if constexpr(std::is_same_v<DT, double>) {
112 using E = double;
113
114 const auto liwork = this->loc.rows;
115 std::vector<IT> iwork(liwork);
116
117 pdposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (E*)exp.sr.data(), (E*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
118 }
119 else if constexpr(std::is_same_v<DT, float>) {
120 using E = float;
121
122 const auto liwork = this->loc.rows;
123 std::vector<IT> iwork(liwork);
124
125 psposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (E*)exp.sr.data(), (E*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
126 }
127 else if constexpr(std::is_same_v<DT, complex16>) {
128 using E = complex16;
129 using BE = work_t<complex16>;
130
131 const auto lrwork = compute_lrwork();
132 std::vector<BE> rwork(lrwork);
133
134 pzposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (BE*)exp.sr.data(), (BE*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
135 }
136 else if constexpr(std::is_same_v<DT, complex8>) {
137 using E = complex8;
138 using BE = work_t<complex8>;
139
140 const auto lrwork = compute_lrwork();
141 std::vector<BE> rwork(lrwork);
142
143 pcposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (BE*)exp.sr.data(), (BE*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
144 }
145 // ReSharper restore CppCStyleCast
146
147 if((info = this->ctx.amx(info)) != 0) return info;
148
149 if(equed == 'C' || equed == 'B')
150 for(auto i = 0; i < loc_cols_b; ++i)
151 for(auto j = 0; j < this->loc.rows; ++j) x[j * loc_cols_b + i] /= exp.sc[j];
152
153 this->ctx.gather(x, loc_desc_b, B, full_desc_b);
154
155 return info;
156 }
157
158 IT solve(full_mat<DT, IT>&& B) override { return IT{-1}; }
159 };
160
161 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_dposvx = pposvx<double, IT, UL, ODER>;
162 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_sposvx = pposvx<float, IT, UL, ODER>;
163 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_zposvx = pposvx<complex16, IT, UL, ODER>;
164 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_cposvx = pposvx<complex8, IT, UL, ODER>;
165} // namespace ezp
166
167#endif // PPOSVX_HPP
168
Definition full_solver.hpp:24
Definition pposvx.hpp:35
Definition abstract_solver.hpp:85