ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
full_solver.hpp
1 /*******************************************************************************
2  * Copyright (C) 2025-2026 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
17 
18 #ifndef FULL_SOLVER_HPP
19 #define FULL_SOLVER_HPP
20 
21 #include "abstract_solver.hpp"
22 
23 namespace ezp::detail {
24  template<data_t DT, index_t IT, char ODER = 'R'> class full_solver : public abstract_solver<DT, IT, full_mat<DT, IT>> {
26 
27  struct full_system {
28  IT n{-1}, block{-1}, rows{-1}, cols{-1};
29  desc<IT> desc_a;
30  std::vector<DT> a, b;
31  std::vector<IT> ipiv;
32  };
33 
34  protected:
35  full_system loc;
36 
38 
39  auto init_storage(const IT n) {
40  loc.n = n;
41  loc.block = std::max(IT{1}, static_cast<IT>(std::sqrt(ctx.row_block(loc.n) * ctx.col_block(loc.n))));
42  loc.rows = ctx.rows(loc.n, loc.block);
43  loc.cols = ctx.cols(loc.n, loc.block);
44  loc.desc_a = ctx.desc_l(loc.n, loc.n, loc.block, loc.rows);
45 
46  loc.a.resize(loc.rows * loc.cols);
47  loc.ipiv.resize(loc.rows + loc.block);
48  }
49 
50  auto gather_pivot() {
51  const auto ipiv_l = ctx.desc_l(loc.n, 1, loc.block, loc.rows);
52  const auto ipiv_g = ctx.desc_g(loc.n, 1);
53 
54  std::vector<IT> ipiv;
55  if(0 == ctx.rank) ipiv.resize(loc.n);
56 
57  ctx.copy_to(loc.ipiv.data(), ipiv_l.data(), ipiv.data(), ipiv_g.data());
58 
59  return ipiv;
60  }
61 
62  using base_t::to_full;
63 
64  public:
65  full_solver()
66  : base_t()
67  , ctx(ODER) {}
68 
69  full_solver(const IT rows, const IT cols)
70  : base_t()
71  , ctx(rows, cols, ODER) {}
72 
73  class indexer {
74  IT n, m;
75 
76  public:
77  explicit indexer(const full_mat<DT, IT>& A)
78  : n(A.n_rows)
79  , m(A.n_cols) {}
80 
81  explicit indexer(const IT N)
82  : indexer(N, N) {}
83 
84  indexer(const IT N, const IT M)
85  : n(N)
86  , m(M) {}
87 
88  auto operator()(const IT i, const IT j) const {
89  if(i < 0 || i >= n || j < 0 || j >= n) return IT{-1};
90  return i + j * n;
91  }
92  };
93 
94  using base_t::solve;
95 
96  template<full_container_t AT, full_container_t BT> IT solve(AT&& A, BT&& B) { return solve(to_full(std::forward<AT>(A)), to_full(std::forward<BT>(B))); }
97  template<full_container_t AT> IT solve(AT&& A, full_mat<DT, IT>&& B) { return solve(to_full(std::forward<AT>(A)), std::move(B)); }
98  template<full_container_t BT> IT solve(full_mat<DT, IT>&& A, BT&& B) { return solve(std::move(A), to_full(std::forward<BT>(B))); }
99  };
100 } // namespace ezp::detail
101 
102 #endif // FULL_SOLVER_HPP
Definition: traits.hpp:176
auto desc_l(const IT num_rows, const IT num_cols, const IT row_block, const IT col_block, const IT lead)
Generates a descriptor for a local matrix.
Definition: traits.hpp:297
auto rows(const IT n, const IT nb) const
Computes the number of local rows of the current process.
Definition: traits.hpp:344
auto col_block(const IT n) const
Computes the column block size.
Definition: traits.hpp:332
auto row_block(const IT n) const
Computes the row block size.
Definition: traits.hpp:327
auto cols(const IT n, const IT nb) const
Computes the number of local columns of the current process.
Definition: traits.hpp:356
auto desc_g(const IT num_rows, const IT num_cols)
Generates a descriptor for a global matrix.
Definition: traits.hpp:275
Definition: abstract_solver.hpp:24
Definition: full_solver.hpp:73
Definition: full_solver.hpp:24
Definition: traits.hpp:85