summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'btl/NumericInterface/NI_internal/CInterface.hpp')
-rw-r--r--btl/NumericInterface/NI_internal/CInterface.hpp164
1 files changed, 156 insertions, 8 deletions
diff --git a/btl/NumericInterface/NI_internal/CInterface.hpp b/btl/NumericInterface/NI_internal/CInterface.hpp
index a067144..c6db258 100644
--- a/btl/NumericInterface/NI_internal/CInterface.hpp
+++ b/btl/NumericInterface/NI_internal/CInterface.hpp
@@ -23,6 +23,11 @@
#define CBLASFUNC(F) CAT(cblas_, CAT(NI_SCALARPREFIX, F))
+#ifndef NI_NAME
+# define NI_NAME CAT(CAT(CInterface,$),NI_SCALAR)
+#endif
+
+
template<>
class NumericInterface<NI_SCALAR>
{
@@ -32,19 +37,162 @@ public:
public:
static std::string name()
{
- std::string name = "CInterface<";
- name += MAKE_STRING(NI_SCALAR);
- name += ">";
+ std::string name = MAKE_STRING(NI_NAME);
return name;
}
- static void matrixVector(
- const int& M, const int& N, const Scalar& alpha, const Scalar* A,
- const Scalar* x, const Scalar& beta, Scalar* y
- )
+
+ /****************
+ * LEVEL 1 BLAS *
+ ****************/
+
+ static void rot(const int& N, Scalar* x, Scalar* y,
+ const Scalar& cosine, const Scalar& sine)
+ {
+ CBLASFUNC(rot)(N, x, 1, y, 1, cosine, sine);
+ }
+
+
+ static void axpy(const int& N, const Scalar& alpha,
+ const Scalar* x, Scalar* y)
+ {
+ CBLASFUNC(axpy)(N, alpha, x, 1, y, 1);
+ }
+
+ static Scalar dot(const int& N, const Scalar* x, const Scalar* y)
+ {
+ return CBLASFUNC(dot)(N, x, 1, y, 1);
+ }
+
+ static Scalar norm(const int& N, const Scalar* x)
+ {
+ return CBLASFUNC(nrm2)(N, x, 1);
+ }
+
+
+
+ /****************
+ * LEVEL 2 BLAS *
+ ****************/
+
+ static void MatrixVector(const bool& trans, const int& M, const int& N,
+ const Scalar& alpha, const Scalar* A, const Scalar* x,
+ const Scalar& beta, Scalar* y)
{
- CBLASFUNC(gemv)(CblasColMajor, CblasNoTrans, M, N, alpha, A, M,
+ const int LDA = trans ? N : M;
+ const int tA = trans ? CblasTrans : CblasNoTrans;
+ CBLASFUNC(gemv)(CblasColMajor, tA, M, N, alpha, A, LDA,
x, 1, beta, y, 1);
}
+
+ static void SymMatrixVector(const char& uplo, const int& N,
+ const Scalar& alpha, const Scalar* A, const Scalar* x,
+ const Scalar& beta, Scalar* y)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(symv)(CblasColMajor, uplo_, N, alpha, A, N, x, 1, beta, y, 1);
+ }
+
+ static void TriMatrixVector(const char& uplo, const int& N,
+ const Scalar* A, Scalar* x)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(trmv)(CblasColMajor, uplo_, CblasNoTrans, CblasNonUnit, N,
+ A, N, x, 1);
+ }
+
+ static void TriSolveVector(const char& uplo,
+ const int& N, const Scalar* A, Scalar* x)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(trsv)(CblasColMajor, uplo_, CblasNoTrans, CblasNonUnit, N,
+ A, N, x, 1);
+ }
+
+ static void Rank1Update(const int& M, const int& N, const Scalar& alpha,
+ const Scalar* x, const Scalar* y, Scalar* A)
+ {
+ CBLASFUNC(ger)(CblasColMajor, M, N, alpha, x, 1, y, 1, A, M);
+ }
+
+ static void Rank2Update(const char& uplo, const int& N, const Scalar& alpha,
+ const Scalar* x, const Scalar* y, Scalar* A)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(syr2)(CblasColMajor, uplo_, N, alpha, x, 1, y, 1, A, N);
+ }
+
+
+
+ /****************
+ * LEVEL 3 BLAS *
+ ****************/
+
+ static void MatrixMatrix(const bool& transA, const bool& transB,
+ const int& M, const int& N, const int& K,
+ const Scalar& alpha, const Scalar* A, const Scalar* B,
+ const Scalar& beta, Scalar* C)
+ {
+ int LDA = M, LDB = K;
+ int tA = CblasNoTrans, tB = CblasNoTrans;
+
+ if (transA) {
+ LDA = K;
+ tA = CblasTrans;
+ }
+ if (transB) {
+ LDB = N;
+ tB = CblasTrans;
+ }
+
+ CBLASFUNC(gemm)(CblasColMajor, tA, tB, M, N, K, alpha, A, LDA, B, LDB,
+ beta, C, M);
+ }
+
+ static void TriMatrixMatrix(const char& uplo,
+ const int& M, const int& N, const Scalar* A, Scalar* B)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(trmm)(CblasColMajor, CblasLeft, uplo_, CblasNoTrans,
+ CblasNonUnit, M, N, 1., A, M, B, M);
+ }
+
+ static void TriSolveMatrix(const char& uplo,
+ const int& M, const int& N, const Scalar* A, Scalar *B)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(trsm)(CblasColMajor, CblasLeft, uplo_, CblasNoTrans,
+ CblasNonUnit, M, N, 1., A, M, B, M);
+ }
};