30#ifndef FULLMATCUDA_HPP
31#define FULLMATCUDA_HPP
35#include <cuda_runtime.h>
36#include <cusolverDn.h>
40 cusolverDnHandle_t handle =
nullptr;
41 cudaStream_t stream =
nullptr;
46 void* buffer =
nullptr;
49 cusolverDnCreate(&handle);
50 cudaStreamCreate(&stream);
51 cusolverDnSetStream(handle, stream);
53 cudaMalloc(&info,
sizeof(
int));
54 cudaMemset(info, 0,
sizeof(
int));
55 cudaMalloc(&ipiv,
sizeof(
int) * this->n_rows);
58 if constexpr(std::is_same_v<T, float>) {
59 cudaMalloc(&d_A,
sizeof(
float) * this->n_elem);
60 cusolverDnSgetrf_bufferSize(handle,
static_cast<int>(this->n_rows),
static_cast<int>(this->n_cols), (
float*)d_A,
static_cast<int>(this->n_elem), &bufferSize);
61 cudaMalloc(&buffer,
sizeof(
float) * bufferSize);
64 cudaMalloc(&d_A,
sizeof(
float) * this->n_elem);
65 cusolverDnSgetrf_bufferSize(handle,
static_cast<int>(this->n_rows),
static_cast<int>(this->n_cols), (
float*)d_A,
static_cast<int>(this->n_elem), &bufferSize);
66 cudaMalloc(&buffer,
sizeof(
float) * bufferSize);
69 cudaMalloc(&d_A,
sizeof(
double) * this->n_elem);
70 cusolverDnDgetrf_bufferSize(handle,
static_cast<int>(this->n_rows),
static_cast<int>(this->n_cols), (
double*)d_A,
static_cast<int>(this->n_elem), &bufferSize);
71 cudaMalloc(&buffer,
sizeof(
double) * bufferSize);
75 void release()
const {
76 if(handle) cusolverDnDestroy(handle);
77 if(stream) cudaStreamDestroy(stream);
79 if(info) cudaFree(info);
80 if(d_A) cudaFree(d_A);
81 if(buffer) cudaFree(buffer);
82 if(ipiv) cudaFree(ipiv);
86 int direct_solve(Mat<T>& X, Mat<T>&& B)
override {
return this->direct_solve(X, B); }
88 int direct_solve(Mat<T>&,
const Mat<T>&)
override;
91 FullMatCUDA(
const uword in_rows,
const uword in_cols)
92 :
FullMat<
T>(in_rows, in_cols) { acquire(); }
103 unique_ptr<MetaMat<T>>
make_copy()
override {
return std::make_unique<FullMatCUDA>(*
this); }
107 if constexpr(std::is_same_v<T, float>) {
109 if(!this->factored) {
110 this->factored =
true;
111 cudaMemcpyAsync(d_A, this->memptr(),
sizeof(
float) * this->n_elem, cudaMemcpyHostToDevice, stream);
112 cusolverDnSgetrf(handle,
static_cast<int>(this->n_rows),
static_cast<int>(this->n_cols), (
float*)d_A,
static_cast<int>(this->n_rows), (
float*)buffer, ipiv, info);
115 const size_t byte_size =
sizeof(float) * B.n_elem;
118 cudaMalloc(&d_x, byte_size);
119 cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
120 cusolverDnSgetrs(handle, CUBLAS_OP_N,
static_cast<int>(this->n_rows),
static_cast<int>(B.n_cols), (
float*)d_A,
static_cast<int>(this->n_rows), ipiv, (
float*)d_x,
static_cast<int>(this->n_rows), info);
122 X.set_size(arma::size(B));
124 cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
126 cudaDeviceSynchronize();
128 if(d_x) cudaFree(d_x);
132 if(!this->factored) {
133 this->factored =
true;
134 this->s_memory = this->to_float();
135 cudaMemcpyAsync(d_A, this->s_memory.memptr(),
sizeof(
float) * this->s_memory.n_elem, cudaMemcpyHostToDevice, stream);
136 cusolverDnSgetrf(handle,
static_cast<int>(this->n_rows),
static_cast<int>(this->n_cols), (
float*)d_A,
static_cast<int>(this->n_rows), (
float*)buffer, ipiv, info);
139 const size_t byte_size =
sizeof(float) * B.n_elem;
142 cudaMalloc(&d_x, byte_size);
144 X = arma::zeros(B.n_rows, B.n_cols);
146 mat full_residual = B;
148 auto multiplier =
norm(full_residual);
151 while(counter++ < this->setting.iterative_refinement) {
152 if(multiplier < this->setting.tolerance)
break;
154 auto residual = conv_to<fmat>::from(full_residual / multiplier);
156 cudaMemcpyAsync(d_x, residual.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
157 cusolverDnSgetrs(handle, CUBLAS_OP_N,
static_cast<int>(this->n_rows),
static_cast<int>(B.n_cols), (
float*)d_A,
static_cast<int>(this->n_rows), ipiv, (
float*)d_x,
static_cast<int>(this->n_rows), info);
158 cudaMemcpyAsync(residual.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
160 cudaDeviceSynchronize();
162 const mat incre = multiplier * conv_to<mat>::from(residual);
166 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->
operator*(incre)));
169 if(d_x) cudaFree(d_x);
173 if(!this->factored) {
174 this->factored =
true;
175 cudaMemcpyAsync(d_A, this->memptr(),
sizeof(
double) * this->n_elem, cudaMemcpyHostToDevice, stream);
176 cusolverDnDgetrf(handle,
static_cast<int>(this->n_rows),
static_cast<int>(this->n_cols), (
double*)d_A,
static_cast<int>(this->n_rows), (
double*)buffer, ipiv, info);
179 const size_t byte_size =
sizeof(double) * B.n_elem;
182 cudaMalloc(&d_x, byte_size);
183 cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
184 cusolverDnDgetrs(handle, CUBLAS_OP_N,
static_cast<int>(this->n_rows),
static_cast<int>(B.n_cols), (
double*)d_A,
static_cast<int>(this->n_rows), ipiv, (
double*)d_x,
static_cast<int>(this->n_rows), info);
186 X.set_size(arma::size(B));
188 cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
190 cudaDeviceSynchronize();
192 if(d_x) cudaFree(d_x);
A FullMatCUDA class that holds matrices.
A FullMat class that holds matrices.
Definition FullMat.hpp:35
unique_ptr< Material > make_copy(const shared_ptr< Material > &)
Definition Material.cpp:370
void info(const std::string_view format_str, const T &... args)
Definition suanPan.h:249
double norm(const vec &)
Definition tensor.cpp:370
#define suanpan_debug(...)
Definition suanPan.h:307