ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
pgesv.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2025-2026 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
36 #ifndef PGESV_HPP
37 #define PGESV_HPP
38 
39 #include "abstract/full_solver.hpp"
40 
41 namespace ezp {
42  template<data_t DT, index_t IT, char ODER = 'R'> class pgesv final : public detail::full_solver<DT, IT, ODER> {
44 
45  public:
46  pgesv()
47  : base_t() {}
48 
49  pgesv(const IT rows, const IT cols)
50  : base_t(rows, cols) {}
51 
52  using base_t::solve;
53 
75  auto det(full_mat<DT, IT>&& A) {
76  DT determinant{0};
77 
78  if(!this->ctx.is_valid()) return determinant;
79 
80  this->ctx.gather(this->loc.a, this->loc.desc_a, A, this->ctx.desc_g(this->loc.n, this->loc.n));
81 
82  const auto ipiv = this->gather_pivot();
83 
84  if(0 == this->ctx.rank) {
85  const auto idx = typename base_t::indexer{A};
86  auto swaps = IT{0};
87  determinant = DT{1};
88  for(auto I = IT{0}; I < this->loc.n; ++I) {
89  determinant *= A.data[idx(I, I)];
90  if(ipiv[I] != I + 1) ++swaps;
91  }
92  if(swaps % 2 == 1) determinant = -determinant;
93  }
94 
95  return determinant;
96  }
97 
98  IT solve(full_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
99  if(!this->ctx.is_valid()) return 0;
100 
101  if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
102 
103  this->init_storage(A.n_rows);
104 
105  this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), this->loc.a, this->loc.desc_a);
106 
107  IT info{-1};
108  // ReSharper disable CppCStyleCast
109  if constexpr(std::is_same_v<DT, double>) {
110  using E = double;
111  pdgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
112  }
113  else if constexpr(std::is_same_v<DT, float>) {
114  using E = float;
115  psgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
116  }
117  else if constexpr(std::is_same_v<DT, complex16>) {
118  using E = complex16;
119  pzgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
120  }
121  else if constexpr(std::is_same_v<DT, complex8>) {
122  using E = complex8;
123  pcgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
124  }
125  // ReSharper restore CppCStyleCast
126 
127  if((info = this->ctx.amx(info)) != 0) return info;
128 
129  return solve(std::move(B));
130  }
131 
132  IT solve(full_mat<DT, IT>&& B) override {
133  static constexpr char TRANS = 'N';
134 
135  if(B.n_rows != this->loc.n) return -1;
136 
137  if(!this->ctx.is_valid()) return 0;
138 
139  this->loc.b.resize(this->loc.rows * this->ctx.cols(B.n_cols, this->loc.block));
140 
141  const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
142  const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
143 
144  this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
145 
146  IT info{-1};
147  // ReSharper disable CppCStyleCast
148  if constexpr(std::is_same_v<DT, double>) {
149  using E = double;
150  pdgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
151  }
152  else if constexpr(std::is_same_v<DT, float>) {
153  using E = float;
154  psgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
155  }
156  else if constexpr(std::is_same_v<DT, complex16>) {
157  using E = complex16;
158  pzgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
159  }
160  else if constexpr(std::is_same_v<DT, complex8>) {
161  using E = complex8;
162  pcgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
163  }
164  // ReSharper restore CppCStyleCast
165 
166  if((info = this->ctx.amx(info)) == 0) this->ctx.gather(this->loc.b, loc_desc_b, B, full_desc_b);
167 
168  return info;
169  }
170  };
171 
172  template<index_t IT, char ODER = 'R'> using par_dgesv = pgesv<double, IT, ODER>;
173  template<index_t IT, char ODER = 'R'> using par_sgesv = pgesv<float, IT, ODER>;
174  template<index_t IT, char ODER = 'R'> using par_zgesv = pgesv<complex16, IT, ODER>;
175  template<index_t IT, char ODER = 'R'> using par_cgesv = pgesv<complex8, IT, ODER>;
176  template<index_t IT> using par_dgesv_c = pgesv<double, IT, 'C'>;
177  template<index_t IT> using par_sgesv_c = pgesv<float, IT, 'C'>;
178  template<index_t IT> using par_zgesv_c = pgesv<complex16, IT, 'C'>;
179  template<index_t IT> using par_cgesv_c = pgesv<complex8, IT, 'C'>;
180 } // namespace ezp
181 
182 #endif // PGESV_HPP
183 
Definition: full_solver.hpp:73
Definition: full_solver.hpp:24
Definition: pgesv.hpp:42
auto det(full_mat< DT, IT > &&A)
Computes the determinant of a matrix.
Definition: pgesv.hpp:75
Solver for general full matrices.
Definition: traits.hpp:85