suanPan
Loading...
Searching...
No Matches
BandMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2023 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef BANDMAT_HPP
31#define BANDMAT_HPP
32
33#include "DenseMat.hpp"
34
35template<sp_d T> class BandMat : public DenseMat<T> {
36 static constexpr char TRAN = 'N';
37
38 static T bin;
39
40 const uword s_band;
41 const uword m_rows; // memory block layout
42
43 int solve_trs(Mat<T>&, Mat<T>&&);
44
45protected:
46 const uword l_band;
47 const uword u_band;
48
50
51 int direct_solve(Mat<T>&, Mat<T>&&) override;
52
53public:
54 BandMat(const uword in_size, const uword in_l, const uword in_u)
55 : DenseMat<T>(in_size, in_size, (2 * in_l + in_u + 1) * in_size)
56 , s_band(in_l + in_u)
57 , m_rows(2 * in_l + in_u + 1)
58 , l_band(in_l)
59 , u_band(in_u) {}
60
61 unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<BandMat>(*this); }
62
63 void nullify(const uword K) override {
64 this->factored = false;
65 suanpan_for(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { this->memory[I + s_band + K * (m_rows - 1)] = T(0); });
66 suanpan_for(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { this->memory[K + s_band + I * (m_rows - 1)] = T(0); });
67 }
68
69 T operator()(const uword in_row, const uword in_col) const override {
70 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]] return bin = T(0);
71 return this->memory[in_row + s_band + in_col * (m_rows - 1)];
72 }
73
74 T& unsafe_at(const uword in_row, const uword in_col) override {
75 this->factored = false;
76 return this->memory[in_row + s_band + in_col * (m_rows - 1)];
77 }
78
79 T& at(const uword in_row, const uword in_col) override {
80 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]] return bin = T(0);
81 return this->unsafe_at(in_row, in_col);
82 }
83
84 Mat<T> operator*(const Mat<T>&) const override;
85};
86
87template<sp_d T> T BandMat<T>::bin = T(0);
88
89template<sp_d T> Mat<T> BandMat<T>::operator*(const Mat<T>& X) const {
90 Mat<T> Y(arma::size(X));
91
92 const auto M = static_cast<int>(this->n_rows);
93 const auto N = static_cast<int>(this->n_cols);
94 const auto KL = static_cast<int>(l_band);
95 const auto KU = static_cast<int>(u_band);
96 const auto LDA = static_cast<int>(m_rows);
97 constexpr auto INC = 1;
98 T ALPHA = T(1);
99 T BETA = T(0);
100
101 if constexpr(std::is_same_v<T, float>) {
102 using E = float;
103 suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
104 }
105 else {
106 using E = double;
107 suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
108 }
109
110 return Y;
111}
112
113template<sp_d T> int BandMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
114 if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
115
116 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
117
118 auto INFO = 0;
119
120 auto N = static_cast<int>(this->n_rows);
121 const auto KL = static_cast<int>(l_band);
122 const auto KU = static_cast<int>(u_band);
123 const auto NRHS = static_cast<int>(B.n_cols);
124 const auto LDAB = static_cast<int>(m_rows);
125 const auto LDB = static_cast<int>(B.n_rows);
126 this->pivot.zeros(N);
127 this->factored = true;
128
129 if constexpr(std::is_same_v<T, float>) {
130 using E = float;
131 arma_fortran(arma_sgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
132 X = std::move(B);
133 }
134 else if(Precision::FULL == this->setting.precision) {
135 using E = double;
136 arma_fortran(arma_dgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
137 X = std::move(B);
138 }
139 else {
140 this->s_memory = this->to_float();
141 arma_fortran(arma_sgbtrf)(&N, &N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
142 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
143 }
144
145 if(0 != INFO)
146 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
147
148 return INFO;
149}
150
151template<sp_d T> int BandMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
152 auto INFO = 0;
153
154 const auto N = static_cast<int>(this->n_rows);
155 const auto KL = static_cast<int>(l_band);
156 const auto KU = static_cast<int>(u_band);
157 const auto NRHS = static_cast<int>(B.n_cols);
158 const auto LDAB = static_cast<int>(m_rows);
159 const auto LDB = static_cast<int>(B.n_rows);
160
161 if constexpr(std::is_same_v<T, float>) {
162 using E = float;
163 arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
164 X = std::move(B);
165 }
166 else if(Precision::FULL == this->setting.precision) {
167 using E = double;
168 arma_fortran(arma_dgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
169 X = std::move(B);
170 }
171 else
172 this->mixed_trs(X, std::forward<Mat<T>>(B), [&](fmat& residual) {
173 arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
174 return INFO;
175 });
176
177 if(0 != INFO)
178 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
179
180 return INFO;
181}
182
183#endif
184
A BandMat class that holds matrices.
Definition BandMat.hpp:35
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandMat.hpp:74
BandMat(const uword in_size, const uword in_l, const uword in_u)
Definition BandMat.hpp:54
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandMat.hpp:79
const uword u_band
Definition BandMat.hpp:47
unique_ptr< MetaMat< T > > make_copy() override
Definition BandMat.hpp:61
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandMat.hpp:69
const uword l_band
Definition BandMat.hpp:46
void nullify(const uword K) override
Definition BandMat.hpp:63
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
const uword n_cols
Definition MetaMat.hpp:86
const uword n_rows
Definition MetaMat.hpp:85
bool factored
Definition MetaMat.hpp:41
Mat< T > operator*(const Mat< T > &) const override
Definition BandMat.hpp:89
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition BandMat.hpp:113
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:296
#define suanpan_error(...)
Definition suanPan.h:309
void suanpan_for(const IT start, const IT end, F &&FN)
Definition utility.h:27