23template<sp_d data_t, sp_i index_t>
class csr_form;
25template<sp_d data_t, sp_i index_t>
class csc_form final {
26 const data_t bin = data_t(0);
28 using index_ptr = std::unique_ptr<index_t[]>;
29 using data_ptr = std::unique_ptr<data_t[]>;
31 index_ptr row_idx =
nullptr;
32 index_ptr col_ptr =
nullptr;
33 data_ptr val_idx =
nullptr;
35 template<sp_d in_dt, sp_i in_it>
void copy_to(in_it*
const new_row_idx, in_it*
const new_col_ptr, in_dt*
const new_val_idx)
const {
36 suanpan_for(index_t(0),
n_cols + 1, [&](
const index_t I) { new_col_ptr[I] = in_it(col_ptr[I]); });
38 new_row_idx[I] = in_it(row_idx[I]);
39 new_val_idx[I] = in_dt(val_idx[I]);
43 void init(
const index_t in_elem) {
44 row_idx = std::move(index_ptr(
new index_t[in_elem]));
45 col_ptr = std::move(index_ptr(
new index_t[
n_cols + 1]));
46 val_idx = std::move(data_ptr(
new data_t[in_elem]));
61 [[nodiscard]] const index_t*
row_mem()
const {
return row_idx.get(); }
63 [[nodiscard]]
const index_t*
col_mem()
const {
return col_ptr.get(); }
65 [[nodiscard]]
const data_t*
val_mem()
const {
return val_idx.get(); }
67 [[nodiscard]] index_t*
row_mem() {
return row_idx.get(); }
69 [[nodiscard]] index_t*
col_mem() {
return col_ptr.get(); }
71 [[nodiscard]] data_t*
val_mem() {
return val_idx.get(); }
73 index_t
row(
const index_t I)
const {
return row_idx[I]; }
75 index_t
col(
const index_t I)
const {
return col_ptr[I]; }
77 data_t
val(
const index_t I)
const {
return val_idx[I]; }
79 [[nodiscard]] data_t
max()
const {
80 if(0 ==
n_elem)
return data_t(0);
88 return copy *= scalar;
93 return copy /= scalar;
110 data_t
operator()(
const index_t in_row,
const index_t in_col)
const {
111 if(in_row <
n_rows && in_col <
n_cols)
for(
auto I = col_ptr[in_col]; I < col_ptr[in_col + 1]; ++I)
if(in_row == row_idx[I])
return val_idx[I];
112 return access::rw(bin) = data_t(0);
115 Mat<data_t>
operator*(
const Col<data_t>& in_mat)
const {
116 Mat<data_t> out_mat = arma::zeros<Mat<data_t>>(in_mat.n_rows, 1);
117 for(index_t I = 0; I <
n_cols; ++I)
for(
auto J = col_ptr[I]; J < col_ptr[I + 1]; ++J) out_mat(row_idx[J]) += val_idx[J] * in_mat(I);
121 Mat<data_t>
operator*(
const Mat<data_t>& in_mat)
const {
122 Mat<data_t> out_mat = arma::zeros<Mat<data_t>>(in_mat.n_rows, in_mat.n_cols);
123 for(index_t I = 0; I <
n_cols; ++I)
for(
auto J = col_ptr[I]; J < col_ptr[I + 1]; ++J) out_mat.row(row_idx[J]) += val_idx[J] * in_mat.row(I);
129 : n_rows{in_mat.n_rows}
130 , n_cols{in_mat.n_cols}
131 , n_elem{in_mat.n_elem} {
133 in_mat.copy_to(row_idx.get(), col_ptr.get(), val_idx.get());
137 : row_idx{std::move(in_mat.row_idx)}
138 , col_ptr{std::move(in_mat.col_ptr)}
139 , val_idx{std::move(in_mat.val_idx)}
140 , n_rows{in_mat.n_rows}
141 , n_cols{in_mat.n_cols}
142 , n_elem{in_mat.n_elem} {}
145 if(
this == &in_mat)
return *
this;
146 access::rw(n_rows) = in_mat.
n_rows;
147 access::rw(n_cols) = in_mat.
n_cols;
148 access::rw(n_elem) = in_mat.
n_elem;
150 in_mat.copy_to(row_idx.get(), col_ptr.get(), val_idx.get());
155 if(
this == &in_mat)
return *
this;
156 access::rw(n_rows) = in_mat.
n_rows;
157 access::rw(n_cols) = in_mat.n_cols;
158 access::rw(n_elem) = in_mat.n_elem;
159 row_idx = std::move(in_mat.row_idx);
160 col_ptr = std::move(in_mat.col_ptr);
161 val_idx = std::move(in_mat.val_idx);
166 suanpan_info(
"A sparse matrix in triplet form with size of {} by {}, the sparsity of {:.3f}%.\n", n_rows, n_cols, 1E2 -
static_cast<double>(n_elem) /
static_cast<double>(n_rows) /
static_cast<double>(n_cols) * 1E2);
167 if(n_elem > index_t(1000)) {
173 for(index_t I = 0; I < n_elem; ++I) {
174 if(I >= col_ptr[c_idx]) ++c_idx;
175 suanpan_info(
"({}, {}) ===> {:+.8E}\n", row_idx[I], c_idx - 1, val_idx[I]);
180 : n_rows(index_t(in_mat.n_rows))
181 , n_cols(index_t(in_mat.n_cols)) {
187 const sp_i auto shift = index_t(base);
190 row_idx[I] = index_t(in_mat.row_idx[I]) + shift;
191 val_idx[I] = data_t(in_mat.val_idx[I]);
194 in_it current_pos = 0, current_col = 0;
196 while(current_pos < in_mat.
n_elem)
197 if(in_mat.col_idx[current_pos] < current_col) ++current_pos;
198 else col_ptr[current_col++] = index_t(current_pos) + shift;
207 access::rw(n_rows) = index_t(in_mat.
n_rows);
208 access::rw(n_cols) = index_t(in_mat.
n_cols);
210 init(access::rw(n_elem) = index_t(in_mat.
n_elem));
213 row_idx[I] = index_t(in_mat.row_idx[I]);
214 val_idx[I] = data_t(in_mat.val_idx[I]);
217 in_it current_pos = 0, current_col = 0;
219 while(current_pos < in_mat.
n_elem)
220 if(in_mat.col_idx[current_pos] < current_col) ++current_pos;
221 else col_ptr[current_col++] = index_t(current_pos);
223 col_ptr[0] = index_t(0);
224 col_ptr[n_cols] = n_elem;
Definition: suanPan.h:319
#define suanpan_info
Definition: suanPan.h:293
#define suanpan_for_each
Definition: suanPan.h:177
constexpr T suanpan_max_element(T start, T end)
Definition: utility.h:36
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27