ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
Loading...
Searching...
No Matches
pgbsv.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2025-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
64#ifndef PGBSV_HPP
65#define PGBSV_HPP
66
67#include "abstract/band_solver.hpp"
68
69namespace ezp {
70 template<data_t DT, index_t IT> class pgbsv final : public detail::band_solver<DT, IT, band_mat<DT, IT>> {
72
73 struct band_system {
74 IT n{-1}, kl{-1}, ku{-1}, lead{-1}, block{-1}, lines{-1};
75 desc<IT> desc1d_a;
76 std::vector<DT> a, b, work;
77 std::vector<IT> ipiv;
78 };
79
80 band_system loc;
81
82 auto init_storage(const IT n, const IT kl, const IT ku) {
83 loc.n = n;
84 loc.kl = kl;
85 loc.ku = ku;
86 loc.lead = 2 * (loc.kl + loc.ku) + 1;
87 loc.block = std::max(loc.n / std::max(IT{1}, this->ctx.n_rows - 1) + 1, std::max(loc.kl + loc.ku + 1, this->ctx.row_block(loc.n)));
88 loc.block = std::min(loc.block, loc.n);
89 loc.lines = this->ctx.rows(loc.n, loc.block);
90 loc.desc1d_a = {501, this->trans_ctx.context, loc.n, loc.block, 0, loc.lead, 0, 0, 0};
91
92 // see: https://github.com/Reference-ScaLAPACK/scalapack/issues/117
93 loc.a.resize(loc.lead * loc.lines + loc.ku);
94 loc.ipiv.resize(std::min(loc.n, loc.lines + loc.kl + loc.ku), -987654);
95 }
96
97 using base_t::to_band;
98 using base_t::to_full;
99
100 public:
101 explicit pgbsv(const IT rows = get_env<IT>().size())
102 : base_t(rows) {}
103
104 class indexer {
105 IT n, kl, ku;
106
107 public:
108 explicit indexer(const band_mat<DT, IT>& A)
109 : n(A.n_rows)
110 , kl(A.kl)
111 , ku(A.ku) {}
112
113 indexer(const IT N, const IT KL, const IT KU)
114 : n(N)
115 , kl(KL)
116 , ku(KU) {}
117
118 auto operator()(const IT i, const IT j) const {
119 if(i < 0 || i >= n || j < 0 || j >= n) return IT{-1};
120 if(i - j > kl || j - i > ku) return IT{-1};
121 return 2 * ku + kl + i + 2 * j * (kl + ku);
122 }
123 };
124
125 template<band_container_t AT, full_container_t BT> IT solve(AT&& A, BT&& B) { return solve(to_band(std::forward<AT>(A)), to_full(std::forward<BT>(B))); }
126 template<band_container_t AT> IT solve(AT&& A, full_mat<DT, IT>&& B) { return solve(to_band(std::forward<AT>(A)), to_full(std::move(B))); }
127
128 IT solve(band_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
129 if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
130
131 if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
132
133 init_storage(A.n_cols, A.kl, A.ku);
134
135 // pretend that A is a full matrix of size (2*(kl+ku)+1) x n
136 // redistribute A to the process grid
137 this->trans_ctx.scatter(
138 full_mat<DT, IT>(loc.lead, loc.n, A.data, A.distributed),
139 this->trans_ctx.desc_g(loc.lead, loc.n),
140 loc.a,
141 this->trans_ctx.desc_l(loc.lead, loc.n, loc.lead, loc.block, loc.lead)
142 );
143
144 const IT laf = (loc.block + 6 * loc.kl + 13 * loc.ku) * (loc.kl + loc.ku);
145 const IT lwork = std::max(B.n_cols * (loc.block + 2 * loc.kl + 4 * loc.ku), IT{1});
146 loc.work.resize(laf + lwork);
147
148 IT info{-1};
149 // ReSharper disable CppCStyleCast
150 if constexpr(std::is_same_v<DT, double>) {
151 using E = double;
152 pdgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
153 }
154 else if constexpr(std::is_same_v<DT, float>) {
155 using E = float;
156 psgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
157 }
158 else if constexpr(std::is_same_v<DT, complex16>) {
159 using E = complex16;
160 pzgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
161 }
162 else if constexpr(std::is_same_v<DT, complex8>) {
163 using E = complex8;
164 pcgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
165 }
166 // ReSharper restore CppCStyleCast
167
168 if((info = this->trans_ctx.amx(info)) != 0) return info;
169
170 return solve(std::move(B));
171 }
172
173 IT solve(full_mat<DT, IT>&& B) override {
174 static constexpr char TRANS = 'N';
175
176 if(B.n_rows != loc.n) return -1;
177
178 if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
179
180 const auto lead_b = std::max(loc.block, loc.lines);
181
182 loc.b.resize(lead_b * B.n_cols);
183
184 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
185 const auto local_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, loc.block, B.n_cols, lead_b);
186
187 this->ctx.scatter(B, full_desc_b, loc.b, local_desc_b);
188
189 const IT laf = (loc.block + 6 * loc.kl + 13 * loc.ku) * (loc.kl + loc.ku);
190 const IT lwork = std::max(B.n_cols * (loc.block + 2 * loc.kl + 4 * loc.ku), IT{1});
191 loc.work.resize(laf + lwork);
192
193 desc<IT> desc1d_b{502, this->trans_ctx.context, loc.n, loc.block, 0, lead_b, 0, 0, 0};
194
195 IT info{-1};
196 // ReSharper disable CppCStyleCast
197 if constexpr(std::is_same_v<DT, double>) {
198 using E = double;
199 pdgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
200 }
201 else if constexpr(std::is_same_v<DT, float>) {
202 using E = float;
203 psgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
204 }
205 else if constexpr(std::is_same_v<DT, complex16>) {
206 using E = complex16;
207 pzgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
208 }
209 else if constexpr(std::is_same_v<DT, complex8>) {
210 using E = complex8;
211 pcgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
212 }
213 // ReSharper restore CppCStyleCast
214
215 if((info = this->trans_ctx.amx(info)) == 0) this->ctx.gather(loc.b, local_desc_b, B, full_desc_b);
216
217 return info;
218 }
219 };
220
221 template<index_t IT> using par_dgbsv = pgbsv<double, IT>;
222 template<index_t IT> using par_sgbsv = pgbsv<float, IT>;
223 template<index_t IT> using par_zgbsv = pgbsv<complex16, IT>;
224 template<index_t IT> using par_cgbsv = pgbsv<complex8, IT>;
225} // namespace ezp
226
227#endif // PGBSV_HPP
228
Definition band_solver.hpp:24
Definition pgbsv.hpp:104
Definition pgbsv.hpp:70
Solver for general band matrices.
Definition traits.hpp:89
Definition traits.hpp:85