summaryrefslogtreecommitdiff
blob: 859667f3da306ae21ea7b3372a3d0e1cba3222d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
Fix uninitialized variable in MIOpenDriver gemm and restore gemmfp16 for testing
Upstream bug: https://github.com/ROCmSoftwarePlatform/MIOpen/issues/2505
--- a/driver/driver.hpp
+++ b/driver/driver.hpp
@@ -141,7 +141,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
     printf("Usage: ./driver *base_arg* *other_args*\n");
     printf("Supported Base Arguments: conv[fp16|int8|bfp16], CBAInfer[fp16], "
            "pool[fp16], lrn[fp16], "
-           "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], "
+           "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], "
            "tensorop[fp16], reduce[fp16,fp64]\n");
     exit(0); // NOLINT (concurrency-mt-unsafe)
 }
@@ -160,7 +160,7 @@ inline std::string ParseBaseArg(int argc, char* argv[])
        arg != "CBAInfer" && arg != "CBAInferfp16" && arg != "pool" && arg != "poolfp16" &&
        arg != "lrn" && arg != "lrnfp16" && arg != "activ" && arg != "activfp16" &&
        arg != "softmax" && arg != "softmaxfp16" && arg != "bnorm" && arg != "bnormfp16" &&
-       arg != "rnn" && arg != "rnnfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" &&
+       arg != "rnn" && arg != "rnnfp16" && arg != "gemm" && arg != "gemmfp16" && arg != "ctc" &&
        arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" &&
        arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "--version")
     {
--- a/driver/gemm_driver.hpp
+++ b/driver/gemm_driver.hpp
@@ -207,6 +207,19 @@ int GemmDriver<T>::GetandSetData()
     gemm_desc.strideB = gemm_desc.k * gemm_desc.n;
     gemm_desc.strideC = gemm_desc.m * gemm_desc.n;
 
+    if constexpr (std::is_same_v<T, float>)
+    {
+        gemm_desc.dataType = miopenFloat;
+    }
+    else if constexpr (std::is_same_v<T, float16>)
+    {
+        gemm_desc.dataType = miopenHalf;
+    }
+    else
+    {
+        static_assert(!"unsupported type");
+    }
+
     return (0);
 }
 
@@ -230,9 +243,9 @@ int GemmDriver<T>::AllocateBuffersAndCopy()
     a = std::vector<T>(a_sz);
     b = std::vector<T>(b_sz);
 #if GEMM_DRIVER_DEBUG
-    c = std::vector<T>(c_sz, 1.);
+    c = std::vector<T>(c_sz, static_cast<T>(1.));
 #else
-    c                 = std::vector<T>(c_sz, 0.);
+    c = std::vector<T>(c_sz, static_cast<T>(0.));
 #endif
     chost = c;
 
--- a/driver/main.cpp
+++ b/driver/main.cpp
@@ -125,11 +125,10 @@ int main(int argc, char* argv[])
     {
         drv = new GemmDriver<float>();
     }
-// TODO half is not supported in gemm
-//    else if(base_arg == "gemmfp16")
-//    {
-//        drv = new GemmDriver<float16>();
-//    }
+    else if(base_arg == "gemmfp16")
+    {
+        drv = new GemmDriver<float16>();
+    }
 #endif
     else if(base_arg == "bnorm")
     {