30#ifndef SYMMPACKMAT_HPP
31#define SYMMPACKMAT_HPP
36 static constexpr char UPLO =
'L';
42 int solve_trs(Mat<T>&, Mat<T>&&);
43 int solve_trs(Mat<T>&,
const Mat<T>&);
48 unique_ptr<MetaMat<T>>
make_copy()
override;
50 void unify(uword)
override;
55 T&
at(uword, uword)
override;
57 Mat<T>
operator*(
const Mat<T>&)
const override;
66 :
DenseMat<
T>(in_size, in_size, (in_size + 1) * in_size / 2)
67 , length(2 * in_size - 1) {}
73 access::rw(this->memory[(length -
K + 2) *
K / 2]) = 1.;
77 suanpan_for(0llu,
K, [&](
const uword I) { access::rw(this->memory[
K + (length - I) * I / 2]) = 0.; });
78 const auto t_factor = (length -
K) *
K / 2;
79 suanpan_for(
K, this->n_rows, [&](
const uword I) { access::rw(this->memory[I + t_factor]) = 0.; });
81 this->factored =
false;
84template<sp_d T>
const T&
SymmPackMat<T>::operator()(
const uword in_row,
const uword in_col)
const {
return this->memory[in_row >= in_col ? in_row + (length - in_col) * in_col / 2 : in_col + (length - in_row) * in_row / 2]; }
87 this->factored =
false;
88 return access::rw(this->memory[in_row + (length - in_col) * in_col / 2]);
92 if(in_row < in_col) [[unlikely]]
return bin;
93 this->factored =
false;
94 return access::rw(this->memory[in_row + (length - in_col) * in_col / 2]);
98 auto Y = Mat<T>(arma::size(X), fill::none);
100 const auto N =
static_cast<int>(this->n_rows);
105 if(std::is_same_v<T, float>) {
107 suanpan_for(0llu, X.n_cols, [&](
const uword I) { arma_fortran(arma_sspmv)(&UPLO, &N, (E*)&ALPHA, (E*)this->memptr(), (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
109 else if(std::is_same_v<T, double>) {
111 suanpan_for(0llu, X.n_cols, [&](
const uword I) { arma_fortran(arma_dspmv)(&UPLO, &N, (E*)&ALPHA, (E*)this->memptr(), (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
118 if(this->factored)
return this->solve_trs(X, B);
120 const auto N =
static_cast<int>(this->n_rows);
121 const auto NRHS =
static_cast<int>(B.n_cols);
122 const auto LDB =
static_cast<int>(B.n_rows);
125 this->factored =
true;
127 if(std::is_same_v<T, float>) {
130 arma_fortran(arma_sppsv)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)X.memptr(), &LDB, &INFO);
135 arma_fortran(arma_dppsv)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)X.memptr(), &LDB, &INFO);
138 this->s_memory = this->to_float();
139 arma_fortran(arma_spptrf)(&UPLO, &
N, this->s_memory.memptr(), &INFO);
140 if(0 == INFO) INFO = this->solve_trs(X, B);
144 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
150 const auto N =
static_cast<int>(this->n_rows);
151 const auto NRHS =
static_cast<int>(B.n_cols);
152 const auto LDB =
static_cast<int>(B.n_rows);
155 if(std::is_same_v<T, float>) {
158 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)X.memptr(), &LDB, &INFO);
163 arma_fortran(arma_dpptrs)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)X.memptr(), &LDB, &INFO);
166 X = arma::zeros(B.n_rows, B.n_cols);
168 mat full_residual = B;
170 auto multiplier =
norm(full_residual);
173 while(counter++ < this->setting.iterative_refinement) {
174 if(multiplier < this->setting.tolerance)
break;
176 auto residual = conv_to<fmat>::from(full_residual / multiplier);
178 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
181 const mat incre = multiplier * conv_to<mat>::from(residual);
185 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->
operator*(incre)));
190 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
196 if(this->factored)
return this->solve_trs(X, std::forward<Mat<T>>(B));
198 const auto N =
static_cast<int>(this->n_rows);
199 const auto NRHS =
static_cast<int>(B.n_cols);
200 const auto LDB =
static_cast<int>(B.n_rows);
203 this->factored =
true;
205 if(std::is_same_v<T, float>) {
207 arma_fortran(arma_sppsv)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)B.memptr(), &LDB, &INFO);
212 arma_fortran(arma_dppsv)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)B.memptr(), &LDB, &INFO);
216 this->s_memory = this->to_float();
217 arma_fortran(arma_spptrf)(&UPLO, &
N, this->s_memory.memptr(), &INFO);
218 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
222 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
228 const auto N =
static_cast<int>(this->n_rows);
229 const auto NRHS =
static_cast<int>(B.n_cols);
230 const auto LDB =
static_cast<int>(B.n_rows);
233 if(std::is_same_v<T, float>) {
235 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)B.memptr(), &LDB, &INFO);
240 arma_fortran(arma_dpptrs)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)B.memptr(), &LDB, &INFO);
244 X = arma::zeros(B.n_rows, B.n_cols);
246 auto multiplier = arma::norm(B);
249 while(counter++ < this->setting.iterative_refinement) {
250 if(multiplier < this->setting.tolerance)
break;
252 auto residual = conv_to<fmat>::from(B / multiplier);
254 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
257 const mat incre = multiplier * conv_to<mat>::from(residual);
261 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->
operator*(incre)));
266 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
A SymmPackMat class that holds matrices.
Definition: SymmPackMat.hpp:35
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27