Skip to content

Commit

Permalink
Move transfer python bindings into jax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718724289
  • Loading branch information
pschuh authored and Google-ML-Automation committed Jan 23, 2025
1 parent e9a1a97 commit a2eaef3
Show file tree
Hide file tree
Showing 7 changed files with 489 additions and 1 deletion.
1 change: 1 addition & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,7 @@ tsl_pybind_extension(
"//conditions:default": [
"//xla/backends/cpu/collectives:gloo_collectives",
"//xla/backends/cpu/collectives:gloo_kv_store",
"//xla/python/transfer:py_socket_transfer",
"@gloo//:transport_tcp",
],
}) + select({
Expand Down
37 changes: 37 additions & 0 deletions xla/python/transfer/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,40 @@ xla_cc_test(
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "py_socket_transfer",
srcs = ["py_socket_transfer.cc"],
hdrs = ["py_socket_transfer.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":event_loop",
":socket-server",
":socket_bulk_transport",
":streaming",
":streaming_ifrt",
"//xla:util",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:status_casters",
"//xla/python:nb_class_ptr",
"//xla/python:nb_numpy",
"//xla/python:py_client",
"//xla/python:traceback",
"//xla/python:types",
"//xla/python/ifrt",
"//xla/python/pjrt_ifrt",
"//xla/python/pjrt_ifrt:pjrt_dtype",
"//xla/tsl/concurrency:ref_count",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@nanobind",
],
)
Loading

0 comments on commit a2eaef3

Please sign in to comment.