23template<sp_d data_t, sp_i index_t>
class csc_form;
25template<sp_d data_t, sp_i index_t>
class csr_form final {
26 const data_t bin = data_t(0);
28 using index_ptr = std::unique_ptr<index_t[]>;
29 using data_ptr = std::unique_ptr<data_t[]>;
31 index_ptr row_ptr =
nullptr;
32 index_ptr col_idx =
nullptr;
33 data_ptr val_idx =
nullptr;
35 template<sp_d in_dt, sp_i in_it>
void copy_to(in_it*
const new_row_ptr, in_it*
const new_col_idx, in_dt*
const new_val_idx)
const {
36 suanpan_for(index_t(0),
n_rows + 1, [&](
const index_t I) { new_row_ptr[I] = in_it(row_ptr[I]); });
38 new_col_idx[I] = in_it(col_idx[I]);
39 new_val_idx[I] = in_dt(val_idx[I]);
43 void init(
const index_t in_elem) {
44 row_ptr = std::move(index_ptr(
new index_t[
n_rows + 1]));
45 col_idx = std::move(index_ptr(
new index_t[in_elem]));
46 val_idx = std::move(data_ptr(
new data_t[in_elem]));
61 [[nodiscard]] const index_t*
row_mem()
const {
return row_ptr.get(); }
63 [[nodiscard]]
const index_t*
col_mem()
const {
return col_idx.get(); }
65 [[nodiscard]]
const data_t*
val_mem()
const {
return val_idx.get(); }
67 [[nodiscard]] index_t*
row_mem() {
return row_ptr.get(); }
69 [[nodiscard]] index_t*
col_mem() {
return col_idx.get(); }
71 [[nodiscard]] data_t*
val_mem() {
return val_idx.get(); }
73 [[nodiscard]] data_t
max()
const {
74 if(0 ==
n_elem)
return data_t(0);
82 return copy *= scalar;
87 return copy /= scalar;
104 const data_t&
operator()(
const index_t in_row,
const index_t in_col)
const {
105 if(in_row <
n_rows && in_col <
n_cols)
for(
auto I = row_ptr[in_row]; I < row_ptr[in_row + 1]; ++I)
if(in_col == col_idx[I])
return val_idx[I];
106 return access::rw(bin) = data_t(0);
109 Mat<data_t>
operator*(
const Col<data_t>& in_mat)
const {
110 Mat<data_t> out_mat = arma::zeros<Mat<data_t>>(in_mat.n_rows, 1);
112 suanpan_for(index_t(0),
n_rows, [&](
const index_t I) {
for(
auto J = row_ptr[I]; J < row_ptr[I + 1]; ++J) out_mat(I) += val_idx[J] * in_mat(col_idx[J]); });
117 Mat<data_t>
operator*(
const Mat<data_t>& in_mat)
const {
118 Mat<data_t> out_mat = arma::zeros<Mat<data_t>>(in_mat.n_rows, in_mat.n_cols);
120 suanpan_for(index_t(0),
n_rows, [&](
const index_t I) {
for(
auto J = row_ptr[I]; J < row_ptr[I + 1]; ++J) out_mat.row(I) += val_idx[J] * in_mat.row(col_idx[J]); });
127 : n_rows{in_mat.n_rows}
128 , n_cols{in_mat.n_cols}
129 , n_elem{in_mat.n_elem} {
131 in_mat.copy_to(row_ptr.get(), col_idx.get(), val_idx.get());
135 : row_ptr{std::move(in_mat.row_ptr)}
136 , col_idx{std::move(in_mat.col_idx)}
137 , val_idx{std::move(in_mat.val_idx)}
138 , n_rows{in_mat.n_rows}
139 , n_cols{in_mat.n_cols}
140 , n_elem{in_mat.n_elem} {}
143 if(
this == &in_mat)
return *
this;
144 access::rw(n_rows) = in_mat.
n_rows;
145 access::rw(n_cols) = in_mat.
n_cols;
146 access::rw(n_elem) = in_mat.
n_elem;
148 in_mat.copy_to(row_ptr.get(), col_idx.get(), val_idx.get());
153 if(
this == &in_mat)
return *
this;
154 access::rw(n_rows) = in_mat.
n_rows;
155 access::rw(n_cols) = in_mat.n_cols;
156 access::rw(n_elem) = in_mat.n_elem;
157 row_ptr = std::move(in_mat.row_ptr);
158 col_idx = std::move(in_mat.col_idx);
159 val_idx = std::move(in_mat.val_idx);
164 suanpan_info(
"A sparse matrix in triplet form with size of {} by {}, the sparsity of {:.3f}%.\n",
static_cast<unsigned>(n_rows),
static_cast<unsigned>(n_cols), 1E2 -
static_cast<double>(n_elem) /
static_cast<double>(n_rows) /
static_cast<double>(n_cols) * 1E2);
165 if(n_elem > index_t(1000)) {
171 for(index_t I = 0; I < n_elem; ++I) {
172 if(I >= row_ptr[c_idx]) ++c_idx;
173 suanpan_info(
"({}, {}) ===> {:+.8E}\n",
static_cast<unsigned>(c_idx) - 1,
static_cast<unsigned>(col_idx[I]), val_idx[I]);
178 : n_rows(index_t(in_mat.n_rows))
179 , n_cols(index_t(in_mat.n_cols)) {
185 const sp_i auto shift = index_t(base);
188 col_idx[I] = index_t(in_mat.col_idx[I]) + shift;
189 val_idx[I] = data_t(in_mat.val_idx[I]);
192 in_it current_pos = 0, current_row = 0;
194 while(current_pos < in_mat.
n_elem)
195 if(in_mat.row_idx[current_pos] < current_row) ++current_pos;
196 else row_ptr[current_row++] = index_t(current_pos) + shift;
205 access::rw(n_rows) = index_t(in_mat.
n_rows);
206 access::rw(n_cols) = index_t(in_mat.
n_cols);
208 init(access::rw(n_elem) = index_t(in_mat.
n_elem));
211 col_idx[I] = index_t(in_mat.col_idx[I]);
212 val_idx[I] = data_t(in_mat.val_idx[I]);
215 in_it current_pos = 0, current_row = 0;
217 while(current_pos < in_mat.
n_elem)
218 if(in_mat.row_idx[current_pos] < current_row) ++current_pos;
219 else row_ptr[current_row++] = index_t(current_pos);
221 row_ptr[0] = index_t(0);
222 row_ptr[n_rows] = n_elem;
Definition: suanPan.h:319
#define suanpan_info
Definition: suanPan.h:293
#define suanpan_for_each
Definition: suanPan.h:177
constexpr T suanpan_max_element(T start, T end)
Definition: utility.h:36
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27