ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
pgbsv.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2025-2026 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
64 #ifndef PGBSV_HPP
65 #define PGBSV_HPP
66 
67 #include "abstract/band_solver.hpp"
68 
69 namespace ezp {
70  template<data_t DT, index_t IT> class pgbsv final : public detail::band_solver<DT, IT, band_mat<DT, IT>> {
72 
73  struct band_system {
74  IT n{-1}, kl{-1}, ku{-1}, lead{-1}, block{-1}, lines{-1};
75  desc<IT> desc1d_a;
76  std::vector<DT> a, b, work;
77  std::vector<IT> ipiv;
78  };
79 
80  band_system loc;
81 
82  auto init_storage(const IT n, const IT kl, const IT ku) {
83  loc.n = n;
84  loc.kl = kl;
85  loc.ku = ku;
86  loc.lead = 2 * (loc.kl + loc.ku) + 1;
87  loc.block = std::max(loc.n / std::max(IT{1}, this->ctx.n_rows - 1) + 1, std::max(loc.kl + loc.ku + 1, this->ctx.row_block(loc.n)));
88  loc.block = std::min(loc.block, loc.n);
89  loc.lines = this->ctx.rows(loc.n, loc.block);
90  loc.desc1d_a = {501, this->trans_ctx.context, loc.n, loc.block, 0, loc.lead, 0, 0, 0};
91 
92  // see: https://github.com/Reference-ScaLAPACK/scalapack/issues/117
93  loc.a.resize(loc.lead * loc.lines + loc.ku);
94  loc.ipiv.resize(std::min(loc.n, loc.lines + loc.kl + loc.ku), -987654);
95  }
96 
97  using base_t::to_band;
98  using base_t::to_full;
99 
100  public:
101  explicit pgbsv(const IT rows = get_env<IT>().size())
102  : base_t(rows) {}
103 
104  class indexer {
105  IT n, kl, ku;
106 
107  public:
108  explicit indexer(const band_mat<DT, IT>& A)
109  : n(A.n_rows)
110  , kl(A.kl)
111  , ku(A.ku) {}
112 
113  indexer(const IT N, const IT KL, const IT KU)
114  : n(N)
115  , kl(KL)
116  , ku(KU) {}
117 
118  auto operator()(const IT i, const IT j) const {
119  if(i < 0 || i >= n || j < 0 || j >= n) return IT{-1};
120  if(i - j > kl || j - i > ku) return IT{-1};
121  return 2 * ku + kl + i + 2 * j * (kl + ku);
122  }
123  };
124 
125  template<band_container_t AT, full_container_t BT> IT solve(AT&& A, BT&& B) { return solve(to_band(std::forward<AT>(A)), to_full(std::forward<BT>(B))); }
126  template<band_container_t AT> IT solve(AT&& A, full_mat<DT, IT>&& B) { return solve(to_band(std::forward<AT>(A)), to_full(std::move(B))); }
127 
128  IT solve(band_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
129  if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
130 
131  if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
132 
133  init_storage(A.n_cols, A.kl, A.ku);
134 
135  // pretend that A is a full matrix of size (2*(kl+ku)+1) x n
136  // redistribute A to the process grid
137  this->trans_ctx.scatter(
138  full_mat<DT, IT>(loc.lead, loc.n, A.data, A.distributed),
139  this->trans_ctx.desc_g(loc.lead, loc.n),
140  loc.a,
141  this->trans_ctx.desc_l(loc.lead, loc.n, loc.lead, loc.block, loc.lead)
142  );
143 
144  const IT laf = (loc.block + 6 * loc.kl + 13 * loc.ku) * (loc.kl + loc.ku);
145  const IT lwork = std::max(B.n_cols * (loc.block + 2 * loc.kl + 4 * loc.ku), IT{1});
146  loc.work.resize(laf + lwork);
147 
148  IT info{-1};
149  // ReSharper disable CppCStyleCast
150  if constexpr(std::is_same_v<DT, double>) {
151  using E = double;
152  pdgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
153  }
154  else if constexpr(std::is_same_v<DT, float>) {
155  using E = float;
156  psgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
157  }
158  else if constexpr(std::is_same_v<DT, complex16>) {
159  using E = complex16;
160  pzgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
161  }
162  else if constexpr(std::is_same_v<DT, complex8>) {
163  using E = complex8;
164  pcgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
165  }
166  // ReSharper restore CppCStyleCast
167 
168  if((info = this->trans_ctx.amx(info)) != 0) return info;
169 
170  return solve(std::move(B));
171  }
172 
173  IT solve(full_mat<DT, IT>&& B) override {
174  static constexpr char TRANS = 'N';
175 
176  if(B.n_rows != loc.n) return -1;
177 
178  if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
179 
180  const auto lead_b = std::max(loc.block, loc.lines);
181 
182  loc.b.resize(lead_b * B.n_cols);
183 
184  const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
185  const auto local_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, loc.block, B.n_cols, lead_b);
186 
187  this->ctx.scatter(B, full_desc_b, loc.b, local_desc_b);
188 
189  const IT laf = (loc.block + 6 * loc.kl + 13 * loc.ku) * (loc.kl + loc.ku);
190  const IT lwork = std::max(B.n_cols * (loc.block + 2 * loc.kl + 4 * loc.ku), IT{1});
191  loc.work.resize(laf + lwork);
192 
193  desc<IT> desc1d_b{502, this->trans_ctx.context, loc.n, loc.block, 0, lead_b, 0, 0, 0};
194 
195  IT info{-1};
196  // ReSharper disable CppCStyleCast
197  if constexpr(std::is_same_v<DT, double>) {
198  using E = double;
199  pdgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
200  }
201  else if constexpr(std::is_same_v<DT, float>) {
202  using E = float;
203  psgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
204  }
205  else if constexpr(std::is_same_v<DT, complex16>) {
206  using E = complex16;
207  pzgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
208  }
209  else if constexpr(std::is_same_v<DT, complex8>) {
210  using E = complex8;
211  pcgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
212  }
213  // ReSharper restore CppCStyleCast
214 
215  if((info = this->trans_ctx.amx(info)) == 0) this->ctx.gather(loc.b, local_desc_b, B, full_desc_b);
216 
217  return info;
218  }
219  };
220 
221  template<index_t IT> using par_dgbsv = pgbsv<double, IT>;
222  template<index_t IT> using par_sgbsv = pgbsv<float, IT>;
223  template<index_t IT> using par_zgbsv = pgbsv<complex16, IT>;
224  template<index_t IT> using par_cgbsv = pgbsv<complex8, IT>;
225 } // namespace ezp
226 
227 #endif // PGBSV_HPP
228 
Definition: band_solver.hpp:24
Definition: pgbsv.hpp:104
Definition: pgbsv.hpp:70
Solver for general band matrices.
Definition: traits.hpp:89
Definition: traits.hpp:85