diff options
author | Andrea Arteaga <andyspiros@gmail.com> | 2012-09-30 02:03:51 +0200 |
---|---|---|
committer | Andrea Arteaga <andyspiros@gmail.com> | 2012-09-30 02:03:51 +0200 |
commit | cb193b088a585330d86c73498fad309665b929bd (patch) | |
tree | 51fd248ed6fc7e5baac6318f39d2b423662c8cde /btl/NumericInterface/NI_internal/CInterface.hpp | |
parent | Updated the (C)BLAS module to use the new interfaces. (diff) | |
download | auto-numerical-bench-cb193b088a585330d86c73498fad309665b929bd.tar.gz auto-numerical-bench-cb193b088a585330d86c73498fad309665b929bd.tar.bz2 auto-numerical-bench-cb193b088a585330d86c73498fad309665b929bd.zip |
Complete the implementation of the CInterface for BLAS actions.
Diffstat (limited to 'btl/NumericInterface/NI_internal/CInterface.hpp')
-rw-r--r-- | btl/NumericInterface/NI_internal/CInterface.hpp | 164 |
1 files changed, 156 insertions, 8 deletions
diff --git a/btl/NumericInterface/NI_internal/CInterface.hpp b/btl/NumericInterface/NI_internal/CInterface.hpp index a067144..c6db258 100644 --- a/btl/NumericInterface/NI_internal/CInterface.hpp +++ b/btl/NumericInterface/NI_internal/CInterface.hpp @@ -23,6 +23,11 @@ #define CBLASFUNC(F) CAT(cblas_, CAT(NI_SCALARPREFIX, F)) +#ifndef NI_NAME +# define NI_NAME CAT(CAT(CInterface,$),NI_SCALAR) +#endif + + template<> class NumericInterface<NI_SCALAR> { @@ -32,19 +37,162 @@ public: public: static std::string name() { - std::string name = "CInterface<"; - name += MAKE_STRING(NI_SCALAR); - name += ">"; + std::string name = MAKE_STRING(NI_NAME); return name; } - static void matrixVector( - const int& M, const int& N, const Scalar& alpha, const Scalar* A, - const Scalar* x, const Scalar& beta, Scalar* y - ) + + /**************** + * LEVEL 1 BLAS * + ****************/ + + static void rot(const int& N, Scalar* x, Scalar* y, + const Scalar& cosine, const Scalar& sine) + { + CBLASFUNC(rot)(N, x, 1, y, 1, cosine, sine); + } + + + static void axpy(const int& N, const Scalar& alpha, + const Scalar* x, Scalar* y) + { + CBLASFUNC(axpy)(N, alpha, x, 1, y, 1); + } + + static Scalar dot(const int& N, const Scalar* x, const Scalar* y) + { + return CBLASFUNC(dot)(N, x, 1, y, 1); + } + + static Scalar norm(const int& N, const Scalar* x) + { + return CBLASFUNC(nrm2)(N, x, 1); + } + + + + /**************** + * LEVEL 2 BLAS * + ****************/ + + static void MatrixVector(const bool& trans, const int& M, const int& N, + const Scalar& alpha, const Scalar* A, const Scalar* x, + const Scalar& beta, Scalar* y) { - CBLASFUNC(gemv)(CblasColMajor, CblasNoTrans, M, N, alpha, A, M, + const int LDA = trans ? N : M; + const int tA = trans ? CblasTrans : CblasNoTrans; + CBLASFUNC(gemv)(CblasColMajor, tA, M, N, alpha, A, LDA, x, 1, beta, y, 1); } + + static void SymMatrixVector(const char& uplo, const int& N, + const Scalar& alpha, const Scalar* A, const Scalar* x, + const Scalar& beta, Scalar* y) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(symv)(CblasColMajor, uplo_, N, alpha, A, N, x, 1, beta, y, 1); + } + + static void TriMatrixVector(const char& uplo, const int& N, + const Scalar* A, Scalar* x) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(trmv)(CblasColMajor, uplo_, CblasNoTrans, CblasNonUnit, N, + A, N, x, 1); + } + + static void TriSolveVector(const char& uplo, + const int& N, const Scalar* A, Scalar* x) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(trsv)(CblasColMajor, uplo_, CblasNoTrans, CblasNonUnit, N, + A, N, x, 1); + } + + static void Rank1Update(const int& M, const int& N, const Scalar& alpha, + const Scalar* x, const Scalar* y, Scalar* A) + { + CBLASFUNC(ger)(CblasColMajor, M, N, alpha, x, 1, y, 1, A, M); + } + + static void Rank2Update(const char& uplo, const int& N, const Scalar& alpha, + const Scalar* x, const Scalar* y, Scalar* A) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(syr2)(CblasColMajor, uplo_, N, alpha, x, 1, y, 1, A, N); + } + + + + /**************** + * LEVEL 3 BLAS * + ****************/ + + static void MatrixMatrix(const bool& transA, const bool& transB, + const int& M, const int& N, const int& K, + const Scalar& alpha, const Scalar* A, const Scalar* B, + const Scalar& beta, Scalar* C) + { + int LDA = M, LDB = K; + int tA = CblasNoTrans, tB = CblasNoTrans; + + if (transA) { + LDA = K; + tA = CblasTrans; + } + if (transB) { + LDB = N; + tB = CblasTrans; + } + + CBLASFUNC(gemm)(CblasColMajor, tA, tB, M, N, K, alpha, A, LDA, B, LDB, + beta, C, M); + } + + static void TriMatrixMatrix(const char& uplo, + const int& M, const int& N, const Scalar* A, Scalar* B) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(trmm)(CblasColMajor, CblasLeft, uplo_, CblasNoTrans, + CblasNonUnit, M, N, 1., A, M, B, M); + } + + static void TriSolveMatrix(const char& uplo, + const int& M, const int& N, const Scalar* A, Scalar *B) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(trsm)(CblasColMajor, CblasLeft, uplo_, CblasNoTrans, + CblasNonUnit, M, N, 1., A, M, B, M); + } }; |