30#ifndef BANDSYMMMAT_HPP
31#define BANDSYMMMAT_HPP
36 static constexpr char UPLO =
'L';
43 int solve_trs(Mat<T>&, Mat<T>&&);
44 int solve_trs(Mat<T>&,
const Mat<T>&);
49 unique_ptr<MetaMat<T>>
make_copy()
override;
51 void unify(uword)
override;
56 T&
at(uword, uword)
override;
58 Mat<T>
operator*(
const Mat<T>&)
const override;
67 :
DenseMat<
T>(in_size, in_size, (in_bandwidth + 1) * in_size)
69 , m_rows(in_bandwidth + 1) {}
75 access::rw(this->memory[
K * m_rows]) = 1.;
79 suanpan_for(std::max(band,
K) - band,
K, [&](
const uword I) { access::rw(this->memory[
K - I + I * m_rows]) = 0.; });
80 const auto t_factor =
K * m_rows -
K;
81 suanpan_for(
K, std::min(this->n_rows,
K + band + 1), [&](
const uword I) { access::rw(this->memory[I + t_factor]) = 0.; });
83 this->factored =
false;
87 if(in_row > band + in_col)
return bin = 0.;
88 return this->memory[in_row > in_col ? in_row - in_col + in_col * m_rows : in_col - in_row + in_row * m_rows];
92 this->factored =
false;
93 return access::rw(this->memory[in_row - in_col + in_col * m_rows]);
97 if(in_row > band + in_col || in_row < in_col) [[unlikely]]
return bin = 0.;
98 this->factored =
false;
99 return access::rw(this->memory[in_row - in_col + in_col * m_rows]);
103 Mat<T> Y(arma::size(X));
105 const auto N =
static_cast<int>(this->n_cols);
106 const auto K =
static_cast<int>(band);
107 const auto LDA =
static_cast<int>(m_rows);
112 if(std::is_same_v<T, float>) {
114 suanpan_for(0llu, X.n_cols, [&](
const uword I) { arma_fortran(arma_ssbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
116 else if(std::is_same_v<T, double>) {
118 suanpan_for(0llu, X.n_cols, [&](
const uword I) { arma_fortran(arma_dsbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
125 if(this->factored)
return this->solve_trs(X, B);
127 const auto N =
static_cast<int>(this->n_rows);
128 const auto KD =
static_cast<int>(band);
129 const auto NRHS =
static_cast<int>(B.n_cols);
130 const auto LDAB =
static_cast<int>(m_rows);
131 const auto LDB =
static_cast<int>(B.n_rows);
134 this->factored =
true;
136 if(std::is_same_v<T, float>) {
139 arma_fortran(arma_spbsv)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)X.memptr(), &LDB, &INFO);
144 arma_fortran(arma_dpbsv)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)X.memptr(), &LDB, &INFO);
147 this->s_memory = this->to_float();
148 arma_fortran(arma_spbtrf)(&UPLO, &
N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
149 if(0 == INFO) INFO = this->solve_trs(X, B);
153 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
159 const auto N =
static_cast<int>(this->n_rows);
160 const auto KD =
static_cast<int>(band);
161 const auto NRHS =
static_cast<int>(B.n_cols);
162 const auto LDAB =
static_cast<int>(m_rows);
163 const auto LDB =
static_cast<int>(B.n_rows);
166 if(std::is_same_v<T, float>) {
169 arma_fortran(arma_spbtrs)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)X.memptr(), &LDB, &INFO);
174 arma_fortran(arma_dpbtrs)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)X.memptr(), &LDB, &INFO);
177 X = arma::zeros(B.n_rows, B.n_cols);
179 mat full_residual = B;
181 auto multiplier =
norm(full_residual);
184 while(counter++ < this->setting.iterative_refinement) {
185 if(multiplier < this->setting.tolerance)
break;
187 auto residual = conv_to<fmat>::from(full_residual / multiplier);
189 arma_fortran(arma_spbtrs)(&UPLO, &
N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
192 const mat incre = multiplier * conv_to<mat>::from(residual);
196 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->
operator*(incre)));
201 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
207 if(this->factored)
return this->solve_trs(X, std::forward<Mat<T>>(B));
209 const auto N =
static_cast<int>(this->n_rows);
210 const auto KD =
static_cast<int>(band);
211 const auto NRHS =
static_cast<int>(B.n_cols);
212 const auto LDAB =
static_cast<int>(m_rows);
213 const auto LDB =
static_cast<int>(B.n_rows);
216 this->factored =
true;
218 if(std::is_same_v<T, float>) {
220 arma_fortran(arma_spbsv)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
225 arma_fortran(arma_dpbsv)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
229 this->s_memory = this->to_float();
230 arma_fortran(arma_spbtrf)(&UPLO, &
N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
231 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
235 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
241 const auto N =
static_cast<int>(this->n_rows);
242 const auto KD =
static_cast<int>(band);
243 const auto NRHS =
static_cast<int>(B.n_cols);
244 const auto LDAB =
static_cast<int>(m_rows);
245 const auto LDB =
static_cast<int>(B.n_rows);
248 if(std::is_same_v<T, float>) {
250 arma_fortran(arma_spbtrs)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
255 arma_fortran(arma_dpbtrs)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
259 X = arma::zeros(B.n_rows, B.n_cols);
261 auto multiplier =
norm(B);
264 while(counter++ < this->setting.iterative_refinement) {
265 if(multiplier < this->setting.tolerance)
break;
267 auto residual = conv_to<fmat>::from(B / multiplier);
269 arma_fortran(arma_spbtrs)(&UPLO, &
N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
272 const mat incre = multiplier * conv_to<mat>::from(residual);
276 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->
operator*(incre)));
281 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
A BandSymmMat class that holds matrices.
Definition: BandSymmMat.hpp:35
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27