ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
Loading...
Searching...
No Matches
pposv.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2025-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
37#ifndef PPOSV_HPP
38#define PPOSV_HPP
39
40#include "abstract/full_solver.hpp"
41
42namespace ezp {
43 template<data_t DT, index_t IT, char UL = 'L', char ODER = 'R'> class pposv final : public detail::full_solver<DT, IT, ODER> {
44 static constexpr char UPLO = UL;
45
47
48 public:
49 pposv()
50 : base_t() {}
51
52 pposv(const IT rows, const IT cols)
53 : base_t(rows, cols) {}
54
55 using base_t::solve;
56
57 IT solve(full_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
58 if(!this->ctx.is_valid()) return 0;
59
60 if(A.n_rows != A.n_cols || A.n_cols != B.n_rows) return -1;
61
62 this->init_storage(A.n_rows);
63
64 this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), this->loc.a, this->loc.desc_a);
65
66 IT info{-1};
67 // ReSharper disable CppCStyleCast
68 if constexpr(std::is_same_v<DT, double>) {
69 using E = double;
70 pdpotrf(&UPLO, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &info);
71 }
72 else if constexpr(std::is_same_v<DT, float>) {
73 using E = float;
74 pspotrf(&UPLO, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &info);
75 }
76 else if constexpr(std::is_same_v<DT, complex16>) {
77 using E = complex16;
78 pzpotrf(&UPLO, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &info);
79 }
80 else if constexpr(std::is_same_v<DT, complex8>) {
81 using E = complex8;
82 pcpotrf(&UPLO, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &info);
83 }
84 // ReSharper restore CppCStyleCast
85
86 if((info = this->ctx.amx(info)) != 0) return info;
87
88 return solve(std::move(B));
89 }
90
91 IT solve(full_mat<DT, IT>&& B) override {
92 if(B.n_rows != this->loc.n) return -1;
93
94 if(!this->ctx.is_valid()) return 0;
95
96 this->loc.b.resize(this->loc.rows * this->ctx.cols(B.n_cols, this->loc.block));
97
98 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
99 const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
100
101 this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
102
103 IT info{-1};
104 // ReSharper disable CppCStyleCast
105 if constexpr(std::is_same_v<DT, double>) {
106 using E = double;
107 pdpotrs(&UPLO, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
108 }
109 else if constexpr(std::is_same_v<DT, float>) {
110 using E = float;
111 pspotrs(&UPLO, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
112 }
113 else if constexpr(std::is_same_v<DT, complex16>) {
114 using E = complex16;
115 pzpotrs(&UPLO, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
116 }
117 else if constexpr(std::is_same_v<DT, complex8>) {
118 using E = complex8;
119 pcpotrs(&UPLO, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
120 }
121 // ReSharper restore CppCStyleCast
122
123 if((info = this->ctx.amx(info)) == 0) this->ctx.gather(this->loc.b, loc_desc_b, B, full_desc_b);
124
125 return info;
126 }
127 };
128
129 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_dposv = pposv<double, IT, UL, ODER>;
130 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_sposv = pposv<float, IT, UL, ODER>;
131 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_zposv = pposv<double, IT, UL, ODER>;
132 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_cposv = pposv<float, IT, UL, ODER>;
133 template<index_t IT, char UL = 'L'> using par_dposv_c = pposv<double, IT, UL, 'C'>;
134 template<index_t IT, char UL = 'L'> using par_sposv_c = pposv<float, IT, UL, 'C'>;
135 template<index_t IT, char UL = 'L'> using par_zposv_c = pposv<double, IT, UL, 'C'>;
136 template<index_t IT, char UL = 'L'> using par_cposv_c = pposv<float, IT, UL, 'C'>;
137 template<index_t IT, char ODER = 'R'> using par_dposv_u = pposv<double, IT, 'U', ODER>;
138 template<index_t IT, char ODER = 'R'> using par_sposv_u = pposv<float, IT, 'U', ODER>;
139 template<index_t IT, char ODER = 'R'> using par_zposv_u = pposv<double, IT, 'U', ODER>;
140 template<index_t IT, char ODER = 'R'> using par_cposv_u = pposv<float, IT, 'U', ODER>;
141 template<index_t IT> using par_dposv_uc = pposv<double, IT, 'U', 'C'>;
142 template<index_t IT> using par_sposv_uc = pposv<float, IT, 'U', 'C'>;
143 template<index_t IT> using par_zposv_uc = pposv<double, IT, 'U', 'C'>;
144 template<index_t IT> using par_cposv_uc = pposv<float, IT, 'U', 'C'>;
145} // namespace ezp
146
147#endif // PPOSV_HPP
148
Definition full_solver.hpp:24
Definition pposv.hpp:43
Definition traits.hpp:85