36 static constexpr char TRAN =
'N';
43 int solve_trs(Mat<T>&, Mat<T>&&);
54 BandMat(
const uword in_size,
const uword in_l,
const uword in_u)
55 :
DenseMat<
T>(in_size, in_size, (2 * in_l + in_u + 1) * in_size)
57 , m_rows(2 * in_l + in_u + 1)
61 suanpan_warning(
"The storage requirement for the banded matrix is larger than that of a full matrix, consider using a full/sparse matrix instead.\n");
64 unique_ptr<MetaMat<T>>
make_copy()
override {
return std::make_unique<BandMat>(*
this); }
72 T operator()(
const uword in_row,
const uword in_col)
const override {
73 if(in_row > in_col +
l_band || in_row +
u_band < in_col) [[unlikely]]
return bin =
T(0);
74 return this->
memory[in_row + s_band + in_col * (m_rows - 1)];
77 T&
unsafe_at(
const uword in_row,
const uword in_col)
override {
79 return this->
memory[in_row + s_band + in_col * (m_rows - 1)];
82 T&
at(
const uword in_row,
const uword in_col)
override {
83 if(in_row > in_col +
l_band || in_row +
u_band < in_col) [[unlikely]]
return bin =
T(0);
87 Mat<T>
operator*(
const Mat<T>&)
const override;
93 Mat<T> Y(arma::size(X));
95 const auto M =
static_cast<int>(this->n_rows);
96 const auto N =
static_cast<int>(this->n_cols);
97 const auto KL =
static_cast<int>(l_band);
98 const auto KU =
static_cast<int>(u_band);
99 const auto LDA =
static_cast<int>(m_rows);
100 constexpr auto INC = 1;
104 if constexpr(std::is_same_v<T, float>) {
106 suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
110 suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
117 if(this->factored)
return this->solve_trs(X, std::forward<Mat<T>>(B));
119 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw invalid_argument(
"requires a square matrix"); });
123 auto N =
static_cast<int>(this->n_rows);
124 const auto KL =
static_cast<int>(l_band);
125 const auto KU =
static_cast<int>(u_band);
126 const auto NRHS =
static_cast<int>(B.n_cols);
127 const auto LDAB =
static_cast<int>(m_rows);
128 const auto LDB =
static_cast<int>(B.n_rows);
129 this->pivot.zeros(
N);
130 this->factored =
true;
132 if constexpr(std::is_same_v<T, float>) {
134 arma_fortran(arma_sgbsv)(&
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
139 arma_fortran(arma_dgbsv)(&
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
143 this->s_memory = this->to_float();
144 arma_fortran(arma_sgbtrf)(&
N, &
N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
145 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
149 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
157 const auto N =
static_cast<int>(this->n_rows);
158 const auto KL =
static_cast<int>(l_band);
159 const auto KU =
static_cast<int>(u_band);
160 const auto NRHS =
static_cast<int>(B.n_cols);
161 const auto LDAB =
static_cast<int>(m_rows);
162 const auto LDB =
static_cast<int>(B.n_rows);
164 if constexpr(std::is_same_v<T, float>) {
166 arma_fortran(arma_sgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
171 arma_fortran(arma_dgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
175 this->mixed_trs(X, std::forward<Mat<T>>(B), [&](fmat& residual) {
176 arma_fortran(arma_sgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
181 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
A BandMat class that holds matrices.
Definition BandMat.hpp:35
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandMat.hpp:77
BandMat(const uword in_size, const uword in_l, const uword in_u)
Definition BandMat.hpp:54
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandMat.hpp:82
const uword u_band
Definition BandMat.hpp:47
unique_ptr< MetaMat< T > > make_copy() override
Definition BandMat.hpp:64
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandMat.hpp:72
const uword l_band
Definition BandMat.hpp:46
void nullify(const uword K) override
Definition BandMat.hpp:66
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:28
#define suanpan_warning(...)
Definition suanPan.h:308
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:296
#define suanpan_error(...)
Definition suanPan.h:309