suanPan
Loading...
Searching...
No Matches
FullMatCUDA.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2023 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef FULLMATCUDA_HPP
31#define FULLMATCUDA_HPP
32
33#ifdef SUANPAN_CUDA
34
35#include <cuda_runtime.h>
36#include <cusolverDn.h>
37#include "FullMat.hpp"
38
39template<sp_d T> class FullMatCUDA final : public FullMat<T> {
40 cusolverDnHandle_t handle = nullptr;
41 cudaStream_t stream = nullptr;
42
43 int* info = nullptr;
44 int* ipiv = nullptr;
45 void* d_A = nullptr;
46 void* buffer = nullptr;
47
48 void acquire();
49 void release() const;
50
51public:
52 FullMatCUDA(uword, uword);
54 FullMatCUDA(FullMatCUDA&&) noexcept = delete;
55 FullMatCUDA& operator=(const FullMatCUDA&) = delete;
56 FullMatCUDA& operator=(FullMatCUDA&&) noexcept = delete;
57 ~FullMatCUDA() override;
58
59 unique_ptr<MetaMat<T>> make_copy() override;
60
61 int direct_solve(Mat<T>&, Mat<T>&&) override;
62 int direct_solve(Mat<T>&, const Mat<T>&) override;
63};
64
65template<sp_d T> void FullMatCUDA<T>::acquire() {
66 cusolverDnCreate(&handle);
67 cudaStreamCreate(&stream);
68 cusolverDnSetStream(handle, stream);
69
70 cudaMalloc(&info, sizeof(int));
71 cudaMemset(info, 0, sizeof(int));
72 cudaMalloc(&ipiv, sizeof(int) * this->n_rows);
73
74 if(int bufferSize = 0; std::is_same_v<T, float> || Precision::MIXED == this->setting.precision) {
75 cudaMalloc(&d_A, sizeof(float) * this->n_elem);
76 cusolverDnSgetrf_bufferSize(handle, int(this->n_rows), int(this->n_cols), (float*)d_A, int(this->n_elem), &bufferSize);
77 cudaMalloc(&buffer, sizeof(float) * bufferSize);
78 }
79 else {
80 cudaMalloc(&d_A, sizeof(double) * this->n_elem);
81 cusolverDnDgetrf_bufferSize(handle, int(this->n_rows), int(this->n_cols), (double*)d_A, int(this->n_elem), &bufferSize);
82 cudaMalloc(&buffer, sizeof(double) * bufferSize);
83 }
84}
85
86template<sp_d T> void FullMatCUDA<T>::release() const {
87 if(handle) cusolverDnDestroy(handle);
88 if(stream) cudaStreamDestroy(stream);
89
90 if(info) cudaFree(info);
91 if(d_A) cudaFree(d_A);
92 if(buffer) cudaFree(buffer);
93 if(ipiv) cudaFree(ipiv);
94}
95
96template<sp_d T> FullMatCUDA<T>::FullMatCUDA(const uword in_rows, const uword in_cols)
97 : FullMat<T>(in_rows, in_cols) { acquire(); }
98
99template<sp_d T> FullMatCUDA<T>::FullMatCUDA(const FullMatCUDA& other)
100 : FullMat<T>(other) { acquire(); }
101
102template<sp_d T> FullMatCUDA<T>::~FullMatCUDA() { release(); }
103
104template<sp_d T> unique_ptr<MetaMat<T>> FullMatCUDA<T>::make_copy() { return make_unique<FullMatCUDA<T>>(*this); }
105
106template<sp_d T> int FullMatCUDA<T>::direct_solve(Mat<T>& X, Mat<T>&& B) { return direct_solve(X, B); }
107
108template<sp_d T> int FullMatCUDA<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
109 if(std::is_same_v<T, float>) {
110 // pure float
111 if(!this->factored) {
112 cudaMemcpyAsync(d_A, this->memptr(), sizeof(float) * this->n_elem, cudaMemcpyHostToDevice, stream);
113 cusolverDnSgetrf(handle, int(this->n_rows), int(this->n_cols), (float*)d_A, int(this->n_rows), (float*)buffer, ipiv, info);
114
115 this->factored = true;
116 }
117
118 const size_t byte_size = sizeof(float) * B.n_elem;
119
120 void* d_x = nullptr;
121 cudaMalloc(&d_x, byte_size);
122 cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
123 cusolverDnSgetrs(handle, CUBLAS_OP_N, int(this->n_rows), int(B.n_cols), (float*)d_A, int(this->n_rows), ipiv, (float*)d_x, int(this->n_rows), info);
124
125 X.set_size(arma::size(B));
126
127 cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
128
129 cudaDeviceSynchronize();
130
131 if(d_x) cudaFree(d_x);
132 }
133 else if(Precision::MIXED == this->setting.precision) {
134 // mixed precision
135 if(!this->factored) {
136 this->s_memory = this->to_float();
137
138 cudaMemcpyAsync(d_A, this->s_memory.memptr(), sizeof(float) * this->s_memory.n_elem, cudaMemcpyHostToDevice, stream);
139 cusolverDnSgetrf(handle, int(this->n_rows), int(this->n_cols), (float*)d_A, int(this->n_rows), (float*)buffer, ipiv, info);
140
141 this->factored = true;
142 }
143
144 const size_t byte_size = sizeof(float) * B.n_elem;
145
146 void* d_x = nullptr;
147 cudaMalloc(&d_x, byte_size);
148
149 X = arma::zeros(B.n_rows, B.n_cols);
150
151 mat full_residual = B;
152
153 auto multiplier = norm(full_residual);
154
155 auto counter = 0u;
156 while(counter++ < this->setting.iterative_refinement) {
157 if(multiplier < this->setting.tolerance) break;
158
159 auto residual = conv_to<fmat>::from(full_residual / multiplier);
160
161 cudaMemcpyAsync(d_x, residual.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
162 cusolverDnSgetrs(handle, CUBLAS_OP_N, int(this->n_rows), int(B.n_cols), (float*)d_A, int(this->n_rows), ipiv, (float*)d_x, int(this->n_rows), info);
163 cudaMemcpyAsync(residual.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
164
165 cudaDeviceSynchronize();
166
167 const mat incre = multiplier * conv_to<mat>::from(residual);
168
169 X += incre;
170
171 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
172 }
173
174 if(d_x) cudaFree(d_x);
175 }
176 else {
177 // pure double
178 if(!this->factored) {
179 cudaMemcpyAsync(d_A, this->memptr(), sizeof(double) * this->n_elem, cudaMemcpyHostToDevice, stream);
180 cusolverDnDgetrf(handle, int(this->n_rows), int(this->n_cols), (double*)d_A, int(this->n_rows), (double*)buffer, ipiv, info);
181
182 this->factored = true;
183 }
184
185 const size_t byte_size = sizeof(double) * B.n_elem;
186
187 void* d_x = nullptr;
188 cudaMalloc(&d_x, byte_size);
189 cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
190 cusolverDnDgetrs(handle, CUBLAS_OP_N, int(this->n_rows), int(B.n_cols), (double*)d_A, int(this->n_rows), ipiv, (double*)d_x, int(this->n_rows), info);
191
192 X.set_size(arma::size(B));
193
194 cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
195
196 cudaDeviceSynchronize();
197
198 if(d_x) cudaFree(d_x);
199 }
200
201 return 0;
202}
203
204#endif
205
206#endif
207
A FullMatCUDA class that holds matrices.
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
A MetaMat class that holds matrices.
Definition: MetaMat.hpp:39
Definition: suanPan.h:318
void info(const std::string_view format_str, const T &... args)
Definition: suanPan.h:237
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295