30#ifndef BANDMATSPIKE_HPP
31#define BANDMATSPIKE_HPP
33#include <feast/spike.h>
37 static constexpr char TRAN =
'N';
45 podarray<int> SPIKE = podarray<int>(64);
47 podarray<float> SWORK;
51 int solve_trs(Mat<T>&, Mat<T>&&);
52 int solve_trs(Mat<T>&,
const Mat<T>&);
57 unique_ptr<MetaMat<T>>
make_copy()
override;
59 void unify(uword)
override;
63 T&
at(uword, uword)
override;
65 Mat<T>
operator*(
const Mat<T>&)
const override;
70 [[nodiscard]]
int sign_det()
const override;
76 auto N =
static_cast<int>(this->n_rows);
77 auto KLU =
static_cast<int>(std::max(l_band, u_band));
79 spikeinit_(
SPIKE.memptr(), &
N, &KLU);
81 std::is_same_v<T, float> ? sspike_tune_(
SPIKE.memptr()) : dspike_tune_(
SPIKE.memptr());
85 :
DenseMat<
T>(in_size, in_size, (in_l + in_u + 1) * in_size)
88 , m_rows(in_l + in_u + 1) { init_spike(); }
94 access::rw(this->memory[u_band +
K * m_rows]) = 1.;
98 suanpan_for(std::max(
K, u_band) - u_band, std::min(this->n_rows,
K + l_band + 1), [&](
const uword I) { access::rw(this->memory[I + u_band +
K * (m_rows - 1)]) = 0.; });
99 suanpan_for(std::max(
K, l_band) - l_band, std::min(this->n_cols,
K + u_band + 1), [&](
const uword I) { access::rw(this->memory[
K + u_band + I * (m_rows - 1)]) = 0.; });
101 this->factored =
false;
105 if(in_row > in_col + l_band || in_row + u_band < in_col)
return bin = 0.;
106 return this->memory[in_row + u_band + in_col * (m_rows - 1)];
110 if(in_row > in_col + l_band || in_row + u_band < in_col)
return bin = 0.;
111 this->factored =
false;
112 return access::rw(this->memory[in_row + u_band + in_col * (m_rows - 1)]);
116 Mat<T> Y(arma::size(X));
118 const auto M =
static_cast<int>(this->n_rows);
119 const auto N =
static_cast<int>(this->n_cols);
120 const auto KL =
static_cast<int>(l_band);
121 const auto KU =
static_cast<int>(u_band);
122 const auto LDA =
static_cast<int>(m_rows);
127 if(std::is_same_v<T, float>) {
129 suanpan_for(0llu, X.n_cols, [&](
const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
131 else if(std::is_same_v<T, double>) {
133 suanpan_for(0llu, X.n_cols, [&](
const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
140 if(!this->factored) {
141 auto N =
static_cast<int>(this->n_rows);
142 auto KL =
static_cast<int>(l_band);
143 auto KU =
static_cast<int>(u_band);
144 auto LDAB =
static_cast<int>(m_rows);
145 const auto KLU = std::max(l_band, u_band);
148 if(std::is_same_v<T, float>) {
150 WORK.zeros(KLU * KLU *
SPIKE(9));
151 sspike_gbtrf_(
SPIKE.memptr(), &
N, &KL, &KU, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), &INFO);
155 WORK.zeros(KLU * KLU *
SPIKE(9));
156 dspike_gbtrf_(
SPIKE.memptr(), &
N, &KL, &KU, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), &INFO);
159 this->s_memory = this->to_float();
160 SWORK.zeros(KLU * KLU *
SPIKE(9));
161 sspike_gbtrf_(
SPIKE.memptr(), &
N, &KL, &KU, this->s_memory.mem, &LDAB, SWORK.memptr(), &INFO);
165 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
169 this->factored =
true;
172 return this->solve_trs(X, B);
176 auto N =
static_cast<int>(this->n_rows);
177 auto KL =
static_cast<int>(l_band);
178 auto KU =
static_cast<int>(u_band);
179 auto NRHS =
static_cast<int>(B.n_cols);
180 auto LDAB =
static_cast<int>(m_rows);
181 auto LDB =
static_cast<int>(B.n_rows);
183 if(std::is_same_v<T, float>) {
186 sspike_gbtrs_(
SPIKE.memptr(), &TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), (
E*)X.memptr(), &LDB);
191 dspike_gbtrs_(
SPIKE.memptr(), &TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), (
E*)X.memptr(), &LDB);
194 X = arma::zeros(B.n_rows, B.n_cols);
196 mat full_residual = B;
198 auto multiplier =
norm(full_residual);
201 while(counter++ < this->setting.iterative_refinement) {
202 if(multiplier < this->setting.tolerance)
break;
204 auto residual = conv_to<fmat>::from(full_residual / multiplier);
206 sspike_gbtrs_(
SPIKE.memptr(), &TRAN, &
N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
208 const mat incre = multiplier * conv_to<mat>::from(residual);
212 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->
operator*(incre)));
220 if(!this->factored) {
221 auto N =
static_cast<int>(this->n_rows);
222 auto KL =
static_cast<int>(l_band);
223 auto KU =
static_cast<int>(u_band);
224 auto LDAB =
static_cast<int>(m_rows);
225 const auto KLU = std::max(l_band, u_band);
228 if(std::is_same_v<T, float>) {
230 WORK.zeros(KLU * KLU *
SPIKE(9));
231 sspike_gbtrf_(
SPIKE.memptr(), &
N, &KL, &KU, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), &INFO);
235 WORK.zeros(KLU * KLU *
SPIKE(9));
236 dspike_gbtrf_(
SPIKE.memptr(), &
N, &KL, &KU, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), &INFO);
239 this->s_memory = this->to_float();
240 SWORK.zeros(KLU * KLU *
SPIKE(9));
241 sspike_gbtrf_(
SPIKE.memptr(), &
N, &KL, &KU, this->s_memory.mem, &LDAB, SWORK.memptr(), &INFO);
245 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
249 this->factored =
true;
252 return this->solve_trs(X, std::forward<Mat<T>>(B));
256 auto N =
static_cast<int>(this->n_rows);
257 auto KL =
static_cast<int>(l_band);
258 auto KU =
static_cast<int>(u_band);
259 auto NRHS =
static_cast<int>(B.n_cols);
260 auto LDAB =
static_cast<int>(m_rows);
261 auto LDB =
static_cast<int>(B.n_rows);
263 if(std::is_same_v<T, float>) {
265 sspike_gbtrs_(
SPIKE.memptr(), &TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), (
E*)B.memptr(), &LDB);
270 dspike_gbtrs_(
SPIKE.memptr(), &TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), (
E*)B.memptr(), &LDB);
274 X = arma::zeros(B.n_rows, B.n_cols);
276 auto multiplier =
norm(B);
279 while(counter++ < this->setting.iterative_refinement) {
280 if(multiplier < this->setting.tolerance)
break;
282 auto residual = conv_to<fmat>::from(B / multiplier);
284 sspike_gbtrs_(
SPIKE.memptr(), &TRAN, &
N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
286 const mat incre = multiplier * conv_to<mat>::from(residual);
290 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->
operator*(incre)));
A BandMatSpike class that holds matrices.
Definition: BandMatSpike.hpp:36
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
constexpr auto SUANPAN_SUCCESS
Definition: suanPan.h:162
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27