ezp
lightweight C++ wrapper for selected distributed solvers for linear systems
Loading...
Searching...
No Matches
traits.hpp
1/*******************************************************************************
2 * Copyright (C) 2025-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
17
18#ifndef TRAITS_HPP
19#define TRAITS_HPP
20
21#include "ezp.h"
22
23#include <array>
24#include <atomic>
25#include <cmath>
26#include <ranges>
27#include <vector>
28
29namespace ezp {
30 template<typename T> constexpr auto always_false_v = false;
31
32 template<typename T> concept floating_t = std::is_same_v<T, float> || std::is_same_v<T, double>;
33 template<typename T> concept complex_t = std::is_same_v<T, std::complex<typename T::value_type>> && floating_t<typename T::value_type>;
34 template<typename T> concept data_t = floating_t<T> || complex_t<T>;
35 template<typename T> concept index_t = std::is_same_v<T, std::int32_t> || std::is_same_v<T, std::int64_t>;
36
37 template<typename T> concept has_mem = requires(T t) { requires data_t<std::remove_pointer_t<decltype(t.mem())>>; };
38 template<typename T> concept has_memptr = requires(T t) { requires data_t<std::remove_pointer_t<decltype(t.memptr())>>; };
39 template<typename T> concept has_data_method = requires(T t) { requires data_t<std::remove_pointer_t<decltype(t.data())>>; };
40 template<typename T> concept has_data_member = requires(T t) { requires data_t<std::remove_pointer_t<decltype(t.data)>>; };
41 template<typename T> concept has_iterator = requires(T t) {
42 requires std::ranges::contiguous_range<T>;
43 requires data_t<std::remove_reference_t<decltype(*t.begin())>>;
44 };
46
47 template<typename T> concept full_container_t = requires(T t) { requires index_t<decltype(t.n_rows)> && index_t<decltype(t.n_cols)> && has_data_pointer<T>; };
48 template<typename T> concept band_container_t = requires(T t) { requires index_t<decltype(t.kl)> && index_t<decltype(t.ku)> && full_container_t<T>; };
49 template<typename T> concept band_symm_container_t = requires(T t) { requires index_t<decltype(t.klu)> && full_container_t<T>; };
50
51 template<typename T> struct WorkType {
52 static_assert(!std::is_same_v<T, T>, "not defined");
53 };
54 template<> struct WorkType<double> {
55 using type = double;
56 };
57 template<> struct WorkType<float> {
58 using type = float;
59 };
60 template<> struct WorkType<complex8> {
61 using type = float;
62 };
63 template<> struct WorkType<complex16> {
64 using type = double;
65 };
66 template<typename T> using work_t = typename WorkType<T>::type;
67
68 template<data_t DT, index_t IT> struct base_mat {
69 using data_type = DT;
70 using index_type = IT;
71
72 IT n_rows, n_cols;
73 DT* data;
74 bool distributed{};
75
76 base_mat() = default;
77
78 base_mat(const IT rows, const IT cols, DT* const ptr, const bool dist = false)
79 : n_rows(rows)
80 , n_cols(cols)
81 , data(ptr)
82 , distributed(dist) {}
83 };
84
85 template<data_t DT, index_t IT> struct full_mat : base_mat<DT, IT> {
86 using base_mat<DT, IT>::base_mat;
87 };
88
89 template<data_t DT, index_t IT> struct band_mat : base_mat<DT, IT> {
90 IT kl, ku;
91
92 band_mat() = default;
93
94 band_mat(const IT rows, const IT cols, const IT kl, const IT ku, DT* const ptr, const bool dist = false)
95 : base_mat<DT, IT>(rows, cols, ptr, dist)
96 , kl(kl)
97 , ku(ku) {}
98 };
99
100 template<data_t DT, index_t IT> struct band_symm_mat : base_mat<DT, IT> {
101 IT klu;
102
103 band_symm_mat() = default;
104
105 band_symm_mat(const IT rows, const IT cols, const IT klu, DT* const ptr, const bool dist = false)
106 : base_mat<DT, IT>(rows, cols, ptr, dist)
107 , klu(klu) {}
108 };
109
110 template<typename T> concept wrapper_t = requires(T t) {
111 requires std::is_same_v<T, full_mat<typename T::data_type, typename T::index_type>> || std::is_same_v<T, band_mat<typename T::data_type, typename T::index_type>> || std::is_same_v<T, band_symm_mat<typename T::data_type, typename T::index_type>>;
112 };
113
114 template<index_t IT> using desc = std::array<IT, 9>;
115
117#ifdef EZP_RELEASE_ONCE
118 protected:
119 static std::atomic_bool RELEASED;
120#endif
121 };
122
123#ifdef EZP_RELEASE_ONCE
124#define EZP_ENSURE_SAFE_EXIT std::atomic_bool ezp::blacs_base::RELEASED = false;
125#endif
126
127 template<index_t IT = int_t> class blacs_env final : blacs_base {
128 static constexpr IT ZERO{0}, ONE{1};
129
130 static std::atomic_bool FINALIZE;
131
132 IT _rank{-1}, _size{-1};
133
134 public:
135 blacs_env() { blacs_pinfo(&_rank, &_size); }
136
137 ~blacs_env() {
138#ifdef EZP_RELEASE_ONCE
139 if(RELEASED.exchange(true)) return;
140#endif
141 blacs_exit(FINALIZE ? &ZERO : &ONE);
142 }
143
151 static void do_not_manage_mpi() { FINALIZE = false; }
152
153 auto rank() const { return _rank; }
154 auto size() const { return _size; }
155 };
156
157 template<index_t IT> std::atomic_bool blacs_env<IT>::FINALIZE{true};
158
170 template<index_t IT = int_t> const auto& get_env() {
171 static const blacs_env<IT> scoped_env;
172
173 return scoped_env;
174 }
175
176 template<index_t IT> class blacs_context final {
177 static constexpr IT ZERO{0}, ONE{1}, NEGONE{-1};
178 static constexpr char SCOPE = 'A', TOP = ' ';
179
180 char layout{'R'};
181
182 IT info{-1};
183
184 auto init() {
185 blacs_get(&NEGONE, &ZERO, &context);
186 blacs_gridinit(&context, &layout, &n_rows, &n_cols);
187 blacs_pinfo(&rank, &size);
188 blacs_gridinfo(&context, &n_rows, &n_cols, &my_row, &my_col);
189 }
190
191 auto release() {
192 if(context >= 0) blacs_gridexit(&context);
193 }
194
211 template<data_t DT> auto copy_to(const DT* A, const IT* desc_a, DT* B, const IT* desc_b) {
212 // ReSharper disable CppCStyleCast
213 if(std::is_same_v<DT, double>) {
214 using E = double;
215 pdgemr2d(desc_a + 2, desc_a + 3, (E*)A, &ONE, &ONE, desc_a, (E*)B, &ONE, &ONE, desc_b, &context);
216 }
217 else if(std::is_same_v<DT, float>) {
218 using E = float;
219 psgemr2d(desc_a + 2, desc_a + 3, (E*)A, &ONE, &ONE, desc_a, (E*)B, &ONE, &ONE, desc_b, &context);
220 }
221 else if(std::is_same_v<DT, complex16>) {
222 using E = complex16;
223 pzgemr2d(desc_a + 2, desc_a + 3, (E*)A, &ONE, &ONE, desc_a, (E*)B, &ONE, &ONE, desc_b, &context);
224 }
225 else if(std::is_same_v<DT, complex8>) {
226 using E = complex8;
227 pcgemr2d(desc_a + 2, desc_a + 3, (E*)A, &ONE, &ONE, desc_a, (E*)B, &ONE, &ONE, desc_b, &context);
228 }
229 // ReSharper restore CppCStyleCast
230 }
231
232 public:
233 IT n_rows, n_cols, context{-1}, rank{-1}, size{-1}, my_row{-1}, my_col{-1};
234
236 : blacs_context(get_env<IT>().size(), 1) {}
237
238 explicit blacs_context(const char order)
239 : layout(order)
240 , n_rows(-1)
241 , n_cols(-1) {
242 const auto& env = get_env<IT>();
243 n_rows = std::max(IT{1}, static_cast<IT>(std::sqrt(env.size())));
244 n_cols = env.size() / n_rows;
245 init();
246 }
247
248 blacs_context(const IT rows, const IT cols, const char order = 'R')
249 : layout(order)
250 , n_rows(rows)
251 , n_cols(cols) { init(); }
252
253 blacs_context(const blacs_context& other)
254 : layout(other.layout)
255 , n_rows(other.n_rows)
256 , n_cols(other.n_cols) { init(); }
257
258 blacs_context(blacs_context&&) noexcept = delete;
259 blacs_context& operator=(const blacs_context&) = delete;
260 blacs_context& operator=(blacs_context&&) noexcept = delete;
261
262 ~blacs_context() { release(); }
263
275 auto desc_g(const IT num_rows, const IT num_cols) {
276 desc<IT> desc_t{};
277
278 descinit(desc_t.data(), &num_rows, &num_cols, &num_rows, &num_cols, &ZERO, &ZERO, &context, &num_rows, &info);
279
280 return desc_t;
281 }
282
297 auto desc_l(const IT num_rows, const IT num_cols, const IT row_block, const IT col_block, const IT lead) {
298 desc<IT> desc_t{};
299
300 const auto loc_lead = std::max(IT{1}, lead);
301 descinit(desc_t.data(), &num_rows, &num_cols, &row_block, &col_block, &ZERO, &ZERO, &context, &loc_lead, &info);
302
303 return desc_t;
304 }
305
306 auto desc_l(const IT num_rows, const IT num_cols, const IT block, const IT lead) { return desc_l(num_rows, num_cols, block, block, lead); }
307
308 template<data_t DT> auto scatter(const full_mat<DT, IT>& A, const desc<IT>& desc_a, std::vector<DT>& B, const desc<IT>& desc_b) {
309 if(!A.distributed) return copy_to(A.data, desc_a.data(), B.data(), desc_b.data());
310
311 for(auto i = 0u; i < B.size(); ++i) B[i] = A.data[i];
312 }
313
314 template<data_t DT> auto gather(const std::vector<DT>& A, const desc<IT>& desc_a, const full_mat<DT, IT>& B, const desc<IT>& desc_b) {
315 if(!B.distributed) return copy_to(A.data(), desc_a.data(), B.data, desc_b.data());
316
317 for(auto i = 0u; i < A.size(); ++i) B.data[i] = A[i];
318 }
319
320 auto copy_to(const IT* A, const IT* desc_a, IT* B, const IT* desc_b) { pigemr2d(desc_a + 2, desc_a + 3, A, &ONE, &ONE, desc_a, B, &ONE, &ONE, desc_b, &context); }
321
322 [[nodiscard]] bool is_valid() const { return my_row >= 0 && my_col >= 0; }
323
327 auto row_block(const IT n) const { return std::max(IT{1}, n / n_rows); }
328
332 auto col_block(const IT n) const { return std::max(IT{1}, n / n_cols); }
333
344 auto rows(const IT n, const IT nb) const { return numroc(&n, &nb, &my_row, &ZERO, &n_rows); }
345
356 auto cols(const IT n, const IT nb) const { return numroc(&n, &nb, &my_col, &ZERO, &n_cols); }
357
371 IT amx(IT number) const {
372 igamx2d(&context, &SCOPE, &TOP, &ONE, &ONE, &number, &ONE, nullptr, nullptr, &NEGONE, &NEGONE, &NEGONE);
373 return number;
374 }
375
389 IT amn(IT number) const {
390 igamn2d(&context, &SCOPE, &TOP, &ONE, &ONE, &number, &ONE, nullptr, nullptr, &NEGONE, &NEGONE, &NEGONE);
391 return number;
392 }
393 };
394} // namespace ezp
395
396#endif // TRAITS_HPP
Definition traits.hpp:116
Definition traits.hpp:176
IT amn(IT number) const
Perform the global amn operation.
Definition traits.hpp:389
auto desc_l(const IT num_rows, const IT num_cols, const IT row_block, const IT col_block, const IT lead)
Generates a descriptor for a local matrix.
Definition traits.hpp:297
auto rows(const IT n, const IT nb) const
Computes the number of local rows of the current process.
Definition traits.hpp:344
auto col_block(const IT n) const
Computes the column block size.
Definition traits.hpp:332
auto row_block(const IT n) const
Computes the row block size.
Definition traits.hpp:327
IT amx(IT number) const
Perform the global amx operation.
Definition traits.hpp:371
auto cols(const IT n, const IT nb) const
Computes the number of local columns of the current process.
Definition traits.hpp:356
auto desc_g(const IT num_rows, const IT num_cols)
Generates a descriptor for a global matrix.
Definition traits.hpp:275
Definition traits.hpp:127
static void do_not_manage_mpi()
Disables the management of MPI (Message Passing Interface) finalization.
Definition traits.hpp:151
Definition traits.hpp:48
Definition traits.hpp:49
Definition traits.hpp:33
Definition traits.hpp:34
Definition traits.hpp:32
Definition traits.hpp:47
Definition traits.hpp:40
Definition traits.hpp:39
Definition traits.hpp:45
Definition traits.hpp:41
Definition traits.hpp:37
Definition traits.hpp:38
Definition traits.hpp:35
Definition traits.hpp:110
Definition traits.hpp:51
Definition traits.hpp:89
Definition traits.hpp:100
Definition traits.hpp:68
Definition traits.hpp:85