ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
pposvx.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2025-2026 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
37 #ifndef PPOSVX_HPP
38 #define PPOSVX_HPP
39 
40 #include "abstract/full_solver.hpp"
41 
42 #include <numeric>
43 
44 namespace ezp {
45  template<data_t DT, index_t IT, char UL = 'L', char ODER = 'R'> class pposvx final : public detail::full_solver<DT, IT, ODER> {
46  static constexpr char FACT = 'E';
47  static constexpr char UPLO = UL;
48 
50 
51  static auto ceil(const IT a, const IT b) { return (a + b - 1) / b; }
52 
53  auto compute_lwork() {
54  const auto ceil_a = std::max(IT{1}, ceil(this->ctx.n_rows - 1, this->ctx.n_cols));
55  const auto ceil_b = std::max(IT{1}, ceil(this->ctx.n_cols - 1, this->ctx.n_rows));
56  const auto ppocon_lwork = 2 * (this->loc.rows + this->loc.cols) + std::max(IT{2}, std::max(this->loc.block * ceil_a, this->loc.cols + this->loc.block * ceil_b));
57 
58  const auto pporfs_lwork = 3 * this->loc.rows;
59 
60  return std::max(ppocon_lwork, pporfs_lwork);
61  }
62 
63  auto compute_lrwork() {
64  const auto lcmp = std::lcm(this->ctx.n_rows, this->ctx.n_cols) / this->ctx.n_rows;
65 
66  return this->loc.rows + 2 * this->loc.cols + this->loc.block * ceil(ceil(this->loc.rows, this->loc.block), lcmp);
67  }
68 
69  struct expert_system {
70  IT lwork;
71  std::vector<DT> af, work;
72  std::vector<work_t<DT>> sr, sc;
73  };
74 
75  expert_system exp;
76 
77  auto init_expert_storage() {
78  exp.lwork = compute_lwork();
79  exp.af.resize(this->loc.a.size());
80  exp.work.resize(exp.lwork);
81  exp.sr.resize(this->loc.rows);
82  exp.sc.resize(this->loc.cols);
83  }
84 
85  public:
86  pposvx()
87  : base_t() {}
88 
89  pposvx(const IT rows, const IT cols)
90  : base_t(rows, cols) {}
91 
92  using base_t::solve;
93 
94  IT solve(full_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
95  if(!this->ctx.is_valid()) return 0;
96 
97  if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
98 
99  this->init_storage(A.n_rows);
100  init_expert_storage();
101 
102  this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), exp.af, this->loc.desc_a);
103 
104  const auto loc_cols_b = this->ctx.cols(B.n_cols, this->loc.block);
105  this->loc.b.resize(this->loc.rows * loc_cols_b);
106 
107  const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
108  const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
109 
110  this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
111 
112  std::vector<work_t<DT>> ferr(loc_cols_b), berr(loc_cols_b);
113  work_t<DT> rcond;
114 
115  std::vector<DT> x(this->loc.b.size());
116 
117  auto equed = 'X';
118 
119  IT info{-1};
120  // ReSharper disable CppCStyleCast
121  if constexpr(std::is_same_v<DT, double>) {
122  using E = double;
123 
124  const auto liwork = this->loc.rows;
125  std::vector<IT> iwork(liwork);
126 
127  pdposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (E*)exp.sr.data(), (E*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
128  }
129  else if constexpr(std::is_same_v<DT, float>) {
130  using E = float;
131 
132  const auto liwork = this->loc.rows;
133  std::vector<IT> iwork(liwork);
134 
135  psposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (E*)exp.sr.data(), (E*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
136  }
137  else if constexpr(std::is_same_v<DT, complex16>) {
138  using E = complex16;
139  using BE = work_t<complex16>;
140 
141  const auto lrwork = compute_lrwork();
142  std::vector<BE> rwork(lrwork);
143 
144  pzposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (BE*)exp.sr.data(), (BE*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
145  }
146  else if constexpr(std::is_same_v<DT, complex8>) {
147  using E = complex8;
148  using BE = work_t<complex8>;
149 
150  const auto lrwork = compute_lrwork();
151  std::vector<BE> rwork(lrwork);
152 
153  pcposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (BE*)exp.sr.data(), (BE*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
154  }
155  // ReSharper restore CppCStyleCast
156 
157  if((info = this->ctx.amx(info)) != 0) return info;
158 
159  this->ctx.gather(x, loc_desc_b, B, full_desc_b);
160 
161  return info;
162  }
163 
164  IT solve(full_mat<DT, IT>&&) override { throw std::runtime_error("not implemented"); }
165  };
166 
167  template<index_t IT, char UL = 'L', char ODER = 'R'> using par_dposvx = pposvx<double, IT, UL, ODER>;
168  template<index_t IT, char UL = 'L', char ODER = 'R'> using par_sposvx = pposvx<float, IT, UL, ODER>;
169  template<index_t IT, char UL = 'L', char ODER = 'R'> using par_zposvx = pposvx<complex16, IT, UL, ODER>;
170  template<index_t IT, char UL = 'L', char ODER = 'R'> using par_cposvx = pposvx<complex8, IT, UL, ODER>;
171 } // namespace ezp
172 
173 #endif // PPOSVX_HPP
174 
Definition: full_solver.hpp:24
Definition: pposvx.hpp:45
Definition: traits.hpp:85