ocaml-torch faces several challenges, including:
- binding to thousands of functions
- avoiding any minor memory leaks in these functions
- quickly cleaning up the memory allocations of tensors when OCaml is done with them
In order to solve this, we have 2 steps of code generation. In this diagram, solid arrows represent the code generation DAG and dashed arrows represent the code dependency DAG:
At a high level,
- Declarations.yaml contains the function signatures for the whole Torch C++ API.
- Custom binding generation reads all the declarations, and whenever possible, generating
- glue code for crossing between C/C++ (the generated C/C++ API),
- glue code for using the (yet to be generated) OCaml
foreign
functions in OCaml (the generated OCaml wrapper), - and
ctypes
bindings.
- Stub generation uses the
ctypes
library, reading the bindings and generating C and OCaml stubs. These are just glue code to handle C/OCaml FFI. Note that we have some manually-written C++ functions and bindings that get generated stubs. - There are an extremely small number of manually-written stubs (just 1 as of writing) that ctypes cannot handle.
- A combination of the generated OCaml wrapper and manually written wrapper provide an actually usable OCaml API. These are further built upon in the main library (not pictured).
A large part of this complexity is driven by memory management.
It is challenging to write manual FFI stubs without memory leaks or race conditions. We
use ctypes
to make sure we get this right on the vast majority of functions. Although it
requires a second code generation step, this spares us from reinventing stub generation.
We ensure that tensors are freed when OCaml garbage collects them. To do this, each Tensor is equipped with a custom finalizer. This could be done on either the C++ or OCaml side. However, the API to inform OCaml of a tensor's true size in memory only exists in C++ (the custom block API). Without this, OCaml would not know when to garbage collect on CPU and would OOM easily.
Each C++ torch::Tensor
is essentially an intrusive pointer to a TensorImpl
that stores
the real information about the tensor. TensorImpls contain a reference count, and whenever
a new intrusive pointer is created or destroyed, that reference count changes. We can't
pass these torch::Tensor
s directly to OCaml, though, so instead we work with
TensorImpl *
s and use the release
/reclaim(_copy)
Torch API for intrusive pointers.
The finalizer on each garbage-collected OCaml tensor just does a reclaim
, allowing the
refcount to decrement to 0 when the resulting torch::Tensor
goes out of scope.
Note that OCaml is unaware of GPU memory usage. GPU users may need to manually garbage collect.
One wrinkle in this setup is that ctypes cannot handle custom blocks. Since we want the
bulk of our stubs to be generated by ctypes, we create a distinction between raw_tensor
s
and gc_tensor
s.
raw tensor | GC tensor | |
---|---|---|
has finalizer? | no | yes |
GC knows its size? | no | yes |
FFI input for C? | no | yes |
FFI output from C? | yes | no |
ctypes type | void ptr | void ptr |
C++ type | TensorImpl * | TensorImpl * |
The only way to convert from a raw_tensor
to gc_tensor
is with the hand-written,
non-ctypes function with_tensor_gc
. It is used copiously in the generated OCaml wrapper
code to ensure we only surface GC tensors to the user.
The lifecycle of each tensor looks like this:
- Some wrapper function
let t = Tensor.foo ()
gets invoked, which makes its way into C++. - C++ returns a
raw_tensor
that goes through a regular ctypes stub and makes its way back to the OCamlTensor.foo
call. - Still in
Tensor.foo
,with_tensor_gc
gets invoked. This goes back into C++ and copies the pointer (but not the data) of the tensor to a new custom block. It now has known off-heap size and a finalizer to free its memory. This gets returned to OCaml with the same memory layout ctypes uses but without going through ctypes. - Now
let () = Tensor.bar t
gets invoked. This goes through usual ctypes stubs, sincet
looks just like a regularvoid ptr
to ctypes. - Eventually
t
gets garbage collected. OCaml traverses its blocks and runs the finalizer on each one, freeing the tensor's data.
This is a lot of indirection. The memory of each tensor (raw or GC) looks like this:
block 1 block2
------------------ ----------
root -> | ctypes fat ptr | -> | void * | -> torch::TensorImpl -> storage
------------------ ----------
Here's what each thing in the chain does:
- block 1: allows ctypes to manage the memory of block2
- block 2: points to the off-heap memory and OCaml finalizer to decrement its refcount
torch::TensorImpl
: holds metadata about the tensor's data type, size, etc., as well as a pointer/reference to its heap-allocated data (storage)- storage: the actual numerical data of the tensor