suanPan
Loading...
Searching...
No Matches
FullMatCUDA.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2024 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef FULLMATCUDA_HPP
31#define FULLMATCUDA_HPP
32
33#ifdef SUANPAN_CUDA
34
35#include <cuda_runtime.h>
36#include <cusolverDn.h>
37#include "FullMat.hpp"
38
39template<sp_d T> class FullMatCUDA final : public FullMat<T> {
40 cusolverDnHandle_t handle = nullptr;
41 cudaStream_t stream = nullptr;
42
43 int* info = nullptr;
44 int* ipiv = nullptr;
45 void* d_A = nullptr;
46 void* buffer = nullptr;
47
48 void acquire() {
49 cusolverDnCreate(&handle);
50 cudaStreamCreate(&stream);
51 cusolverDnSetStream(handle, stream);
52
53 cudaMalloc(&info, sizeof(int));
54 cudaMemset(info, 0, sizeof(int));
55 cudaMalloc(&ipiv, sizeof(int) * this->n_rows);
56
57 int bufferSize = 0;
58 if constexpr(std::is_same_v<T, float>) {
59 cudaMalloc(&d_A, sizeof(float) * this->n_elem);
60 cusolverDnSgetrf_bufferSize(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (float*)d_A, static_cast<int>(this->n_elem), &bufferSize);
61 cudaMalloc(&buffer, sizeof(float) * bufferSize);
62 }
63 else if(Precision::MIXED == this->setting.precision) {
64 cudaMalloc(&d_A, sizeof(float) * this->n_elem);
65 cusolverDnSgetrf_bufferSize(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (float*)d_A, static_cast<int>(this->n_elem), &bufferSize);
66 cudaMalloc(&buffer, sizeof(float) * bufferSize);
67 }
68 else {
69 cudaMalloc(&d_A, sizeof(double) * this->n_elem);
70 cusolverDnDgetrf_bufferSize(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (double*)d_A, static_cast<int>(this->n_elem), &bufferSize);
71 cudaMalloc(&buffer, sizeof(double) * bufferSize);
72 }
73 }
74
75 void release() const {
76 if(handle) cusolverDnDestroy(handle);
77 if(stream) cudaStreamDestroy(stream);
78
79 if(info) cudaFree(info);
80 if(d_A) cudaFree(d_A);
81 if(buffer) cudaFree(buffer);
82 if(ipiv) cudaFree(ipiv);
83 }
84
85protected:
86 int direct_solve(Mat<T>& X, Mat<T>&& B) override { return this->direct_solve(X, B); }
87
88 int direct_solve(Mat<T>&, const Mat<T>&) override;
89
90public:
91 FullMatCUDA(const uword in_rows, const uword in_cols)
92 : FullMat<T>(in_rows, in_cols) { acquire(); }
93
94 FullMatCUDA(const FullMatCUDA& other)
95 : FullMat<T>(other) { acquire(); }
96
97 FullMatCUDA(FullMatCUDA&&) noexcept = delete;
98 FullMatCUDA& operator=(const FullMatCUDA&) = delete;
99 FullMatCUDA& operator=(FullMatCUDA&&) noexcept = delete;
100
101 ~FullMatCUDA() override { release(); }
102
103 unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<FullMatCUDA>(*this); }
104};
105
106template<sp_d T> int FullMatCUDA<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
107 if constexpr(std::is_same_v<T, float>) {
108 // pure float
109 if(!this->factored) {
110 this->factored = true;
111 cudaMemcpyAsync(d_A, this->memptr(), sizeof(float) * this->n_elem, cudaMemcpyHostToDevice, stream);
112 cusolverDnSgetrf(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (float*)d_A, static_cast<int>(this->n_rows), (float*)buffer, ipiv, info);
113 }
114
115 const size_t byte_size = sizeof(float) * B.n_elem;
116
117 void* d_x = nullptr;
118 cudaMalloc(&d_x, byte_size);
119 cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
120 cusolverDnSgetrs(handle, CUBLAS_OP_N, static_cast<int>(this->n_rows), static_cast<int>(B.n_cols), (float*)d_A, static_cast<int>(this->n_rows), ipiv, (float*)d_x, static_cast<int>(this->n_rows), info);
121
122 X.set_size(arma::size(B));
123
124 cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
125
126 cudaDeviceSynchronize();
127
128 if(d_x) cudaFree(d_x);
129 }
130 else if(Precision::MIXED == this->setting.precision) {
131 // mixed precision
132 if(!this->factored) {
133 this->factored = true;
134 this->s_memory = this->to_float();
135 cudaMemcpyAsync(d_A, this->s_memory.memptr(), sizeof(float) * this->s_memory.n_elem, cudaMemcpyHostToDevice, stream);
136 cusolverDnSgetrf(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (float*)d_A, static_cast<int>(this->n_rows), (float*)buffer, ipiv, info);
137 }
138
139 const size_t byte_size = sizeof(float) * B.n_elem;
140
141 void* d_x = nullptr;
142 cudaMalloc(&d_x, byte_size);
143
144 X = arma::zeros(B.n_rows, B.n_cols);
145
146 mat full_residual = B;
147
148 auto multiplier = norm(full_residual);
149
150 auto counter = 0u;
151 while(counter++ < this->setting.iterative_refinement) {
152 if(multiplier < this->setting.tolerance) break;
153
154 auto residual = conv_to<fmat>::from(full_residual / multiplier);
155
156 cudaMemcpyAsync(d_x, residual.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
157 cusolverDnSgetrs(handle, CUBLAS_OP_N, static_cast<int>(this->n_rows), static_cast<int>(B.n_cols), (float*)d_A, static_cast<int>(this->n_rows), ipiv, (float*)d_x, static_cast<int>(this->n_rows), info);
158 cudaMemcpyAsync(residual.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
159
160 cudaDeviceSynchronize();
161
162 const mat incre = multiplier * conv_to<mat>::from(residual);
163
164 X += incre;
165
166 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
167 }
168
169 if(d_x) cudaFree(d_x);
170 }
171 else {
172 // pure double
173 if(!this->factored) {
174 this->factored = true;
175 cudaMemcpyAsync(d_A, this->memptr(), sizeof(double) * this->n_elem, cudaMemcpyHostToDevice, stream);
176 cusolverDnDgetrf(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (double*)d_A, static_cast<int>(this->n_rows), (double*)buffer, ipiv, info);
177 }
178
179 const size_t byte_size = sizeof(double) * B.n_elem;
180
181 void* d_x = nullptr;
182 cudaMalloc(&d_x, byte_size);
183 cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
184 cusolverDnDgetrs(handle, CUBLAS_OP_N, static_cast<int>(this->n_rows), static_cast<int>(B.n_cols), (double*)d_A, static_cast<int>(this->n_rows), ipiv, (double*)d_x, static_cast<int>(this->n_rows), info);
185
186 X.set_size(arma::size(B));
187
188 cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
189
190 cudaDeviceSynchronize();
191
192 if(d_x) cudaFree(d_x);
193 }
194
195 return 0;
196}
197
198#endif
199
200#endif
201
A FullMatCUDA class that holds matrices.
A FullMat class that holds matrices.
Definition FullMat.hpp:35
unique_ptr< Material > make_copy(const shared_ptr< Material > &)
Definition Material.cpp:370
void info(const std::string_view format_str, const T &... args)
Definition suanPan.h:249
double norm(const vec &)
Definition tensor.cpp:370
#define suanpan_debug(...)
Definition suanPan.h:307