70 IT n{-1}, kl{-1}, ku{-1}, lead{-1}, block{-1}, lines{-1};
72 std::vector<DT> a, b, work;
78 auto init_storage(
const IT n,
const IT kl,
const IT ku) {
82 loc.lead = 2 * (loc.kl + loc.ku) + 1;
83 loc.block = std::max(loc.n / std::max(IT{1}, this->ctx.n_rows - 1) + 1, std::max(loc.kl + loc.ku + 1, this->ctx.row_block(loc.n)));
84 loc.block = std::min(loc.block, loc.n);
85 loc.lines = this->ctx.rows(loc.n, loc.block);
86 loc.desc1d_a = {501, this->trans_ctx.context, loc.n, loc.block, 0, loc.lead, 0, 0, 0};
89 loc.a.resize(loc.lead * loc.lines + loc.ku);
90 loc.ipiv.resize(std::min(loc.n, loc.lines + loc.kl + loc.ku), -987654);
93 using base_t::to_band;
94 using base_t::to_full;
97 explicit pgbsv(
const IT rows = get_env<IT>().size())
109 indexer(
const IT N,
const IT KL,
const IT KU)
114 auto operator()(
const IT i,
const IT j)
const {
115 if(i < 0 || i >= n || j < 0 || j >= n)
return IT{-1};
116 if(i - j > kl || j - i > ku)
return IT{-1};
117 return 2 * ku + kl + i + 2 * j * (kl + ku);
121 template<band_container_t AT, full_container_t BT> IT solve(AT&& A, BT&& B) {
return solve(to_band(A), to_full(B)); }
124 if(!this->ctx.is_valid() || !this->trans_ctx.is_valid())
return 0;
126 if(A.n_rows != A.n_cols || A.n_rows != B.n_rows)
return -1;
128 init_storage(A.n_cols, A.kl, A.ku);
132 this->trans_ctx.scatter(
134 this->trans_ctx.desc_g(loc.lead, loc.n),
136 this->trans_ctx.desc_l(loc.lead, loc.n, loc.lead, loc.block, loc.lead)
139 const IT laf = (loc.block + 6 * loc.kl + 13 * loc.ku) * (loc.kl + loc.ku);
140 const IT lwork = std::max(B.n_cols * (loc.block + 2 * loc.kl + 4 * loc.ku), IT{1});
141 loc.work.resize(laf + lwork);
145 if constexpr(std::is_same_v<DT, double>) {
147 pdgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
149 else if constexpr(std::is_same_v<DT, float>) {
151 psgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
153 else if constexpr(std::is_same_v<DT, complex16>) {
155 pzgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
157 else if constexpr(std::is_same_v<DT, complex8>) {
159 pcgbtrf(&loc.n, &loc.kl, &loc.ku, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
163 if((info = this->trans_ctx.amx(info)) != 0)
return info;
165 return solve(std::move(B));
168 IT solve(full_mat<DT, IT>&& B) {
169 static constexpr char TRANS =
'N';
171 if(B.n_rows != loc.n)
return -1;
173 if(!this->ctx.is_valid() || !this->trans_ctx.is_valid())
return 0;
175 const auto lead_b = std::max(loc.block, loc.lines);
177 loc.b.resize(lead_b * B.n_cols);
179 const auto full_desc_b = this->ctx.desc_g(B.n_rows, B.n_cols);
180 const auto local_desc_b = this->ctx.desc_l(B.n_rows, B.n_cols, loc.block, B.n_cols, lead_b);
182 this->ctx.scatter(B, full_desc_b, loc.b, local_desc_b);
184 const IT laf = (loc.block + 6 * loc.kl + 13 * loc.ku) * (loc.kl + loc.ku);
185 const IT lwork = std::max(B.n_cols * (loc.block + 2 * loc.kl + 4 * loc.ku), IT{1});
186 loc.work.resize(laf + lwork);
188 desc<IT> desc1d_b{502, this->trans_ctx.context, loc.n, loc.block, 0, lead_b, 0, 0, 0};
192 if constexpr(std::is_same_v<DT, double>) {
194 pdgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
196 else if constexpr(std::is_same_v<DT, float>) {
198 psgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
200 else if constexpr(std::is_same_v<DT, complex16>) {
202 pzgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
204 else if constexpr(std::is_same_v<DT, complex8>) {
206 pcgbtrs(&TRANS, &loc.n, &loc.kl, &loc.ku, &B.n_cols, (E*)loc.a.data(), &this->ONE, loc.desc1d_a.data(), loc.ipiv.data(), (E*)loc.b.data(), &this->ONE, desc1d_b.data(), (E*)loc.work.data(), &laf, (E*)(loc.work.data() + laf), &lwork, &info);
210 if((info = this->trans_ctx.amx(info)) == 0) this->ctx.gather(loc.b, local_desc_b, B, full_desc_b);