Skip to content

Commit

Permalink
Don't always compile fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 20, 2024
1 parent cb8082b commit df1d0e9
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 1 deletion.
2 changes: 2 additions & 0 deletions candle-kernels/src/affine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ extern "C" __global__ void FN_NAME( \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"

AFFINE_OP(__nv_bfloat16, affine_bf16, x * mul + add)

#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3))
Expand Down
2 changes: 2 additions & 0 deletions candle-kernels/src/binary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include<stdint.h>

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"

BINARY_OP(__nv_bfloat16, badd_bf16, x + y)
BINARY_OP(__nv_bfloat16, bdiv_bf16, x / y)
BINARY_OP(__nv_bfloat16, bmul_bf16, x * y)
Expand Down
1 change: 1 addition & 0 deletions candle-kernels/src/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ extern "C" __global__ void FN_NAME( \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"
CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)
CAST_OP(__nv_fp8_e4m3, __nv_fp8_e4m3, cast_f8_e4m3_f8_e4m3)

Expand Down
1 change: 0 additions & 1 deletion candle-kernels/src/compatibility.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "cuda_fp16.h"
#include "cuda_bf16.h"
#include "cuda_fp8.h"

// Table showing which features are supported on which compute capability
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications
Expand Down
2 changes: 2 additions & 0 deletions candle-kernels/src/indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ extern "C" __global__ void FN_NAME( \


#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"

IS_OP(__nv_bfloat16, int16_t, is_i16_bf16)
IS_OP(__nv_bfloat16, int32_t, is_i32_bf16)
IS_OP(__nv_bfloat16, int64_t, is_i64_bf16)
Expand Down
2 changes: 2 additions & 0 deletions candle-kernels/src/kvconcat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ KVCONCAT_OP(__half, kvconcat_f16)
#endif

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"

KVCONCAT_OP(__nv_bfloat16, kvconcat_bf16)
KVCONCAT_OP(__nv_fp8_e4m3, kvconcat_f8_e4m3)
#endif
2 changes: 2 additions & 0 deletions candle-kernels/src/ternary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ extern "C" __global__ void FN_NAME( \
} \

#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"

WHERE_OP(__nv_bfloat16, int16_t, where_i16_bf16)
WHERE_OP(__nv_bfloat16, int32_t, where_i32_bf16)
WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16)
Expand Down
2 changes: 2 additions & 0 deletions candle-kernels/src/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ __device__ T sign_(T t) {


#if __CUDA_ARCH__ >= 800
#include "cuda_fp8.h"

UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
UNARY_OP(__nv_bfloat16, uneg_bf16, -x)
UNARY_OP(__nv_bfloat16, urecip_bf16, recipg(x))
Expand Down

0 comments on commit df1d0e9

Please sign in to comment.