suanPan
Loading...
Searching...
No Matches
BandSymmMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2023 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef BANDSYMMMAT_HPP
31#define BANDSYMMMAT_HPP
32
33#include "DenseMat.hpp"
34
35template<sp_d T> class BandSymmMat final : public DenseMat<T> {
36 static constexpr char UPLO = 'L';
37
38 static T bin;
39
40 const uword band;
41 const uword m_rows; // memory block layout
42
43 int solve_trs(Mat<T>&, Mat<T>&&);
44 int solve_trs(Mat<T>&, const Mat<T>&);
45
46public:
47 BandSymmMat(uword, uword);
48
49 unique_ptr<MetaMat<T>> make_copy() override;
50
51 void unify(uword) override;
52 void nullify(uword) override;
53
54 const T& operator()(uword, uword) const override;
55 T& unsafe_at(uword, uword) override;
56 T& at(uword, uword) override;
57
58 Mat<T> operator*(const Mat<T>&) const override;
59
60 int direct_solve(Mat<T>&, Mat<T>&&) override;
61 int direct_solve(Mat<T>&, const Mat<T>&) override;
62};
63
64template<sp_d T> T BandSymmMat<T>::bin = 0.;
65
66template<sp_d T> BandSymmMat<T>::BandSymmMat(const uword in_size, const uword in_bandwidth)
67 : DenseMat<T>(in_size, in_size, (in_bandwidth + 1) * in_size)
68 , band(in_bandwidth)
69 , m_rows(in_bandwidth + 1) {}
70
71template<sp_d T> unique_ptr<MetaMat<T>> BandSymmMat<T>::make_copy() { return std::make_unique<BandSymmMat<T>>(*this); }
72
73template<sp_d T> void BandSymmMat<T>::unify(const uword K) {
74 nullify(K);
75 access::rw(this->memory[K * m_rows]) = 1.;
76}
77
78template<sp_d T> void BandSymmMat<T>::nullify(const uword K) {
79 suanpan_for(std::max(band, K) - band, K, [&](const uword I) { access::rw(this->memory[K - I + I * m_rows]) = 0.; });
80 const auto t_factor = K * m_rows - K;
81 suanpan_for(K, std::min(this->n_rows, K + band + 1), [&](const uword I) { access::rw(this->memory[I + t_factor]) = 0.; });
82
83 this->factored = false;
84}
85
86template<sp_d T> const T& BandSymmMat<T>::operator()(const uword in_row, const uword in_col) const {
87 if(in_row > band + in_col) return bin = 0.;
88 return this->memory[in_row > in_col ? in_row - in_col + in_col * m_rows : in_col - in_row + in_row * m_rows];
89}
90
91template<sp_d T> T& BandSymmMat<T>::unsafe_at(const uword in_row, const uword in_col) {
92 this->factored = false;
93 return access::rw(this->memory[in_row - in_col + in_col * m_rows]);
94}
95
96template<sp_d T> T& BandSymmMat<T>::at(const uword in_row, const uword in_col) {
97 if(in_row > band + in_col || in_row < in_col) [[unlikely]] return bin = 0.;
98 this->factored = false;
99 return access::rw(this->memory[in_row - in_col + in_col * m_rows]);
100}
101
102template<sp_d T> Mat<T> BandSymmMat<T>::operator*(const Mat<T>& X) const {
103 Mat<T> Y(arma::size(X));
104
105 const auto N = static_cast<int>(this->n_cols);
106 const auto K = static_cast<int>(band);
107 const auto LDA = static_cast<int>(m_rows);
108 const auto INC = 1;
109 T ALPHA = 1.;
110 T BETA = 0.;
111
112 if(std::is_same_v<T, float>) {
113 using E = float;
114 suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_ssbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
115 }
116 else if(std::is_same_v<T, double>) {
117 using E = double;
118 suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dsbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
119 }
120
121 return Y;
122}
123
124template<sp_d T> int BandSymmMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
125 if(this->factored) return this->solve_trs(X, B);
126
127 const auto N = static_cast<int>(this->n_rows);
128 const auto KD = static_cast<int>(band);
129 const auto NRHS = static_cast<int>(B.n_cols);
130 const auto LDAB = static_cast<int>(m_rows);
131 const auto LDB = static_cast<int>(B.n_rows);
132 auto INFO = 0;
133
134 this->factored = true;
135
136 if(std::is_same_v<T, float>) {
137 using E = float;
138 X = B;
139 arma_fortran(arma_spbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
140 }
141 else if(Precision::FULL == this->setting.precision) {
142 using E = double;
143 X = B;
144 arma_fortran(arma_dpbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
145 }
146 else {
147 this->s_memory = this->to_float();
148 arma_fortran(arma_spbtrf)(&UPLO, &N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
149 if(0 == INFO) INFO = this->solve_trs(X, B);
150 }
151
152 if(0 != INFO)
153 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
154
155 return INFO;
156}
157
158template<sp_d T> int BandSymmMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
159 const auto N = static_cast<int>(this->n_rows);
160 const auto KD = static_cast<int>(band);
161 const auto NRHS = static_cast<int>(B.n_cols);
162 const auto LDAB = static_cast<int>(m_rows);
163 const auto LDB = static_cast<int>(B.n_rows);
164 auto INFO = 0;
165
166 if(std::is_same_v<T, float>) {
167 using E = float;
168 X = B;
169 arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
170 }
171 else if(Precision::FULL == this->setting.precision) {
172 using E = double;
173 X = B;
174 arma_fortran(arma_dpbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
175 }
176 else {
177 X = arma::zeros(B.n_rows, B.n_cols);
178
179 mat full_residual = B;
180
181 auto multiplier = norm(full_residual);
182
183 auto counter = 0u;
184 while(counter++ < this->setting.iterative_refinement) {
185 if(multiplier < this->setting.tolerance) break;
186
187 auto residual = conv_to<fmat>::from(full_residual / multiplier);
188
189 arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
190 if(0 != INFO) break;
191
192 const mat incre = multiplier * conv_to<mat>::from(residual);
193
194 X += incre;
195
196 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
197 }
198 }
199
200 if(0 != INFO)
201 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
202
203 return INFO;
204}
205
206template<sp_d T> int BandSymmMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
207 if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
208
209 const auto N = static_cast<int>(this->n_rows);
210 const auto KD = static_cast<int>(band);
211 const auto NRHS = static_cast<int>(B.n_cols);
212 const auto LDAB = static_cast<int>(m_rows);
213 const auto LDB = static_cast<int>(B.n_rows);
214 auto INFO = 0;
215
216 this->factored = true;
217
218 if(std::is_same_v<T, float>) {
219 using E = float;
220 arma_fortran(arma_spbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
221 X = std::move(B);
222 }
223 else if(Precision::FULL == this->setting.precision) {
224 using E = double;
225 arma_fortran(arma_dpbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
226 X = std::move(B);
227 }
228 else {
229 this->s_memory = this->to_float();
230 arma_fortran(arma_spbtrf)(&UPLO, &N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
231 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
232 }
233
234 if(0 != INFO)
235 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
236
237 return INFO;
238}
239
240template<sp_d T> int BandSymmMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
241 const auto N = static_cast<int>(this->n_rows);
242 const auto KD = static_cast<int>(band);
243 const auto NRHS = static_cast<int>(B.n_cols);
244 const auto LDAB = static_cast<int>(m_rows);
245 const auto LDB = static_cast<int>(B.n_rows);
246 auto INFO = 0;
247
248 if(std::is_same_v<T, float>) {
249 using E = float;
250 arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
251 X = std::move(B);
252 }
253 else if(Precision::FULL == this->setting.precision) {
254 using E = double;
255 arma_fortran(arma_dpbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
256 X = std::move(B);
257 }
258 else {
259 X = arma::zeros(B.n_rows, B.n_cols);
260
261 auto multiplier = norm(B);
262
263 auto counter = 0u;
264 while(counter++ < this->setting.iterative_refinement) {
265 if(multiplier < this->setting.tolerance) break;
266
267 auto residual = conv_to<fmat>::from(B / multiplier);
268
269 arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
270 if(0 != INFO) break;
271
272 const mat incre = multiplier * conv_to<mat>::from(residual);
273
274 X += incre;
275
276 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
277 }
278 }
279
280 if(0 != INFO)
281 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
282
283 return INFO;
284}
285
286#endif
287
A BandSymmMat class that holds matrices.
Definition: BandSymmMat.hpp:35
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
T & at(uword, uword) override
Access element with bound check.
Definition: BandSymmMat.hpp:96
const T & operator()(uword, uword) const override
Access element (read-only), returns zero if out-of-bound.
Definition: BandSymmMat.hpp:86
void nullify(uword) override
Definition: BandSymmMat.hpp:78
BandSymmMat(uword, uword)
Definition: BandSymmMat.hpp:66
Mat< T > operator*(const Mat< T > &) const override
Definition: BandSymmMat.hpp:102
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandSymmMat.hpp:71
void unify(uword) override
Definition: BandSymmMat.hpp:73
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: BandSymmMat.hpp:206
T & unsafe_at(uword, uword) override
Access element without bound check.
Definition: BandSymmMat.hpp:91
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27