suanPan
Loading...
Searching...
No Matches
BandMatSpike.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2023 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef BANDMATSPIKE_HPP
31#define BANDMATSPIKE_HPP
32
33#include <feast/spike.h>
34#include "DenseMat.hpp"
35
36template<sp_d T> class BandMatSpike final : public DenseMat<T> {
37 static constexpr char TRAN = 'N';
38
39 static T bin;
40
41 const uword l_band;
42 const uword u_band;
43 const uword m_rows; // memory block layout
44
45 podarray<int> SPIKE = podarray<int>(64);
46 podarray<T> WORK;
47 podarray<float> SWORK;
48
49 void init_spike();
50
51 int solve_trs(Mat<T>&, Mat<T>&&);
52 int solve_trs(Mat<T>&, const Mat<T>&);
53
54public:
55 BandMatSpike(uword, uword, uword);
56
57 unique_ptr<MetaMat<T>> make_copy() override;
58
59 void unify(uword) override;
60 void nullify(uword) override;
61
62 const T& operator()(uword, uword) const override;
63 T& at(uword, uword) override;
64
65 Mat<T> operator*(const Mat<T>&) const override;
66
67 int direct_solve(Mat<T>&, Mat<T>&&) override;
68 int direct_solve(Mat<T>&, const Mat<T>&) override;
69
70 [[nodiscard]] int sign_det() const override;
71};
72
73template<sp_d T> T BandMatSpike<T>::bin = 0.;
74
75template<sp_d T> void BandMatSpike<T>::init_spike() {
76 auto N = static_cast<int>(this->n_rows);
77 auto KLU = static_cast<int>(std::max(l_band, u_band));
78
79 spikeinit_(SPIKE.memptr(), &N, &KLU);
80
81 std::is_same_v<T, float> ? sspike_tune_(SPIKE.memptr()) : dspike_tune_(SPIKE.memptr());
82}
83
84template<sp_d T> BandMatSpike<T>::BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
85 : DenseMat<T>(in_size, in_size, (in_l + in_u + 1) * in_size)
86 , l_band(in_l)
87 , u_band(in_u)
88 , m_rows(in_l + in_u + 1) { init_spike(); }
89
90template<sp_d T> unique_ptr<MetaMat<T>> BandMatSpike<T>::make_copy() { return std::make_unique<BandMatSpike<T>>(*this); }
91
92template<sp_d T> void BandMatSpike<T>::unify(const uword K) {
93 nullify(K);
94 access::rw(this->memory[u_band + K * m_rows]) = 1.;
95}
96
97template<sp_d T> void BandMatSpike<T>::nullify(const uword K) {
98 suanpan_for(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { access::rw(this->memory[I + u_band + K * (m_rows - 1)]) = 0.; });
99 suanpan_for(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { access::rw(this->memory[K + u_band + I * (m_rows - 1)]) = 0.; });
100
101 this->factored = false;
102}
103
104template<sp_d T> const T& BandMatSpike<T>::operator()(const uword in_row, const uword in_col) const {
105 if(in_row > in_col + l_band || in_row + u_band < in_col) return bin = 0.;
106 return this->memory[in_row + u_band + in_col * (m_rows - 1)];
107}
108
109template<sp_d T> T& BandMatSpike<T>::at(const uword in_row, const uword in_col) {
110 if(in_row > in_col + l_band || in_row + u_band < in_col) return bin = 0.;
111 this->factored = false;
112 return access::rw(this->memory[in_row + u_band + in_col * (m_rows - 1)]);
113}
114
115template<sp_d T> Mat<T> BandMatSpike<T>::operator*(const Mat<T>& X) const {
116 Mat<T> Y(arma::size(X));
117
118 const auto M = static_cast<int>(this->n_rows);
119 const auto N = static_cast<int>(this->n_cols);
120 const auto KL = static_cast<int>(l_band);
121 const auto KU = static_cast<int>(u_band);
122 const auto LDA = static_cast<int>(m_rows);
123 const auto INC = 1;
124 T ALPHA = 1.;
125 T BETA = 0.;
126
127 if(std::is_same_v<T, float>) {
128 using E = float;
129 suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
130 }
131 else if(std::is_same_v<T, double>) {
132 using E = double;
133 suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
134 }
135
136 return Y;
137}
138
139template<sp_d T> int BandMatSpike<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
140 if(!this->factored) {
141 auto N = static_cast<int>(this->n_rows);
142 auto KL = static_cast<int>(l_band);
143 auto KU = static_cast<int>(u_band);
144 auto LDAB = static_cast<int>(m_rows);
145 const auto KLU = std::max(l_band, u_band);
146 auto INFO = 0;
147
148 if(std::is_same_v<T, float>) {
149 using E = float;
150 WORK.zeros(KLU * KLU * SPIKE(9));
151 sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
152 }
153 else if(Precision::FULL == this->setting.precision) {
154 using E = double;
155 WORK.zeros(KLU * KLU * SPIKE(9));
156 dspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
157 }
158 else {
159 this->s_memory = this->to_float();
160 SWORK.zeros(KLU * KLU * SPIKE(9));
161 sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, this->s_memory.mem, &LDAB, SWORK.memptr(), &INFO);
162 }
163
164 if(0 != INFO) {
165 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
166 return INFO;
167 }
168
169 this->factored = true;
170 }
171
172 return this->solve_trs(X, B);
173}
174
175template<sp_d T> int BandMatSpike<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
176 auto N = static_cast<int>(this->n_rows);
177 auto KL = static_cast<int>(l_band);
178 auto KU = static_cast<int>(u_band);
179 auto NRHS = static_cast<int>(B.n_cols);
180 auto LDAB = static_cast<int>(m_rows);
181 auto LDB = static_cast<int>(B.n_rows);
182
183 if(std::is_same_v<T, float>) {
184 using E = float;
185 X = B;
186 sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)X.memptr(), &LDB);
187 }
188 else if(Precision::FULL == this->setting.precision) {
189 using E = double;
190 X = B;
191 dspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)X.memptr(), &LDB);
192 }
193 else {
194 X = arma::zeros(B.n_rows, B.n_cols);
195
196 mat full_residual = B;
197
198 auto multiplier = norm(full_residual);
199
200 auto counter = 0u;
201 while(counter++ < this->setting.iterative_refinement) {
202 if(multiplier < this->setting.tolerance) break;
203
204 auto residual = conv_to<fmat>::from(full_residual / multiplier);
205
206 sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
207
208 const mat incre = multiplier * conv_to<mat>::from(residual);
209
210 X += incre;
211
212 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
213 }
214 }
215
216 return SUANPAN_SUCCESS;
217}
218
219template<sp_d T> int BandMatSpike<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
220 if(!this->factored) {
221 auto N = static_cast<int>(this->n_rows);
222 auto KL = static_cast<int>(l_band);
223 auto KU = static_cast<int>(u_band);
224 auto LDAB = static_cast<int>(m_rows);
225 const auto KLU = std::max(l_band, u_band);
226 auto INFO = 0;
227
228 if(std::is_same_v<T, float>) {
229 using E = float;
230 WORK.zeros(KLU * KLU * SPIKE(9));
231 sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
232 }
233 else if(Precision::FULL == this->setting.precision) {
234 using E = double;
235 WORK.zeros(KLU * KLU * SPIKE(9));
236 dspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
237 }
238 else {
239 this->s_memory = this->to_float();
240 SWORK.zeros(KLU * KLU * SPIKE(9));
241 sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, this->s_memory.mem, &LDAB, SWORK.memptr(), &INFO);
242 }
243
244 if(0 != INFO) {
245 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
246 return INFO;
247 }
248
249 this->factored = true;
250 }
251
252 return this->solve_trs(X, std::forward<Mat<T>>(B));
253}
254
255template<sp_d T> int BandMatSpike<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
256 auto N = static_cast<int>(this->n_rows);
257 auto KL = static_cast<int>(l_band);
258 auto KU = static_cast<int>(u_band);
259 auto NRHS = static_cast<int>(B.n_cols);
260 auto LDAB = static_cast<int>(m_rows);
261 auto LDB = static_cast<int>(B.n_rows);
262
263 if(std::is_same_v<T, float>) {
264 using E = float;
265 sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)B.memptr(), &LDB);
266 X = std::move(B);
267 }
268 else if(Precision::FULL == this->setting.precision) {
269 using E = double;
270 dspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)B.memptr(), &LDB);
271 X = std::move(B);
272 }
273 else {
274 X = arma::zeros(B.n_rows, B.n_cols);
275
276 auto multiplier = norm(B);
277
278 auto counter = 0u;
279 while(counter++ < this->setting.iterative_refinement) {
280 if(multiplier < this->setting.tolerance) break;
281
282 auto residual = conv_to<fmat>::from(B / multiplier);
283
284 sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
285
286 const mat incre = multiplier * conv_to<mat>::from(residual);
287
288 X += incre;
289
290 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
291 }
292 }
293
294 return SUANPAN_SUCCESS;
295}
296
297template<sp_d T> int BandMatSpike<T>::sign_det() const { throw invalid_argument("not supported"); }
298
299#endif
300
A BandMatSpike class that holds matrices.
Definition: BandMatSpike.hpp:36
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
void unify(uword) override
Definition: BandMatSpike.hpp:92
Mat< T > operator*(const Mat< T > &) const override
Definition: BandMatSpike.hpp:115
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: BandMatSpike.hpp:219
int sign_det() const override
Definition: BandMatSpike.hpp:297
const T & operator()(uword, uword) const override
Access element (read-only), returns zero if out-of-bound.
Definition: BandMatSpike.hpp:104
void nullify(uword) override
Definition: BandMatSpike.hpp:97
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandMatSpike.hpp:90
BandMatSpike(uword, uword, uword)
Definition: BandMatSpike.hpp:84
T & at(uword, uword) override
Access element with bound check.
Definition: BandMatSpike.hpp:109
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
constexpr auto SUANPAN_SUCCESS
Definition: suanPan.h:162
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27