40 #include "abstract/full_solver.hpp"
46 static constexpr
char FACT =
'E';
47 static constexpr
char UPLO = UL;
51 static auto ceil(
const IT a,
const IT b) {
return (a + b - 1) / b; }
53 auto compute_lwork() {
54 const auto ceil_a = std::max(IT{1}, ceil(this->ctx.n_rows - 1, this->ctx.n_cols));
55 const auto ceil_b = std::max(IT{1}, ceil(this->ctx.n_cols - 1, this->ctx.n_rows));
56 const auto ppocon_lwork = 2 * (this->loc.rows + this->loc.cols) + std::max(IT{2}, std::max(this->loc.block * ceil_a, this->loc.cols + this->loc.block * ceil_b));
58 const auto pporfs_lwork = 3 * this->loc.rows;
60 return std::max(ppocon_lwork, pporfs_lwork);
63 auto compute_lrwork() {
64 const auto lcmp = std::lcm(this->ctx.n_rows, this->ctx.n_cols) / this->ctx.n_rows;
66 return this->loc.rows + 2 * this->loc.cols + this->loc.block * ceil(ceil(this->loc.rows, this->loc.block), lcmp);
69 struct expert_system {
71 std::vector<DT> af, work;
72 std::vector<work_t<DT>> sr, sc;
77 auto init_expert_storage() {
78 exp.lwork = compute_lwork();
79 exp.af.resize(this->loc.a.size());
80 exp.work.resize(exp.lwork);
81 exp.sr.resize(this->loc.rows);
82 exp.sc.resize(this->loc.cols);
89 pposvx(
const IT rows,
const IT cols)
95 if(!this->ctx.is_valid())
return 0;
97 if(A.n_rows != A.n_cols || A.n_cols != B.n_rows)
return -1;
99 this->init_storage(A.n_rows);
100 init_expert_storage();
102 this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), exp.af, this->loc.desc_a);
104 const auto loc_cols_b = this->ctx.cols(B.n_cols, this->loc.block);
105 this->loc.b.resize(this->loc.rows * loc_cols_b);
107 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
108 const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
110 this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
112 std::vector<work_t<DT>> ferr(loc_cols_b), berr(loc_cols_b);
115 std::vector<DT> x(this->loc.b.size());
121 if constexpr(std::is_same_v<DT, double>) {
124 const auto liwork = this->loc.rows;
125 std::vector<IT> iwork(liwork);
127 pdposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (E*)exp.sr.data(), (E*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
129 else if constexpr(std::is_same_v<DT, float>) {
132 const auto liwork = this->loc.rows;
133 std::vector<IT> iwork(liwork);
135 psposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (E*)exp.sr.data(), (E*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
137 else if constexpr(std::is_same_v<DT, complex16>) {
139 using BE = work_t<complex16>;
141 const auto lrwork = compute_lrwork();
142 std::vector<BE> rwork(lrwork);
144 pzposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (BE*)exp.sr.data(), (BE*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
146 else if constexpr(std::is_same_v<DT, complex8>) {
148 using BE = work_t<complex8>;
150 const auto lrwork = compute_lrwork();
151 std::vector<BE> rwork(lrwork);
153 pcposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (BE*)exp.sr.data(), (BE*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
157 if((info = this->ctx.amx(info)) != 0)
return info;
159 this->ctx.gather(x, loc_desc_b, B, full_desc_b);
164 IT solve(
full_mat<DT, IT>&&)
override {
throw std::runtime_error(
"not implemented"); }
Definition: full_solver.hpp:24
Definition: pposvx.hpp:45
Definition: traits.hpp:85