30#ifndef SPARSEMATSUPERLU_HPP
31#define SPARSEMATSUPERLU_HPP
33#include <superlu-mt/superlu-mt.h>
38 SuperMatrix A{}, L{}, U{}, B{};
40#ifndef SUANPAN_SUPERLUMT
41 superlu_options_t options{};
45 const int ordering_num = 1;
50 void* t_val =
nullptr;
54 int* perm_r =
nullptr;
55 int* perm_c =
nullptr;
57 bool allocated =
false;
60 void dealloc_supermatrix();
62 template<sp_d ET>
void wrap_b(
const Mat<ET>&);
63 template<sp_d ET>
void tri_solve(
int&);
64 template<sp_d ET>
void full_solve(
int&);
66 int solve_trs(Mat<T>&, Mat<T>&&);
67 int solve_trs(Mat<T>&,
const Mat<T>&);
77 void zeros() override;
86 dealloc_supermatrix();
88 auto t_size =
sizeof(ET) * in.n_elem;
89 t_val = superlu_malloc(t_size);
90 memcpy(t_val, (
void*)in.val_mem(), t_size);
92 t_size =
sizeof(int) * in.n_elem;
93 t_row = (
int*)superlu_malloc(t_size);
94 memcpy(t_row, (
void*)in.row_mem(), t_size);
96 t_size =
sizeof(int) * (in.n_cols + 1llu);
97 t_col = (
int*)superlu_malloc(t_size);
98 memcpy(t_col, (
void*)in.col_mem(), t_size);
100 if(std::is_same_v<ET, double>) {
102 dCreate_CompCol_Matrix(&A, in.n_rows, in.n_cols, in.n_elem, (
E*)t_val, t_row, t_col, Stype_t::SLU_NC, Dtype_t::SLU_D, Mtype_t::SLU_GE);
106 sCreate_CompCol_Matrix(&A, in.n_rows, in.n_cols, in.n_elem, (
E*)t_val, t_row, t_col, Stype_t::SLU_NC, Dtype_t::SLU_S, Mtype_t::SLU_GE);
109 perm_r = (
int*)superlu_malloc(
sizeof(
int) * (this->
n_rows + 1));
110 perm_c = (
int*)superlu_malloc(
sizeof(
int) * (this->
n_cols + 1));
116 if(!allocated)
return;
118 Destroy_SuperMatrix_Store(&A);
119#ifdef SUANPAN_SUPERLUMT
120 Destroy_SuperNode_SCP(&L);
121 Destroy_CompCol_NCP(&U);
123 Destroy_SuperNode_Matrix(&L);
124 Destroy_CompCol_Matrix(&U);
127 if(t_val) superlu_free(t_val);
128 if(t_row) superlu_free(t_row);
129 if(t_col) superlu_free(t_col);
130 if(perm_r) superlu_free(perm_r);
131 if(perm_c) superlu_free(perm_c);
137 if(std::is_same_v<ET, float>) {
139 sCreate_Dense_Matrix(&B, (
int)in_mat.n_rows, (
int)in_mat.n_cols, (
E*)in_mat.memptr(), (
int)in_mat.n_rows, Stype_t::SLU_DN, Dtype_t::SLU_S, Mtype_t::SLU_GE);
143 dCreate_Dense_Matrix(&B, (
int)in_mat.n_rows, (
int)in_mat.n_cols, (
E*)in_mat.memptr(), (
int)in_mat.n_rows, Stype_t::SLU_DN, Dtype_t::SLU_D, Mtype_t::SLU_GE);
148#ifdef SUANPAN_SUPERLUMT
149 if(std::is_same_v<ET, float>) sgstrs(NOTRANS, &L, &U, perm_c, perm_r, &B, &stat, &flag);
150 else dgstrs(NOTRANS, &L, &U, perm_c, perm_r, &B, &stat, &flag);
152 superlu::gstrs<ET>(options.Trans, &L, &U, perm_c, perm_r, &B, &stat, &flag);
155 Destroy_SuperMatrix_Store(&B);
159#ifdef SUANPAN_SUPERLUMT
160 get_perm_c(ordering_num, &A, perm_c);
161 if(std::is_same_v<ET, float>) psgssv(
SUANPAN_NUM_THREADS, &A, perm_c, perm_r, &L, &U, &B, &flag);
164 superlu::gssv<ET>(&options, &A, perm_c, perm_r, &L, &U, &B, &stat, &flag);
167 Destroy_SuperMatrix_Store(&B);
172#ifndef SUANPAN_SUPERLUMT
173 set_default_options(&options);
174 options.IterRefine = std::is_same_v<T, float> ? superlu::IterRefine_t::SLU_SINGLE : superlu::IterRefine_t::SLU_DOUBLE;
175 options.Equil = superlu::yes_no_t::NO;
177 arrayops::fill_zeros(
reinterpret_cast<char*
>(&stat),
sizeof(SuperLUStat_t));
188#ifndef SUANPAN_SUPERLUMT
189 set_default_options(&options);
190 options.IterRefine = std::is_same_v<T, float> ? superlu::IterRefine_t::SLU_SINGLE : superlu::IterRefine_t::SLU_DOUBLE;
191 options.Equil = superlu::yes_no_t::NO;
193 arrayops::fill_zeros(
reinterpret_cast<char*
>(&stat),
sizeof(SuperLUStat_t));
203 dealloc_supermatrix();
209 dealloc_supermatrix();
215 if(this->factored)
return solve_trs(out_mat, in_mat);
217 this->factored =
true;
221 if(std::is_same_v<T, float> ||
Precision::FULL == this->setting.precision) {
235 const fmat f_mat(arma::size(in_mat), fill::none);
239 full_solve<float>(flag);
241 return 0 == flag ? solve_trs(out_mat, in_mat) : flag;
247 if(std::is_same_v<T, float> ||
Precision::FULL == this->setting.precision) {
257 out_mat.zeros(arma::size(in_mat));
259 mat full_residual = in_mat;
261 auto multiplier =
norm(full_residual);
264 while(counter++ < this->setting.iterative_refinement) {
265 if(multiplier < this->setting.tolerance)
break;
267 auto residual = conv_to<fmat>::from(full_residual / multiplier);
271 tri_solve<float>(flag);
275 const mat incre = multiplier * conv_to<mat>::from(residual);
279 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier =
norm(full_residual -= this->
operator*(incre)));
286 if(this->factored)
return solve_trs(out_mat, std::forward<Mat<T>>(in_mat));
288 this->factored =
true;
292 if(std::is_same_v<T, float> ||
Precision::FULL == this->setting.precision) {
299 out_mat = std::move(in_mat);
306 const fmat f_mat(arma::size(in_mat), fill::none);
310 full_solve<float>(flag);
312 return 0 == flag ? solve_trs(out_mat, std::forward<Mat<T>>(in_mat)) : flag;
318 if(std::is_same_v<T, float> ||
Precision::FULL == this->setting.precision) {
323 out_mat = std::move(in_mat);
328 out_mat.zeros(arma::size(in_mat));
330 auto multiplier = arma::norm(in_mat);
333 while(counter++ < this->setting.iterative_refinement) {
334 if(multiplier < this->setting.tolerance)
break;
336 auto residual = conv_to<fmat>::from(in_mat / multiplier);
340 tri_solve<float>(flag);
344 const mat incre = multiplier * conv_to<mat>::from(residual);
348 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier =
norm(in_mat -= this->
operator*(incre)));
A SparseMat class that holds matrices.
Definition: SparseMat.hpp:34
A SparseMatSuperLU class that holds matrices.
Definition: SparseMatSuperLU.hpp:37
SparseMatSuperLU(SparseMatSuperLU &&) noexcept=delete
int SUANPAN_NUM_THREADS
Definition: command.cpp:67
Definition: suanPan.h:318
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295