ezp
Loading...
Searching...
No Matches
pgesv.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
27#ifndef PGESV_HPP
28#define PGESV_HPP
29
30#include "abstract/full_solver.hpp"
31
32namespace ezp {
33 template<data_t DT, index_t IT, char ODER = 'R'> class pgesv final : public detail::full_solver<DT, IT, ODER> {
35
36 public:
37 pgesv()
38 : base_t() {}
39
40 pgesv(const IT rows, const IT cols)
41 : base_t(rows, cols) {}
42
43 using base_t::solve;
44
45 IT solve(full_mat<DT, IT>&& A, full_mat<DT, IT>&& B) override {
46 if(!this->ctx.is_valid()) return 0;
47
48 if(A.n_rows != A.n_cols || A.n_rows != B.n_rows) return -1;
49
50 this->init_storage(A.n_rows);
51
52 this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), this->loc.a, this->loc.desc_a);
53
54 IT info{-1};
55 // ReSharper disable CppCStyleCast
56 if constexpr(std::is_same_v<DT, double>) {
57 using E = double;
58 pdgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
59 }
60 else if constexpr(std::is_same_v<DT, float>) {
61 using E = float;
62 psgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
63 }
64 else if constexpr(std::is_same_v<DT, complex16>) {
65 using E = complex16;
66 pzgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
67 }
68 else if constexpr(std::is_same_v<DT, complex8>) {
69 using E = complex8;
70 pcgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
71 }
72 // ReSharper restore CppCStyleCast
73
74 if((info = this->ctx.amx(info)) != 0) return info;
75
76 return solve(std::move(B));
77 }
78
79 IT solve(full_mat<DT, IT>&& B) override {
80 static constexpr char TRANS = 'N';
81
82 if(B.n_rows != this->loc.n) return -1;
83
84 if(!this->ctx.is_valid()) return 0;
85
86 this->loc.b.resize(this->loc.rows * this->ctx.cols(B.n_cols, this->loc.block));
87
88 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
89 const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
90
91 this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
92
93 IT info{-1};
94 // ReSharper disable CppCStyleCast
95 if constexpr(std::is_same_v<DT, double>) {
96 using E = double;
97 pdgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
98 }
99 else if constexpr(std::is_same_v<DT, float>) {
100 using E = float;
101 psgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
102 }
103 else if constexpr(std::is_same_v<DT, complex16>) {
104 using E = complex16;
105 pzgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
106 }
107 else if constexpr(std::is_same_v<DT, complex8>) {
108 using E = complex8;
109 pcgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
110 }
111 // ReSharper restore CppCStyleCast
112
113 if((info = this->ctx.amx(info)) == 0) this->ctx.gather(this->loc.b, loc_desc_b, B, full_desc_b);
114
115 return info;
116 }
117 };
118
119 template<index_t IT, char ODER = 'R'> using par_dgesv = pgesv<double, IT, ODER>;
120 template<index_t IT, char ODER = 'R'> using par_sgesv = pgesv<float, IT, ODER>;
121 template<index_t IT, char ODER = 'R'> using par_zgesv = pgesv<complex16, IT, ODER>;
122 template<index_t IT, char ODER = 'R'> using par_cgesv = pgesv<complex8, IT, ODER>;
123 template<index_t IT> using par_dgesv_c = pgesv<double, IT, 'C'>;
124 template<index_t IT> using par_sgesv_c = pgesv<float, IT, 'C'>;
125 template<index_t IT> using par_zgesv_c = pgesv<complex16, IT, 'C'>;
126 template<index_t IT> using par_cgesv_c = pgesv<complex8, IT, 'C'>;
127} // namespace ezp
128
129#endif // PGESV_HPP
130
Definition full_solver.hpp:24
Definition pgesv.hpp:33
Definition abstract_solver.hpp:68