summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrea Arteaga <andyspiros@gmail.com>2012-09-30 02:03:51 +0200
committerAndrea Arteaga <andyspiros@gmail.com>2012-09-30 02:03:51 +0200
commitcb193b088a585330d86c73498fad309665b929bd (patch)
tree51fd248ed6fc7e5baac6318f39d2b423662c8cde
parentUpdated the (C)BLAS module to use the new interfaces. (diff)
downloadauto-numerical-bench-cb193b088a585330d86c73498fad309665b929bd.tar.gz
auto-numerical-bench-cb193b088a585330d86c73498fad309665b929bd.tar.bz2
auto-numerical-bench-cb193b088a585330d86c73498fad309665b929bd.zip
Complete the implementation of the CInterface for BLAS actions.
-rw-r--r--btl/NumericInterface/NI_internal/CDeclarations.hpp62
-rw-r--r--btl/NumericInterface/NI_internal/CInterface.hpp164
-rw-r--r--btl/NumericInterface/NI_internal/FortranDeclarations.hpp1
-rw-r--r--btl/NumericInterface/NI_internal/FortranInterface.hpp16
-rw-r--r--btl/actions/action_TriSolveVector.hpp14
5 files changed, 227 insertions, 30 deletions
diff --git a/btl/NumericInterface/NI_internal/CDeclarations.hpp b/btl/NumericInterface/NI_internal/CDeclarations.hpp
index 0109559..872cd2b 100644
--- a/btl/NumericInterface/NI_internal/CDeclarations.hpp
+++ b/btl/NumericInterface/NI_internal/CDeclarations.hpp
@@ -40,11 +40,65 @@ const int CblasRight = 142;
// Cblas functions
extern "C" {
- void cblas_sgemv(int, int, int, int, float, const float*, int,
- const float*, int, float, float*, int);
- void cblas_dgemv(int, int, int, int, double, const double*, int,
- const double*, int, double, double*, int);
+
+
+/****************
+ * LEVEL 1 BLAS *
+ ****************/
+
+ void cblas_srot(int, float*, int, float*, int, float, float);
+ void cblas_drot(int, double*, int, double*, int, double, double);
+
+ void cblas_saxpy(int, float, const float*, int, float*, int);
+ void cblas_daxpy(int, double, const double*, int, double*, int);
+
+ float cblas_sdot(int, const float*, int, const float*, int);
+ double cblas_ddot(int, const double*, int, const double*, int);
+
+ float cblas_snrm2(int, const float*, int);
+ double cblas_dnrm2(int, const double*, int);
+
+
+
+
+ /****************
+ * LEVEL 2 BLAS *
+ ****************/
+
+ void cblas_sgemv(int, int, int, int, float, const float*, int, const float*, int, float, float*, int);
+ void cblas_dgemv(int, int, int, int, double, const double*, int, const double*, int, double, double*, int);
+
+ void cblas_ssymv(int, int, int, float, const float*, int, const float*, int, float, float*, int);
+ void cblas_dsymv(int, int, int, double, const double*, int, const double*, int, double, double*, int);
+
+ void cblas_strmv(int, int, int, int, int, const float*, int, float*, int);
+ void cblas_dtrmv(int, int, int, int, int, const double*, int, double*, int);
+
+ void cblas_strsv(int, int, int, int, int, const float*, int, float*, int);
+ void cblas_dtrsv(int, int, int, int, int, const double*, int, double*, int);
+
+ void cblas_sger(int, int, int, float, const float*, int, const float*, int, float*, int);
+ void cblas_dger(int, int, int, double, const double*, int, const double*, int, double*, int);
+
+ void cblas_ssyr2(int, int, int, float, const float*, int, const float*, int, float*, int);
+ void cblas_dsyr2(int, int, int, double, const double*, int, const double*, int, double*, int);
+
+
+
+
+ /****************
+ * LEVEL 3 BLAS *
+ ****************/
+
+ void cblas_sgemm(int, int, int, int, int, int, float, const float*, int, const float*, int, float, float*, int);
+ void cblas_dgemm(int, int, int, int, int, int, double, const double*, int, const double*, int, double, double*, int);
+
+ void cblas_strmm(int, int, int, int, int, int, int, float, const float*, int, float*, int);
+ void cblas_dtrmm(int, int, int, int, int, int, int, double, const double*, int, double*, int);
+
+ void cblas_strsm(int, int, int, int, int, int, int, float, const float*, int, float*, int);
+ void cblas_dtrsm(int, int, int, int, int, int, int, double, const double*, int, double*, int);
}
diff --git a/btl/NumericInterface/NI_internal/CInterface.hpp b/btl/NumericInterface/NI_internal/CInterface.hpp
index a067144..c6db258 100644
--- a/btl/NumericInterface/NI_internal/CInterface.hpp
+++ b/btl/NumericInterface/NI_internal/CInterface.hpp
@@ -23,6 +23,11 @@
#define CBLASFUNC(F) CAT(cblas_, CAT(NI_SCALARPREFIX, F))
+#ifndef NI_NAME
+# define NI_NAME CAT(CAT(CInterface,$),NI_SCALAR)
+#endif
+
+
template<>
class NumericInterface<NI_SCALAR>
{
@@ -32,19 +37,162 @@ public:
public:
static std::string name()
{
- std::string name = "CInterface<";
- name += MAKE_STRING(NI_SCALAR);
- name += ">";
+ std::string name = MAKE_STRING(NI_NAME);
return name;
}
- static void matrixVector(
- const int& M, const int& N, const Scalar& alpha, const Scalar* A,
- const Scalar* x, const Scalar& beta, Scalar* y
- )
+
+ /****************
+ * LEVEL 1 BLAS *
+ ****************/
+
+ static void rot(const int& N, Scalar* x, Scalar* y,
+ const Scalar& cosine, const Scalar& sine)
+ {
+ CBLASFUNC(rot)(N, x, 1, y, 1, cosine, sine);
+ }
+
+
+ static void axpy(const int& N, const Scalar& alpha,
+ const Scalar* x, Scalar* y)
+ {
+ CBLASFUNC(axpy)(N, alpha, x, 1, y, 1);
+ }
+
+ static Scalar dot(const int& N, const Scalar* x, const Scalar* y)
+ {
+ return CBLASFUNC(dot)(N, x, 1, y, 1);
+ }
+
+ static Scalar norm(const int& N, const Scalar* x)
+ {
+ return CBLASFUNC(nrm2)(N, x, 1);
+ }
+
+
+
+ /****************
+ * LEVEL 2 BLAS *
+ ****************/
+
+ static void MatrixVector(const bool& trans, const int& M, const int& N,
+ const Scalar& alpha, const Scalar* A, const Scalar* x,
+ const Scalar& beta, Scalar* y)
{
- CBLASFUNC(gemv)(CblasColMajor, CblasNoTrans, M, N, alpha, A, M,
+ const int LDA = trans ? N : M;
+ const int tA = trans ? CblasTrans : CblasNoTrans;
+ CBLASFUNC(gemv)(CblasColMajor, tA, M, N, alpha, A, LDA,
x, 1, beta, y, 1);
}
+
+ static void SymMatrixVector(const char& uplo, const int& N,
+ const Scalar& alpha, const Scalar* A, const Scalar* x,
+ const Scalar& beta, Scalar* y)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(symv)(CblasColMajor, uplo_, N, alpha, A, N, x, 1, beta, y, 1);
+ }
+
+ static void TriMatrixVector(const char& uplo, const int& N,
+ const Scalar* A, Scalar* x)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(trmv)(CblasColMajor, uplo_, CblasNoTrans, CblasNonUnit, N,
+ A, N, x, 1);
+ }
+
+ static void TriSolveVector(const char& uplo,
+ const int& N, const Scalar* A, Scalar* x)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(trsv)(CblasColMajor, uplo_, CblasNoTrans, CblasNonUnit, N,
+ A, N, x, 1);
+ }
+
+ static void Rank1Update(const int& M, const int& N, const Scalar& alpha,
+ const Scalar* x, const Scalar* y, Scalar* A)
+ {
+ CBLASFUNC(ger)(CblasColMajor, M, N, alpha, x, 1, y, 1, A, M);
+ }
+
+ static void Rank2Update(const char& uplo, const int& N, const Scalar& alpha,
+ const Scalar* x, const Scalar* y, Scalar* A)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(syr2)(CblasColMajor, uplo_, N, alpha, x, 1, y, 1, A, N);
+ }
+
+
+
+ /****************
+ * LEVEL 3 BLAS *
+ ****************/
+
+ static void MatrixMatrix(const bool& transA, const bool& transB,
+ const int& M, const int& N, const int& K,
+ const Scalar& alpha, const Scalar* A, const Scalar* B,
+ const Scalar& beta, Scalar* C)
+ {
+ int LDA = M, LDB = K;
+ int tA = CblasNoTrans, tB = CblasNoTrans;
+
+ if (transA) {
+ LDA = K;
+ tA = CblasTrans;
+ }
+ if (transB) {
+ LDB = N;
+ tB = CblasTrans;
+ }
+
+ CBLASFUNC(gemm)(CblasColMajor, tA, tB, M, N, K, alpha, A, LDA, B, LDB,
+ beta, C, M);
+ }
+
+ static void TriMatrixMatrix(const char& uplo,
+ const int& M, const int& N, const Scalar* A, Scalar* B)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(trmm)(CblasColMajor, CblasLeft, uplo_, CblasNoTrans,
+ CblasNonUnit, M, N, 1., A, M, B, M);
+ }
+
+ static void TriSolveMatrix(const char& uplo,
+ const int& M, const int& N, const Scalar* A, Scalar *B)
+ {
+ int uplo_ = -1;
+ if (uplo == 'u' || uplo == 'U')
+ uplo_ = CblasUpper;
+ else if (uplo == 'l' || uplo == 'L')
+ uplo_ = CblasLower;
+
+ CBLASFUNC(trsm)(CblasColMajor, CblasLeft, uplo_, CblasNoTrans,
+ CblasNonUnit, M, N, 1., A, M, B, M);
+ }
};
diff --git a/btl/NumericInterface/NI_internal/FortranDeclarations.hpp b/btl/NumericInterface/NI_internal/FortranDeclarations.hpp
index 76dfc01..35bce39 100644
--- a/btl/NumericInterface/NI_internal/FortranDeclarations.hpp
+++ b/btl/NumericInterface/NI_internal/FortranDeclarations.hpp
@@ -78,6 +78,7 @@ extern "C" {
void strsm_(const char*, const char*, const char*, const char*, const int*, const int*, const float*, const float*, const int*, float*, const int*);
void dtrsm_(const char*, const char*, const char*, const char*, const int*, const int*, const double*, const double*, const int*, double*, const int*);
+
}
diff --git a/btl/NumericInterface/NI_internal/FortranInterface.hpp b/btl/NumericInterface/NI_internal/FortranInterface.hpp
index 576145d..eea5898 100644
--- a/btl/NumericInterface/NI_internal/FortranInterface.hpp
+++ b/btl/NumericInterface/NI_internal/FortranInterface.hpp
@@ -129,7 +129,7 @@ public:
* LEVEL 3 BLAS *
****************/
- static void matrixMatrix(const bool& transA, const bool& transB,
+ static void MatrixMatrix(const bool& transA, const bool& transB,
const int& M, const int& N, const int& K,
const Scalar& alpha, const Scalar* A, const Scalar* B,
const Scalar& beta, Scalar* C)
@@ -150,13 +150,13 @@ public:
&beta, C, &M);
}
- static void triangularMatrixMatrix(const char& uplo,
+ static void TriMatrixMatrix(const char& uplo,
const int& M, const int& N, const Scalar* A, Scalar* B)
{
FORTFUNC(trmm)("L", &uplo, "N", "N", &M, &N, &fONE, A, &M, B, &M);
}
- static void triangularSolveMatrix(const char& uplo,
+ static void TriSolveMatrix(const char& uplo,
const int& M, const int& N, const Scalar* A, Scalar *B)
{
FORTFUNC(trsm)("L", &uplo, "N", "N", &M, &N, &fONE, A, &M, B, &M);
@@ -168,13 +168,3 @@ const int NumericInterface<NI_SCALAR>::ONE = 1;
const NI_SCALAR NumericInterface<NI_SCALAR>::fONE = 1.;
const char NumericInterface<NI_SCALAR>::NoTrans = 'N';
const char NumericInterface<NI_SCALAR>::Trans = 'T';
-
-
-
-
-
-
-
-
-
-
diff --git a/btl/actions/action_TriSolveVector.hpp b/btl/actions/action_TriSolveVector.hpp
index 2bfefb8..6be18b0 100644
--- a/btl/actions/action_TriSolveVector.hpp
+++ b/btl/actions/action_TriSolveVector.hpp
@@ -37,10 +37,14 @@ public:
// Constructor
Action_TriSolveVector(int size)
: _size(size), lc(10),
- A(lc.fillVector<Scalar>(size*size)), x(lc.fillVector<Scalar>(size)),
+ A(lc.fillVector<Scalar>(size*size)), b(lc.fillVector<Scalar>(size)),
x_work(size)
{
MESSAGE("Action_TriSolveVector Constructor");
+
+ // Adding size to the diagonal of A to make it invertible
+ for (int i = 0; i < size; ++i)
+ A[i+size*i] += size;
}
// Action name
@@ -54,7 +58,7 @@ public:
}
inline void initialize(){
- std::copy(x.begin(), x.end(), x_work.begin());
+ std::copy(b.begin(), b.end(), x_work.begin());
}
inline void calculate() {
@@ -65,15 +69,15 @@ public:
initialize();
calculate();
Interface::TriMatrixVector('U', _size, &A[0], &x_work[0]);
- Interface::axpy(_size, -1., &x[0], &x_work[0]);
+ Interface::axpy(_size, -1., &b[0], &x_work[0]);
return Interface::norm(_size, &x_work[0]);
}
-private:
+//private:
const int _size;
LinearCongruential<> lc;
- const vector_t A, x;
+ vector_t A, b;
vector_t x_work;
};