suanPan
Loading...
Searching...
No Matches
BandMatSpike.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2024 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef BANDMATSPIKE_HPP
31#define BANDMATSPIKE_HPP
32
33#include <feast/spike.h>
34#include "DenseMat.hpp"
35
36template<sp_d T> class BandMatSpike final : public DenseMat<T> {
37 static constexpr char TRAN = 'N';
38
39 static T bin;
40
41 const uword l_band;
42 const uword u_band;
43 const uword m_rows; // memory block layout
44
45 podarray<int> SPIKE{64};
46 podarray<T> WORK;
47 podarray<float> SWORK;
48
49 void init_spike() {
50 auto N = static_cast<int>(this->n_rows);
51 auto KLU = static_cast<int>(std::max(l_band, u_band));
52
53 spikeinit_(SPIKE.memptr(), &N, &KLU);
54
55 std::is_same_v<T, float> ? sspike_tune_(SPIKE.memptr()) : dspike_tune_(SPIKE.memptr());
56 }
57
58 int solve_trs(Mat<T>&, Mat<T>&&);
59
60protected:
62
63 int direct_solve(Mat<T>&, Mat<T>&&) override;
64
65public:
66 BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
67 : DenseMat<T>(in_size, in_size, (in_l + in_u + 1) * in_size)
68 , l_band(in_l)
69 , u_band(in_u)
70 , m_rows(in_l + in_u + 1) { init_spike(); }
71
72 unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<BandMatSpike>(*this); }
73
74 void nullify(const uword K) override {
75 this->factored = false;
76 suanpan::for_each(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { this->memory[I + u_band + K * (m_rows - 1)] = T(0); });
77 suanpan::for_each(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { this->memory[K + u_band + I * (m_rows - 1)] = T(0); });
78 }
79
80 T operator()(const uword in_row, const uword in_col) const override {
81 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]] return bin = T(0);
82 return this->memory[in_row + u_band + in_col * (m_rows - 1)];
83 }
84
85 T& unsafe_at(const uword in_row, const uword in_col) override {
86 this->factored = false;
87 return this->memory[in_row + u_band + in_col * (m_rows - 1)];
88 }
89
90 T& at(const uword in_row, const uword in_col) override {
91 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]] return bin = T(0);
92 return this->unsafe_at(in_row, in_col);
93 }
94
95 Mat<T> operator*(const Mat<T>&) const override;
96
97 [[nodiscard]] int sign_det() const override { throw invalid_argument("not supported"); }
98};
99
100template<sp_d T> T BandMatSpike<T>::bin = T(0);
101
102template<sp_d T> Mat<T> BandMatSpike<T>::operator*(const Mat<T>& X) const {
103 Mat<T> Y(arma::size(X));
104
105 const auto M = static_cast<int>(this->n_rows);
106 const auto N = static_cast<int>(this->n_cols);
107 const auto KL = static_cast<int>(l_band);
108 const auto KU = static_cast<int>(u_band);
109 const auto LDA = static_cast<int>(m_rows);
110 constexpr auto INC = 1;
111 T ALPHA = T(1);
112 T BETA = T(0);
113
114 if constexpr(std::is_same_v<T, float>) {
115 using E = float;
116 suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
117 }
118 else {
119 using E = double;
120 suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
121 }
122
123 return Y;
124}
125
126template<sp_d T> int BandMatSpike<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
127 if(!this->factored) {
128 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
129
130 auto INFO = 0;
131
132 auto N = static_cast<int>(this->n_rows);
133 auto KL = static_cast<int>(l_band);
134 auto KU = static_cast<int>(u_band);
135 auto LDAB = static_cast<int>(m_rows);
136 const auto KLU = std::max(l_band, u_band);
137 this->factored = true;
138
139 if constexpr(std::is_same_v<T, float>) {
140 using E = float;
141 WORK.zeros(KLU * KLU * SPIKE(9));
142 sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
143 }
144 else if(Precision::FULL == this->setting.precision) {
145 using E = double;
146 WORK.zeros(KLU * KLU * SPIKE(9));
147 dspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
148 }
149 else {
150 this->s_memory = this->to_float();
151 SWORK.zeros(KLU * KLU * SPIKE(9));
152 sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, this->s_memory.memptr(), &LDAB, SWORK.memptr(), &INFO);
153 }
154
155 if(0 != INFO) {
156 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
157 return INFO;
158 }
159 }
160
161 return this->solve_trs(X, std::forward<Mat<T>>(B));
162}
163
164template<sp_d T> int BandMatSpike<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
165 auto N = static_cast<int>(this->n_rows);
166 auto KL = static_cast<int>(l_band);
167 auto KU = static_cast<int>(u_band);
168 auto NRHS = static_cast<int>(B.n_cols);
169 auto LDAB = static_cast<int>(m_rows);
170 auto LDB = static_cast<int>(B.n_rows);
171
172 if constexpr(std::is_same_v<T, float>) {
173 using E = float;
174 sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)B.memptr(), &LDB);
175 X = std::move(B);
176 }
177 else if(Precision::FULL == this->setting.precision) {
178 using E = double;
179 dspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)B.memptr(), &LDB);
180 X = std::move(B);
181 }
182 else
183 this->mixed_trs(X, std::forward<Mat<T>>(B), [&](fmat& residual) {
184 sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
185 return 0;
186 });
187
188 return SUANPAN_SUCCESS;
189}
190
191#endif
192
A BandMatSpike class that holds matrices.
Definition BandMatSpike.hpp:36
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandMatSpike.hpp:90
BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
Definition BandMatSpike.hpp:66
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandMatSpike.hpp:80
int sign_det() const override
Definition BandMatSpike.hpp:97
void nullify(const uword K) override
Definition BandMatSpike.hpp:74
unique_ptr< MetaMat< T > > make_copy() override
Definition BandMatSpike.hpp:72
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandMatSpike.hpp:85
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
const uword n_cols
Definition MetaMat.hpp:119
const uword n_rows
Definition MetaMat.hpp:118
bool factored
Definition MetaMat.hpp:74
Mat< T > operator*(const Mat< T > &) const override
Definition BandMatSpike.hpp:102
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition BandMatSpike.hpp:126
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:28
constexpr auto SUANPAN_SUCCESS
Definition suanPan.h:172
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:296
#define suanpan_error(...)
Definition suanPan.h:309