30#ifndef BANDMATSPIKE_HPP
31#define BANDMATSPIKE_HPP
33#include <feast/spike.h>
37 static constexpr char TRAN =
'N';
45 podarray<int> SPIKE{64};
47 podarray<float> SWORK;
50 auto N =
static_cast<int>(this->
n_rows);
51 auto KLU =
static_cast<int>(std::max(l_band, u_band));
53 spikeinit_(SPIKE.memptr(), &
N, &KLU);
55 std::is_same_v<T, float> ? sspike_tune_(SPIKE.memptr()) : dspike_tune_(SPIKE.memptr());
58 int solve_trs(Mat<T>&, Mat<T>&&);
66 BandMatSpike(
const uword in_size,
const uword in_l,
const uword in_u)
67 :
DenseMat<
T>(in_size, in_size, (in_l + in_u + 1) * in_size)
70 , m_rows(in_l + in_u + 1) { init_spike(); }
72 unique_ptr<MetaMat<T>>
make_copy()
override {
return std::make_unique<BandMatSpike>(*
this); }
76 suanpan::for_each(std::max(
K, u_band) - u_band, std::min(this->
n_rows, K + l_band + 1), [&](
const uword I) { this->
memory[I + u_band +
K * (m_rows - 1)] =
T(0); });
77 suanpan::for_each(std::max(
K, l_band) - l_band, std::min(this->
n_cols, K + u_band + 1), [&](
const uword I) { this->
memory[K + u_band + I * (m_rows - 1)] =
T(0); });
80 T operator()(
const uword in_row,
const uword in_col)
const override {
81 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]]
return bin =
T(0);
82 return this->
memory[in_row + u_band + in_col * (m_rows - 1)];
85 T&
unsafe_at(
const uword in_row,
const uword in_col)
override {
87 return this->
memory[in_row + u_band + in_col * (m_rows - 1)];
90 T&
at(
const uword in_row,
const uword in_col)
override {
91 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]]
return bin =
T(0);
95 Mat<T>
operator*(
const Mat<T>&)
const override;
97 [[nodiscard]]
int sign_det()
const override {
throw invalid_argument(
"not supported"); }
103 Mat<T> Y(arma::size(X));
105 const auto M =
static_cast<int>(this->n_rows);
106 const auto N =
static_cast<int>(this->n_cols);
107 const auto KL =
static_cast<int>(l_band);
108 const auto KU =
static_cast<int>(u_band);
109 const auto LDA =
static_cast<int>(m_rows);
110 constexpr auto INC = 1;
114 if constexpr(std::is_same_v<T, float>) {
116 suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
120 suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
127 if(!this->factored) {
128 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw invalid_argument(
"requires a square matrix"); });
132 auto N =
static_cast<int>(this->n_rows);
133 auto KL =
static_cast<int>(l_band);
134 auto KU =
static_cast<int>(u_band);
135 auto LDAB =
static_cast<int>(m_rows);
136 const auto KLU = std::max(l_band, u_band);
137 this->factored =
true;
139 if constexpr(std::is_same_v<T, float>) {
141 WORK.zeros(KLU * KLU *
SPIKE(9));
142 sspike_gbtrf_(
SPIKE.memptr(), &
N, &KL, &KU, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), &INFO);
146 WORK.zeros(KLU * KLU *
SPIKE(9));
147 dspike_gbtrf_(
SPIKE.memptr(), &
N, &KL, &KU, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), &INFO);
150 this->s_memory = this->to_float();
151 SWORK.zeros(KLU * KLU *
SPIKE(9));
152 sspike_gbtrf_(
SPIKE.memptr(), &
N, &KL, &KU, this->s_memory.memptr(), &LDAB, SWORK.memptr(), &INFO);
156 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
161 return this->solve_trs(X, std::forward<Mat<T>>(B));
165 auto N =
static_cast<int>(this->n_rows);
166 auto KL =
static_cast<int>(l_band);
167 auto KU =
static_cast<int>(u_band);
168 auto NRHS =
static_cast<int>(B.n_cols);
169 auto LDAB =
static_cast<int>(m_rows);
170 auto LDB =
static_cast<int>(B.n_rows);
172 if constexpr(std::is_same_v<T, float>) {
174 sspike_gbtrs_(
SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)B.memptr(), &LDB);
179 dspike_gbtrs_(
SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)B.memptr(), &LDB);
183 this->mixed_trs(X, std::forward<Mat<T>>(B), [&](fmat& residual) {
184 sspike_gbtrs_(
SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
A BandMatSpike class that holds matrices.
Definition: BandMatSpike.hpp:36
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition: BandMatSpike.hpp:90
BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
Definition: BandMatSpike.hpp:66
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition: BandMatSpike.hpp:80
int sign_det() const override
Definition: BandMatSpike.hpp:97
void nullify(const uword K) override
Definition: BandMatSpike.hpp:74
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandMatSpike.hpp:72
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition: BandMatSpike.hpp:85
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition: DenseMat.hpp:48
void for_each(const IT start, const IT end, F &&FN)
Definition: utility.h:28
constexpr auto SUANPAN_SUCCESS
Definition: suanPan.h:172
void suanpan_assert(const std::function< void()> &F)
Definition: suanPan.h:296
#define suanpan_error(...)
Definition: suanPan.h:309