ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
Loading...
Searching...
No Matches
sparse_solver.hpp
1/*******************************************************************************
2 * Copyright (C) 2025-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
17
18#ifndef SPARSE_SOLVER_HPP
19#define SPARSE_SOLVER_HPP
20
21#include "traits.hpp"
22
23#include <algorithm>
24#include <numeric>
25#ifdef EZP_TBB
26#include <tbb/parallel_sort.h>
27#define ezp_sort tbb::parallel_sort
28#else
29#define ezp_sort std::sort
30#endif
31
32namespace ezp {
33 template<data_t DT, index_t IT> struct sparse_coo_mat {
34 IT n, nnz;
35 IT *row, *col;
36 DT* data;
37
38 sparse_coo_mat(const IT n, const IT nnz, IT* const row, IT* const col, DT* const data)
39 : n(n)
40 , nnz(nnz)
41 , row(row)
42 , col(col)
43 , data(data) {}
44
45 auto is_valid() const { return row && col && data; }
46 };
47
48 namespace detail {
49 template<index_t IT> class csr_comparator {
50 const IT* const row_idx;
51 const IT* const col_idx;
52
53 public:
54 csr_comparator(const IT* const in_row_idx, const IT* const in_col_idx)
55 : row_idx(in_row_idx)
56 , col_idx(in_col_idx) {}
57
58 bool operator()(const IT idx_a, const IT idx_b) const {
59 if(row_idx[idx_a] == row_idx[idx_b]) return col_idx[idx_a] < col_idx[idx_b];
60 return row_idx[idx_a] < row_idx[idx_b];
61 }
62 };
63
64 template<typename T> bool approx_equal(T x, T y, int ulp = 2)
65 requires(!std::numeric_limits<T>::is_integer)
66 { return std::fabs(x - y) <= std::numeric_limits<T>::epsilon() * std::fabs(x + y) * ulp || std::fabs(x - y) < std::numeric_limits<T>::min(); }
67 } // namespace detail
68
69 template<data_t DT, index_t IT> struct sparse_csr_mat final {
70 IT n, nnz;
71 IT *row_ptr, *col_idx;
72 DT* data;
73
74 std::vector<IT> row_storage, col_storage;
75 std::vector<DT> data_storage;
76
77 sparse_csr_mat() = default;
78 sparse_csr_mat(const sparse_csr_mat& other)
79 : n(other.n)
80 , row_ptr(other.row_ptr)
81 , col_idx(other.col_idx)
82 , data(other.data)
83 , row_storage(other.row_storage)
84 , col_storage(other.col_storage)
85 , data_storage(other.data_storage) {
86 if(!row_storage.empty()) row_ptr = row_storage.data();
87 if(!col_storage.empty()) col_idx = col_storage.data();
88 if(!data_storage.empty()) data = data_storage.data();
89 }
91 sparse_csr_mat& operator=(const sparse_csr_mat&) = delete;
92 sparse_csr_mat& operator=(sparse_csr_mat&&) = default;
93 ~sparse_csr_mat() = default;
94
95 sparse_csr_mat(const IT n, const IT nnz, IT* const row_ptr, IT* const col_idx, DT* const data)
96 : n(n)
97 , nnz(nnz)
98 , row_ptr(row_ptr)
99 , col_idx(col_idx)
100 , data(data) {}
101
102 template<data_t DT2, index_t IT2> explicit sparse_csr_mat(const sparse_coo_mat<DT2, IT2>& coo, const bool one_based = false, const bool full = false)
103 : n(IT{coo.n})
104 , nnz(IT{coo.nnz}) {
105 if(!coo.is_valid()) return;
106
107 std::vector<IT2> index(nnz);
108 std::iota(index.begin(), index.end(), IT2(0));
109 ezp_sort(index.begin(), index.end(), detail::csr_comparator(coo.row, coo.col));
110
111 row_storage.resize(nnz);
112 col_storage.resize(nnz);
113 data_storage.resize(nnz);
114
115 for(auto I = IT{0}; I < nnz; ++I) {
116 row_storage[I] = coo.row[index[I]];
117 col_storage[I] = coo.col[index[I]];
118 data_storage[I] = coo.data[index[I]];
119 }
120
121 condense(one_based, full);
122
123 row_ptr = row_storage.data();
124 col_idx = col_storage.data();
125 data = data_storage.data();
126 }
127
128 auto condense(const bool one_based, const bool full) {
129 auto last_row = row_storage[0], last_col = col_storage[0];
130
131 auto current_pos = IT{0};
132 auto last_sum = DT{0};
133
134 auto populate = [&] {
135 if(detail::approx_equal(last_sum, DT(0)) && (!full || last_row != last_col)) return;
136 row_storage[current_pos] = last_row;
137 col_storage[current_pos] = last_col;
138 data_storage[current_pos] = last_sum;
139 ++current_pos;
140 last_sum = DT{0};
141 };
142
143 for(auto I = IT{0}; I < nnz; ++I) {
144 if(last_row != row_storage[I] || last_col != col_storage[I]) {
145 populate();
146 last_row = row_storage[I];
147 last_col = col_storage[I];
148 }
149 last_sum += data_storage[I];
150 }
151
152 populate();
153
154 nnz = current_pos;
155
156 auto current_row = current_pos = IT{0};
157 auto shift = one_based ? IT{1} : IT{0};
158
159 while(current_pos < nnz)
160 if(row_storage[current_pos] < current_row + shift) ++current_pos;
161 else row_storage[current_row++] = current_pos + shift;
162
163 row_storage[0] = IT{0} + shift;
164 row_storage[n] = nnz + shift;
165
166 row_storage.resize(n + 1);
167 col_storage.resize(nnz);
168 data_storage.resize(nnz);
169 }
170
171 auto is_valid() { return row_ptr && col_idx && data; }
172 };
173} // namespace ezp
174
175#endif // SPARSE_SOLVER_HPP
Definition sparse_solver.hpp:49
Definition sparse_solver.hpp:33
Definition sparse_solver.hpp:69