Skip to content

Commit

Permalink
Clean up the rocBLAS API a bit. (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
benson31 authored Oct 7, 2020
1 parent 8e830ea commit 9fab17c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
3 changes: 3 additions & 0 deletions include/hydrogen/device/gpu/rocm/rocBLASMeta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ template <>
struct IsSupportedType_Base<rocblas_half, BLAS_Op::AXPY> : std::true_type {};
template <>
struct IsSupportedType_Base<rocblas_half, BLAS_Op::GEMM> : std::true_type {};
template <>
struct IsSupportedType_Base<rocblas_half, BLAS_Op::GEMMSTRIDEDBATCHED>
: std::true_type {};
#endif // HYDROGEN_GPU_USE_FP16

/** @class IsSupportedType
Expand Down
40 changes: 34 additions & 6 deletions include/hydrogen/device/gpu/rocm/rocBLAS_API.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ namespace rocblas
rocblas_int n, ScalarType const* X, rocblas_int incx, \
ScalarType* Y, rocblas_int incy)

#define ADD_DOT_DECL(ScalarType) \
void Dot(rocblasHandle_t handle, \
rocblas_int n, \
ScalarType const* X, rocblas_int incx, \
ScalarType const* Y, rocblas_int incy, \
#define ADD_DOT_DECL(ScalarType) \
void Dot(rocblas_handle handle, \
rocblas_int n, \
ScalarType const* X, rocblas_int incx, \
ScalarType const* Y, rocblas_int incy, \
ScalarType* output)

#define ADD_NRM2_DECL(ScalarType) \
void Nrm2(rocblasHandle_t handle, \
void Nrm2(rocblas_handle handle, \
rocblas_int n, \
ScalarType const* X, rocblas_int incx, \
ScalarType* output)
Expand All @@ -51,6 +51,12 @@ ADD_AXPY_DECL(double);
ADD_COPY_DECL(float);
ADD_COPY_DECL(double);

ADD_DOT_DECL(float);
ADD_DOT_DECL(double);

ADD_NRM2_DECL(float);
ADD_NRM2_DECL(double);

#ifdef HYDROGEN_GPU_USE_FP16
ADD_SCALE_DECL(rocblas_half);
#endif // HYDROGEN_GPU_USE_FP16
Expand Down Expand Up @@ -120,6 +126,22 @@ ADD_GEMV_DECL(double);
ScalarType const& beta, \
ScalarType* C, rocblas_int ldc)

#define ADD_GEMM_STRIDED_BATCHED_DECL(ScalarType) \
void GemmStridedBatched( \
rocblas_handle handle, \
rocblas_operation transpA, \
rocblas_operation transpB, \
rocblas_int m, rocblas_int n, rocblas_int k, \
ScalarType const* alpha, \
ScalarType const* A, rocblas_int lda, \
rocblas_stride strideA, \
ScalarType const* B, rocblas_int ldb, \
rocblas_stride strideB, \
ScalarType const* beta, \
ScalarType* C, rocblas_int ldc, \
rocblas_stride strideC, \
rocblas_int batchCount)

ADD_HERK_DECL(rocblas_float_complex, float);
ADD_HERK_DECL(rocblas_double_complex, double);

Expand All @@ -139,6 +161,12 @@ ADD_GEMM_DECL(rocblas_half);
ADD_GEMM_DECL(float);
ADD_GEMM_DECL(double);

#ifdef HYDROGEN_GPU_USE_FP16
ADD_GEMM_STRIDED_BATCHED_DECL(rocblas_half);
#endif // HYDROGEN_GPU_USE_FP16
ADD_GEMM_STRIDED_BATCHED_DECL(float);
ADD_GEMM_STRIDED_BATCHED_DECL(double);

///@}
/** @name BLAS-like Extension Routines */
///@{
Expand Down

0 comments on commit 9fab17c

Please sign in to comment.