30#ifndef FULLMATCUDA_HPP
31#define FULLMATCUDA_HPP
35#include <cuda_runtime.h>
36#include <cusolverDn.h>
40 cusolverDnHandle_t handle =
nullptr;
41 cudaStream_t stream =
nullptr;
46 void* buffer =
nullptr;
59 unique_ptr<
MetaMat<
T>> make_copy() override;
61 int direct_solve(Mat<
T>&, Mat<
T>&&) override;
62 int direct_solve(Mat<
T>&, const Mat<
T>&) override;
66 cusolverDnCreate(&handle);
67 cudaStreamCreate(&stream);
68 cusolverDnSetStream(handle, stream);
70 cudaMalloc(&info,
sizeof(
int));
71 cudaMemset(info, 0,
sizeof(
int));
72 cudaMalloc(&ipiv,
sizeof(
int) * this->n_rows);
74 if(
int bufferSize = 0; std::is_same_v<T, float> || Precision::MIXED == this->setting.precision) {
75 cudaMalloc(&d_A,
sizeof(
float) * this->n_elem);
76 cusolverDnSgetrf_bufferSize(handle,
int(this->n_rows),
int(this->n_cols), (
float*)d_A,
int(this->n_elem), &bufferSize);
77 cudaMalloc(&buffer,
sizeof(
float) * bufferSize);
80 cudaMalloc(&d_A,
sizeof(
double) * this->n_elem);
81 cusolverDnDgetrf_bufferSize(handle,
int(this->n_rows),
int(this->n_cols), (
double*)d_A,
int(this->n_elem), &bufferSize);
82 cudaMalloc(&buffer,
sizeof(
double) * bufferSize);
87 if(handle) cusolverDnDestroy(handle);
88 if(stream) cudaStreamDestroy(stream);
90 if(info) cudaFree(info);
91 if(d_A) cudaFree(d_A);
92 if(buffer) cudaFree(buffer);
93 if(ipiv) cudaFree(ipiv);
97 :
FullMat<
T>(in_rows, in_cols) { acquire(); }
109 if(std::is_same_v<T, float>) {
111 if(!this->factored) {
112 cudaMemcpyAsync(d_A, this->memptr(),
sizeof(
float) * this->n_elem, cudaMemcpyHostToDevice, stream);
113 cusolverDnSgetrf(handle,
int(this->n_rows),
int(this->n_cols), (
float*)d_A,
int(this->n_rows), (
float*)buffer, ipiv, info);
115 this->factored =
true;
118 const size_t byte_size =
sizeof(float) * B.n_elem;
121 cudaMalloc(&d_x, byte_size);
122 cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
123 cusolverDnSgetrs(handle, CUBLAS_OP_N,
int(this->n_rows),
int(B.n_cols), (
float*)d_A,
int(this->n_rows), ipiv, (
float*)d_x,
int(this->n_rows), info);
125 X.set_size(arma::size(B));
127 cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
129 cudaDeviceSynchronize();
131 if(d_x) cudaFree(d_x);
135 if(!this->factored) {
136 this->s_memory = this->to_float();
138 cudaMemcpyAsync(d_A, this->s_memory.memptr(),
sizeof(
float) * this->s_memory.n_elem, cudaMemcpyHostToDevice, stream);
139 cusolverDnSgetrf(handle,
int(this->n_rows),
int(this->n_cols), (
float*)d_A,
int(this->n_rows), (
float*)buffer, ipiv, info);
141 this->factored =
true;
144 const size_t byte_size =
sizeof(float) * B.n_elem;
147 cudaMalloc(&d_x, byte_size);
149 X = arma::zeros(B.n_rows, B.n_cols);
151 mat full_residual = B;
153 auto multiplier =
norm(full_residual);
156 while(counter++ < this->setting.iterative_refinement) {
157 if(multiplier < this->setting.tolerance)
break;
159 auto residual = conv_to<fmat>::from(full_residual / multiplier);
161 cudaMemcpyAsync(d_x, residual.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
162 cusolverDnSgetrs(handle, CUBLAS_OP_N,
int(this->n_rows),
int(B.n_cols), (
float*)d_A,
int(this->n_rows), ipiv, (
float*)d_x,
int(this->n_rows), info);
163 cudaMemcpyAsync(residual.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
165 cudaDeviceSynchronize();
167 const mat incre = multiplier * conv_to<mat>::from(residual);
171 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->
operator*(incre)));
174 if(d_x) cudaFree(d_x);
178 if(!this->factored) {
179 cudaMemcpyAsync(d_A, this->memptr(),
sizeof(
double) * this->n_elem, cudaMemcpyHostToDevice, stream);
180 cusolverDnDgetrf(handle,
int(this->n_rows),
int(this->n_cols), (
double*)d_A,
int(this->n_rows), (
double*)buffer, ipiv, info);
182 this->factored =
true;
185 const size_t byte_size =
sizeof(double) * B.n_elem;
188 cudaMalloc(&d_x, byte_size);
189 cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
190 cusolverDnDgetrs(handle, CUBLAS_OP_N,
int(this->n_rows),
int(B.n_cols), (
double*)d_A,
int(this->n_rows), ipiv, (
double*)d_x,
int(this->n_rows), info);
192 X.set_size(arma::size(B));
194 cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
196 cudaDeviceSynchronize();
198 if(d_x) cudaFree(d_x);
A FullMatCUDA class that holds matrices.
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
Definition: suanPan.h:318
void info(const std::string_view format_str, const T &... args)
Definition: suanPan.h:237
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295