Skip to content

Commit

Permalink
Implement lazy import
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 26, 2024
1 parent b2c599d commit cfb53d4
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 12 deletions.
1 change: 1 addition & 0 deletions scripts/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/bin/bash

pip install -e .
python setup.py build_ext --inplace
pip install pytest
python -m pytest tests
25 changes: 25 additions & 0 deletions tensor_bridge/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import TYPE_CHECKING, Optional, Type

if TYPE_CHECKING:
import jax
import torch

__all__ = ["get_torch", "get_jax"]


def get_torch() -> Optional["torch"]:
try:
import torch

return torch
except ImportError:
return None


def get_jax() -> Optional["jax"]:
try:
import jax

return jax
except ImportError:
return None
10 changes: 6 additions & 4 deletions tensor_bridge/tensor_bridge.pyx
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import jax
import torch
from .imports import get_jax, get_torch

from _tensor_bridge cimport DataPtr, native_copy_tensor
from libcpp.pair cimport pair

torch = get_torch()
jax = get_jax()


cdef DataPtr get_ptr_and_size(data):
cdef unsigned long ptr
cdef unsigned long size
cdef DataPtr ret
if isinstance(data, torch.Tensor):
if torch is not None and isinstance(data, torch.Tensor):
ret.ptr = data.data_ptr()
ret.size = torch.numel(data) * data.element_size()
ret.device = data.device.index
elif isinstance(data, jax.Array):
elif jax is not None and isinstance(data, jax.Array):
ret.ptr = data.unsafe_buffer_pointer()
ret.size = data.size * data.dtype.itemsize
ret.device = next(iter(data.devices())).id
Expand Down
11 changes: 7 additions & 4 deletions tensor_bridge/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Union

import jax
import numpy as np
import torch

if TYPE_CHECKING:
import jax
import torch


__all__ = ["NumpyArray", "Array"]


NumpyArray = np.ndarray[Any, Any]
Array = Union[torch.Tensor, jax.Array]
Array = Union["torch.Tensor", "jax.Array"]
11 changes: 7 additions & 4 deletions tensor_bridge/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import jax
import numpy as np
import torch

from .imports import get_jax, get_torch
from .types import Array, NumpyArray

__all__ = ["get_numpy_data"]


torch = get_torch()
jax = get_jax()


def get_numpy_data(tensor: Array) -> NumpyArray:
if isinstance(tensor, torch.Tensor):
if torch is not None and isinstance(tensor, torch.Tensor):
return tensor.cpu().detach().numpy() # type: ignore
elif isinstance(tensor, jax.Array):
elif jax is not None and isinstance(tensor, jax.Array):
return np.array(tensor)
else:
raise ValueError(f"Unsupported tensor type: {type(tensor)}")

0 comments on commit cfb53d4

Please sign in to comment.