60void dot_blas(std::span<const T> A, std::array<std::size_t, 2> Ashape,
61 std::span<const T> B, std::array<std::size_t, 2> Bshape,
79 sgemm_(&trans, &trans, &N, &M, &K, &alpha,
const_cast<T*
>(B.data()), &ldb,
80 const_cast<T*
>(A.data()), &lda, &beta, C.data(), &ldc);
84 dgemm_(&trans, &trans, &N, &M, &K, &alpha,
const_cast<T*
>(B.data()), &ldb,
85 const_cast<T*
>(A.data()), &lda, &beta, C.data(), &ldc);
99 std::vector<typename U::value_type> result(u.size() * v.size());
100 for (std::size_t i = 0; i < u.size(); ++i)
101 for (std::size_t j = 0; j < v.size(); ++j)
102 result[i * v.size() + j] = u[i] * v[j];
103 return {std::move(result), {u.size(), v.size()}};
127std::pair<std::vector<T>, std::vector<T>>
eigh(std::span<const T> A,
131 std::vector<T> M(A.begin(), A.end());
134 std::vector<T> w(n, 0);
143 std::vector<T> work(1);
144 std::vector<int> iwork(1);
147 if constexpr (std::is_same_v<T, float>)
149 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
150 iwork.data(), &liwork, &info);
152 else if constexpr (std::is_same_v<T, double>)
154 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
155 iwork.data(), &liwork, &info);
159 throw std::runtime_error(
"Could not find workspace size for syevd.");
162 work.resize(work[0]);
163 iwork.resize(iwork[0]);
165 liwork = iwork.size();
166 if constexpr (std::is_same_v<T, float>)
168 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
169 iwork.data(), &liwork, &info);
171 else if constexpr (std::is_same_v<T, double>)
173 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
174 iwork.data(), &liwork, &info);
177 throw std::runtime_error(
"Eigenvalue computation did not converge.");
179 return {std::move(w), std::move(M)};
187std::vector<T>
solve(md::mdspan<
const T, md::dextents<std::size_t, 2>> A,
188 md::mdspan<
const T, md::dextents<std::size_t, 2>> B)
191 mdex::mdarray<T, md::dextents<std::size_t, 2>, md::layout_left> _A(
194 for (std::size_t i = 0; i < A.extent(0); ++i)
195 for (std::size_t j = 0; j < A.extent(1); ++j)
197 for (std::size_t i = 0; i < B.extent(0); ++i)
198 for (std::size_t j = 0; j < B.extent(1); ++j)
201 int N = _A.extent(0);
202 int nrhs = _B.extent(1);
203 int lda = _A.extent(0);
204 int ldb = _B.extent(0);
207 std::vector<int> piv(N);
209 if constexpr (std::is_same_v<T, float>)
210 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
211 else if constexpr (std::is_same_v<T, double>)
212 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
214 throw std::runtime_error(
"Call to dgesv failed: " + std::to_string(info));
217 std::vector<T> rb(_B.extent(0) * _B.extent(1));
218 md::mdspan<T, md::dextents<std::size_t, 2>> r(rb.data(), _B.extents());
219 for (std::size_t i = 0; i < _B.extent(0); ++i)
220 for (std::size_t j = 0; j < _B.extent(1); ++j)
230bool is_singular(md::mdspan<
const T, md::dextents<std::size_t, 2>> A)
233 mdex::mdarray<T, md::dextents<std::size_t, 2>, md::layout_left> _A(
235 for (std::size_t i = 0; i < A.extent(0); ++i)
236 for (std::size_t j = 0; j < A.extent(1); ++j)
239 std::vector<T> B(A.extent(1), 1);
240 int N = _A.extent(0);
242 int lda = _A.extent(0);
246 std::vector<int> piv(N);
248 if constexpr (std::is_same_v<T, float>)
249 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
250 else if constexpr (std::is_same_v<T, double>)
251 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
255 throw std::runtime_error(
"dgesv failed due to invalid value: "
256 + std::to_string(info));
273 std::size_t dim = A.second[0];
274 assert(dim == A.second[1]);
277 std::vector<int> lu_perm(dim);
280 if constexpr (std::is_same_v<T, float>)
281 sgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
282 else if constexpr (std::is_same_v<T, double>)
283 dgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
287 throw std::runtime_error(
"LU decomposition failed: "
288 + std::to_string(info));
291 std::vector<std::size_t> perm(dim);
292 for (std::size_t i = 0; i < dim; ++i)
293 perm[i] =
static_cast<std::size_t
>(lu_perm[i] - 1);
306void dot(
const U& A,
const V& B, W&& C,
307 typename std::decay_t<U>::value_type alpha = 1,
308 typename std::decay_t<U>::value_type beta = 0)
310 using T =
typename std::decay_t<U>::value_type;
312 assert(A.extent(1) == B.extent(0));
313 assert(C.extent(0) == A.extent(0));
314 assert(C.extent(1) == B.extent(1));
315 if (A.extent(0) * B.extent(1) * A.extent(1) < 256)
317 for (std::size_t i = 0; i < A.extent(0); ++i)
319 for (std::size_t j = 0; j < B.extent(1); ++j)
324 for (std::size_t k = 0; k < A.extent(1); ++k)
325 _C += A(i, k) * B(k, j);
326 _C = alpha * _C + beta * C0;
332 static_assert(std::is_same_v<typename std::decay_t<U>::layout_type,
334 static_assert(std::is_same_v<typename std::decay_t<V>::layout_type,
336 static_assert(std::is_same_v<typename std::decay_t<W>::layout_type,
338 static_assert(std::is_same_v<typename std::decay_t<V>::value_type, T>);
339 static_assert(std::is_same_v<typename std::decay_t<W>::value_type, T>);
341 std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
342 std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
343 std::span(C.data_handle(), C.size()), alpha, beta);
351std::vector<T>
eye(std::size_t n)
353 std::vector<T> I(n * n, 0);
354 md::mdspan<T, md::dextents<std::size_t, 2>> Iview(I.data(), n, n);
355 for (std::size_t i = 0; i < n; ++i)
366 std::size_t start = 0)
368 for (std::size_t i = start; i < wcoeffs.extent(0); ++i)
371 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
372 norm += wcoeffs(i, k) * wcoeffs(i, k);
374 norm = std::sqrt(norm);
375 if (norm < 2 * std::numeric_limits<T>::epsilon())
377 throw std::runtime_error(
"Cannot orthogonalise the rows of a matrix "
378 "with incomplete row rank");
381 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
382 wcoeffs(i, k) /= norm;
384 for (std::size_t j = i + 1; j < wcoeffs.extent(0); ++j)
387 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
388 a += wcoeffs(i, k) * wcoeffs(j, k);
389 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
390 wcoeffs(j, k) -= a * wcoeffs(i, k);