39 #include "abstract/full_solver.hpp"
49 pgesv(
const IT rows,
const IT cols)
78 if(!this->ctx.is_valid())
return determinant;
80 this->ctx.gather(this->loc.a, this->loc.desc_a, A, this->ctx.desc_g(this->loc.n, this->loc.n));
82 const auto ipiv = this->gather_pivot();
84 if(0 == this->ctx.rank) {
88 for(
auto I = IT{0}; I < this->loc.n; ++I) {
89 determinant *= A.data[idx(I, I)];
90 if(ipiv[I] != I + 1) ++swaps;
92 if(swaps % 2 == 1) determinant = -determinant;
99 if(!this->ctx.is_valid())
return 0;
101 if(A.n_rows != A.n_cols || A.n_cols != B.n_rows)
return -1;
103 this->init_storage(A.n_rows);
105 this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), this->loc.a, this->loc.desc_a);
109 if constexpr(std::is_same_v<DT, double>) {
111 pdgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
113 else if constexpr(std::is_same_v<DT, float>) {
115 psgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
117 else if constexpr(std::is_same_v<DT, complex16>) {
119 pzgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
121 else if constexpr(std::is_same_v<DT, complex8>) {
123 pcgetrf(&this->loc.n, &this->loc.n, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), &info);
127 if((info = this->ctx.amx(info)) != 0)
return info;
129 return solve(std::move(B));
132 IT solve(full_mat<DT, IT>&& B)
override {
133 static constexpr
char TRANS =
'N';
135 if(B.n_rows != this->loc.n)
return -1;
137 if(!this->ctx.is_valid())
return 0;
139 this->loc.b.resize(this->loc.rows * this->ctx.cols(B.n_cols, this->loc.block));
141 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
142 const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
144 this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
148 if constexpr(std::is_same_v<DT, double>) {
150 pdgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
152 else if constexpr(std::is_same_v<DT, float>) {
154 psgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
156 else if constexpr(std::is_same_v<DT, complex16>) {
158 pzgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
160 else if constexpr(std::is_same_v<DT, complex8>) {
162 pcgetrs(&TRANS, &this->loc.n, &B.n_cols, (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), this->loc.ipiv.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), &info);
166 if((info = this->ctx.amx(info)) == 0) this->ctx.gather(this->loc.b, loc_desc_b, B, full_desc_b);
Definition: full_solver.hpp:73
Definition: full_solver.hpp:24
auto det(full_mat< DT, IT > &&A)
Computes the determinant of a matrix.
Definition: pgesv.hpp:75
Solver for general full matrices.
Definition: traits.hpp:85