ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
Loading...
Searching...
No Matches
full_solver.hpp
1/*******************************************************************************
2 * Copyright (C) 2025-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
17
18#ifndef FULL_SOLVER_HPP
19#define FULL_SOLVER_HPP
20
21#include "abstract_solver.hpp"
22
23namespace ezp::detail {
24 template<data_t DT, index_t IT, char ODER = 'R'> class full_solver : public abstract_solver<DT, IT, full_mat<DT, IT>> {
26
27 struct full_system {
28 IT n{-1}, block{-1}, rows{-1}, cols{-1};
29 desc<IT> desc_a;
30 std::vector<DT> a, b;
31 std::vector<IT> ipiv;
32 };
33
34 protected:
35 full_system loc;
36
38
39 auto init_storage(const IT n) {
40 loc.n = n;
41 loc.block = std::max(IT{1}, static_cast<IT>(std::sqrt(ctx.row_block(loc.n) * ctx.col_block(loc.n))));
42 loc.rows = ctx.rows(loc.n, loc.block);
43 loc.cols = ctx.cols(loc.n, loc.block);
44 loc.desc_a = ctx.desc_l(loc.n, loc.n, loc.block, loc.rows);
45
46 loc.a.resize(loc.rows * loc.cols);
47 loc.ipiv.resize(loc.rows + loc.block);
48 }
49
50 auto gather_pivot() {
51 const auto ipiv_l = ctx.desc_l(loc.n, 1, loc.block, loc.rows);
52 const auto ipiv_g = ctx.desc_g(loc.n, 1);
53
54 std::vector<IT> ipiv;
55 if(0 == ctx.rank) ipiv.resize(loc.n);
56
57 ctx.copy_to(loc.ipiv.data(), ipiv_l.data(), ipiv.data(), ipiv_g.data());
58
59 return ipiv;
60 }
61
62 using base_t::to_full;
63
64 public:
66 : base_t()
67 , ctx(ODER) {}
68
69 full_solver(const IT rows, const IT cols)
70 : base_t()
71 , ctx(rows, cols, ODER) {}
72
73 class indexer {
74 IT n, m;
75
76 public:
77 explicit indexer(const full_mat<DT, IT>& A)
78 : n(A.n_rows)
79 , m(A.n_cols) {}
80
81 explicit indexer(const IT N)
82 : indexer(N, N) {}
83
84 indexer(const IT N, const IT M)
85 : n(N)
86 , m(M) {}
87
88 auto operator()(const IT i, const IT j) const {
89 if(i < 0 || i >= n || j < 0 || j >= n) return IT{-1};
90 return i + j * n;
91 }
92 };
93
94 using base_t::solve;
95
96 template<full_container_t AT, full_container_t BT> IT solve(AT&& A, BT&& B) { return solve(to_full(std::forward<AT>(A)), to_full(std::forward<BT>(B))); }
97 template<full_container_t AT> IT solve(AT&& A, full_mat<DT, IT>&& B) { return solve(to_full(std::forward<AT>(A)), std::move(B)); }
98 template<full_container_t BT> IT solve(full_mat<DT, IT>&& A, BT&& B) { return solve(std::move(A), to_full(std::forward<BT>(B))); }
99 };
100} // namespace ezp::detail
101
102#endif // FULL_SOLVER_HPP
Definition traits.hpp:176
auto desc_l(const IT num_rows, const IT num_cols, const IT row_block, const IT col_block, const IT lead)
Generates a descriptor for a local matrix.
Definition traits.hpp:297
auto rows(const IT n, const IT nb) const
Computes the number of local rows of the current process.
Definition traits.hpp:344
auto col_block(const IT n) const
Computes the column block size.
Definition traits.hpp:332
auto row_block(const IT n) const
Computes the row block size.
Definition traits.hpp:327
auto cols(const IT n, const IT nb) const
Computes the number of local columns of the current process.
Definition traits.hpp:356
auto desc_g(const IT num_rows, const IT num_cols)
Generates a descriptor for a global matrix.
Definition traits.hpp:275
Definition abstract_solver.hpp:24
Definition full_solver.hpp:73
Definition full_solver.hpp:24
Definition traits.hpp:85