ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
traits.hpp
1 /*******************************************************************************
2  * Copyright (C) 2025-2026 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
17 
18 #ifndef TRAITS_HPP
19 #define TRAITS_HPP
20 
21 #include "ezp.h"
22 
23 #include <array>
24 #include <atomic>
25 #include <cmath>
26 #include <ranges>
27 #include <vector>
28 
29 namespace ezp {
30  template<typename T> constexpr auto always_false_v = false;
31 
32  template<typename T> concept floating_t = std::is_same_v<T, float> || std::is_same_v<T, double>;
33  template<typename T> concept complex_t = std::is_same_v<T, std::complex<typename T::value_type>> && floating_t<typename T::value_type>;
34  template<typename T> concept data_t = floating_t<T> || complex_t<T>;
35  template<typename T> concept index_t = std::is_same_v<T, std::int32_t> || std::is_same_v<T, std::int64_t>;
36 
37  template<typename T> concept has_mem = requires(T t) { requires data_t<std::remove_pointer_t<decltype(t.mem())>>; };
38  template<typename T> concept has_memptr = requires(T t) { requires data_t<std::remove_pointer_t<decltype(t.memptr())>>; };
39  template<typename T> concept has_data_method = requires(T t) { requires data_t<std::remove_pointer_t<decltype(t.data())>>; };
40  template<typename T> concept has_data_member = requires(T t) { requires data_t<std::remove_pointer_t<decltype(t.data)>>; };
41  template<typename T> concept has_iterator = requires(T t) {
42  requires std::ranges::contiguous_range<T>;
43  requires data_t<std::remove_reference_t<decltype(*t.begin())>>;
44  };
45  template<typename T> concept has_data_pointer = has_mem<T> || has_memptr<T> || has_data_method<T> || has_iterator<T>;
46 
47  template<typename T> concept full_container_t = requires(T t) { requires index_t<decltype(t.n_rows)> && index_t<decltype(t.n_cols)> && has_data_pointer<T>; };
48  template<typename T> concept band_container_t = requires(T t) { requires index_t<decltype(t.kl)> && index_t<decltype(t.ku)> && full_container_t<T>; };
49  template<typename T> concept band_symm_container_t = requires(T t) { requires index_t<decltype(t.klu)> && full_container_t<T>; };
50 
51  template<typename T> struct WorkType {
52  static_assert(!std::is_same_v<T, T>, "not defined");
53  };
54  template<> struct WorkType<double> {
55  using type = double;
56  };
57  template<> struct WorkType<float> {
58  using type = float;
59  };
60  template<> struct WorkType<complex8> {
61  using type = float;
62  };
63  template<> struct WorkType<complex16> {
64  using type = double;
65  };
66  template<typename T> using work_t = typename WorkType<T>::type;
67 
68  template<data_t DT, index_t IT> struct base_mat {
69  using data_type = DT;
70  using index_type = IT;
71 
72  IT n_rows, n_cols;
73  DT* data;
74  bool distributed{};
75 
76  base_mat() = default;
77 
78  base_mat(const IT rows, const IT cols, DT* const ptr, const bool dist = false)
79  : n_rows(rows)
80  , n_cols(cols)
81  , data(ptr)
82  , distributed(dist) {}
83  };
84 
85  template<data_t DT, index_t IT> struct full_mat : base_mat<DT, IT> {
87  };
88 
89  template<data_t DT, index_t IT> struct band_mat : base_mat<DT, IT> {
90  IT kl, ku;
91 
92  band_mat() = default;
93 
94  band_mat(const IT rows, const IT cols, const IT kl, const IT ku, DT* const ptr, const bool dist = false)
95  : base_mat<DT, IT>(rows, cols, ptr, dist)
96  , kl(kl)
97  , ku(ku) {}
98  };
99 
100  template<data_t DT, index_t IT> struct band_symm_mat : base_mat<DT, IT> {
101  IT klu;
102 
103  band_symm_mat() = default;
104 
105  band_symm_mat(const IT rows, const IT cols, const IT klu, DT* const ptr, const bool dist = false)
106  : base_mat<DT, IT>(rows, cols, ptr, dist)
107  , klu(klu) {}
108  };
109 
110  template<typename T> concept wrapper_t = requires(T t) {
111  requires std::is_same_v<T, full_mat<typename T::data_type, typename T::index_type>> || std::is_same_v<T, band_mat<typename T::data_type, typename T::index_type>> || std::is_same_v<T, band_symm_mat<typename T::data_type, typename T::index_type>>;
112  };
113 
114  template<index_t IT> using desc = std::array<IT, 9>;
115 
116  class blacs_base {
117 #ifdef EZP_RELEASE_ONCE
118  protected:
119  static std::atomic_bool RELEASED;
120 #endif
121  };
122 
123 #ifdef EZP_RELEASE_ONCE
124 #define EZP_ENSURE_SAFE_EXIT std::atomic_bool ezp::blacs_base::RELEASED = false;
125 #endif
126 
127  template<index_t IT = int_t> class blacs_env final : blacs_base {
128  static constexpr IT ZERO{0}, ONE{1};
129 
130  static std::atomic_bool FINALIZE;
131 
132  IT _rank{-1}, _size{-1};
133 
134  public:
135  blacs_env() { blacs_pinfo(&_rank, &_size); }
136 
137  ~blacs_env() {
138 #ifdef EZP_RELEASE_ONCE
139  if(RELEASED.exchange(true)) return;
140 #endif
141  blacs_exit(FINALIZE ? &ZERO : &ONE);
142  }
143 
151  static void do_not_manage_mpi() { FINALIZE = false; }
152 
153  auto rank() const { return _rank; }
154  auto size() const { return _size; }
155  };
156 
157  template<index_t IT> std::atomic_bool blacs_env<IT>::FINALIZE{true};
158 
170  template<index_t IT = int_t> const auto& get_env() {
171  static const blacs_env<IT> scoped_env;
172 
173  return scoped_env;
174  }
175 
176  template<index_t IT> class blacs_context final {
177  static constexpr IT ZERO{0}, ONE{1}, NEGONE{-1};
178  static constexpr char SCOPE = 'A', TOP = ' ';
179 
180  char layout{'R'};
181 
182  IT info{-1};
183 
184  auto init() {
185  blacs_get(&NEGONE, &ZERO, &context);
186  blacs_gridinit(&context, &layout, &n_rows, &n_cols);
187  blacs_pinfo(&rank, &size);
188  blacs_gridinfo(&context, &n_rows, &n_cols, &my_row, &my_col);
189  }
190 
191  auto release() {
192  if(context >= 0) blacs_gridexit(&context);
193  }
194 
211  template<data_t DT> auto copy_to(const DT* A, const IT* desc_a, DT* B, const IT* desc_b) {
212  // ReSharper disable CppCStyleCast
213  if(std::is_same_v<DT, double>) {
214  using E = double;
215  pdgemr2d(desc_a + 2, desc_a + 3, (E*)A, &ONE, &ONE, desc_a, (E*)B, &ONE, &ONE, desc_b, &context);
216  }
217  else if(std::is_same_v<DT, float>) {
218  using E = float;
219  psgemr2d(desc_a + 2, desc_a + 3, (E*)A, &ONE, &ONE, desc_a, (E*)B, &ONE, &ONE, desc_b, &context);
220  }
221  else if(std::is_same_v<DT, complex16>) {
222  using E = complex16;
223  pzgemr2d(desc_a + 2, desc_a + 3, (E*)A, &ONE, &ONE, desc_a, (E*)B, &ONE, &ONE, desc_b, &context);
224  }
225  else if(std::is_same_v<DT, complex8>) {
226  using E = complex8;
227  pcgemr2d(desc_a + 2, desc_a + 3, (E*)A, &ONE, &ONE, desc_a, (E*)B, &ONE, &ONE, desc_b, &context);
228  }
229  // ReSharper restore CppCStyleCast
230  }
231 
232  public:
233  IT n_rows, n_cols, context{-1}, rank{-1}, size{-1}, my_row{-1}, my_col{-1};
234 
235  blacs_context()
236  : blacs_context(get_env<IT>().size(), 1) {}
237 
238  explicit blacs_context(const char order)
239  : layout(order)
240  , n_rows(-1)
241  , n_cols(-1) {
242  const auto& env = get_env<IT>();
243  n_rows = std::max(IT{1}, static_cast<IT>(std::sqrt(env.size())));
244  n_cols = env.size() / n_rows;
245  init();
246  }
247 
248  blacs_context(const IT rows, const IT cols, const char order = 'R')
249  : layout(order)
250  , n_rows(rows)
251  , n_cols(cols) { init(); }
252 
253  blacs_context(const blacs_context& other)
254  : layout(other.layout)
255  , n_rows(other.n_rows)
256  , n_cols(other.n_cols) { init(); }
257 
258  blacs_context(blacs_context&&) noexcept = delete;
259  blacs_context& operator=(const blacs_context&) = delete;
260  blacs_context& operator=(blacs_context&&) noexcept = delete;
261 
262  ~blacs_context() { release(); }
263 
275  auto desc_g(const IT num_rows, const IT num_cols) {
276  desc<IT> desc_t{};
277 
278  descinit(desc_t.data(), &num_rows, &num_cols, &num_rows, &num_cols, &ZERO, &ZERO, &context, &num_rows, &info);
279 
280  return desc_t;
281  }
282 
297  auto desc_l(const IT num_rows, const IT num_cols, const IT row_block, const IT col_block, const IT lead) {
298  desc<IT> desc_t{};
299 
300  const auto loc_lead = std::max(IT{1}, lead);
301  descinit(desc_t.data(), &num_rows, &num_cols, &row_block, &col_block, &ZERO, &ZERO, &context, &loc_lead, &info);
302 
303  return desc_t;
304  }
305 
306  auto desc_l(const IT num_rows, const IT num_cols, const IT block, const IT lead) { return desc_l(num_rows, num_cols, block, block, lead); }
307 
308  template<data_t DT> auto scatter(const full_mat<DT, IT>& A, const desc<IT>& desc_a, std::vector<DT>& B, const desc<IT>& desc_b) {
309  if(!A.distributed) return copy_to(A.data, desc_a.data(), B.data(), desc_b.data());
310 
311  for(auto i = 0u; i < B.size(); ++i) B[i] = A.data[i];
312  }
313 
314  template<data_t DT> auto gather(const std::vector<DT>& A, const desc<IT>& desc_a, const full_mat<DT, IT>& B, const desc<IT>& desc_b) {
315  if(!B.distributed) return copy_to(A.data(), desc_a.data(), B.data, desc_b.data());
316 
317  for(auto i = 0u; i < A.size(); ++i) B.data[i] = A[i];
318  }
319 
320  auto copy_to(const IT* A, const IT* desc_a, IT* B, const IT* desc_b) { pigemr2d(desc_a + 2, desc_a + 3, A, &ONE, &ONE, desc_a, B, &ONE, &ONE, desc_b, &context); }
321 
322  [[nodiscard]] bool is_valid() const { return my_row >= 0 && my_col >= 0; }
323 
327  auto row_block(const IT n) const { return std::max(IT{1}, n / n_rows); }
328 
332  auto col_block(const IT n) const { return std::max(IT{1}, n / n_cols); }
333 
344  auto rows(const IT n, const IT nb) const { return numroc(&n, &nb, &my_row, &ZERO, &n_rows); }
345 
356  auto cols(const IT n, const IT nb) const { return numroc(&n, &nb, &my_col, &ZERO, &n_cols); }
357 
371  IT amx(IT number) const {
372  igamx2d(&context, &SCOPE, &TOP, &ONE, &ONE, &number, &ONE, nullptr, nullptr, &NEGONE, &NEGONE, &NEGONE);
373  return number;
374  }
375 
389  IT amn(IT number) const {
390  igamn2d(&context, &SCOPE, &TOP, &ONE, &ONE, &number, &ONE, nullptr, nullptr, &NEGONE, &NEGONE, &NEGONE);
391  return number;
392  }
393  };
394 } // namespace ezp
395 
396 #endif // TRAITS_HPP
Definition: traits.hpp:116
Definition: traits.hpp:176
IT amn(IT number) const
Perform the global amn operation.
Definition: traits.hpp:389
auto desc_l(const IT num_rows, const IT num_cols, const IT row_block, const IT col_block, const IT lead)
Generates a descriptor for a local matrix.
Definition: traits.hpp:297
auto rows(const IT n, const IT nb) const
Computes the number of local rows of the current process.
Definition: traits.hpp:344
auto col_block(const IT n) const
Computes the column block size.
Definition: traits.hpp:332
auto row_block(const IT n) const
Computes the row block size.
Definition: traits.hpp:327
IT amx(IT number) const
Perform the global amx operation.
Definition: traits.hpp:371
auto cols(const IT n, const IT nb) const
Computes the number of local columns of the current process.
Definition: traits.hpp:356
auto desc_g(const IT num_rows, const IT num_cols)
Generates a descriptor for a global matrix.
Definition: traits.hpp:275
Definition: traits.hpp:127
static void do_not_manage_mpi()
Disables the management of MPI (Message Passing Interface) finalization.
Definition: traits.hpp:151
Definition: traits.hpp:51
Definition: traits.hpp:89
Definition: traits.hpp:100
Definition: traits.hpp:68
Definition: traits.hpp:85