ezp
Loading...
Searching...
No Matches
pdbsv.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
57#ifndef PDBSV_HPP
58#define PDBSV_HPP
59
60#include "abstract/band_solver.hpp"
61
62namespace ezp {
63 template<data_t DT, index_t IT> class pdbsv final : public detail::band_solver<DT, IT, band_mat<DT, IT>> {
65
66 struct band_system {
67 IT n{-1}, kl{-1}, ku{-1}, max_klu{-1}, lead{-1}, block{-1}, lines{-1};
68 desc<IT> desc1d_a;
69 std::vector<DT> a, b, work;
70 };
71
72 band_system loc;
73
85 auto init_storage(const IT n, const IT kl, const IT ku) {
86 loc.n = n;
87 loc.kl = kl;
88 loc.ku = ku;
89 loc.max_klu = std::max(loc.kl, loc.ku);
90 loc.lead = loc.kl + loc.ku + 1;
91 loc.block = std::max(loc.n / std::max(IT{1}, this->ctx.n_rows - 1) + 1, std::max(2 * loc.max_klu, this->ctx.row_block(loc.n)));
92 loc.block = std::min(loc.block, loc.n);
93 loc.lines = this->ctx.rows(loc.n, loc.block);
94 loc.desc1d_a = {501, this->trans_ctx.context, loc.n, loc.block, 0, loc.lead, 0, 0, 0};
95
96 loc.a.resize(loc.lead * loc.lines);
97 }
98
99 using base_t::to_band;
100 using base_t::to_full;
101
102 public:
103 explicit pdbsv(const IT rows = get_env<IT>().size())
104 : base_t(rows) {}
105
106 class indexer {
107 IT n, kl, ku;
108
109 public:
110 explicit indexer(const band_mat<DT, IT>& A)
111 : n(A.n)
112 , kl(A.kl)
113 , ku(A.ku) {}
114
115 indexer(const IT N, const IT KL, const IT KU)
116 : n(N)
117 , kl(KL)
118 , ku(KU) {}
119
120 auto operator()(const IT i, const IT j) const {
121 if(i < 0 || i >= n || j < 0 || j >= n) return IT{-1};
122 if(i - j > kl || j - i > ku) return IT{-1};
123 return ku + i + j * (kl + ku);
124 }
125 };
126
127 template<band_container_t AT, full_container_t BT> IT solve(AT&& A, BT&& B) { return solve(to_band(A), to_full(B)); }
128
129 IT solve(band_mat<DT, IT>&& A, full_mat<DT, IT>&& B) {
130 if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
131
132 if(A.n_rows != A.n_cols || A.n_rows != B.n_rows) return -1;
133
134 init_storage(A.n_cols, A.kl, A.ku);
135
136 // pretend that A is a full matrix of size (2*(kl+ku)+1) x n
137 // redistribute A to the process grid
138 this->trans_ctx.scatter(
139 full_mat<DT, IT>(loc.lead, loc.n, A.data, A.distributed),
140 this->trans_ctx.desc_g(loc.lead, loc.n),
141 loc.a,
142 this->trans_ctx.desc_l(loc.lead, loc.n, loc.lead, loc.block, loc.lead)
143 );
144
145 const IT laf = loc.block * (loc.kl + loc.ku) + 6 * std::pow(loc.max_klu, 2);
146 const IT lwork = loc.max_klu * std::max(2 * B.n_cols - 1, loc.max_klu);
147 loc.work.resize(laf + lwork);
148
149 IT info{-1};
150 // ReSharper disable CppCStyleCast
151 if constexpr(std::is_same_v<DT, double>) {
152 using E = double;
153 pddbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
154 }
155 else if constexpr(std::is_same_v<DT, float>) {
156 using E = float;
157 psdbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
158 }
159 else if constexpr(std::is_same_v<DT, complex16>) {
160 using E = complex16;
161 pzdbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
162 }
163 else if constexpr(std::is_same_v<DT, complex8>) {
164 using E = complex8;
165 pcdbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
166 }
167 // ReSharper restore CppCStyleCast
168
169 if((info = this->trans_ctx.amx(info)) != 0) return info;
170
171 return solve(std::move(B));
172 }
173
174 IT solve(full_mat<DT, IT>&& B) {
175 static constexpr char TRANS = 'N';
176
177 if(B.n_rows != loc.n) return -1;
178
179 if(!this->ctx.is_valid() || !this->trans_ctx.is_valid()) return 0;
180
181 const auto lead_b = std::max(loc.block, loc.lines);
182
183 loc.b.resize(lead_b * B.n_cols);
184
185 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
186 const auto local_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, loc.block, B.n_cols, lead_b);
187
188 this->ctx.scatter(B, full_desc_b, loc.b, local_desc_b);
189
190 const IT laf = loc.block * (loc.kl + loc.ku) + 6 * std::pow(loc.max_klu, 2);
191 const IT lwork = loc.max_klu * std::max(2 * B.n_cols - 1, loc.max_klu);
192 loc.work.resize(laf + lwork);
193
194 desc<IT> desc1d_b{502, this->trans_ctx.context, loc.n, loc.block, 0, lead_b, 0, 0, 0};
195
196 IT info{-1};
197 // ReSharper disable CppCStyleCast
198 if constexpr(std::is_same_v<DT, double>) {
199 using E = double;
200 pddbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
201 }
202 else if constexpr(std::is_same_v<DT, float>) {
203 using E = float;
204 psdbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
205 }
206 else if constexpr(std::is_same_v<DT, complex16>) {
207 using E = complex16;
208 pzdbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
209 }
210 else if constexpr(std::is_same_v<DT, complex8>) {
211 using E = complex8;
212 pcdbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
213 }
214 // ReSharper restore CppCStyleCast
215
216 if((info = this->trans_ctx.amx(info)) == 0) this->ctx.gather(loc.b, local_desc_b, B, full_desc_b);
217
218 return info;
219 }
220 };
221
222 template<index_t IT> using par_ddbsv = pdbsv<double, IT>;
223 template<index_t IT> using par_sdbsv = pdbsv<float, IT>;
224 template<index_t IT> using par_zdbsv = pdbsv<complex16, IT>;
225 template<index_t IT> using par_cdbsv = pdbsv<complex8, IT>;
226} // namespace ezp
227
228#endif // PDBSV_HPP
229
Definition band_solver.hpp:24
Definition pdbsv.hpp:106
Definition pdbsv.hpp:63
Solver for general band matrices.
Definition abstract_solver.hpp:72
Definition abstract_solver.hpp:68