suanPan
csr_form.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2024 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
17
18#ifndef CSR_FORM_HPP
19#define CSR_FORM_HPP
20
21#include "triplet_form.hpp"
22
23template<sp_d data_t, sp_i index_t> class csc_form;
24
25template<sp_d data_t, sp_i index_t> class csr_form final {
26 const data_t bin = data_t(0);
27
28 using index_ptr = std::unique_ptr<index_t[]>;
29 using data_ptr = std::unique_ptr<data_t[]>;
30
31 index_ptr row_ptr = nullptr; // index storage
32 index_ptr col_idx = nullptr; // index storage
33 data_ptr val_idx = nullptr; // value storage
34
35 template<sp_d in_dt, sp_i in_it> void copy_to(in_it* const new_row_ptr, in_it* const new_col_idx, in_dt* const new_val_idx) const {
36 suanpan::for_each(n_rows + 1, [&](const index_t I) { new_row_ptr[I] = in_it(row_ptr[I]); });
37 suanpan::for_each(n_elem, [&](const index_t I) {
38 new_col_idx[I] = in_it(col_idx[I]);
39 new_val_idx[I] = in_dt(val_idx[I]);
40 });
41 }
42
43 void init(const index_t in_elem) {
44 row_ptr = std::move(index_ptr(new index_t[n_rows + 1]));
45 col_idx = std::move(index_ptr(new index_t[in_elem]));
46 val_idx = std::move(data_ptr(new data_t[in_elem]));
47 }
48
49public:
50 const index_t n_rows = 0;
51 const index_t n_cols = 0;
52 const index_t n_elem = 0;
53
54 csr_form() = default;
55 csr_form(const csr_form&);
56 csr_form(csr_form&&) noexcept;
57 csr_form& operator=(const csr_form&);
58 csr_form& operator=(csr_form&&) noexcept;
59 ~csr_form() = default;
60
61 [[nodiscard]] const index_t* row_mem() const { return row_ptr.get(); }
62
63 [[nodiscard]] const index_t* col_mem() const { return col_idx.get(); }
64
65 [[nodiscard]] const data_t* val_mem() const { return val_idx.get(); }
66
67 [[nodiscard]] index_t* row_mem() { return row_ptr.get(); }
68
69 [[nodiscard]] index_t* col_mem() { return col_idx.get(); }
70
71 [[nodiscard]] data_t* val_mem() { return val_idx.get(); }
72
73 [[nodiscard]] data_t max() const {
74 if(0 == n_elem) return data_t(0);
75 return *suanpan::max_element(val_idx.get(), val_idx.get() + n_elem);
76 }
77
78 void print() const;
79
80 template<sp_d T2> csr_form operator*(const T2 scalar) const {
81 auto copy = *this;
82 return copy *= scalar;
83 }
84
85 template<sp_d T2> csr_form operator/(const T2 scalar) const {
86 auto copy = *this;
87 return copy /= scalar;
88 }
89
90 template<sp_d T2> csr_form& operator*=(const T2 scalar) {
91 suanpan_for_each(val_idx.get(), val_idx.get() + n_elem, [=](data_t& I) { I *= data_t(scalar); });
92 return *this;
93 }
94
95 template<sp_d T2> csr_form& operator/=(const T2 scalar) {
96 suanpan_for_each(val_idx.get(), val_idx.get() + n_elem, [=](data_t& I) { I /= data_t(scalar); });
97 return *this;
98 }
99
100 template<sp_d in_dt, sp_i in_it> explicit csr_form(triplet_form<in_dt, in_it>&, SparseBase = SparseBase::ZERO, bool = false);
101
102 template<sp_d in_dt, sp_i in_it> csr_form& operator=(triplet_form<in_dt, in_it>&);
103
104 data_t operator()(const index_t in_row, const index_t in_col) const {
105 if(in_row < n_rows && in_col < n_cols) for(auto I = row_ptr[in_row]; I < row_ptr[in_row + 1]; ++I) if(in_col == col_idx[I]) return val_idx[I];
106 return access::rw(bin) = data_t(0);
107 }
108
109 Mat<data_t> operator*(const Col<data_t>& in_mat) const {
110 Mat<data_t> out_mat = arma::zeros<Mat<data_t>>(in_mat.n_rows, 1);
111
112 suanpan::for_each(n_rows, [&](const index_t I) { for(auto J = row_ptr[I]; J < row_ptr[I + 1]; ++J) out_mat(I) += val_idx[J] * in_mat(col_idx[J]); });
113
114 return out_mat;
115 }
116
117 Mat<data_t> operator*(const Mat<data_t>& in_mat) const {
118 Mat<data_t> out_mat = arma::zeros<Mat<data_t>>(in_mat.n_rows, in_mat.n_cols);
119
120 suanpan::for_each(n_rows, [&](const index_t I) { for(auto J = row_ptr[I]; J < row_ptr[I + 1]; ++J) out_mat.row(I) += val_idx[J] * in_mat.row(col_idx[J]); });
121
122 return out_mat;
123 }
124};
125
126template<sp_d data_t, sp_i index_t> csr_form<data_t, index_t>::csr_form(const csr_form& in_mat)
127 : n_rows{in_mat.n_rows}
128 , n_cols{in_mat.n_cols}
129 , n_elem{in_mat.n_elem} {
130 init(n_elem);
131 in_mat.copy_to(row_ptr.get(), col_idx.get(), val_idx.get());
132}
133
134template<sp_d data_t, sp_i index_t> csr_form<data_t, index_t>::csr_form(csr_form&& in_mat) noexcept
135 : row_ptr{std::move(in_mat.row_ptr)}
136 , col_idx{std::move(in_mat.col_idx)}
137 , val_idx{std::move(in_mat.val_idx)}
138 , n_rows{in_mat.n_rows}
139 , n_cols{in_mat.n_cols}
140 , n_elem{in_mat.n_elem} {}
141
142template<sp_d data_t, sp_i index_t> csr_form<data_t, index_t>& csr_form<data_t, index_t>::operator=(const csr_form& in_mat) {
143 if(this == &in_mat) return *this;
144 access::rw(n_rows) = in_mat.n_rows;
145 access::rw(n_cols) = in_mat.n_cols;
146 access::rw(n_elem) = in_mat.n_elem;
147 init(n_elem);
148 in_mat.copy_to(row_ptr.get(), col_idx.get(), val_idx.get());
149 return *this;
150}
151
152template<sp_d data_t, sp_i index_t> csr_form<data_t, index_t>& csr_form<data_t, index_t>::operator=(csr_form&& in_mat) noexcept {
153 if(this == &in_mat) return *this;
154 access::rw(n_rows) = in_mat.n_rows;
155 access::rw(n_cols) = in_mat.n_cols;
156 access::rw(n_elem) = in_mat.n_elem;
157 row_ptr = std::move(in_mat.row_ptr);
158 col_idx = std::move(in_mat.col_idx);
159 val_idx = std::move(in_mat.val_idx);
160 return *this;
161}
162
163template<sp_d data_t, sp_i index_t> void csr_form<data_t, index_t>::print() const {
164 suanpan_info("A sparse matrix in triplet form with size of {} by {}, the sparsity of {:.3f}%.\n", n_rows, n_cols, 1E2 - static_cast<double>(n_elem) / static_cast<double>(n_rows) / static_cast<double>(n_cols) * 1E2);
165 if(n_elem > index_t(1000)) {
166 suanpan_info("More than 1000 elements exist.\n");
167 return;
168 }
169
170 index_t c_idx = 1;
171 for(index_t I = 0; I < n_elem; ++I) {
172 if(I >= row_ptr[c_idx]) ++c_idx;
173 suanpan_info("({}, {}) ===> {:+.8E}\n", c_idx - 1, col_idx[I], val_idx[I]);
174 }
175}
176
177template<sp_d data_t, sp_i index_t> template<sp_d in_dt, sp_i in_it> csr_form<data_t, index_t>::csr_form(triplet_form<in_dt, in_it>& in_mat, const SparseBase base, const bool full)
178 : n_rows(index_t(in_mat.n_rows))
179 , n_cols(index_t(in_mat.n_cols)) {
180 if(full) in_mat.full_csr_condense();
181 else in_mat.csr_condense();
182
183 init(access::rw(n_elem) = index_t(in_mat.n_elem));
184
185 const sp_i auto shift = index_t(base);
186
187 suanpan::for_each(in_mat.n_elem, [&](const in_it I) {
188 col_idx[I] = index_t(in_mat.col_idx[I]) + shift;
189 val_idx[I] = data_t(in_mat.val_idx[I]);
190 });
191
192 in_it current_pos = 0, current_row = 0;
193
194 while(current_pos < in_mat.n_elem)
195 if(in_mat.row_idx[current_pos] < current_row) ++current_pos;
196 else row_ptr[current_row++] = index_t(current_pos) + shift;
197
198 row_ptr[0] = shift;
199 row_ptr[n_rows] = n_elem + shift;
200}
201
202template<sp_d data_t, sp_i index_t> template<sp_d in_dt, sp_i in_it> csr_form<data_t, index_t>& csr_form<data_t, index_t>::operator=(triplet_form<in_dt, in_it>& in_mat) {
203 in_mat.csr_condense();
204
205 access::rw(n_rows) = index_t(in_mat.n_rows);
206 access::rw(n_cols) = index_t(in_mat.n_cols);
207
208 init(access::rw(n_elem) = index_t(in_mat.n_elem));
209
210 suanpan::for_each(in_mat.n_elem, [&](const in_it I) {
211 col_idx[I] = index_t(in_mat.col_idx[I]);
212 val_idx[I] = data_t(in_mat.val_idx[I]);
213 });
214
215 in_it current_pos = 0, current_row = 0;
216
217 while(current_pos < in_mat.n_elem)
218 if(in_mat.row_idx[current_pos] < current_row) ++current_pos;
219 else row_ptr[current_row++] = index_t(current_pos);
220
221 row_ptr[0] = index_t(0);
222 row_ptr[n_rows] = n_elem;
223
224 return *this;
225}
226
227#endif
Definition: csc_form.hpp:25
Definition: csr_form.hpp:25
data_t max() const
Definition: csr_form.hpp:73
csr_form()=default
const index_t * row_mem() const
Definition: csr_form.hpp:61
index_t * row_mem()
Definition: csr_form.hpp:67
Mat< data_t > operator*(const Col< data_t > &in_mat) const
Definition: csr_form.hpp:109
const index_t n_rows
Definition: csr_form.hpp:50
const index_t n_cols
Definition: csr_form.hpp:51
csr_form & operator=(const csr_form &)
Definition: csr_form.hpp:142
csr_form & operator=(triplet_form< in_dt, in_it > &)
void print() const
Definition: csr_form.hpp:163
const data_t * val_mem() const
Definition: csr_form.hpp:65
csr_form & operator*=(const T2 scalar)
Definition: csr_form.hpp:90
const index_t * col_mem() const
Definition: csr_form.hpp:63
csr_form & operator/=(const T2 scalar)
Definition: csr_form.hpp:95
csr_form operator*(const T2 scalar) const
Definition: csr_form.hpp:80
csr_form operator/(const T2 scalar) const
Definition: csr_form.hpp:85
Mat< data_t > operator*(const Mat< data_t > &in_mat) const
Definition: csr_form.hpp:117
const index_t n_elem
Definition: csr_form.hpp:52
index_t * col_mem()
Definition: csr_form.hpp:69
data_t * val_mem()
Definition: csr_form.hpp:71
data_t operator()(const index_t in_row, const index_t in_col) const
Definition: csr_form.hpp:104
Definition: triplet_form.hpp:62
const index_t n_rows
Definition: triplet_form.hpp:128
void full_csr_condense()
Definition: triplet_form.hpp:214
void csr_condense()
Definition: triplet_form.hpp:204
const index_t n_cols
Definition: triplet_form.hpp:129
const index_t n_elem
Definition: triplet_form.hpp:130
Definition: suanPan.h:331
constexpr T max_element(T start, T end)
Definition: utility.h:39
void for_each(const IT start, const IT end, F &&FN)
Definition: utility.h:28
#define suanpan_info
Definition: suanPan.h:305
#define suanpan_for_each
Definition: suanPan.h:187
SparseBase
Definition: triplet_form.hpp:27