ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
ppbsv.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2025-2026 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
84 #ifndef PPBSV_HPP
85 #define PPBSV_HPP
86 
87 #include "abstract/band_solver.hpp"
88 
89 namespace ezp {
90  template<data_t DT, index_t IT, char UL = 'L'> class ppbsv final : public detail::band_solver<DT, IT, band_symm_mat<DT, IT>> {
91  static constexpr char UPLO = UL;
92 
94 
95  struct band_system {
96  IT n{-1}, klu{-1}, lead{-1}, block{-1}, lines{-1};
97  desc<IT> desc1d_a;
98  std::vector<DT> a, b, work;
99  };
100 
101  band_system loc;
102 
103  auto init_storage(const IT n, const IT klu) {
104  loc.n = n;
105  loc.klu = klu;
106  loc.lead = loc.klu + 1;
107  loc.block = std::max(loc.n / std::max(IT{1}, this->ctx.n_rows - 1) + 1, std::max(2 * loc.klu, this->ctx.row_block(loc.n)));
108  loc.block = std::min(loc.block, loc.n);
109  loc.lines = this->ctx.rows(loc.n, loc.block);
110  loc.desc1d_a = {501, this->trans_ctx.context, loc.n, loc.block, 0, loc.lead, 0, 0, 0};
111 
112  loc.a.resize(loc.lead * loc.lines);
113  }
114 
115  using base_t::to_band_symm;
116  using base_t::to_full;
117 
118  public:
119  explicit ppbsv(const IT rows = get_env<IT>().size())
120  : base_t(rows) {}
121 
122  class indexer {
123  IT n, klu;
124 
125  public:
126  explicit indexer(const band_mat<DT, IT>& A)
127  : n(A.n_rows)
128  , klu(A.klu) {}
129 
130  indexer(const IT N, const IT KLU)
131  : n(N)
132  , klu(KLU) {}
133 
134  auto operator()(IT i, IT j) const {
135  if(i < 0 || i >= n || j < 0 || j >= n) return IT{-1};
136  if('L' == UL) {
137  if(i < j) std::swap(i, j);
138  if(i - j > klu) return IT{-1};
139  return i + j * klu;
140  }
141  else {
142  if(i > j) std::swap(i, j);
143  if(j - i > klu) return IT{-1};
144  return 2 * j - i + (j + 1) * klu;
145  }
146  }
147  };
148 
149  template<band_symm_container_t AT, full_container_t BT> IT solve(AT&& A, BT&& B) { return solve(to_band_symm(std::forward<AT>(A)), to_full(std::forward<BT>(B))); }
150  template<band_symm_container_t AT> IT solve(AT&& A, full_mat<DT, IT>&& B) { return solve(to_band_symm(std::forward<AT>(A)), to_full(std::move(B))); }
151 
152  IT solve(band_symm_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
153  if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
154 
155  if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
156 
157  init_storage(A.n_cols, A.klu);
158 
159  // pretend that A is a full matrix of size (2*(kl+ku)+1) x n
160  // redistribute A to the process grid
161  this->trans_ctx.scatter(
162  full_mat<DT, IT>(loc.lead, loc.n, A.data, A.distributed),
163  this->trans_ctx.desc_g(loc.lead, loc.n),
164  loc.a,
165  this->trans_ctx.desc_l(loc.lead, loc.n, loc.lead, loc.block, loc.lead)
166  );
167 
168  const IT laf = (loc.block + 2 * loc.klu) * loc.klu;
169  const IT lwork = loc.klu * std::max(B.n_cols, loc.klu);
170  loc.work.resize(laf + lwork);
171 
172  IT info{-1};
173  // ReSharper disable CppCStyleCast
174  if constexpr(std::is_same_v<DT, double>) {
175  using E = double;
176  pdpbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
177  }
178  else if constexpr(std::is_same_v<DT, float>) {
179  using E = float;
180  pspbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
181  }
182  else if constexpr(std::is_same_v<DT, complex16>) {
183  using E = complex16;
184  pzpbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
185  }
186  else if constexpr(std::is_same_v<DT, complex8>) {
187  using E = complex8;
188  pcpbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
189  }
190  // ReSharper restore CppCStyleCast
191 
192  if((info = this->trans_ctx.amx(info)) != 0) return info;
193 
194  return solve(std::move(B));
195  }
196 
197  IT solve(full_mat<DT, IT>&& B) override {
198  if(B.n_rows != loc.n) return -1;
199 
200  if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
201 
202  const auto lead_b = std::max(loc.block, loc.lines);
203 
204  loc.b.resize(lead_b * B.n_cols);
205 
206  const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
207  const auto local_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, loc.block, B.n_cols, lead_b);
208 
209  this->ctx.scatter(B, full_desc_b, loc.b, local_desc_b);
210 
211  const IT laf = (loc.block + 2 * loc.klu) * loc.klu;
212  const IT lwork = loc.klu * std::max(B.n_cols, loc.klu);
213  loc.work.resize(laf + lwork);
214 
215  desc<IT> desc1d_b{502, this->trans_ctx.context, loc.n, loc.block, 0, lead_b, 0, 0, 0};
216 
217  IT info{-1};
218  // ReSharper disable CppCStyleCast
219  if constexpr(std::is_same_v<DT, double>) {
220  using E = double;
221  pdpbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
222  }
223  else if constexpr(std::is_same_v<DT, float>) {
224  using E = float;
225  pspbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
226  }
227  else if constexpr(std::is_same_v<DT, complex16>) {
228  using E = complex16;
229  pzpbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
230  }
231  else if constexpr(std::is_same_v<DT, complex8>) {
232  using E = complex8;
233  pcpbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
234  }
235  // ReSharper restore CppCStyleCast
236 
237  if((info = this->trans_ctx.amx(info)) == 0) this->ctx.gather(loc.b, local_desc_b, B, full_desc_b);
238 
239  return info;
240  }
241  };
242 
243  template<index_t IT, char UL = 'L'> using par_dpbsv = ppbsv<double, IT, UL>;
244  template<index_t IT, char UL = 'L'> using par_spbsv = ppbsv<float, IT, UL>;
245  template<index_t IT, char UL = 'L'> using par_zpbsv = ppbsv<complex16, IT, UL>;
246  template<index_t IT, char UL = 'L'> using par_cpbsv = ppbsv<complex8, IT, UL>;
247  template<index_t IT> using par_dpbsv_u = ppbsv<double, IT, 'U'>;
248  template<index_t IT> using par_spbsv_u = ppbsv<float, IT, 'U'>;
249  template<index_t IT> using par_zpbsv_u = ppbsv<complex16, IT, 'U'>;
250  template<index_t IT> using par_cpbsv_u = ppbsv<complex8, IT, 'U'>;
251 } // namespace ezp
252 
253 #endif // PPBSV_HPP
254 
Definition: band_solver.hpp:24
Definition: ppbsv.hpp:122
Definition: ppbsv.hpp:90
Solver for symmetric band positive definite matrices.
Definition: traits.hpp:89
Definition: traits.hpp:85