36 static constexpr char TRAN =
'N';
45 int solve_trs(Mat<T>&, Mat<T>&&);
46 int solve_trs(Mat<T>&,
const Mat<T>&);
51 unique_ptr<MetaMat<T>>
make_copy()
override;
53 void unify(uword)
override;
58 T&
at(uword, uword)
override;
60 Mat<T>
operator*(
const Mat<T>&)
const override;
69 :
DenseMat<
T>(in_size, in_size, (2 * in_l + in_u + 1) * in_size)
73 , m_rows(2 * in_l + in_u + 1) {}
79 access::rw(this->memory[s_band +
K * m_rows]) = 1.;
83 suanpan_for(std::max(
K, u_band) - u_band, std::min(this->n_rows,
K + l_band + 1), [&](
const uword I) { access::rw(this->memory[I + s_band +
K * (m_rows - 1)]) = 0.; });
84 suanpan_for(std::max(
K, l_band) - l_band, std::min(this->n_cols,
K + u_band + 1), [&](
const uword I) { access::rw(this->memory[
K + s_band + I * (m_rows - 1)]) = 0.; });
86 this->factored =
false;
90 if(in_row > in_col + l_band || in_row + u_band < in_col)
return bin = 0.;
91 return this->memory[in_row + s_band + in_col * (m_rows - 1)];
95 this->factored =
false;
96 return access::rw(this->memory[in_row + s_band + in_col * (m_rows - 1)]);
100 if(in_row > in_col + l_band || in_row + u_band < in_col)
return bin = 0.;
101 this->factored =
false;
102 return access::rw(this->memory[in_row + s_band + in_col * (m_rows - 1)]);
106 Mat<T> Y(arma::size(X));
108 const auto M =
static_cast<int>(this->n_rows);
109 const auto N =
static_cast<int>(this->n_cols);
110 const auto KL =
static_cast<int>(l_band);
111 const auto KU =
static_cast<int>(u_band);
112 const auto LDA =
static_cast<int>(m_rows);
117 if(std::is_same_v<T, float>) {
119 suanpan_for(0llu, X.n_cols, [&](
const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
121 else if(std::is_same_v<T, double>) {
123 suanpan_for(0llu, X.n_cols, [&](
const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
130 if(this->factored)
return this->solve_trs(X, B);
132 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw invalid_argument(
"requires a square matrix"); });
136 auto N =
static_cast<int>(this->n_rows);
137 const auto KL =
static_cast<int>(l_band);
138 const auto KU =
static_cast<int>(u_band);
139 const auto NRHS =
static_cast<int>(B.n_cols);
140 const auto LDAB =
static_cast<int>(m_rows);
141 const auto LDB =
static_cast<int>(B.n_rows);
142 this->pivot.zeros(
N);
143 this->factored =
true;
145 if(std::is_same_v<T, float>) {
148 arma_fortran(arma_sgbsv)(&
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
153 arma_fortran(arma_dgbsv)(&
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
156 this->s_memory = this->to_float();
157 arma_fortran(arma_sgbtrf)(&
N, &
N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
158 if(0 == INFO) INFO = this->solve_trs(X, B);
162 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
170 const auto N =
static_cast<int>(this->n_rows);
171 const auto KL =
static_cast<int>(l_band);
172 const auto KU =
static_cast<int>(u_band);
173 const auto NRHS =
static_cast<int>(B.n_cols);
174 const auto LDAB =
static_cast<int>(m_rows);
175 const auto LDB =
static_cast<int>(B.n_rows);
177 if(std::is_same_v<T, float>) {
180 arma_fortran(arma_sgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
185 arma_fortran(arma_dgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
188 X = arma::zeros(B.n_rows, B.n_cols);
190 mat full_residual = B;
192 auto multiplier =
norm(full_residual);
195 while(counter++ < this->setting.iterative_refinement) {
196 if(multiplier < this->setting.tolerance)
break;
198 auto residual = conv_to<fmat>::from(full_residual / multiplier);
200 arma_fortran(arma_sgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
203 const mat incre = multiplier * conv_to<mat>::from(residual);
207 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->
operator*(incre)));
212 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
218 if(this->factored)
return this->solve_trs(X, std::forward<Mat<T>>(B));
220 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw invalid_argument(
"requires a square matrix"); });
224 auto N =
static_cast<int>(this->n_rows);
225 const auto KL =
static_cast<int>(l_band);
226 const auto KU =
static_cast<int>(u_band);
227 const auto NRHS =
static_cast<int>(B.n_cols);
228 const auto LDAB =
static_cast<int>(m_rows);
229 const auto LDB =
static_cast<int>(B.n_rows);
230 this->pivot.zeros(
N);
231 this->factored =
true;
233 if(std::is_same_v<T, float>) {
235 arma_fortran(arma_sgbsv)(&
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
240 arma_fortran(arma_dgbsv)(&
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
244 this->s_memory = this->to_float();
245 arma_fortran(arma_sgbtrf)(&
N, &
N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
246 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
250 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
258 const auto N =
static_cast<int>(this->n_rows);
259 const auto KL =
static_cast<int>(l_band);
260 const auto KU =
static_cast<int>(u_band);
261 const auto NRHS =
static_cast<int>(B.n_cols);
262 const auto LDAB =
static_cast<int>(m_rows);
263 const auto LDB =
static_cast<int>(B.n_rows);
265 if(std::is_same_v<T, float>) {
267 arma_fortran(arma_sgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
272 arma_fortran(arma_dgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
276 X = arma::zeros(B.n_rows, B.n_cols);
278 auto multiplier =
norm(B);
281 while(counter++ < this->setting.iterative_refinement) {
282 if(multiplier < this->setting.tolerance)
break;
284 auto residual = conv_to<fmat>::from(B / multiplier);
286 arma_fortran(arma_sgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
289 const mat incre = multiplier * conv_to<mat>::from(residual);
293 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->
operator*(incre)));
298 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
A BandMat class that holds matrices.
Definition: BandMat.hpp:35
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
void suanpan_assert(const std::function< void()> &F)
Definition: suanPan.h:284
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27