-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
92 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
(library | ||
(name cudajit_test_no_device) | ||
(inline_tests) | ||
(libraries dynlink cudajit) | ||
(preprocess | ||
(pps ppx_expect)) | ||
(modes native)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
(* If the test fails, to verify your CUDA and NVRTC installation, follow the following instructions: | ||
https://docs.nvidia.com/cuda/nvrtc/index.html#code-saxpy-cpp | ||
and see where the OCaml version diverges. *) | ||
|
||
let kernel = | ||
{| | ||
extern "C" __global__ void saxpy(float a, float *x, float *y, float *out, size_t n) { | ||
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (tid < n) { | ||
out[tid] = a * x[tid] + y[tid]; | ||
} | ||
} | ||
|} | ||
|
||
let%expect_test "SAXPY compilation" = | ||
let prog = | ||
Cudajit.compile_to_ptx ~cu_src:kernel ~name:"saxpy" ~options:[ "--use_fast_math" ] ~with_debug:true | ||
in | ||
(match prog.log with None -> () | Some log -> Format.printf "\nCUDA Compile log: %s\n%!" log); | ||
[%expect {| CUDA Compile log: |}]; | ||
Format.printf "PTX: %s%!" | ||
@@ Str.global_replace | ||
(Str.regexp {|CL-[0-9]+\|release [0-9]+\.[0-9]+\|V[0-9]+\.[0-9]+\.[0-9]+\|NVVM [0-9]+\.[0-9]+\.[0-9]+|}) | ||
"NNN" | ||
@@ Cudajit.string_from_ptx prog; | ||
[%expect | ||
{| | ||
PTX: // | ||
// Generated by NVIDIA NVVM Compiler | ||
// | ||
// Compiler Build ID: NNN | ||
// Cuda compilation tools, NNN, NNN | ||
// Based on NNN | ||
// | ||
|
||
.version 8.3 | ||
.target sm_52 | ||
.address_size 64 | ||
|
||
// .globl saxpy | ||
|
||
.visible .entry saxpy( | ||
.param .f32 saxpy_param_0, | ||
.param .u64 saxpy_param_1, | ||
.param .u64 saxpy_param_2, | ||
.param .u64 saxpy_param_3, | ||
.param .u64 saxpy_param_4 | ||
) | ||
{ | ||
.reg .pred %p<2>; | ||
.reg .f32 %f<5>; | ||
.reg .b32 %r<5>; | ||
.reg .b64 %rd<13>; | ||
|
||
|
||
ld.param.f32 %f1, [saxpy_param_0]; | ||
ld.param.u64 %rd2, [saxpy_param_1]; | ||
ld.param.u64 %rd3, [saxpy_param_2]; | ||
ld.param.u64 %rd4, [saxpy_param_3]; | ||
ld.param.u64 %rd5, [saxpy_param_4]; | ||
mov.u32 %r1, %ctaid.x; | ||
mov.u32 %r2, %ntid.x; | ||
mov.u32 %r3, %tid.x; | ||
mad.lo.s32 %r4, %r1, %r2, %r3; | ||
cvt.u64.u32 %rd1, %r4; | ||
setp.ge.u64 %p1, %rd1, %rd5; | ||
@%p1 bra $L__BB0_2; | ||
|
||
cvta.to.global.u64 %rd6, %rd2; | ||
shl.b64 %rd7, %rd1, 2; | ||
add.s64 %rd8, %rd6, %rd7; | ||
ld.global.f32 %f2, [%rd8]; | ||
cvta.to.global.u64 %rd9, %rd3; | ||
add.s64 %rd10, %rd9, %rd7; | ||
ld.global.f32 %f3, [%rd10]; | ||
fma.rn.ftz.f32 %f4, %f2, %f1, %f3; | ||
cvta.to.global.u64 %rd11, %rd4; | ||
add.s64 %rd12, %rd11, %rd7; | ||
st.global.f32 [%rd12], %f4; | ||
|
||
$L__BB0_2: | ||
ret; | ||
|
||
} |}] |