diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2a5adde..9ca3df6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,4 +36,4 @@ jobs: - run: opam depext -yt cudajit - run: opam install . --deps-only --with-test --with-doc - run: opam exec -- dune build - - run: opam exec -- dune runtest + - run: opam exec -- dune test test_no_device diff --git a/test/saxpy.ml b/test/saxpy.ml index 0fcbb39..da6f6ef 100644 --- a/test/saxpy.ml +++ b/test/saxpy.ml @@ -12,77 +12,6 @@ extern "C" __global__ void saxpy(float a, float *x, float *y, float *out, size_t } |} -let%expect_test "SAXPY compilation" = - let prog = - Cudajit.compile_to_ptx ~cu_src:kernel ~name:"saxpy" ~options:[ "--use_fast_math" ] ~with_debug:true - in - (match prog.log with None -> () | Some log -> Format.printf "\nCUDA Compile log: %s\n%!" log); - [%expect {| CUDA Compile log: |}]; - Format.printf "PTX: %s%!" - @@ Str.global_replace - (Str.regexp {|CL-[0-9]+\|release [0-9]+\.[0-9]+\|V[0-9]+\.[0-9]+\.[0-9]+\|NVVM [0-9]+\.[0-9]+\.[0-9]+|}) - "NNN" - @@ Cudajit.string_from_ptx prog; - [%expect - {| - PTX: // - // Generated by NVIDIA NVVM Compiler - // - // Compiler Build ID: NNN - // Cuda compilation tools, NNN, NNN - // Based on NNN - // - - .version 8.3 - .target sm_52 - .address_size 64 - - // .globl saxpy - - .visible .entry saxpy( - .param .f32 saxpy_param_0, - .param .u64 saxpy_param_1, - .param .u64 saxpy_param_2, - .param .u64 saxpy_param_3, - .param .u64 saxpy_param_4 - ) - { - .reg .pred %p<2>; - .reg .f32 %f<5>; - .reg .b32 %r<5>; - .reg .b64 %rd<13>; - - - ld.param.f32 %f1, [saxpy_param_0]; - ld.param.u64 %rd2, [saxpy_param_1]; - ld.param.u64 %rd3, [saxpy_param_2]; - ld.param.u64 %rd4, [saxpy_param_3]; - ld.param.u64 %rd5, [saxpy_param_4]; - mov.u32 %r1, %ctaid.x; - mov.u32 %r2, %ntid.x; - mov.u32 %r3, %tid.x; - mad.lo.s32 %r4, %r1, %r2, %r3; - cvt.u64.u32 %rd1, %r4; - setp.ge.u64 %p1, %rd1, %rd5; - @%p1 bra $L__BB0_2; - - cvta.to.global.u64 %rd6, %rd2; - shl.b64 %rd7, %rd1, 2; - add.s64 %rd8, %rd6, %rd7; - ld.global.f32 %f2, [%rd8]; - cvta.to.global.u64 %rd9, %rd3; - add.s64 %rd10, %rd9, %rd7; - ld.global.f32 %f3, [%rd10]; - fma.rn.ftz.f32 %f4, %f2, %f1, %f3; - cvta.to.global.u64 %rd11, %rd4; - add.s64 %rd12, %rd11, %rd7; - st.global.f32 [%rd12], %f4; - - $L__BB0_2: - ret; - - } |}] - let%expect_test "SAXPY" = let num_blocks = 32 in let num_threads = 128 in diff --git a/test_no_device/dune b/test_no_device/dune new file mode 100644 index 0000000..8f3944e --- /dev/null +++ b/test_no_device/dune @@ -0,0 +1,7 @@ +(library + (name cudajit_test_no_device) + (inline_tests) + (libraries dynlink cudajit) + (preprocess + (pps ppx_expect)) + (modes native)) diff --git a/test_no_device/saxpy_ptx.ml b/test_no_device/saxpy_ptx.ml new file mode 100644 index 0000000..6304d62 --- /dev/null +++ b/test_no_device/saxpy_ptx.ml @@ -0,0 +1,84 @@ +(* If the test fails, to verify your CUDA and NVRTC installation, follow the following instructions: + https://docs.nvidia.com/cuda/nvrtc/index.html#code-saxpy-cpp + and see where the OCaml version diverges. *) + +let kernel = + {| +extern "C" __global__ void saxpy(float a, float *x, float *y, float *out, size_t n) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) { + out[tid] = a * x[tid] + y[tid]; + } +} +|} + +let%expect_test "SAXPY compilation" = + let prog = + Cudajit.compile_to_ptx ~cu_src:kernel ~name:"saxpy" ~options:[ "--use_fast_math" ] ~with_debug:true + in + (match prog.log with None -> () | Some log -> Format.printf "\nCUDA Compile log: %s\n%!" log); + [%expect {| CUDA Compile log: |}]; + Format.printf "PTX: %s%!" + @@ Str.global_replace + (Str.regexp {|CL-[0-9]+\|release [0-9]+\.[0-9]+\|V[0-9]+\.[0-9]+\.[0-9]+\|NVVM [0-9]+\.[0-9]+\.[0-9]+|}) + "NNN" + @@ Cudajit.string_from_ptx prog; + [%expect + {| + PTX: // + // Generated by NVIDIA NVVM Compiler + // + // Compiler Build ID: NNN + // Cuda compilation tools, NNN, NNN + // Based on NNN + // + + .version 8.3 + .target sm_52 + .address_size 64 + + // .globl saxpy + + .visible .entry saxpy( + .param .f32 saxpy_param_0, + .param .u64 saxpy_param_1, + .param .u64 saxpy_param_2, + .param .u64 saxpy_param_3, + .param .u64 saxpy_param_4 + ) + { + .reg .pred %p<2>; + .reg .f32 %f<5>; + .reg .b32 %r<5>; + .reg .b64 %rd<13>; + + + ld.param.f32 %f1, [saxpy_param_0]; + ld.param.u64 %rd2, [saxpy_param_1]; + ld.param.u64 %rd3, [saxpy_param_2]; + ld.param.u64 %rd4, [saxpy_param_3]; + ld.param.u64 %rd5, [saxpy_param_4]; + mov.u32 %r1, %ctaid.x; + mov.u32 %r2, %ntid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r1, %r2, %r3; + cvt.u64.u32 %rd1, %r4; + setp.ge.u64 %p1, %rd1, %rd5; + @%p1 bra $L__BB0_2; + + cvta.to.global.u64 %rd6, %rd2; + shl.b64 %rd7, %rd1, 2; + add.s64 %rd8, %rd6, %rd7; + ld.global.f32 %f2, [%rd8]; + cvta.to.global.u64 %rd9, %rd3; + add.s64 %rd10, %rd9, %rd7; + ld.global.f32 %f3, [%rd10]; + fma.rn.ftz.f32 %f4, %f2, %f1, %f3; + cvta.to.global.u64 %rd11, %rd4; + add.s64 %rd12, %rd11, %rd7; + st.global.f32 [%rd12], %f4; + + $L__BB0_2: + ret; + + } |}]