diff options
author | Andrea Arteaga <andyspiros@gmail.com> | 2012-09-30 02:03:51 +0200 |
---|---|---|
committer | Andrea Arteaga <andyspiros@gmail.com> | 2012-09-30 02:03:51 +0200 |
commit | cb193b088a585330d86c73498fad309665b929bd (patch) | |
tree | 51fd248ed6fc7e5baac6318f39d2b423662c8cde | |
parent | Updated the (C)BLAS module to use the new interfaces. (diff) | |
download | auto-numerical-bench-cb193b088a585330d86c73498fad309665b929bd.tar.gz auto-numerical-bench-cb193b088a585330d86c73498fad309665b929bd.tar.bz2 auto-numerical-bench-cb193b088a585330d86c73498fad309665b929bd.zip |
Complete the implementation of the CInterface for BLAS actions.
-rw-r--r-- | btl/NumericInterface/NI_internal/CDeclarations.hpp | 62 | ||||
-rw-r--r-- | btl/NumericInterface/NI_internal/CInterface.hpp | 164 | ||||
-rw-r--r-- | btl/NumericInterface/NI_internal/FortranDeclarations.hpp | 1 | ||||
-rw-r--r-- | btl/NumericInterface/NI_internal/FortranInterface.hpp | 16 | ||||
-rw-r--r-- | btl/actions/action_TriSolveVector.hpp | 14 |
5 files changed, 227 insertions, 30 deletions
diff --git a/btl/NumericInterface/NI_internal/CDeclarations.hpp b/btl/NumericInterface/NI_internal/CDeclarations.hpp index 0109559..872cd2b 100644 --- a/btl/NumericInterface/NI_internal/CDeclarations.hpp +++ b/btl/NumericInterface/NI_internal/CDeclarations.hpp @@ -40,11 +40,65 @@ const int CblasRight = 142; // Cblas functions extern "C" { - void cblas_sgemv(int, int, int, int, float, const float*, int, - const float*, int, float, float*, int); - void cblas_dgemv(int, int, int, int, double, const double*, int, - const double*, int, double, double*, int); + + +/**************** + * LEVEL 1 BLAS * + ****************/ + + void cblas_srot(int, float*, int, float*, int, float, float); + void cblas_drot(int, double*, int, double*, int, double, double); + + void cblas_saxpy(int, float, const float*, int, float*, int); + void cblas_daxpy(int, double, const double*, int, double*, int); + + float cblas_sdot(int, const float*, int, const float*, int); + double cblas_ddot(int, const double*, int, const double*, int); + + float cblas_snrm2(int, const float*, int); + double cblas_dnrm2(int, const double*, int); + + + + + /**************** + * LEVEL 2 BLAS * + ****************/ + + void cblas_sgemv(int, int, int, int, float, const float*, int, const float*, int, float, float*, int); + void cblas_dgemv(int, int, int, int, double, const double*, int, const double*, int, double, double*, int); + + void cblas_ssymv(int, int, int, float, const float*, int, const float*, int, float, float*, int); + void cblas_dsymv(int, int, int, double, const double*, int, const double*, int, double, double*, int); + + void cblas_strmv(int, int, int, int, int, const float*, int, float*, int); + void cblas_dtrmv(int, int, int, int, int, const double*, int, double*, int); + + void cblas_strsv(int, int, int, int, int, const float*, int, float*, int); + void cblas_dtrsv(int, int, int, int, int, const double*, int, double*, int); + + void cblas_sger(int, int, int, float, const float*, int, const float*, int, float*, int); + void cblas_dger(int, int, int, double, const double*, int, const double*, int, double*, int); + + void cblas_ssyr2(int, int, int, float, const float*, int, const float*, int, float*, int); + void cblas_dsyr2(int, int, int, double, const double*, int, const double*, int, double*, int); + + + + + /**************** + * LEVEL 3 BLAS * + ****************/ + + void cblas_sgemm(int, int, int, int, int, int, float, const float*, int, const float*, int, float, float*, int); + void cblas_dgemm(int, int, int, int, int, int, double, const double*, int, const double*, int, double, double*, int); + + void cblas_strmm(int, int, int, int, int, int, int, float, const float*, int, float*, int); + void cblas_dtrmm(int, int, int, int, int, int, int, double, const double*, int, double*, int); + + void cblas_strsm(int, int, int, int, int, int, int, float, const float*, int, float*, int); + void cblas_dtrsm(int, int, int, int, int, int, int, double, const double*, int, double*, int); } diff --git a/btl/NumericInterface/NI_internal/CInterface.hpp b/btl/NumericInterface/NI_internal/CInterface.hpp index a067144..c6db258 100644 --- a/btl/NumericInterface/NI_internal/CInterface.hpp +++ b/btl/NumericInterface/NI_internal/CInterface.hpp @@ -23,6 +23,11 @@ #define CBLASFUNC(F) CAT(cblas_, CAT(NI_SCALARPREFIX, F)) +#ifndef NI_NAME +# define NI_NAME CAT(CAT(CInterface,$),NI_SCALAR) +#endif + + template<> class NumericInterface<NI_SCALAR> { @@ -32,19 +37,162 @@ public: public: static std::string name() { - std::string name = "CInterface<"; - name += MAKE_STRING(NI_SCALAR); - name += ">"; + std::string name = MAKE_STRING(NI_NAME); return name; } - static void matrixVector( - const int& M, const int& N, const Scalar& alpha, const Scalar* A, - const Scalar* x, const Scalar& beta, Scalar* y - ) + + /**************** + * LEVEL 1 BLAS * + ****************/ + + static void rot(const int& N, Scalar* x, Scalar* y, + const Scalar& cosine, const Scalar& sine) + { + CBLASFUNC(rot)(N, x, 1, y, 1, cosine, sine); + } + + + static void axpy(const int& N, const Scalar& alpha, + const Scalar* x, Scalar* y) + { + CBLASFUNC(axpy)(N, alpha, x, 1, y, 1); + } + + static Scalar dot(const int& N, const Scalar* x, const Scalar* y) + { + return CBLASFUNC(dot)(N, x, 1, y, 1); + } + + static Scalar norm(const int& N, const Scalar* x) + { + return CBLASFUNC(nrm2)(N, x, 1); + } + + + + /**************** + * LEVEL 2 BLAS * + ****************/ + + static void MatrixVector(const bool& trans, const int& M, const int& N, + const Scalar& alpha, const Scalar* A, const Scalar* x, + const Scalar& beta, Scalar* y) { - CBLASFUNC(gemv)(CblasColMajor, CblasNoTrans, M, N, alpha, A, M, + const int LDA = trans ? N : M; + const int tA = trans ? CblasTrans : CblasNoTrans; + CBLASFUNC(gemv)(CblasColMajor, tA, M, N, alpha, A, LDA, x, 1, beta, y, 1); } + + static void SymMatrixVector(const char& uplo, const int& N, + const Scalar& alpha, const Scalar* A, const Scalar* x, + const Scalar& beta, Scalar* y) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(symv)(CblasColMajor, uplo_, N, alpha, A, N, x, 1, beta, y, 1); + } + + static void TriMatrixVector(const char& uplo, const int& N, + const Scalar* A, Scalar* x) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(trmv)(CblasColMajor, uplo_, CblasNoTrans, CblasNonUnit, N, + A, N, x, 1); + } + + static void TriSolveVector(const char& uplo, + const int& N, const Scalar* A, Scalar* x) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(trsv)(CblasColMajor, uplo_, CblasNoTrans, CblasNonUnit, N, + A, N, x, 1); + } + + static void Rank1Update(const int& M, const int& N, const Scalar& alpha, + const Scalar* x, const Scalar* y, Scalar* A) + { + CBLASFUNC(ger)(CblasColMajor, M, N, alpha, x, 1, y, 1, A, M); + } + + static void Rank2Update(const char& uplo, const int& N, const Scalar& alpha, + const Scalar* x, const Scalar* y, Scalar* A) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(syr2)(CblasColMajor, uplo_, N, alpha, x, 1, y, 1, A, N); + } + + + + /**************** + * LEVEL 3 BLAS * + ****************/ + + static void MatrixMatrix(const bool& transA, const bool& transB, + const int& M, const int& N, const int& K, + const Scalar& alpha, const Scalar* A, const Scalar* B, + const Scalar& beta, Scalar* C) + { + int LDA = M, LDB = K; + int tA = CblasNoTrans, tB = CblasNoTrans; + + if (transA) { + LDA = K; + tA = CblasTrans; + } + if (transB) { + LDB = N; + tB = CblasTrans; + } + + CBLASFUNC(gemm)(CblasColMajor, tA, tB, M, N, K, alpha, A, LDA, B, LDB, + beta, C, M); + } + + static void TriMatrixMatrix(const char& uplo, + const int& M, const int& N, const Scalar* A, Scalar* B) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(trmm)(CblasColMajor, CblasLeft, uplo_, CblasNoTrans, + CblasNonUnit, M, N, 1., A, M, B, M); + } + + static void TriSolveMatrix(const char& uplo, + const int& M, const int& N, const Scalar* A, Scalar *B) + { + int uplo_ = -1; + if (uplo == 'u' || uplo == 'U') + uplo_ = CblasUpper; + else if (uplo == 'l' || uplo == 'L') + uplo_ = CblasLower; + + CBLASFUNC(trsm)(CblasColMajor, CblasLeft, uplo_, CblasNoTrans, + CblasNonUnit, M, N, 1., A, M, B, M); + } }; diff --git a/btl/NumericInterface/NI_internal/FortranDeclarations.hpp b/btl/NumericInterface/NI_internal/FortranDeclarations.hpp index 76dfc01..35bce39 100644 --- a/btl/NumericInterface/NI_internal/FortranDeclarations.hpp +++ b/btl/NumericInterface/NI_internal/FortranDeclarations.hpp @@ -78,6 +78,7 @@ extern "C" { void strsm_(const char*, const char*, const char*, const char*, const int*, const int*, const float*, const float*, const int*, float*, const int*); void dtrsm_(const char*, const char*, const char*, const char*, const int*, const int*, const double*, const double*, const int*, double*, const int*); + } diff --git a/btl/NumericInterface/NI_internal/FortranInterface.hpp b/btl/NumericInterface/NI_internal/FortranInterface.hpp index 576145d..eea5898 100644 --- a/btl/NumericInterface/NI_internal/FortranInterface.hpp +++ b/btl/NumericInterface/NI_internal/FortranInterface.hpp @@ -129,7 +129,7 @@ public: * LEVEL 3 BLAS * ****************/ - static void matrixMatrix(const bool& transA, const bool& transB, + static void MatrixMatrix(const bool& transA, const bool& transB, const int& M, const int& N, const int& K, const Scalar& alpha, const Scalar* A, const Scalar* B, const Scalar& beta, Scalar* C) @@ -150,13 +150,13 @@ public: &beta, C, &M); } - static void triangularMatrixMatrix(const char& uplo, + static void TriMatrixMatrix(const char& uplo, const int& M, const int& N, const Scalar* A, Scalar* B) { FORTFUNC(trmm)("L", &uplo, "N", "N", &M, &N, &fONE, A, &M, B, &M); } - static void triangularSolveMatrix(const char& uplo, + static void TriSolveMatrix(const char& uplo, const int& M, const int& N, const Scalar* A, Scalar *B) { FORTFUNC(trsm)("L", &uplo, "N", "N", &M, &N, &fONE, A, &M, B, &M); @@ -168,13 +168,3 @@ const int NumericInterface<NI_SCALAR>::ONE = 1; const NI_SCALAR NumericInterface<NI_SCALAR>::fONE = 1.; const char NumericInterface<NI_SCALAR>::NoTrans = 'N'; const char NumericInterface<NI_SCALAR>::Trans = 'T'; - - - - - - - - - - diff --git a/btl/actions/action_TriSolveVector.hpp b/btl/actions/action_TriSolveVector.hpp index 2bfefb8..6be18b0 100644 --- a/btl/actions/action_TriSolveVector.hpp +++ b/btl/actions/action_TriSolveVector.hpp @@ -37,10 +37,14 @@ public: // Constructor Action_TriSolveVector(int size) : _size(size), lc(10), - A(lc.fillVector<Scalar>(size*size)), x(lc.fillVector<Scalar>(size)), + A(lc.fillVector<Scalar>(size*size)), b(lc.fillVector<Scalar>(size)), x_work(size) { MESSAGE("Action_TriSolveVector Constructor"); + + // Adding size to the diagonal of A to make it invertible + for (int i = 0; i < size; ++i) + A[i+size*i] += size; } // Action name @@ -54,7 +58,7 @@ public: } inline void initialize(){ - std::copy(x.begin(), x.end(), x_work.begin()); + std::copy(b.begin(), b.end(), x_work.begin()); } inline void calculate() { @@ -65,15 +69,15 @@ public: initialize(); calculate(); Interface::TriMatrixVector('U', _size, &A[0], &x_work[0]); - Interface::axpy(_size, -1., &x[0], &x_work[0]); + Interface::axpy(_size, -1., &b[0], &x_work[0]); return Interface::norm(_size, &x_work[0]); } -private: +//private: const int _size; LinearCongruential<> lc; - const vector_t A, x; + vector_t A, b; vector_t x_work; }; |