ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
Loading...
Searching...
No Matches
pgesv.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2025-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
36#ifndef PGESV_HPP
37#define PGESV_HPP
38
39#include "abstract/full_solver.hpp"
40
41namespace ezp {
42 template<data_t DT, index_t IT, char ODER = 'R'> class pgesv final : public detail::full_solver<DT, IT, ODER> {
44
45 public:
46 pgesv()
47 : base_t() {}
48
49 pgesv(const IT rows, const IT cols)
50 : base_t(rows, cols) {}
51
52 using base_t::solve;
53
76 DT determinant{0};
77
78 if(!this->ctx.is_valid()) return determinant;
79
80 this->ctx.gather(this->loc.a, this->loc.desc_a, A, this->ctx.desc_g(this->loc.n, this->loc.n));
81
82 const auto ipiv = this->gather_pivot();
83
84 if(0 == this->ctx.rank) {
85 const auto idx = typename base_t::indexer{A};
86 auto swaps = IT{0};
87 determinant = DT{1};
88 for(auto I = IT{0}; I < this->loc.n; ++I) {
89 determinant *= A.data[idx(I, I)];
90 if(ipiv[I] != I + 1) ++swaps;
91 }
92 if(swaps % 2 == 1) determinant = -determinant;
93 }
94
95 return determinant;
96 }
97
98 IT solve(full_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
99 if(!this->ctx.is_valid()) return 0;
100
101 if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
102
103 this->init_storage(A.n_rows);
104
105 this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), this->loc.a, this->loc.desc_a);
106
107 IT info{-1};
108 // ReSharper disable CppCStyleCast
109 if constexpr(std::is_same_v<DT, double>) {
110 using E = double;
111 pdgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
112 }
113 else if constexpr(std::is_same_v<DT, float>) {
114 using E = float;
115 psgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
116 }
117 else if constexpr(std::is_same_v<DT, complex16>) {
118 using E = complex16;
119 pzgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
120 }
121 else if constexpr(std::is_same_v<DT, complex8>) {
122 using E = complex8;
123 pcgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
124 }
125 // ReSharper restore CppCStyleCast
126
127 if((info = this->ctx.amx(info)) != 0) return info;
128
129 return solve(std::move(B));
130 }
131
132 IT solve(full_mat<DT, IT>&& B) override {
133 static constexpr char TRANS = 'N';
134
135 if(B.n_rows != this->loc.n) return -1;
136
137 if(!this->ctx.is_valid()) return 0;
138
139 this->loc.b.resize(this->loc.rows * this->ctx.cols(B.n_cols, this->loc.block));
140
141 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
142 const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
143
144 this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
145
146 IT info{-1};
147 // ReSharper disable CppCStyleCast
148 if constexpr(std::is_same_v<DT, double>) {
149 using E = double;
150 pdgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
151 }
152 else if constexpr(std::is_same_v<DT, float>) {
153 using E = float;
154 psgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
155 }
156 else if constexpr(std::is_same_v<DT, complex16>) {
157 using E = complex16;
158 pzgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
159 }
160 else if constexpr(std::is_same_v<DT, complex8>) {
161 using E = complex8;
162 pcgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
163 }
164 // ReSharper restore CppCStyleCast
165
166 if((info = this->ctx.amx(info)) == 0) this->ctx.gather(this->loc.b, loc_desc_b, B, full_desc_b);
167
168 return info;
169 }
170 };
171
172 template<index_t IT, char ODER = 'R'> using par_dgesv = pgesv<double, IT, ODER>;
173 template<index_t IT, char ODER = 'R'> using par_sgesv = pgesv<float, IT, ODER>;
174 template<index_t IT, char ODER = 'R'> using par_zgesv = pgesv<complex16, IT, ODER>;
175 template<index_t IT, char ODER = 'R'> using par_cgesv = pgesv<complex8, IT, ODER>;
176 template<index_t IT> using par_dgesv_c = pgesv<double, IT, 'C'>;
177 template<index_t IT> using par_sgesv_c = pgesv<float, IT, 'C'>;
178 template<index_t IT> using par_zgesv_c = pgesv<complex16, IT, 'C'>;
179 template<index_t IT> using par_cgesv_c = pgesv<complex8, IT, 'C'>;
180} // namespace ezp
181
182#endif // PGESV_HPP
183
Definition full_solver.hpp:73
Definition full_solver.hpp:24
Definition pgesv.hpp:42
auto det(full_mat< DT, IT > &&A)
Computes the determinant of a matrix.
Definition pgesv.hpp:75
Solver for general full matrices.
Definition traits.hpp:85