suanPan
Loading...
Searching...
No Matches
SymmPackMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2023 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef SYMMPACKMAT_HPP
31#define SYMMPACKMAT_HPP
32
33#include "DenseMat.hpp"
34
35template<sp_d T> class SymmPackMat final : public DenseMat<T> {
36 static constexpr char UPLO = 'L';
37
38 static T bin;
39
40 const uword length; // 2n-1
41
42 int solve_trs(Mat<T>&, Mat<T>&&);
43 int solve_trs(Mat<T>&, const Mat<T>&);
44
45public:
46 explicit SymmPackMat(uword);
47
48 unique_ptr<MetaMat<T>> make_copy() override;
49
50 void unify(uword) override;
51 void nullify(uword) override;
52
53 const T& operator()(uword, uword) const override;
54 T& unsafe_at(uword, uword) override;
55 T& at(uword, uword) override;
56
57 Mat<T> operator*(const Mat<T>&) const override;
58
59 int direct_solve(Mat<T>&, Mat<T>&&) override;
60 int direct_solve(Mat<T>&, const Mat<T>&) override;
61};
62
63template<sp_d T> T SymmPackMat<T>::bin = 0.;
64
65template<sp_d T> SymmPackMat<T>::SymmPackMat(const uword in_size)
66 : DenseMat<T>(in_size, in_size, (in_size + 1) * in_size / 2)
67 , length(2 * in_size - 1) {}
68
69template<sp_d T> unique_ptr<MetaMat<T>> SymmPackMat<T>::make_copy() { return std::make_unique<SymmPackMat<T>>(*this); }
70
71template<sp_d T> void SymmPackMat<T>::unify(const uword K) {
72 nullify(K);
73 access::rw(this->memory[(length - K + 2) * K / 2]) = 1.;
74}
75
76template<sp_d T> void SymmPackMat<T>::nullify(const uword K) {
77 suanpan_for(0llu, K, [&](const uword I) { access::rw(this->memory[K + (length - I) * I / 2]) = 0.; });
78 const auto t_factor = (length - K) * K / 2;
79 suanpan_for(K, this->n_rows, [&](const uword I) { access::rw(this->memory[I + t_factor]) = 0.; });
80
81 this->factored = false;
82}
83
84template<sp_d T> const T& SymmPackMat<T>::operator()(const uword in_row, const uword in_col) const { return this->memory[in_row >= in_col ? in_row + (length - in_col) * in_col / 2 : in_col + (length - in_row) * in_row / 2]; }
85
86template<sp_d T> T& SymmPackMat<T>::unsafe_at(const uword in_row, const uword in_col) {
87 this->factored = false;
88 return access::rw(this->memory[in_row + (length - in_col) * in_col / 2]);
89}
90
91template<sp_d T> T& SymmPackMat<T>::at(const uword in_row, const uword in_col) {
92 if(in_row < in_col) [[unlikely]] return bin;
93 this->factored = false;
94 return access::rw(this->memory[in_row + (length - in_col) * in_col / 2]);
95}
96
97template<sp_d T> Mat<T> SymmPackMat<T>::operator*(const Mat<T>& X) const {
98 auto Y = Mat<T>(arma::size(X), fill::none);
99
100 const auto N = static_cast<int>(this->n_rows);
101 const auto INC = 1;
102 T ALPHA = 1.;
103 T BETA = 0.;
104
105 if(std::is_same_v<T, float>) {
106 using E = float;
107 suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_sspmv)(&UPLO, &N, (E*)&ALPHA, (E*)this->memptr(), (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
108 }
109 else if(std::is_same_v<T, double>) {
110 using E = double;
111 suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dspmv)(&UPLO, &N, (E*)&ALPHA, (E*)this->memptr(), (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
112 }
113
114 return Y;
115}
116
117template<sp_d T> int SymmPackMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
118 if(this->factored) return this->solve_trs(X, B);
119
120 const auto N = static_cast<int>(this->n_rows);
121 const auto NRHS = static_cast<int>(B.n_cols);
122 const auto LDB = static_cast<int>(B.n_rows);
123 auto INFO = 0;
124
125 this->factored = true;
126
127 if(std::is_same_v<T, float>) {
128 using E = float;
129 X = B;
130 arma_fortran(arma_sppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
131 }
132 else if(Precision::FULL == this->setting.precision) {
133 using E = double;
134 X = B;
135 arma_fortran(arma_dppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
136 }
137 else {
138 this->s_memory = this->to_float();
139 arma_fortran(arma_spptrf)(&UPLO, &N, this->s_memory.memptr(), &INFO);
140 if(0 == INFO) INFO = this->solve_trs(X, B);
141 }
142
143 if(0 != INFO)
144 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
145
146 return INFO;
147}
148
149template<sp_d T> int SymmPackMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
150 const auto N = static_cast<int>(this->n_rows);
151 const auto NRHS = static_cast<int>(B.n_cols);
152 const auto LDB = static_cast<int>(B.n_rows);
153 auto INFO = 0;
154
155 if(std::is_same_v<T, float>) {
156 using E = float;
157 X = B;
158 arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
159 }
160 else if(Precision::FULL == this->setting.precision) {
161 using E = double;
162 X = B;
163 arma_fortran(arma_dpptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
164 }
165 else {
166 X = arma::zeros(B.n_rows, B.n_cols);
167
168 mat full_residual = B;
169
170 auto multiplier = norm(full_residual);
171
172 auto counter = 0u;
173 while(counter++ < this->setting.iterative_refinement) {
174 if(multiplier < this->setting.tolerance) break;
175
176 auto residual = conv_to<fmat>::from(full_residual / multiplier);
177
178 arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
179 if(0 != INFO) break;
180
181 const mat incre = multiplier * conv_to<mat>::from(residual);
182
183 X += incre;
184
185 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
186 }
187 }
188
189 if(0 != INFO)
190 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
191
192 return INFO;
193}
194
195template<sp_d T> int SymmPackMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
196 if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
197
198 const auto N = static_cast<int>(this->n_rows);
199 const auto NRHS = static_cast<int>(B.n_cols);
200 const auto LDB = static_cast<int>(B.n_rows);
201 auto INFO = 0;
202
203 this->factored = true;
204
205 if(std::is_same_v<T, float>) {
206 using E = float;
207 arma_fortran(arma_sppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
208 X = std::move(B);
209 }
210 else if(Precision::FULL == this->setting.precision) {
211 using E = double;
212 arma_fortran(arma_dppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
213 X = std::move(B);
214 }
215 else {
216 this->s_memory = this->to_float();
217 arma_fortran(arma_spptrf)(&UPLO, &N, this->s_memory.memptr(), &INFO);
218 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
219 }
220
221 if(0 != INFO)
222 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
223
224 return INFO;
225}
226
227template<sp_d T> int SymmPackMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
228 const auto N = static_cast<int>(this->n_rows);
229 const auto NRHS = static_cast<int>(B.n_cols);
230 const auto LDB = static_cast<int>(B.n_rows);
231 auto INFO = 0;
232
233 if(std::is_same_v<T, float>) {
234 using E = float;
235 arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
236 X = std::move(B);
237 }
238 else if(Precision::FULL == this->setting.precision) {
239 using E = double;
240 arma_fortran(arma_dpptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
241 X = std::move(B);
242 }
243 else {
244 X = arma::zeros(B.n_rows, B.n_cols);
245
246 auto multiplier = arma::norm(B);
247
248 auto counter = 0u;
249 while(counter++ < this->setting.iterative_refinement) {
250 if(multiplier < this->setting.tolerance) break;
251
252 auto residual = conv_to<fmat>::from(B / multiplier);
253
254 arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
255 if(0 != INFO) break;
256
257 const mat incre = multiplier * conv_to<mat>::from(residual);
258
259 X += incre;
260
261 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
262 }
263 }
264
265 if(0 != INFO)
266 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
267
268 return INFO;
269}
270
271#endif
272
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
A SymmPackMat class that holds matrices.
Definition: SymmPackMat.hpp:35
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: SymmPackMat.hpp:195
T & unsafe_at(uword, uword) override
Access element without bound check.
Definition: SymmPackMat.hpp:86
void nullify(uword) override
Definition: SymmPackMat.hpp:76
unique_ptr< MetaMat< T > > make_copy() override
Definition: SymmPackMat.hpp:69
Mat< T > operator*(const Mat< T > &) const override
Definition: SymmPackMat.hpp:97
void unify(uword) override
Definition: SymmPackMat.hpp:71
T & at(uword, uword) override
Access element with bound check.
Definition: SymmPackMat.hpp:91
SymmPackMat(uword)
Definition: SymmPackMat.hpp:65
const T & operator()(uword, uword) const override
Access element (read-only), returns zero if out-of-bound.
Definition: SymmPackMat.hpp:84
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27