ezp
Loading...
Searching...
No Matches
pposv.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
27#ifndef PPOSV_HPP
28#define PPOSV_HPP
29
30#include "abstract/full_solver.hpp"
31
32namespace ezp {
33 template<data_t DT, index_t IT, char UL = 'L', char ODER = 'R'> class pposv final : public detail::full_solver<DT, IT, ODER> {
34 static constexpr char UPLO = UL;
35
37
38 public:
39 pposv()
40 : base_t() {}
41
42 pposv(const IT rows, const IT cols)
43 : base_t(rows, cols) {}
44
45 using base_t::solve;
46
47 IT solve(full_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
48 if(!this->ctx.is_valid()) return 0;
49
50 if(A.n_rows != A.n_cols || A.n_rows != B.n_rows) return -1;
51
52 this->init_storage(A.n_rows);
53
54 this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), this->loc.a, this->loc.desc_a);
55
56 IT info{-1};
57 // ReSharper disable CppCStyleCast
58 if constexpr(std::is_same_v<DT, double>) {
59 using E = double;
60 pdpotrf(&UPLO, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &info);
61 }
62 else if constexpr(std::is_same_v<DT, float>) {
63 using E = float;
64 pspotrf(&UPLO, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &info);
65 }
66 else if constexpr(std::is_same_v<DT, complex16>) {
67 using E = complex16;
68 pzpotrf(&UPLO, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &info);
69 }
70 else if constexpr(std::is_same_v<DT, complex8>) {
71 using E = complex8;
72 pcpotrf(&UPLO, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &info);
73 }
74 // ReSharper restore CppCStyleCast
75
76 if((info = this->ctx.amx(info)) != 0) return info;
77
78 return solve(std::move(B));
79 }
80
81 IT solve(full_mat<DT, IT>&& B) override {
82 if(B.n_rows != this->loc.n) return -1;
83
84 if(!this->ctx.is_valid()) return 0;
85
86 this->loc.b.resize(this->loc.rows * this->ctx.cols(B.n_cols, this->loc.block));
87
88 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
89 const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
90
91 this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
92
93 IT info{-1};
94 // ReSharper disable CppCStyleCast
95 if constexpr(std::is_same_v<DT, double>) {
96 using E = double;
97 pdpotrs(&UPLO, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
98 }
99 else if constexpr(std::is_same_v<DT, float>) {
100 using E = float;
101 pspotrs(&UPLO, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
102 }
103 else if constexpr(std::is_same_v<DT, complex16>) {
104 using E = complex16;
105 pzpotrs(&UPLO, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
106 }
107 else if constexpr(std::is_same_v<DT, complex8>) {
108 using E = complex8;
109 pcpotrs(&UPLO, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
110 }
111 // ReSharper restore CppCStyleCast
112
113 if((info = this->ctx.amx(info)) == 0) this->ctx.gather(this->loc.b, loc_desc_b, B, full_desc_b);
114
115 return info;
116 }
117 };
118
119 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_dposv = pposv<double, IT, UL, ODER>;
120 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_sposv = pposv<float, IT, UL, ODER>;
121 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_zposv = pposv<double, IT, UL, ODER>;
122 template<index_t IT, char UL = 'L', char ODER = 'R'> using par_cposv = pposv<float, IT, UL, ODER>;
123 template<index_t IT, char UL = 'L'> using par_dposv_c = pposv<double, IT, UL, 'C'>;
124 template<index_t IT, char UL = 'L'> using par_sposv_c = pposv<float, IT, UL, 'C'>;
125 template<index_t IT, char UL = 'L'> using par_zposv_c = pposv<double, IT, UL, 'C'>;
126 template<index_t IT, char UL = 'L'> using par_cposv_c = pposv<float, IT, UL, 'C'>;
127 template<index_t IT, char ODER = 'R'> using par_dposv_u = pposv<double, IT, 'U', ODER>;
128 template<index_t IT, char ODER = 'R'> using par_sposv_u = pposv<float, IT, 'U', ODER>;
129 template<index_t IT, char ODER = 'R'> using par_zposv_u = pposv<double, IT, 'U', ODER>;
130 template<index_t IT, char ODER = 'R'> using par_cposv_u = pposv<float, IT, 'U', ODER>;
131 template<index_t IT> using par_dposv_uc = pposv<double, IT, 'U', 'C'>;
132 template<index_t IT> using par_sposv_uc = pposv<float, IT, 'U', 'C'>;
133 template<index_t IT> using par_zposv_uc = pposv<double, IT, 'U', 'C'>;
134 template<index_t IT> using par_cposv_uc = pposv<float, IT, 'U', 'C'>;
135} // namespace ezp
136
137#endif // PPOSV_HPP
138
Definition full_solver.hpp:24
Definition pposv.hpp:33
Definition abstract_solver.hpp:68