30#ifndef SPARSEMATSUPERLU_HPP
31#define SPARSEMATSUPERLU_HPP
33#include <superlu-mt/superlu-mt.h>
38 SuperMatrix A{}, L{}, U{}, B{};
40#ifndef SUANPAN_SUPERLUMT
41 superlu_options_t options{};
45 const int ordering_num = 1;
50 void* t_val =
nullptr;
54 int* perm_r =
nullptr;
55 int* perm_c =
nullptr;
57 bool allocated =
false;
62 template<sp_d ET>
void wrap_b(
const Mat<ET>&);
63 template<sp_d ET>
void tri_solve(
int&);
64 template<sp_d ET>
void full_solve(
int&);
66 int solve_trs(Mat<T>&, Mat<T>&&);
81 void zeros() override;
89 auto t_size =
sizeof(ET) * in.n_elem;
90 t_val = superlu_malloc(t_size);
91 memcpy(t_val, (
void*)in.val_mem(), t_size);
93 t_size =
sizeof(int) * in.n_elem;
94 t_row = (
int*)superlu_malloc(t_size);
95 memcpy(t_row, (
void*)in.row_mem(), t_size);
97 t_size =
sizeof(int) * (in.n_cols + 1llu);
98 t_col = (
int*)superlu_malloc(t_size);
99 memcpy(t_col, (
void*)in.col_mem(), t_size);
101 if constexpr(std::is_same_v<ET, double>) {
103 dCreate_CompCol_Matrix(&A, in.n_rows, in.n_cols, in.n_elem, (
E*)t_val, t_row, t_col, Stype_t::SLU_NC, Dtype_t::SLU_D, Mtype_t::SLU_GE);
107 sCreate_CompCol_Matrix(&A, in.n_rows, in.n_cols, in.n_elem, (
E*)t_val, t_row, t_col, Stype_t::SLU_NC, Dtype_t::SLU_S, Mtype_t::SLU_GE);
110 perm_r = (
int*)superlu_malloc(
sizeof(
int) * (this->
n_rows + 1));
111 perm_c = (
int*)superlu_malloc(
sizeof(
int) * (this->
n_cols + 1));
117 if(!allocated)
return;
119 Destroy_SuperMatrix_Store(&A);
120#ifdef SUANPAN_SUPERLUMT
121 Destroy_SuperNode_SCP(&L);
122 Destroy_CompCol_NCP(&U);
124 Destroy_SuperNode_Matrix(&L);
125 Destroy_CompCol_Matrix(&U);
128 if(t_val) superlu_free(t_val);
129 if(t_row) superlu_free(t_row);
130 if(t_col) superlu_free(t_col);
131 if(perm_r) superlu_free(perm_r);
132 if(perm_c) superlu_free(perm_c);
138 if constexpr(std::is_same_v<ET, float>) {
140 sCreate_Dense_Matrix(&B, (
int)in_mat.n_rows, (
int)in_mat.n_cols, (
E*)in_mat.memptr(), (
int)in_mat.n_rows, Stype_t::SLU_DN, Dtype_t::SLU_S, Mtype_t::SLU_GE);
144 dCreate_Dense_Matrix(&B, (
int)in_mat.n_rows, (
int)in_mat.n_cols, (
E*)in_mat.memptr(), (
int)in_mat.n_rows, Stype_t::SLU_DN, Dtype_t::SLU_D, Mtype_t::SLU_GE);
149#ifdef SUANPAN_SUPERLUMT
150 if(std::is_same_v<ET, float>) sgstrs(NOTRANS, &L, &U, perm_c, perm_r, &B, &stat, &flag);
151 else dgstrs(NOTRANS, &L, &U, perm_c, perm_r, &B, &stat, &flag);
153 superlu::gstrs<ET>(options.Trans, &L, &U, perm_c, perm_r, &B, &stat, &flag);
156 Destroy_SuperMatrix_Store(&B);
160#ifdef SUANPAN_SUPERLUMT
161 get_perm_c(ordering_num, &A, perm_c);
162 if(std::is_same_v<ET, float>) psgssv(
SUANPAN_NUM_THREADS, &A, perm_c, perm_r, &L, &U, &B, &flag);
165 superlu::gssv<ET>(&options, &A, perm_c, perm_r, &L, &U, &B, &stat, &flag);
168 Destroy_SuperMatrix_Store(&B);
173#ifndef SUANPAN_SUPERLUMT
174 set_default_options(&options);
175 options.IterRefine = std::is_same_v<T, float> ? superlu::IterRefine_t::SLU_SINGLE : superlu::IterRefine_t::SLU_DOUBLE;
176 options.Equil = superlu::yes_no_t::NO;
178 arrayops::fill_zeros(
reinterpret_cast<char*
>(&stat),
sizeof(SuperLUStat_t));
189#ifndef SUANPAN_SUPERLUMT
190 set_default_options(&options);
191 options.IterRefine = std::is_same_v<T, float> ? superlu::IterRefine_t::SLU_SINGLE : superlu::IterRefine_t::SLU_DOUBLE;
192 options.Equil = superlu::yes_no_t::NO;
194 arrayops::fill_zeros(
reinterpret_cast<char*
>(&stat),
sizeof(SuperLUStat_t));
216 if(this->factored)
return solve_trs(out_mat, std::forward<Mat<T>>(in_mat));
218 this->factored =
true;
222 if constexpr(std::is_same_v<T, float>) {
227 full_solve<float>(flag);
229 out_mat = std::move(in_mat);
236 full_solve<double>(flag);
238 out_mat = std::move(in_mat);
243 const fmat f_mat(arma::size(in_mat), fill::none);
247 full_solve<float>(flag);
249 if(0 == flag) flag = solve_trs(out_mat, std::forward<Mat<T>>(in_mat));
258 if constexpr(std::is_same_v<T, float>) {
261 tri_solve<float>(flag);
263 out_mat = std::move(in_mat);
268 tri_solve<double>(flag);
270 out_mat = std::move(in_mat);
273 out_mat.zeros(arma::size(in_mat));
278 while(counter++ < this->setting.iterative_refinement) {
279 if(multiplier < this->setting.tolerance)
break;
281 auto residual = conv_to<fmat>::from(in_mat / multiplier);
285 tri_solve<float>(flag);
289 const mat incre = multiplier * conv_to<mat>::from(residual);
293 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier =
arma::norm(in_mat -= this->
operator*(incre)));
A SparseMat class that holds matrices.
Definition: SparseMat.hpp:34
void zeros() override
Definition: SparseMat.hpp:46
A SparseMatSuperLU class that holds matrices.
Definition: SparseMatSuperLU.hpp:37
SparseMatSuperLU(SparseMatSuperLU &&) noexcept=delete
int direct_solve(Mat< T > &out_mat, const Mat< T > &in_mat) override
Definition: SparseMatSuperLU.hpp:69
int SUANPAN_NUM_THREADS
Definition: command.cpp:72
Definition: suanPan.h:330
double norm(const vec &)
Definition: tensor.cpp:370
#define suanpan_debug(...)
Definition: suanPan.h:307