18 #ifndef SPARSE_SOLVER_HPP
19 #define SPARSE_SOLVER_HPP
26 #include <tbb/parallel_sort.h>
27 #define ezp_sort tbb::parallel_sort
29 #define ezp_sort std::sort
38 sparse_coo_mat(
const IT n,
const IT nnz, IT*
const row, IT*
const col, DT*
const data)
45 auto is_valid()
const {
return row && col && data; }
50 const IT*
const row_idx;
51 const IT*
const col_idx;
54 csr_comparator(
const IT*
const in_row_idx,
const IT*
const in_col_idx)
56 , col_idx(in_col_idx) {}
58 bool operator()(
const IT idx_a,
const IT idx_b)
const {
59 if(row_idx[idx_a] == row_idx[idx_b])
return col_idx[idx_a] < col_idx[idx_b];
60 return row_idx[idx_a] < row_idx[idx_b];
64 template<
typename T>
bool approx_equal(T x, T y,
int ulp = 2)
65 requires(!std::numeric_limits<T>::is_integer)
66 {
return std::fabs(x - y) <= std::numeric_limits<T>::epsilon() * std::fabs(x + y) * ulp || std::fabs(x - y) < std::numeric_limits<T>::min(); }
71 IT *row_ptr, *col_idx;
74 std::vector<IT> row_storage, col_storage;
75 std::vector<DT> data_storage;
80 , row_ptr(other.row_ptr)
81 , col_idx(other.col_idx)
83 , row_storage(other.row_storage)
84 , col_storage(other.col_storage)
85 , data_storage(other.data_storage) {
86 if(!row_storage.empty()) row_ptr = row_storage.data();
87 if(!col_storage.empty()) col_idx = col_storage.data();
88 if(!data_storage.empty()) data = data_storage.data();
95 sparse_csr_mat(
const IT n,
const IT nnz, IT*
const row_ptr, IT*
const col_idx, DT*
const data)
105 if(!coo.is_valid())
return;
107 std::vector<IT2> index(nnz);
108 std::iota(index.begin(), index.end(), IT2(0));
111 row_storage.resize(nnz);
112 col_storage.resize(nnz);
113 data_storage.resize(nnz);
115 for(
auto I = IT{0}; I < nnz; ++I) {
116 row_storage[I] = coo.row[index[I]];
117 col_storage[I] = coo.col[index[I]];
118 data_storage[I] = coo.data[index[I]];
121 condense(one_based, full);
123 row_ptr = row_storage.data();
124 col_idx = col_storage.data();
125 data = data_storage.data();
128 auto condense(
const bool one_based,
const bool full) {
129 auto last_row = row_storage[0], last_col = col_storage[0];
131 auto current_pos = IT{0};
132 auto last_sum = DT{0};
134 auto populate = [&] {
135 if(detail::approx_equal(last_sum, DT(0)) && (!full || last_row != last_col))
return;
136 row_storage[current_pos] = last_row;
137 col_storage[current_pos] = last_col;
138 data_storage[current_pos] = last_sum;
143 for(
auto I = IT{0}; I < nnz; ++I) {
144 if(last_row != row_storage[I] || last_col != col_storage[I]) {
146 last_row = row_storage[I];
147 last_col = col_storage[I];
149 last_sum += data_storage[I];
156 auto current_row = current_pos = IT{0};
157 auto shift = one_based ? IT{1} : IT{0};
159 while(current_pos < nnz)
160 if(row_storage[current_pos] < current_row + shift) ++current_pos;
161 else row_storage[current_row++] = current_pos + shift;
163 row_storage[0] = IT{0} + shift;
164 row_storage[n] = nnz + shift;
166 row_storage.resize(n + 1);
167 col_storage.resize(nnz);
168 data_storage.resize(nnz);
171 auto is_valid() {
return row_ptr && col_idx && data; }
Definition: sparse_solver.hpp:49
Definition: sparse_solver.hpp:33
Definition: sparse_solver.hpp:69