Skip to content

Commit

Permalink
Deviceptr.equal and Deviceptr.hash
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Dec 16, 2024
1 parent f9d02b3 commit 0ad175a
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

- Documentation for `Deviceptr.mem_free`: mention it's safe to call multiple times on the same pointer.

### Added

- `Deviceptr.equal` and `Deviceptr.hash`.

## [0.6.1] 2024-12-02

### Added
Expand Down
8 changes: 6 additions & 2 deletions cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,11 @@ let get_size_in_bytes ?kind ?length ?size_in_bytes provenance =
module Deviceptr = struct
type t = deviceptr [@@deriving sexp_of]

let equal (Deviceptr { ptr = ptr1; freed = _ }) (Deviceptr { ptr = ptr2; freed = _ }) =
Unsigned.UInt64.equal ptr1 ptr2

let hash (Deviceptr { ptr; freed = _ }) = Unsigned.UInt64.to_int ptr

let string_of (Deviceptr { ptr; freed }) =
let addr = string_of_memptr ptr in
if Atomic.get freed then addr ^ "/FREED" else addr
Expand Down Expand Up @@ -1693,8 +1698,7 @@ module Stream = struct
| Double of float
[@@deriving sexp_of]

let no_stream =
{ args_lifetimes = []; owned_events = []; stream = Ctypes.(coerce (ptr void) cu_stream null) }
let no_stream = no_stream

let total_unreleased_unfinished_delimited_events stream =
List.fold_left
Expand Down
7 changes: 7 additions & 0 deletions cudajit.mli
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,13 @@ module Deviceptr : sig
{{:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g183f7b0d8ad008ea2a5fd552537ace4e}
CUdeviceptr}. *)

val equal : t -> t -> bool
(** Compares the pointer values for equality. *)

val hash : t -> int
(** Converts the pointer to an OCaml int using {!Unsigned.UInt64.to_int} (truncating bits as
needed). *)

val string_of : t -> string
(** Hexadecimal representation of the pointer. *)

Expand Down
31 changes: 18 additions & 13 deletions test/saxpy.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ let%expect_test "SAXPY" =
let dY = Cu.Deviceptr.alloc_and_memcpy hY in
let hOut = Host.create Bigarray.Float32 Bigarray.C_layout [| size |] in
let dOut = Cu.Deviceptr.alloc_and_memcpy hOut in
let eq = Cu.Deviceptr.equal in
Printf.printf "dX = dX %b; dX = dY %b; dY = dOut %b.\n" (eq dX dX) (eq dX dY) (eq dY dOut);
[%expect
{|
cu_init
cu_device_get_count
cu_device_get
cu_ctx_create
cu_ctx_get_current
cu_module_load_data_ex
cu_module_get_function
cu_mem_alloc
cu_memcpy_H_to_D
cu_mem_alloc
cu_memcpy_H_to_D
cu_mem_alloc
cu_memcpy_H_to_D
dX = dX true; dX = dY false; dY = dOut false.|}];
Cu.Stream.launch_kernel kernel ~grid_dim_x:num_blocks ~block_dim_x:num_threads
~shared_mem_bytes:0 Cu.Stream.no_stream
[
Expand All @@ -64,19 +82,6 @@ let%expect_test "SAXPY" =
Gc.full_major ();
[%expect
{|
cu_init
cu_device_get_count
cu_device_get
cu_ctx_create
cu_ctx_get_current
cu_module_load_data_ex
cu_module_get_function
cu_mem_alloc
cu_memcpy_H_to_D
cu_mem_alloc
cu_memcpy_H_to_D
cu_mem_alloc
cu_memcpy_H_to_D
cu_launch_kernel
cu_ctx_synchronize
cu_memcpy_D_to_H
Expand Down

0 comments on commit 0ad175a

Please sign in to comment.