ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
pdbsv.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2025-2026 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
61 #ifndef PDBSV_HPP
62 #define PDBSV_HPP
63 
64 #include "abstract/band_solver.hpp"
65 
66 namespace ezp {
67  template<data_t DT, index_t IT> class pdbsv final : public detail::band_solver<DT, IT, band_mat<DT, IT>> {
69 
70  struct band_system {
71  IT n{-1}, kl{-1}, ku{-1}, max_klu{-1}, lead{-1}, block{-1}, lines{-1};
72  desc<IT> desc1d_a;
73  std::vector<DT> a, b, work;
74  };
75 
76  band_system loc;
77 
89  auto init_storage(const IT n, const IT kl, const IT ku) {
90  loc.n = n;
91  loc.kl = kl;
92  loc.ku = ku;
93  loc.max_klu = std::max(loc.kl, loc.ku);
94  loc.lead = loc.kl + loc.ku + 1;
95  loc.block = std::max(loc.n / std::max(IT{1}, this->ctx.n_rows - 1) + 1, std::max(2 * loc.max_klu, this->ctx.row_block(loc.n)));
96  loc.block = std::min(loc.block, loc.n);
97  loc.lines = this->ctx.rows(loc.n, loc.block);
98  loc.desc1d_a = {501, this->trans_ctx.context, loc.n, loc.block, 0, loc.lead, 0, 0, 0};
99 
100  loc.a.resize(loc.lead * loc.lines);
101  }
102 
103  using base_t::to_band;
104  using base_t::to_full;
105 
106  public:
107  explicit pdbsv(const IT rows = get_env<IT>().size())
108  : base_t(rows) {}
109 
110  class indexer {
111  IT n, kl, ku;
112 
113  public:
114  explicit indexer(const band_mat<DT, IT>& A)
115  : n(A.n_rows)
116  , kl(A.kl)
117  , ku(A.ku) {}
118 
119  indexer(const IT N, const IT KL, const IT KU)
120  : n(N)
121  , kl(KL)
122  , ku(KU) {}
123 
124  auto operator()(const IT i, const IT j) const {
125  if(i < 0 || i >= n || j < 0 || j >= n) return IT{-1};
126  if(i - j > kl || j - i > ku) return IT{-1};
127  return ku + i + j * (kl + ku);
128  }
129  };
130 
131  template<band_container_t AT, full_container_t BT> IT solve(AT&& A, BT&& B) { return solve(to_band(std::forward<AT>(A)), to_full(std::forward<BT>(B))); }
132  template<band_container_t AT> IT solve(AT&& A, full_mat<DT, IT>&& B) { return solve(to_band(std::forward<AT>(A)), to_full(std::move(B))); }
133 
134  IT solve(band_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
135  if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
136 
137  if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
138 
139  init_storage(A.n_cols, A.kl, A.ku);
140 
141  // pretend that A is a full matrix of size (2*(kl+ku)+1) x n
142  // redistribute A to the process grid
143  this->trans_ctx.scatter(
144  full_mat<DT, IT>(loc.lead, loc.n, A.data, A.distributed),
145  this->trans_ctx.desc_g(loc.lead, loc.n),
146  loc.a,
147  this->trans_ctx.desc_l(loc.lead, loc.n, loc.lead, loc.block, loc.lead)
148  );
149 
150  const IT laf = loc.block * (loc.kl + loc.ku) + 6 * loc.max_klu * loc.max_klu;
151  const IT lwork = loc.max_klu * std::max(2 * B.n_cols - 1, loc.max_klu);
152  loc.work.resize(laf + lwork);
153 
154  IT info{-1};
155  // ReSharper disable CppCStyleCast
156  if constexpr(std::is_same_v<DT, double>) {
157  using E = double;
158  pddbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
159  }
160  else if constexpr(std::is_same_v<DT, float>) {
161  using E = float;
162  psdbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
163  }
164  else if constexpr(std::is_same_v<DT, complex16>) {
165  using E = complex16;
166  pzdbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
167  }
168  else if constexpr(std::is_same_v<DT, complex8>) {
169  using E = complex8;
170  pcdbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
171  }
172  // ReSharper restore CppCStyleCast
173 
174  if((info = this->trans_ctx.amx(info)) != 0) return info;
175 
176  return solve(std::move(B));
177  }
178 
179  IT solve(full_mat<DT, IT>&& B) override {
180  static constexpr char TRANS = 'N';
181 
182  if(B.n_rows != loc.n) return -1;
183 
184  if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
185 
186  const auto lead_b = std::max(loc.block, loc.lines);
187 
188  loc.b.resize(lead_b * B.n_cols);
189 
190  const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
191  const auto local_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, loc.block, B.n_cols, lead_b);
192 
193  this->ctx.scatter(B, full_desc_b, loc.b, local_desc_b);
194 
195  const IT laf = loc.block * (loc.kl + loc.ku) + 6 * loc.max_klu * loc.max_klu;
196  const IT lwork = loc.max_klu * std::max(2 * B.n_cols - 1, loc.max_klu);
197  loc.work.resize(laf + lwork);
198 
199  desc<IT> desc1d_b{502, this->trans_ctx.context, loc.n, loc.block, 0, lead_b, 0, 0, 0};
200 
201  IT info{-1};
202  // ReSharper disable CppCStyleCast
203  if constexpr(std::is_same_v<DT, double>) {
204  using E = double;
205  pddbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
206  }
207  else if constexpr(std::is_same_v<DT, float>) {
208  using E = float;
209  psdbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
210  }
211  else if constexpr(std::is_same_v<DT, complex16>) {
212  using E = complex16;
213  pzdbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
214  }
215  else if constexpr(std::is_same_v<DT, complex8>) {
216  using E = complex8;
217  pcdbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
218  }
219  // ReSharper restore CppCStyleCast
220 
221  if((info = this->trans_ctx.amx(info)) == 0) this->ctx.gather(loc.b, local_desc_b, B, full_desc_b);
222 
223  return info;
224  }
225  };
226 
227  template<index_t IT> using par_ddbsv = pdbsv<double, IT>;
228  template<index_t IT> using par_sdbsv = pdbsv<float, IT>;
229  template<index_t IT> using par_zdbsv = pdbsv<complex16, IT>;
230  template<index_t IT> using par_cdbsv = pdbsv<complex8, IT>;
231 } // namespace ezp
232 
233 #endif // PDBSV_HPP
234 
Definition: band_solver.hpp:24
Definition: pdbsv.hpp:110
Definition: pdbsv.hpp:67
Solver for general band matrices.
Definition: traits.hpp:89
Definition: traits.hpp:85