Skip to content

Commit

Permalink
Only test PTX compilation on GitHub
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed May 9, 2024
1 parent faa525d commit b9330c7
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ jobs:
- run: opam depext -yt cudajit
- run: opam install . --deps-only --with-test --with-doc
- run: opam exec -- dune build
- run: opam exec -- dune runtest
- run: opam exec -- dune test test_no_device
71 changes: 0 additions & 71 deletions test/saxpy.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,6 @@ extern "C" __global__ void saxpy(float a, float *x, float *y, float *out, size_t
}
|}

let%expect_test "SAXPY compilation" =
let prog =
Cudajit.compile_to_ptx ~cu_src:kernel ~name:"saxpy" ~options:[ "--use_fast_math" ] ~with_debug:true
in
(match prog.log with None -> () | Some log -> Format.printf "\nCUDA Compile log: %s\n%!" log);
[%expect {| CUDA Compile log: |}];
Format.printf "PTX: %s%!"
@@ Str.global_replace
(Str.regexp {|CL-[0-9]+\|release [0-9]+\.[0-9]+\|V[0-9]+\.[0-9]+\.[0-9]+\|NVVM [0-9]+\.[0-9]+\.[0-9]+|})
"NNN"
@@ Cudajit.string_from_ptx prog;
[%expect
{|
PTX: //
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: NNN
// Cuda compilation tools, NNN, NNN
// Based on NNN
//

.version 8.3
.target sm_52
.address_size 64

// .globl saxpy

.visible .entry saxpy(
.param .f32 saxpy_param_0,
.param .u64 saxpy_param_1,
.param .u64 saxpy_param_2,
.param .u64 saxpy_param_3,
.param .u64 saxpy_param_4
)
{
.reg .pred %p<2>;
.reg .f32 %f<5>;
.reg .b32 %r<5>;
.reg .b64 %rd<13>;


ld.param.f32 %f1, [saxpy_param_0];
ld.param.u64 %rd2, [saxpy_param_1];
ld.param.u64 %rd3, [saxpy_param_2];
ld.param.u64 %rd4, [saxpy_param_3];
ld.param.u64 %rd5, [saxpy_param_4];
mov.u32 %r1, %ctaid.x;
mov.u32 %r2, %ntid.x;
mov.u32 %r3, %tid.x;
mad.lo.s32 %r4, %r1, %r2, %r3;
cvt.u64.u32 %rd1, %r4;
setp.ge.u64 %p1, %rd1, %rd5;
@%p1 bra $L__BB0_2;

cvta.to.global.u64 %rd6, %rd2;
shl.b64 %rd7, %rd1, 2;
add.s64 %rd8, %rd6, %rd7;
ld.global.f32 %f2, [%rd8];
cvta.to.global.u64 %rd9, %rd3;
add.s64 %rd10, %rd9, %rd7;
ld.global.f32 %f3, [%rd10];
fma.rn.ftz.f32 %f4, %f2, %f1, %f3;
cvta.to.global.u64 %rd11, %rd4;
add.s64 %rd12, %rd11, %rd7;
st.global.f32 [%rd12], %f4;

$L__BB0_2:
ret;

} |}]

let%expect_test "SAXPY" =
let num_blocks = 32 in
let num_threads = 128 in
Expand Down
7 changes: 7 additions & 0 deletions test_no_device/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
(library
(name cudajit_test_no_device)
(inline_tests)
(libraries dynlink cudajit)
(preprocess
(pps ppx_expect))
(modes native))
84 changes: 84 additions & 0 deletions test_no_device/saxpy_ptx.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
(* If the test fails, to verify your CUDA and NVRTC installation, follow the following instructions:
https://docs.nvidia.com/cuda/nvrtc/index.html#code-saxpy-cpp
and see where the OCaml version diverges. *)

let kernel =
{|
extern "C" __global__ void saxpy(float a, float *x, float *y, float *out, size_t n) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) {
out[tid] = a * x[tid] + y[tid];
}
}
|}

let%expect_test "SAXPY compilation" =
let prog =
Cudajit.compile_to_ptx ~cu_src:kernel ~name:"saxpy" ~options:[ "--use_fast_math" ] ~with_debug:true
in
(match prog.log with None -> () | Some log -> Format.printf "\nCUDA Compile log: %s\n%!" log);
[%expect {| CUDA Compile log: |}];
Format.printf "PTX: %s%!"
@@ Str.global_replace
(Str.regexp {|CL-[0-9]+\|release [0-9]+\.[0-9]+\|V[0-9]+\.[0-9]+\.[0-9]+\|NVVM [0-9]+\.[0-9]+\.[0-9]+|})
"NNN"
@@ Cudajit.string_from_ptx prog;
[%expect
{|
PTX: //
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: NNN
// Cuda compilation tools, NNN, NNN
// Based on NNN
//

.version 8.3
.target sm_52
.address_size 64

// .globl saxpy

.visible .entry saxpy(
.param .f32 saxpy_param_0,
.param .u64 saxpy_param_1,
.param .u64 saxpy_param_2,
.param .u64 saxpy_param_3,
.param .u64 saxpy_param_4
)
{
.reg .pred %p<2>;
.reg .f32 %f<5>;
.reg .b32 %r<5>;
.reg .b64 %rd<13>;


ld.param.f32 %f1, [saxpy_param_0];
ld.param.u64 %rd2, [saxpy_param_1];
ld.param.u64 %rd3, [saxpy_param_2];
ld.param.u64 %rd4, [saxpy_param_3];
ld.param.u64 %rd5, [saxpy_param_4];
mov.u32 %r1, %ctaid.x;
mov.u32 %r2, %ntid.x;
mov.u32 %r3, %tid.x;
mad.lo.s32 %r4, %r1, %r2, %r3;
cvt.u64.u32 %rd1, %r4;
setp.ge.u64 %p1, %rd1, %rd5;
@%p1 bra $L__BB0_2;

cvta.to.global.u64 %rd6, %rd2;
shl.b64 %rd7, %rd1, 2;
add.s64 %rd8, %rd6, %rd7;
ld.global.f32 %f2, [%rd8];
cvta.to.global.u64 %rd9, %rd3;
add.s64 %rd10, %rd9, %rd7;
ld.global.f32 %f3, [%rd10];
fma.rn.ftz.f32 %f4, %f2, %f1, %f3;
cvta.to.global.u64 %rd11, %rd4;
add.s64 %rd12, %rd11, %rd7;
st.global.f32 [%rd12], %f4;

$L__BB0_2:
ret;

} |}]

0 comments on commit b9330c7

Please sign in to comment.