18#ifndef OPERATOR_TIMES_HPP
19#define OPERATOR_TIMES_HPP
25 if(
nullptr ==
M)
return nullptr;
27 auto N =
M->make_copy();
33 if(
nullptr ==
M)
return nullptr;
35 auto N =
M->make_copy();
66 M->operator+=(std::forward<unique_ptr<MetaMat<T>>>(
A));
71 if(
nullptr ==
A &&
nullptr == B)
return nullptr;
74 A->operator+=(std::forward<unique_ptr<MetaMat<T>>>(B));
75 return std::forward<unique_ptr<MetaMat<T>>>(
A);
78 return std::forward<unique_ptr<MetaMat<T>>>(B);
83 return std::forward<unique_ptr<MetaMat<T>>>(B);
88 return std::forward<unique_ptr<MetaMat<T>>>(B);
97 M->operator-=(std::forward<unique_ptr<MetaMat<T>>>(
A));
102 if(
nullptr ==
A &&
nullptr == B)
return nullptr;
105 A->operator-=(std::forward<unique_ptr<MetaMat<T>>>(B));
106 return std::forward<unique_ptr<MetaMat<T>>>(
A);
109 return std::forward<unique_ptr<MetaMat<T>>>(B);
114 return std::forward<unique_ptr<MetaMat<T>>>(B);
119 return std::forward<unique_ptr<MetaMat<T>>>(B);
123 if(
nullptr ==
M)
return nullptr;
125 return M->operator*(
A);
129 if(
nullptr ==
M)
return nullptr;
131 return M->operator*(
A);
135 Mat<T> C(
A.n_rows,
A.n_cols);
137 constexpr auto TRAN =
'N';
139 const auto M =
static_cast<int>(
A.n_rows);
140 const auto N =
static_cast<int>(B.
n_cols);
141 const auto K =
static_cast<int>(
A.n_cols);
148 if(std::is_same_v<T, float>) {
150 arma_fortran(arma_sgemm)(&TRAN, &TRAN, &
M, &
N, &
K, (
E*)&ALPHA, (
E*)
A.memptr(), &LDA, (
E*)B.
memptr(), &LDB, (
E*)&BETA, (
E*)C.memptr(), &LDC);
152 else if(std::is_same_v<T, double>) {
154 arma_fortran(arma_dgemm)(&TRAN, &TRAN, &
M, &
N, &
K, (
E*)&ALPHA, (
E*)
A.memptr(), &LDA, (
E*)B.
memptr(), &LDB, (
E*)&BETA, (
E*)C.memptr(), &LDC);
const T * memptr() const override
Definition: DenseMat.hpp:91
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
unique_ptr< MetaMat< T > > operator-(unique_ptr< MetaMat< T > > &&A, unique_ptr< MetaMat< T > > &&B)
Definition: operator_times.hpp:101
const shared_ptr< MetaMat< T > > & operator-=(const shared_ptr< MetaMat< T > > &M, const shared_ptr< MetaMat< T > > &A)
Definition: operator_times.hpp:91
const shared_ptr< MetaMat< T > > & operator*=(const shared_ptr< MetaMat< T > > &M, const T value)
Definition: operator_times.hpp:40
unique_ptr< MetaMat< T > > operator*(const T value, const unique_ptr< MetaMat< T > > &M)
Definition: operator_times.hpp:24
const shared_ptr< MetaMat< T > > & operator+=(const shared_ptr< MetaMat< T > > &M, const shared_ptr< MetaMat< T > > &A)
Definition: operator_times.hpp:50
unique_ptr< MetaMat< T > > operator+(unique_ptr< MetaMat< T > > &&A, unique_ptr< MetaMat< T > > &&B)
Definition: operator_times.hpp:70