ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
sparse_solver.hpp
1 /*******************************************************************************
2  * Copyright (C) 2025-2026 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
17 
18 #ifndef SPARSE_SOLVER_HPP
19 #define SPARSE_SOLVER_HPP
20 
21 #include "traits.hpp"
22 
23 #include <algorithm>
24 #include <numeric>
25 #ifdef EZP_TBB
26 #include <tbb/parallel_sort.h>
27 #define ezp_sort tbb::parallel_sort
28 #else
29 #define ezp_sort std::sort
30 #endif
31 
32 namespace ezp {
33  template<data_t DT, index_t IT> struct sparse_coo_mat {
34  IT n, nnz;
35  IT *row, *col;
36  DT* data;
37 
38  sparse_coo_mat(const IT n, const IT nnz, IT* const row, IT* const col, DT* const data)
39  : n(n)
40  , nnz(nnz)
41  , row(row)
42  , col(col)
43  , data(data) {}
44 
45  auto is_valid() const { return row && col && data; }
46  };
47 
48  namespace detail {
49  template<index_t IT> class csr_comparator {
50  const IT* const row_idx;
51  const IT* const col_idx;
52 
53  public:
54  csr_comparator(const IT* const in_row_idx, const IT* const in_col_idx)
55  : row_idx(in_row_idx)
56  , col_idx(in_col_idx) {}
57 
58  bool operator()(const IT idx_a, const IT idx_b) const {
59  if(row_idx[idx_a] == row_idx[idx_b]) return col_idx[idx_a] < col_idx[idx_b];
60  return row_idx[idx_a] < row_idx[idx_b];
61  }
62  };
63 
64  template<typename T> bool approx_equal(T x, T y, int ulp = 2)
65  requires(!std::numeric_limits<T>::is_integer)
66  { return std::fabs(x - y) <= std::numeric_limits<T>::epsilon() * std::fabs(x + y) * ulp || std::fabs(x - y) < std::numeric_limits<T>::min(); }
67  } // namespace detail
68 
69  template<data_t DT, index_t IT> struct sparse_csr_mat final {
70  IT n, nnz;
71  IT *row_ptr, *col_idx;
72  DT* data;
73 
74  std::vector<IT> row_storage, col_storage;
75  std::vector<DT> data_storage;
76 
77  sparse_csr_mat() = default;
78  sparse_csr_mat(const sparse_csr_mat& other)
79  : n(other.n)
80  , row_ptr(other.row_ptr)
81  , col_idx(other.col_idx)
82  , data(other.data)
83  , row_storage(other.row_storage)
84  , col_storage(other.col_storage)
85  , data_storage(other.data_storage) {
86  if(!row_storage.empty()) row_ptr = row_storage.data();
87  if(!col_storage.empty()) col_idx = col_storage.data();
88  if(!data_storage.empty()) data = data_storage.data();
89  }
90  sparse_csr_mat(sparse_csr_mat&&) = delete;
91  sparse_csr_mat& operator=(const sparse_csr_mat&) = delete;
92  sparse_csr_mat& operator=(sparse_csr_mat&&) = default;
93  ~sparse_csr_mat() = default;
94 
95  sparse_csr_mat(const IT n, const IT nnz, IT* const row_ptr, IT* const col_idx, DT* const data)
96  : n(n)
97  , nnz(nnz)
98  , row_ptr(row_ptr)
99  , col_idx(col_idx)
100  , data(data) {}
101 
102  template<data_t DT2, index_t IT2> explicit sparse_csr_mat(const sparse_coo_mat<DT2, IT2>& coo, const bool one_based = false, const bool full = false)
103  : n(IT{coo.n})
104  , nnz(IT{coo.nnz}) {
105  if(!coo.is_valid()) return;
106 
107  std::vector<IT2> index(nnz);
108  std::iota(index.begin(), index.end(), IT2(0));
109  ezp_sort(index.begin(), index.end(), detail::csr_comparator(coo.row, coo.col));
110 
111  row_storage.resize(nnz);
112  col_storage.resize(nnz);
113  data_storage.resize(nnz);
114 
115  for(auto I = IT{0}; I < nnz; ++I) {
116  row_storage[I] = coo.row[index[I]];
117  col_storage[I] = coo.col[index[I]];
118  data_storage[I] = coo.data[index[I]];
119  }
120 
121  condense(one_based, full);
122 
123  row_ptr = row_storage.data();
124  col_idx = col_storage.data();
125  data = data_storage.data();
126  }
127 
128  auto condense(const bool one_based, const bool full) {
129  auto last_row = row_storage[0], last_col = col_storage[0];
130 
131  auto current_pos = IT{0};
132  auto last_sum = DT{0};
133 
134  auto populate = [&] {
135  if(detail::approx_equal(last_sum, DT(0)) && (!full || last_row != last_col)) return;
136  row_storage[current_pos] = last_row;
137  col_storage[current_pos] = last_col;
138  data_storage[current_pos] = last_sum;
139  ++current_pos;
140  last_sum = DT{0};
141  };
142 
143  for(auto I = IT{0}; I < nnz; ++I) {
144  if(last_row != row_storage[I] || last_col != col_storage[I]) {
145  populate();
146  last_row = row_storage[I];
147  last_col = col_storage[I];
148  }
149  last_sum += data_storage[I];
150  }
151 
152  populate();
153 
154  nnz = current_pos;
155 
156  auto current_row = current_pos = IT{0};
157  auto shift = one_based ? IT{1} : IT{0};
158 
159  while(current_pos < nnz)
160  if(row_storage[current_pos] < current_row + shift) ++current_pos;
161  else row_storage[current_row++] = current_pos + shift;
162 
163  row_storage[0] = IT{0} + shift;
164  row_storage[n] = nnz + shift;
165 
166  row_storage.resize(n + 1);
167  col_storage.resize(nnz);
168  data_storage.resize(nnz);
169  }
170 
171  auto is_valid() { return row_ptr && col_idx && data; }
172  };
173 } // namespace ezp
174 
175 #endif // SPARSE_SOLVER_HPP
Definition: sparse_solver.hpp:49
Definition: sparse_solver.hpp:33
Definition: sparse_solver.hpp:69