suanPan
Loading...
Searching...
No Matches
FullMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2023 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef FULLMAT_HPP
31#define FULLMAT_HPP
32
33#include "DenseMat.hpp"
34
35template<sp_d T> class FullMat : public DenseMat<T> {
36 static constexpr char TRAN = 'N';
37
38 int solve_trs(Mat<T>&, Mat<T>&&);
39 int solve_trs(Mat<T>&, const Mat<T>&);
40
41public:
42 FullMat(uword, uword);
43
44 unique_ptr<MetaMat<T>> make_copy() override;
45
46 void unify(uword) override;
47 void nullify(uword) override;
48
49 const T& operator()(uword, uword) const override;
50
51 T& at(uword, uword) override;
52
53 Mat<T> operator*(const Mat<T>&) const override;
54
55 int direct_solve(Mat<T>&, Mat<T>&&) override;
56 int direct_solve(Mat<T>&, const Mat<T>&) override;
57};
58
59template<sp_d T> FullMat<T>::FullMat(const uword in_rows, const uword in_cols)
60 : DenseMat<T>(in_rows, in_cols, in_rows * in_cols) {}
61
62template<sp_d T> unique_ptr<MetaMat<T>> FullMat<T>::make_copy() { return std::make_unique<FullMat<T>>(*this); }
63
64template<sp_d T> void FullMat<T>::unify(const uword K) {
65 nullify(K);
66 at(K, K) = 1.;
67}
68
69template<sp_d T> void FullMat<T>::nullify(const uword K) {
70 suanpan_for(0llu, this->n_rows, [&](const uword I) { at(I, K) = 0.; });
71 suanpan_for(0llu, this->n_cols, [&](const uword I) { at(K, I) = 0.; });
72
73 this->factored = false;
74}
75
76template<sp_d T> const T& FullMat<T>::operator()(const uword in_row, const uword in_col) const { return this->memory[in_row + in_col * this->n_rows]; }
77
78template<sp_d T> T& FullMat<T>::at(const uword in_row, const uword in_col) {
79 this->factored = false;
80 return access::rw(this->memory[in_row + in_col * this->n_rows]);
81}
82
83template<sp_d T> Mat<T> FullMat<T>::operator*(const Mat<T>& B) const {
84 Mat<T> C(arma::size(B));
85
86 const auto M = static_cast<int>(this->n_rows);
87 const auto N = static_cast<int>(this->n_cols);
88
89 T ALPHA = 1., BETA = 0.;
90
91 if(1 == B.n_cols) {
92 constexpr auto INCX = 1, INCY = 1;
93
94 if(std::is_same_v<T, float>) {
95 using E = float;
96 arma_fortran(arma_sgemv)(&TRAN, &M, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &INCX, (E*)&BETA, (E*)C.memptr(), &INCY);
97 }
98 else if(std::is_same_v<T, double>) {
99 using E = double;
100 arma_fortran(arma_dgemv)(&TRAN, &M, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &INCX, (E*)&BETA, (E*)C.memptr(), &INCY);
101 }
102 }
103 else {
104 const auto K = static_cast<int>(B.n_cols);
105
106 if(std::is_same_v<T, float>) {
107 using E = float;
108 arma_fortran(arma_sgemm)(&TRAN, &TRAN, &M, &K, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &N, (E*)&BETA, (E*)C.memptr(), &M);
109 }
110 else if(std::is_same_v<T, double>) {
111 using E = double;
112 arma_fortran(arma_dgemm)(&TRAN, &TRAN, &M, &K, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &N, (E*)&BETA, (E*)C.memptr(), &M);
113 }
114 }
115
116 return C;
117}
118
119template<sp_d T> int FullMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
120 if(this->factored) return this->solve_trs(X, B);
121
122 auto N = static_cast<int>(this->n_rows);
123 const auto NRHS = static_cast<int>(B.n_cols);
124 const auto LDB = static_cast<int>(B.n_rows);
125 auto INFO = 0;
126 this->pivot.zeros(N);
127 this->factored = true;
128
129 if(std::is_same_v<T, float>) {
130 using E = float;
131 X = B;
132 arma_fortran(arma_sgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
133 }
134 else if(Precision::FULL == this->setting.precision) {
135 using E = double;
136 X = B;
137 arma_fortran(arma_dgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
138 }
139 else {
140 this->s_memory = this->to_float();
141 arma_fortran(arma_sgetrf)(&N, &N, this->s_memory.memptr(), &N, this->pivot.memptr(), &INFO);
142 if(0 == INFO) INFO = this->solve_trs(X, B);
143 }
144
145 if(0 != INFO)
146 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
147
148 return INFO;
149}
150
151template<sp_d T> int FullMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
152 const auto N = static_cast<int>(this->n_rows);
153 const auto NRHS = static_cast<int>(B.n_cols);
154 const auto LDB = static_cast<int>(B.n_rows);
155 auto INFO = 0;
156
157 if(std::is_same_v<T, float>) {
158 using E = float;
159 X = B;
160 arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
161 }
162 else if(Precision::FULL == this->setting.precision) {
163 using E = double;
164 X = B;
165 arma_fortran(arma_dgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
166 }
167 else {
168 X = arma::zeros(B.n_rows, B.n_cols);
169
170 mat full_residual = B;
171
172 auto multiplier = norm(full_residual);
173
174 auto counter = 0u;
175 while(counter++ < this->setting.iterative_refinement) {
176 if(multiplier < this->setting.tolerance) break;
177
178 auto residual = conv_to<fmat>::from(full_residual / multiplier);
179
180 arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, this->s_memory.memptr(), &N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
181 if(0 != INFO) break;
182
183 const mat incre = multiplier * conv_to<mat>::from(residual);
184
185 X += incre;
186
187 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
188 }
189 }
190
191 return INFO;
192}
193
194template<sp_d T> int FullMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
195 if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
196
197 auto N = static_cast<int>(this->n_rows);
198 const auto NRHS = static_cast<int>(B.n_cols);
199 const auto LDB = static_cast<int>(B.n_rows);
200 auto INFO = 0;
201
202 this->pivot.zeros(N);
203
204 this->factored = true;
205
206 if(std::is_same_v<T, float>) {
207 using E = float;
208 arma_fortran(arma_sgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
209 X = std::move(B);
210 }
211 else if(Precision::FULL == this->setting.precision) {
212 using E = double;
213 arma_fortran(arma_dgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
214 X = std::move(B);
215 }
216 else {
217 this->s_memory = this->to_float();
218 arma_fortran(arma_sgetrf)(&N, &N, this->s_memory.memptr(), &N, this->pivot.memptr(), &INFO);
219 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
220 }
221
222 if(0 != INFO)
223 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
224
225 return INFO;
226}
227
228template<sp_d T> int FullMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
229 const auto N = static_cast<int>(this->n_rows);
230 const auto NRHS = static_cast<int>(B.n_cols);
231 const auto LDB = static_cast<int>(B.n_rows);
232 auto INFO = 0;
233
234 if(std::is_same_v<T, float>) {
235 using E = float;
236 arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
237 X = std::move(B);
238 }
239 else if(Precision::FULL == this->setting.precision) {
240 using E = double;
241 arma_fortran(arma_dgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
242 X = std::move(B);
243 }
244 else {
245 X = arma::zeros(B.n_rows, B.n_cols);
246
247 auto multiplier = arma::norm(B);
248
249 auto counter = 0u;
250 while(counter++ < this->setting.iterative_refinement) {
251 if(multiplier < this->setting.tolerance) break;
252
253 auto residual = conv_to<fmat>::from(B / multiplier);
254
255 arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, this->s_memory.memptr(), &N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
256 if(0 != INFO) break;
257
258 const mat incre = multiplier * conv_to<mat>::from(residual);
259
260 X += incre;
261
262 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
263 }
264 }
265
266 return INFO;
267}
268
269#endif
270
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: FullMat.hpp:194
T & at(uword, uword) override
Access element with bound check.
Definition: FullMat.hpp:78
const T & operator()(uword, uword) const override
Access element (read-only), returns zero if out-of-bound.
Definition: FullMat.hpp:76
void unify(uword) override
Definition: FullMat.hpp:64
void nullify(uword) override
Definition: FullMat.hpp:69
FullMat(uword, uword)
Definition: FullMat.hpp:59
Mat< T > operator*(const Mat< T > &) const override
Definition: FullMat.hpp:83
unique_ptr< MetaMat< T > > make_copy() override
Definition: FullMat.hpp:62
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27