36 static constexpr char TRAN =
'N';
38 int solve_trs(Mat<T>&, Mat<T>&&);
39 int solve_trs(Mat<T>&,
const Mat<T>&);
44 unique_ptr<MetaMat<T>>
make_copy()
override;
46 void unify(uword)
override;
51 T&
at(uword, uword)
override;
53 Mat<T>
operator*(
const Mat<T>&)
const override;
60 :
DenseMat<
T>(in_rows, in_cols, in_rows * in_cols) {}
70 suanpan_for(0llu, this->n_rows, [&](
const uword I) { at(I,
K) = 0.; });
71 suanpan_for(0llu, this->n_cols, [&](
const uword I) { at(
K, I) = 0.; });
73 this->factored =
false;
76template<sp_d T>
const T&
FullMat<T>::operator()(
const uword in_row,
const uword in_col)
const {
return this->memory[in_row + in_col * this->n_rows]; }
79 this->factored =
false;
80 return access::rw(this->memory[in_row + in_col * this->n_rows]);
84 Mat<T> C(arma::size(B));
86 const auto M =
static_cast<int>(this->n_rows);
87 const auto N =
static_cast<int>(this->n_cols);
89 T ALPHA = 1., BETA = 0.;
92 constexpr auto INCX = 1, INCY = 1;
94 if(std::is_same_v<T, float>) {
96 arma_fortran(arma_sgemv)(&TRAN, &
M, &
N, (
E*)&ALPHA, (
E*)this->memptr(), &
M, (
E*)B.memptr(), &INCX, (
E*)&BETA, (
E*)C.memptr(), &INCY);
98 else if(std::is_same_v<T, double>) {
100 arma_fortran(arma_dgemv)(&TRAN, &
M, &
N, (
E*)&ALPHA, (
E*)this->memptr(), &
M, (
E*)B.memptr(), &INCX, (
E*)&BETA, (
E*)C.memptr(), &INCY);
104 const auto K =
static_cast<int>(B.n_cols);
106 if(std::is_same_v<T, float>) {
108 arma_fortran(arma_sgemm)(&TRAN, &TRAN, &
M, &
K, &
N, (
E*)&ALPHA, (
E*)this->memptr(), &
M, (
E*)B.memptr(), &
N, (
E*)&BETA, (
E*)C.memptr(), &
M);
110 else if(std::is_same_v<T, double>) {
112 arma_fortran(arma_dgemm)(&TRAN, &TRAN, &
M, &
K, &
N, (
E*)&ALPHA, (
E*)this->memptr(), &
M, (
E*)B.memptr(), &
N, (
E*)&BETA, (
E*)C.memptr(), &
M);
120 if(this->factored)
return this->solve_trs(X, B);
122 auto N =
static_cast<int>(this->n_rows);
123 const auto NRHS =
static_cast<int>(B.n_cols);
124 const auto LDB =
static_cast<int>(B.n_rows);
126 this->pivot.zeros(
N);
127 this->factored =
true;
129 if(std::is_same_v<T, float>) {
132 arma_fortran(arma_sgesv)(&
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
137 arma_fortran(arma_dgesv)(&
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
140 this->s_memory = this->to_float();
141 arma_fortran(arma_sgetrf)(&
N, &
N, this->s_memory.memptr(), &
N, this->pivot.memptr(), &INFO);
142 if(0 == INFO) INFO = this->solve_trs(X, B);
146 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
152 const auto N =
static_cast<int>(this->n_rows);
153 const auto NRHS =
static_cast<int>(B.n_cols);
154 const auto LDB =
static_cast<int>(B.n_rows);
157 if(std::is_same_v<T, float>) {
160 arma_fortran(arma_sgetrs)(&TRAN, &
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
165 arma_fortran(arma_dgetrs)(&TRAN, &
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
168 X = arma::zeros(B.n_rows, B.n_cols);
170 mat full_residual = B;
172 auto multiplier =
norm(full_residual);
175 while(counter++ < this->setting.iterative_refinement) {
176 if(multiplier < this->setting.tolerance)
break;
178 auto residual = conv_to<fmat>::from(full_residual / multiplier);
180 arma_fortran(arma_sgetrs)(&TRAN, &
N, &NRHS, this->s_memory.memptr(), &
N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
183 const mat incre = multiplier * conv_to<mat>::from(residual);
187 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->
operator*(incre)));
195 if(this->factored)
return this->solve_trs(X, std::forward<Mat<T>>(B));
197 auto N =
static_cast<int>(this->n_rows);
198 const auto NRHS =
static_cast<int>(B.n_cols);
199 const auto LDB =
static_cast<int>(B.n_rows);
202 this->pivot.zeros(
N);
204 this->factored =
true;
206 if(std::is_same_v<T, float>) {
208 arma_fortran(arma_sgesv)(&
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
213 arma_fortran(arma_dgesv)(&
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
217 this->s_memory = this->to_float();
218 arma_fortran(arma_sgetrf)(&
N, &
N, this->s_memory.memptr(), &
N, this->pivot.memptr(), &INFO);
219 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
223 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
229 const auto N =
static_cast<int>(this->n_rows);
230 const auto NRHS =
static_cast<int>(B.n_cols);
231 const auto LDB =
static_cast<int>(B.n_rows);
234 if(std::is_same_v<T, float>) {
236 arma_fortran(arma_sgetrs)(&TRAN, &
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
241 arma_fortran(arma_dgetrs)(&TRAN, &
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
245 X = arma::zeros(B.n_rows, B.n_cols);
247 auto multiplier = arma::norm(B);
250 while(counter++ < this->setting.iterative_refinement) {
251 if(multiplier < this->setting.tolerance)
break;
253 auto residual = conv_to<fmat>::from(B / multiplier);
255 arma_fortran(arma_sgetrs)(&TRAN, &
N, &NRHS, this->s_memory.memptr(), &
N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
258 const mat incre = multiplier * conv_to<mat>::from(residual);
262 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->
operator*(incre)));
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27