suanPan
Loading...
Searching...
No Matches
operator_times.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2023 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
17
18#ifndef OPERATOR_TIMES_HPP
19#define OPERATOR_TIMES_HPP
20
21#include "FullMat.hpp"
22#include "SymmPackMat.hpp"
23
24template<sp_d T> unique_ptr<MetaMat<T>> operator*(const T value, const unique_ptr<MetaMat<T>>& M) {
25 if(nullptr == M) return nullptr;
26
27 auto N = M->make_copy();
28 N->operator*=(value);
29 return N;
30}
31
32template<sp_d T> unique_ptr<MetaMat<T>> operator*(const T value, const shared_ptr<MetaMat<T>>& M) {
33 if(nullptr == M) return nullptr;
34
35 auto N = M->make_copy();
36 N->operator*=(value);
37 return N;
38}
39
40template<sp_d T> const shared_ptr<MetaMat<T>>& operator*=(const shared_ptr<MetaMat<T>>& M, const T value) {
41 M->operator*=(value);
42 return M;
43}
44
45template<sp_d T> const unique_ptr<MetaMat<T>>& operator*=(const unique_ptr<MetaMat<T>>& M, const T value) {
46 M->operator*=(value);
47 return M;
48}
49
50template<sp_d T> const shared_ptr<MetaMat<T>>& operator+=(const shared_ptr<MetaMat<T>>& M, const shared_ptr<MetaMat<T>>& A) {
51 M->operator+=(A);
52 return M;
53}
54
55template<sp_d T> const unique_ptr<MetaMat<T>>& operator+=(const unique_ptr<MetaMat<T>>& M, const shared_ptr<MetaMat<T>>& A) {
56 M->operator+=(A);
57 return M;
58}
59
60template<sp_d DT, sp_i IT> const unique_ptr<MetaMat<DT>>& operator+=(const unique_ptr<MetaMat<DT>>& M, const triplet_form<DT, IT>& A) {
61 M->operator+=(A);
62 return M;
63}
64
65template<sp_d T> const shared_ptr<MetaMat<T>>& operator+=(const shared_ptr<MetaMat<T>>& M, unique_ptr<MetaMat<T>>&& A) {
66 M->operator+=(std::forward<unique_ptr<MetaMat<T>>>(A));
67 return M;
68}
69
70template<sp_d T> unique_ptr<MetaMat<T>> operator+(unique_ptr<MetaMat<T>>&& A, unique_ptr<MetaMat<T>>&& B) {
71 if(nullptr == A && nullptr == B) return nullptr;
72
73 if(nullptr != A) {
74 A->operator+=(std::forward<unique_ptr<MetaMat<T>>>(B));
75 return std::forward<unique_ptr<MetaMat<T>>>(A);
76 }
77
78 return std::forward<unique_ptr<MetaMat<T>>>(B);
79}
80
81template<sp_d T> unique_ptr<MetaMat<T>> operator+(const shared_ptr<MetaMat<T>>& A, unique_ptr<MetaMat<T>>&& B) {
82 B->operator+=(A);
83 return std::forward<unique_ptr<MetaMat<T>>>(B);
84}
85
86template<sp_d T> unique_ptr<MetaMat<T>> operator+(unique_ptr<MetaMat<T>>&& B, const shared_ptr<MetaMat<T>>& A) {
87 B->operator+=(A);
88 return std::forward<unique_ptr<MetaMat<T>>>(B);
89}
90
91template<sp_d T> const shared_ptr<MetaMat<T>>& operator-=(const shared_ptr<MetaMat<T>>& M, const shared_ptr<MetaMat<T>>& A) {
92 M->operator-=(A);
93 return M;
94}
95
96template<sp_d T> const shared_ptr<MetaMat<T>>& operator-=(const shared_ptr<MetaMat<T>>& M, unique_ptr<MetaMat<T>>&& A) {
97 M->operator-=(std::forward<unique_ptr<MetaMat<T>>>(A));
98 return M;
99}
100
101template<sp_d T> unique_ptr<MetaMat<T>> operator-(unique_ptr<MetaMat<T>>&& A, unique_ptr<MetaMat<T>>&& B) {
102 if(nullptr == A && nullptr == B) return nullptr;
103
104 if(nullptr != A) {
105 A->operator-=(std::forward<unique_ptr<MetaMat<T>>>(B));
106 return std::forward<unique_ptr<MetaMat<T>>>(A);
107 }
108
109 return std::forward<unique_ptr<MetaMat<T>>>(B);
110}
111
112template<sp_d T> unique_ptr<MetaMat<T>> operator-(const shared_ptr<MetaMat<T>>& A, unique_ptr<MetaMat<T>>&& B) {
113 B->operator-=(A);
114 return std::forward<unique_ptr<MetaMat<T>>>(B);
115}
116
117template<sp_d T> unique_ptr<MetaMat<T>> operator-(unique_ptr<MetaMat<T>>&& B, const shared_ptr<MetaMat<T>>& A) {
118 B->operator-=(A);
119 return std::forward<unique_ptr<MetaMat<T>>>(B);
120}
121
122template<sp_d T> Mat<T> operator*(const shared_ptr<MetaMat<T>>& M, const Mat<T>& A) {
123 if(nullptr == M) return nullptr;
124
125 return M->operator*(A);
126}
127
128template<sp_d T> Mat<T> operator*(const unique_ptr<MetaMat<T>>& M, const Mat<T>& A) {
129 if(nullptr == M) return nullptr;
130
131 return M->operator*(A);
132}
133
134template<sp_d T> Mat<T> operator*(const Mat<T>& A, const FullMat<T>& B) {
135 Mat<T> C(A.n_rows, A.n_cols);
136
137 constexpr auto TRAN = 'N';
138
139 const auto M = static_cast<int>(A.n_rows);
140 const auto N = static_cast<int>(B.n_cols);
141 const auto K = static_cast<int>(A.n_cols);
142 T ALPHA = 1.;
143 const auto LDA = M;
144 const auto LDB = K;
145 T BETA = 0.;
146 const auto LDC = M;
147
148 if(std::is_same_v<T, float>) {
149 using E = float;
150 arma_fortran(arma_sgemm)(&TRAN, &TRAN, &M, &N, &K, (E*)&ALPHA, (E*)A.memptr(), &LDA, (E*)B.memptr(), &LDB, (E*)&BETA, (E*)C.memptr(), &LDC);
151 }
152 else if(std::is_same_v<T, double>) {
153 using E = double;
154 arma_fortran(arma_dgemm)(&TRAN, &TRAN, &M, &N, &K, (E*)&ALPHA, (E*)A.memptr(), &LDA, (E*)B.memptr(), &LDB, (E*)&BETA, (E*)C.memptr(), &LDC);
155 }
156
157 return C;
158}
159
160template<sp_d T, sp_i IT> triplet_form<T, IT> operator*(const T value, const triplet_form<T, IT>& M) {
161 auto N = M;
162 N *= value;
163 return N;
164}
165
166template<sp_d T, sp_i IT> triplet_form<T, IT> operator*(const T value, triplet_form<T, IT>&& M) {
167 M *= value;
168 return M;
169}
170
171#endif // OPERATOR_TIMES_HPP
const T * memptr() const override
Definition: DenseMat.hpp:91
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
A MetaMat class that holds matrices.
Definition: MetaMat.hpp:39
const uword n_cols
Definition: MetaMat.hpp:86
Definition: triplet_form.hpp:62
unique_ptr< MetaMat< T > > operator-(unique_ptr< MetaMat< T > > &&A, unique_ptr< MetaMat< T > > &&B)
Definition: operator_times.hpp:101
const shared_ptr< MetaMat< T > > & operator-=(const shared_ptr< MetaMat< T > > &M, const shared_ptr< MetaMat< T > > &A)
Definition: operator_times.hpp:91
const shared_ptr< MetaMat< T > > & operator*=(const shared_ptr< MetaMat< T > > &M, const T value)
Definition: operator_times.hpp:40
unique_ptr< MetaMat< T > > operator*(const T value, const unique_ptr< MetaMat< T > > &M)
Definition: operator_times.hpp:24
const shared_ptr< MetaMat< T > > & operator+=(const shared_ptr< MetaMat< T > > &M, const shared_ptr< MetaMat< T > > &A)
Definition: operator_times.hpp:50
unique_ptr< MetaMat< T > > operator+(unique_ptr< MetaMat< T > > &&A, unique_ptr< MetaMat< T > > &&B)
Definition: operator_times.hpp:70