36 static constexpr char FACT =
'E';
37 static constexpr char UPLO = UL;
41 static auto ceil(
const IT a,
const IT b) {
return (a + b - 1) / b; }
43 auto compute_lwork() {
44 const auto ceil_a = std::max(IT{1}, ceil(this->ctx.n_rows - 1, this->ctx.n_cols));
45 const auto ceil_b = std::max(IT{1}, ceil(this->ctx.n_cols - 1, this->ctx.n_rows));
46 const auto ppocon_lwork = 2 * (this->loc.rows + this->loc.cols) + std::max(IT{2}, std::max(this->loc.block * ceil_a, this->loc.cols + this->loc.block * ceil_b));
48 const auto pporfs_lwork = 3 * this->loc.rows;
50 return std::max(ppocon_lwork, pporfs_lwork);
53 auto compute_lrwork() {
54 const auto lcmp = std::lcm(this->ctx.n_rows, this->ctx.n_cols) / this->ctx.n_rows;
56 return this->loc.rows + 2 * this->loc.cols + this->loc.block * ceil(ceil(this->loc.rows, this->loc.block), lcmp);
59 struct expert_system {
61 std::vector<DT> af, work;
62 std::vector<work_t<DT>> sr, sc;
67 auto init_expert_storage() {
68 exp.lwork = compute_lwork();
69 exp.af.resize(this->loc.a.size());
70 exp.work.resize(exp.lwork);
71 exp.sr.resize(this->loc.rows);
72 exp.sc.resize(this->loc.cols);
79 pposvx(
const IT rows,
const IT cols)
85 if(!this->ctx.is_valid())
return 0;
87 if(A.n_rows != A.n_cols || A.n_cols != B.n_rows)
return -1;
89 this->init_storage(A.n_rows);
90 init_expert_storage();
92 this->ctx.scatter(A, this->ctx.desc_g(A.n_rows, A.n_cols), exp.af, this->loc.desc_a);
94 const auto loc_cols_b = this->ctx.cols(B.n_cols, this->loc.block);
95 this->loc.b.resize(this->loc.rows * loc_cols_b);
97 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
98 const auto loc_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, this->loc.block, this->loc.rows);
100 this->ctx.scatter(B, full_desc_b, this->loc.b, loc_desc_b);
102 std::vector<work_t<DT>> ferr(loc_cols_b), berr(loc_cols_b);
105 std::vector<DT> x(this->loc.b.size());
111 if constexpr(std::is_same_v<DT, double>) {
114 const auto liwork = this->loc.rows;
115 std::vector<IT> iwork(liwork);
117 pdposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (E*)exp.sr.data(), (E*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
119 else if constexpr(std::is_same_v<DT, float>) {
122 const auto liwork = this->loc.rows;
123 std::vector<IT> iwork(liwork);
125 psposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (E*)exp.sr.data(), (E*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)&rcond, (E*)ferr.data(), (E*)berr.data(), (E*)exp.work.data(), &exp.lwork, iwork.data(), &liwork, &info);
127 else if constexpr(std::is_same_v<DT, complex16>) {
129 using BE = work_t<complex16>;
131 const auto lrwork = compute_lrwork();
132 std::vector<BE> rwork(lrwork);
134 pzposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (BE*)exp.sr.data(), (BE*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
136 else if constexpr(std::is_same_v<DT, complex8>) {
138 using BE = work_t<complex8>;
140 const auto lrwork = compute_lrwork();
141 std::vector<BE> rwork(lrwork);
143 pcposvx(&FACT, &UPLO, &this->loc.n, &B.n_cols, (E*)exp.af.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), (E*)this->loc.a.data(), &this->ONE, &this->ONE, this->loc.desc_a.data(), &equed, (BE*)exp.sr.data(), (BE*)exp.sc.data(), (E*)this->loc.b.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (E*)x.data(), &this->ONE, &this->ONE, loc_desc_b.data(), (BE*)&rcond, (BE*)ferr.data(), (BE*)berr.data(), (E*)exp.work.data(), &exp.lwork, rwork.data(), &lrwork, &info);
147 if((info = this->ctx.amx(info)) != 0)
return info;
149 if(equed ==
'C' || equed ==
'B')
150 for(
auto i = 0; i < loc_cols_b; ++i)
151 for(
auto j = 0; j < this->loc.rows; ++j) x[j * loc_cols_b + i] /= exp.sc[j];
153 this->ctx.gather(x, loc_desc_b, B, full_desc_b);