Basix
Loading...
Searching...
No Matches
math.h
1// Copyright (C) 2021-2024 Igor Baratta and Garth N. Wells
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include "mdspan.hpp"
10#include "types.h"
11#include <algorithm>
12#include <array>
13#include <cmath>
14#include <concepts>
15#include <span>
16#include <stdexcept>
17#include <string>
18#include <utility>
19#include <vector>
20
21extern "C"
22{
23 void ssyevd_(char* jobz, char* uplo, int* n, float* a, int* lda, float* w,
24 float* work, int* lwork, int* iwork, int* liwork, int* info);
25 void dsyevd_(char* jobz, char* uplo, int* n, double* a, int* lda, double* w,
26 double* work, int* lwork, int* iwork, int* liwork, int* info);
27
28 void sgesv_(int* N, int* NRHS, float* A, int* LDA, int* IPIV, float* B,
29 int* LDB, int* INFO);
30 void dgesv_(int* N, int* NRHS, double* A, int* LDA, int* IPIV, double* B,
31 int* LDB, int* INFO);
32
33 void sgemm_(char* transa, char* transb, int* m, int* n, int* k, float* alpha,
34 float* a, int* lda, float* b, int* ldb, float* beta, float* c,
35 int* ldc);
36 void dgemm_(char* transa, char* transb, int* m, int* n, int* k, double* alpha,
37 double* a, int* lda, double* b, int* ldb, double* beta, double* c,
38 int* ldc);
39
40 int sgetrf_(const int* m, const int* n, float* a, const int* lda, int* lpiv,
41 int* info);
42 int dgetrf_(const int* m, const int* n, double* a, const int* lda, int* lpiv,
43 int* info);
44}
45
50namespace basix::math
51{
52namespace impl
53{
59template <std::floating_point T>
60void dot_blas(std::span<const T> A, std::array<std::size_t, 2> Ashape,
61 std::span<const T> B, std::array<std::size_t, 2> Bshape,
62 std::span<T> C, T alpha = 1, T beta = 0)
63{
64 static_assert(std::is_same_v<T, float> or std::is_same_v<T, double>);
65
66 assert(Ashape[1] == Bshape[0]);
67 assert(C.size() == Ashape[0] * Bshape[1]);
68
69 int M = Ashape[0];
70 int N = Bshape[1];
71 int K = Ashape[1];
72
73 int lda = K;
74 int ldb = N;
75 int ldc = N;
76 char trans = 'N';
77 if constexpr (std::is_same_v<T, float>)
78 {
79 sgemm_(&trans, &trans, &N, &M, &K, &alpha, const_cast<T*>(B.data()), &ldb,
80 const_cast<T*>(A.data()), &lda, &beta, C.data(), &ldc);
81 }
82 else if constexpr (std::is_same_v<T, double>)
83 {
84 dgemm_(&trans, &trans, &N, &M, &K, &alpha, const_cast<T*>(B.data()), &ldb,
85 const_cast<T*>(A.data()), &lda, &beta, C.data(), &ldc);
86 }
87}
88
89} // namespace impl
90
95template <typename U, typename V>
96std::pair<std::vector<typename U::value_type>, std::array<std::size_t, 2>>
97outer(const U& u, const V& v)
98{
99 std::vector<typename U::value_type> result(u.size() * v.size());
100 for (std::size_t i = 0; i < u.size(); ++i)
101 for (std::size_t j = 0; j < v.size(); ++j)
102 result[i * v.size() + j] = u[i] * v[j];
103 return {std::move(result), {u.size(), v.size()}};
104}
105
110template <typename U, typename V>
111std::array<typename U::value_type, 3> cross(const U& u, const V& v)
112{
113 assert(u.size() == 3);
114 assert(v.size() == 3);
115 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
116 u[0] * v[1] - u[1] * v[0]};
117}
118
126template <std::floating_point T>
127std::pair<std::vector<T>, std::vector<T>> eigh(std::span<const T> A,
128 std::size_t n)
129{
130 // Copy A
131 std::vector<T> M(A.begin(), A.end());
132
133 // Allocate storage for eigenvalues
134 std::vector<T> w(n, 0);
135
136 int N = n;
137 char jobz = 'V'; // Compute eigenvalues and eigenvectors
138 char uplo = 'L'; // Lower
139 int ldA = n;
140 int lwork = -1;
141 int liwork = -1;
142 int info;
143 std::vector<T> work(1);
144 std::vector<int> iwork(1);
145
146 // Query optimal workspace size
147 if constexpr (std::is_same_v<T, float>)
148 {
149 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
150 iwork.data(), &liwork, &info);
151 }
152 else if constexpr (std::is_same_v<T, double>)
153 {
154 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
155 iwork.data(), &liwork, &info);
156 }
157
158 if (info != 0)
159 throw std::runtime_error("Could not find workspace size for syevd.");
160
161 // Solve eigen problem
162 work.resize(work[0]);
163 iwork.resize(iwork[0]);
164 lwork = work.size();
165 liwork = iwork.size();
166 if constexpr (std::is_same_v<T, float>)
167 {
168 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
169 iwork.data(), &liwork, &info);
170 }
171 else if constexpr (std::is_same_v<T, double>)
172 {
173 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
174 iwork.data(), &liwork, &info);
175 }
176 if (info != 0)
177 throw std::runtime_error("Eigenvalue computation did not converge.");
178
179 return {std::move(w), std::move(M)};
180}
181
186template <std::floating_point T>
187std::vector<T> solve(md::mdspan<const T, md::dextents<std::size_t, 2>> A,
188 md::mdspan<const T, md::dextents<std::size_t, 2>> B)
189{
190 // Copy A and B to column-major storage
191 mdex::mdarray<T, md::dextents<std::size_t, 2>, md::layout_left> _A(
192 A.extents()),
193 _B(B.extents());
194 for (std::size_t i = 0; i < A.extent(0); ++i)
195 for (std::size_t j = 0; j < A.extent(1); ++j)
196 _A(i, j) = A(i, j);
197 for (std::size_t i = 0; i < B.extent(0); ++i)
198 for (std::size_t j = 0; j < B.extent(1); ++j)
199 _B(i, j) = B(i, j);
200
201 int N = _A.extent(0);
202 int nrhs = _B.extent(1);
203 int lda = _A.extent(0);
204 int ldb = _B.extent(0);
205
206 // Pivot indices that define the permutation matrix for the LU solver
207 std::vector<int> piv(N);
208 int info;
209 if constexpr (std::is_same_v<T, float>)
210 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
211 else if constexpr (std::is_same_v<T, double>)
212 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
213 if (info != 0)
214 throw std::runtime_error("Call to dgesv failed: " + std::to_string(info));
215
216 // Copy result to row-major storage
217 std::vector<T> rb(_B.extent(0) * _B.extent(1));
218 md::mdspan<T, md::dextents<std::size_t, 2>> r(rb.data(), _B.extents());
219 for (std::size_t i = 0; i < _B.extent(0); ++i)
220 for (std::size_t j = 0; j < _B.extent(1); ++j)
221 r(i, j) = _B(i, j);
222
223 return rb;
224}
225
229template <std::floating_point T>
230bool is_singular(md::mdspan<const T, md::dextents<std::size_t, 2>> A)
231{
232 // Copy to column major matrix
233 mdex::mdarray<T, md::dextents<std::size_t, 2>, md::layout_left> _A(
234 A.extents());
235 for (std::size_t i = 0; i < A.extent(0); ++i)
236 for (std::size_t j = 0; j < A.extent(1); ++j)
237 _A(i, j) = A(i, j);
238
239 std::vector<T> B(A.extent(1), 1);
240 int N = _A.extent(0);
241 int nrhs = 1;
242 int lda = _A.extent(0);
243 int ldb = B.size();
244
245 // Pivot indices that define the permutation matrix for the LU solver
246 std::vector<int> piv(N);
247 int info;
248 if constexpr (std::is_same_v<T, float>)
249 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
250 else if constexpr (std::is_same_v<T, double>)
251 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
252
253 if (info < 0)
254 {
255 throw std::runtime_error("dgesv failed due to invalid value: "
256 + std::to_string(info));
257 }
258 else if (info > 0)
259 return true;
260 else
261 return false;
262}
263
269template <std::floating_point T>
270std::vector<std::size_t>
271transpose_lu(std::pair<std::vector<T>, std::array<std::size_t, 2>>& A)
272{
273 std::size_t dim = A.second[0];
274 assert(dim == A.second[1]);
275 int N = dim;
276 int info;
277 std::vector<int> lu_perm(dim);
278
279 // Comput LU decomposition of M
280 if constexpr (std::is_same_v<T, float>)
281 sgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
282 else if constexpr (std::is_same_v<T, double>)
283 dgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
284
285 if (info != 0)
286 {
287 throw std::runtime_error("LU decomposition failed: "
288 + std::to_string(info));
289 }
290
291 std::vector<std::size_t> perm(dim);
292 for (std::size_t i = 0; i < dim; ++i)
293 perm[i] = static_cast<std::size_t>(lu_perm[i] - 1);
294
295 return perm;
296}
297
305template <typename U, typename V, typename W>
306void dot(const U& A, const V& B, W&& C,
307 typename std::decay_t<U>::value_type alpha = 1,
308 typename std::decay_t<U>::value_type beta = 0)
309{
310 using T = typename std::decay_t<U>::value_type;
311
312 assert(A.extent(1) == B.extent(0));
313 assert(C.extent(0) == A.extent(0));
314 assert(C.extent(1) == B.extent(1));
315 if (A.extent(0) * B.extent(1) * A.extent(1) < 256)
316 {
317 for (std::size_t i = 0; i < A.extent(0); ++i)
318 {
319 for (std::size_t j = 0; j < B.extent(1); ++j)
320 {
321 T C0 = C(i, j);
322 C(i, j) = 0;
323 T& _C = C(i, j);
324 for (std::size_t k = 0; k < A.extent(1); ++k)
325 _C += A(i, k) * B(k, j);
326 _C = alpha * _C + beta * C0;
327 }
328 }
329 }
330 else
331 {
332 static_assert(std::is_same_v<typename std::decay_t<U>::layout_type,
333 md::layout_right>);
334 static_assert(std::is_same_v<typename std::decay_t<V>::layout_type,
335 md::layout_right>);
336 static_assert(std::is_same_v<typename std::decay_t<W>::layout_type,
337 md::layout_right>);
338 static_assert(std::is_same_v<typename std::decay_t<V>::value_type, T>);
339 static_assert(std::is_same_v<typename std::decay_t<W>::value_type, T>);
340 impl::dot_blas<T>(
341 std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
342 std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
343 std::span(C.data_handle(), C.size()), alpha, beta);
344 }
345}
346
350template <std::floating_point T>
351std::vector<T> eye(std::size_t n)
352{
353 std::vector<T> I(n * n, 0);
354 md::mdspan<T, md::dextents<std::size_t, 2>> Iview(I.data(), n, n);
355 for (std::size_t i = 0; i < n; ++i)
356 Iview(i, i) = 1;
357 return I;
358}
359
364template <std::floating_point T>
365void orthogonalise(md::mdspan<T, md::dextents<std::size_t, 2>> wcoeffs,
366 std::size_t start = 0)
367{
368 for (std::size_t i = start; i < wcoeffs.extent(0); ++i)
369 {
370 T norm = 0;
371 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
372 norm += wcoeffs(i, k) * wcoeffs(i, k);
373
374 norm = std::sqrt(norm);
375 if (norm < 2 * std::numeric_limits<T>::epsilon())
376 {
377 throw std::runtime_error("Cannot orthogonalise the rows of a matrix "
378 "with incomplete row rank");
379 }
380
381 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
382 wcoeffs(i, k) /= norm;
383
384 for (std::size_t j = i + 1; j < wcoeffs.extent(0); ++j)
385 {
386 T a = 0;
387 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
388 a += wcoeffs(i, k) * wcoeffs(j, k);
389 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
390 wcoeffs(j, k) -= a * wcoeffs(i, k);
391 }
392 }
393}
394} // namespace basix::math
Mathematical functions.
Definition math.h:51
bool is_singular(md::mdspan< const T, md::dextents< std::size_t, 2 > > A)
Check if A is a singular matrix.
Definition math.h:230
std::array< typename U::value_type, 3 > cross(const U &u, const V &v)
Compute the cross product u x v.
Definition math.h:111
std::vector< std::size_t > transpose_lu(std::pair< std::vector< T >, std::array< std::size_t, 2 > > &A)
Compute the LU decomposition of the transpose of a square matrix A.
Definition math.h:271
void orthogonalise(md::mdspan< T, md::dextents< std::size_t, 2 > > wcoeffs, std::size_t start=0)
Orthogonalise the rows of a matrix (in place).
Definition math.h:365
std::vector< T > solve(md::mdspan< const T, md::dextents< std::size_t, 2 > > A, md::mdspan< const T, md::dextents< std::size_t, 2 > > B)
Solve A X = B.
Definition math.h:187
void dot(const U &A, const V &B, W &&C, typename std::decay_t< U >::value_type alpha=1, typename std::decay_t< U >::value_type beta=0)
Compute C = alpha A * B + beta C.
Definition math.h:306
std::pair< std::vector< T >, std::vector< T > > eigh(std::span< const T > A, std::size_t n)
Compute the eigenvalues and eigenvectors of a square Hermitian matrix A.
Definition math.h:127
std::vector< T > eye(std::size_t n)
Build an identity matrix.
Definition math.h:351
std::pair< std::vector< typename U::value_type >, std::array< std::size_t, 2 > > outer(const U &u, const V &v)
Compute the outer product of vectors u and v.
Definition math.h:97