87 static constexpr char UPLO = UL;
92 IT n{-1}, klu{-1}, lead{-1}, block{-1}, lines{-1};
94 std::vector<DT> a, b, work;
99 auto init_storage(
const IT n,
const IT klu) {
102 loc.lead = loc.klu + 1;
103 loc.block = std::max(loc.n / std::max(IT{1}, this->ctx.n_rows - 1) + 1, std::max(2 * loc.klu, this->ctx.row_block(loc.n)));
104 loc.block = std::min(loc.block, loc.n);
105 loc.lines = this->ctx.rows(loc.n, loc.block);
106 loc.desc1d_a = {501, this->trans_ctx.context, loc.n, loc.block, 0, loc.lead, 0, 0, 0};
108 loc.a.resize(loc.lead * loc.lines);
111 using base_t::to_band_symm;
112 using base_t::to_full;
115 explicit ppbsv(
const IT rows = get_env<IT>().size())
126 indexer(
const IT N,
const IT KLU)
130 auto operator()(IT i, IT j)
const {
131 if(i < 0 || i >= n || j < 0 || j >= n)
return IT{-1};
133 if(i < j) std::swap(i, j);
134 if(i - j > klu)
return IT{-1};
138 if(i > j) std::swap(i, j);
139 if(j - i > klu)
return IT{-1};
140 return 2 * j - i + (j + 1) * klu;
145 template<band_symm_container_t AT, full_container_t BT> IT solve(AT&& A, BT&& B) {
return solve(to_band_symm(A), to_full(B)); }
148 if(!this->ctx.is_valid() || !this->trans_ctx.is_valid())
return 0;
150 if(A.n_rows != A.n_cols || A.n_rows != B.n_rows)
return -1;
152 init_storage(A.n_cols, A.klu);
156 this->trans_ctx.scatter(
158 this->trans_ctx.desc_g(loc.lead, loc.n),
160 this->trans_ctx.desc_l(loc.lead, loc.n, loc.lead, loc.block, loc.lead)
163 const IT laf = (loc.block + 2 * loc.klu) * loc.klu;
164 const IT lwork = loc.klu * std::max(B.n_cols, loc.klu);
165 loc.work.resize(laf + lwork);
169 if constexpr(std::is_same_v<DT, double>) {
171 pdpbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
173 else if constexpr(std::is_same_v<DT, float>) {
175 pspbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
177 else if constexpr(std::is_same_v<DT, complex16>) {
179 pzpbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
181 else if constexpr(std::is_same_v<DT, complex8>) {
183 pcpbtrf(&UPLO, &loc.n, &loc.klu, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
187 if((info = this->trans_ctx.amx(info)) != 0)
return info;
189 return solve(std::move(B));
192 IT solve(full_mat<DT, IT>&& B) {
193 if(B.n_rows != loc.n)
return -1;
195 if(!this->ctx.is_valid() || !this->trans_ctx.is_valid())
return 0;
197 const auto lead_b = std::max(loc.block, loc.lines);
199 loc.b.resize(lead_b * B.n_cols);
201 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
202 const auto local_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, loc.block, B.n_cols, lead_b);
204 this->ctx.scatter(B, full_desc_b, loc.b, local_desc_b);
206 const IT laf = (loc.block + 2 * loc.klu) * loc.klu;
207 const IT lwork = loc.klu * std::max(B.n_cols, loc.klu);
208 loc.work.resize(laf + lwork);
210 desc<IT> desc1d_b{502, this->trans_ctx.context, loc.n, loc.block, 0, lead_b, 0, 0, 0};
214 if constexpr(std::is_same_v<DT, double>) {
216 pdpbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
218 else if constexpr(std::is_same_v<DT, float>) {
220 pspbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
222 else if constexpr(std::is_same_v<DT, complex16>) {
224 pzpbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
226 else if constexpr(std::is_same_v<DT, complex8>) {
228 pcpbtrs(&UPLO, &loc.n, &loc.klu, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
232 if((info = this->trans_ctx.amx(info)) == 0) this->ctx.gather(loc.b, local_desc_b, B, full_desc_b);