ezp
Loading...
Searching...
No Matches
ppbsv.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
80#ifndef PPBSV_HPP
81#define PPBSV_HPP
82
83#include "abstract/band_solver.hpp"
84
85namespace ezp {
86 template<data_t DT, index_t IT, char UL = 'L'> class ppbsv final : public detail::band_solver<DT, IT, band_symm_mat<DT, IT>> {
87 static constexpr char UPLO = UL;
88
90
91 struct band_system {
92 IT n{-1}, klu{-1}, lead{-1}, block{-1}, lines{-1};
93 desc<IT> desc1d_a;
94 std::vector<DT> a, b, work;
95 };
96
97 band_system loc;
98
99 auto init_storage(const IT n, const IT klu) {
100 loc.n = n;
101 loc.klu = klu;
102 loc.lead = loc.klu + 1;
103 loc.block = std::max(loc.n / std::max(IT{1}, this->ctx.n_rows - 1) + 1, std::max(2 * loc.klu, this->ctx.row_block(loc.n)));
104 loc.block = std::min(loc.block, loc.n);
105 loc.lines = this->ctx.rows(loc.n, loc.block);
106 loc.desc1d_a = {501, this->trans_ctx.context, loc.n, loc.block, 0, loc.lead, 0, 0, 0};
107
108 loc.a.resize(loc.lead * loc.lines);
109 }
110
111 using base_t::to_band_symm;
112 using base_t::to_full;
113
114 public:
115 explicit ppbsv(const IT rows = get_env<IT>().size())
116 : base_t(rows) {}
117
118 class indexer {
119 IT n, klu;
120
121 public:
122 explicit indexer(const band_mat<DT, IT>& A)
123 : n(A.n)
124 , klu(A.klu) {}
125
126 indexer(const IT N, const IT KLU)
127 : n(N)
128 , klu(KLU) {}
129
130 auto operator()(IT i, IT j) const {
131 if(i < 0 || i >= n || j < 0 || j >= n) return IT{-1};
132 if('L' == UL) {
133 if(i < j) std::swap(i, j);
134 if(i - j > klu) return IT{-1};
135 return i + j * klu;
136 }
137 else {
138 if(i > j) std::swap(i, j);
139 if(j - i > klu) return IT{-1};
140 return 2 * j - i + (j + 1) * klu;
141 }
142 }
143 };
144
145 template<band_symm_container_t AT, full_container_t BT> IT solve(AT&& A, BT&& B) { return solve(to_band_symm(A), to_full(B)); }
146
147 IT solve(band_symm_mat<DT, IT>&& A, full_mat<DT, IT>&& B) {
148 if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
149
150 if(A.n_rows != A.n_cols || A.n_rows != B.n_rows) return -1;
151
152 init_storage(A.n_cols, A.klu);
153
154 // pretend that A is a full matrix of size (2*(kl+ku)+1) x n
155 // redistribute A to the process grid
156 this->trans_ctx.scatter(
157 full_mat<DT, IT>(loc.lead, loc.n, A.data, A.distributed),
158 this->trans_ctx.desc_g(loc.lead, loc.n),
159 loc.a,
160 this->trans_ctx.desc_l(loc.lead, loc.n, loc.lead, loc.block, loc.lead)
161 );
162
163 const IT laf = (loc.block + 2 * loc.klu) * loc.klu;
164 const IT lwork = loc.klu * std::max(B.n_cols, loc.klu);
165 loc.work.resize(laf + lwork);
166
167 IT info{-1};
168 // ReSharper disable CppCStyleCast
169 if constexpr(std::is_same_v<DT, double>) {
170 using E = double;
171 pdpbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
172 }
173 else if constexpr(std::is_same_v<DT, float>) {
174 using E = float;
175 pspbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
176 }
177 else if constexpr(std::is_same_v<DT, complex16>) {
178 using E = complex16;
179 pzpbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
180 }
181 else if constexpr(std::is_same_v<DT, complex8>) {
182 using E = complex8;
183 pcpbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
184 }
185 // ReSharper restore CppCStyleCast
186
187 if((info = this->trans_ctx.amx(info)) != 0) return info;
188
189 return solve(std::move(B));
190 }
191
192 IT solve(full_mat<DT, IT>&& B) {
193 if(B.n_rows != loc.n) return -1;
194
195 if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
196
197 const auto lead_b = std::max(loc.block, loc.lines);
198
199 loc.b.resize(lead_b * B.n_cols);
200
201 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
202 const auto local_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, loc.block, B.n_cols, lead_b);
203
204 this->ctx.scatter(B, full_desc_b, loc.b, local_desc_b);
205
206 const IT laf = (loc.block + 2 * loc.klu) * loc.klu;
207 const IT lwork = loc.klu * std::max(B.n_cols, loc.klu);
208 loc.work.resize(laf + lwork);
209
210 desc<IT> desc1d_b{502, this->trans_ctx.context, loc.n, loc.block, 0, lead_b, 0, 0, 0};
211
212 IT info{-1};
213 // ReSharper disable CppCStyleCast
214 if constexpr(std::is_same_v<DT, double>) {
215 using E = double;
216 pdpbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
217 }
218 else if constexpr(std::is_same_v<DT, float>) {
219 using E = float;
220 pspbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
221 }
222 else if constexpr(std::is_same_v<DT, complex16>) {
223 using E = complex16;
224 pzpbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
225 }
226 else if constexpr(std::is_same_v<DT, complex8>) {
227 using E = complex8;
228 pcpbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
229 }
230 // ReSharper restore CppCStyleCast
231
232 if((info = this->trans_ctx.amx(info)) == 0) this->ctx.gather(loc.b, local_desc_b, B, full_desc_b);
233
234 return info;
235 }
236 };
237
238 template<index_t IT, char UL = 'L'> using par_dpbsv = ppbsv<double, IT, UL>;
239 template<index_t IT, char UL = 'L'> using par_spbsv = ppbsv<float, IT, UL>;
240 template<index_t IT, char UL = 'L'> using par_zpbsv = ppbsv<complex16, IT, UL>;
241 template<index_t IT, char UL = 'L'> using par_cpbsv = ppbsv<complex8, IT, UL>;
242 template<index_t IT> using par_dpbsv_u = ppbsv<double, IT, 'U'>;
243 template<index_t IT> using par_spbsv_u = ppbsv<float, IT, 'U'>;
244 template<index_t IT> using par_zpbsv_u = ppbsv<complex16, IT, 'U'>;
245 template<index_t IT> using par_cpbsv_u = ppbsv<complex8, IT, 'U'>;
246} // namespace ezp
247
248#endif // PPBSV_HPP
249
Definition band_solver.hpp:24
Definition ppbsv.hpp:118
Definition ppbsv.hpp:86
Solver for symmetric band positive definite matrices.
Definition abstract_solver.hpp:72
Definition abstract_solver.hpp:83
Definition abstract_solver.hpp:68