Skip to content

Commit

Permalink
Reject invalid None in jax.NamedSharding(spec=None).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722500631
  • Loading branch information
pschuh authored and Google-ML-Automation committed Feb 3, 2025
1 parent 8bc7fa2 commit 0a5c792
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 4 additions & 0 deletions xla/python/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
parsed_pspec_(std::move(parsed_pspec)),
manual_axes_(std::move(manual_axes)),
logical_device_ids_(std::move(logical_device_ids)) {
if (spec_.is_none()) {
throw nb::type_error(
"Unexpected None passed as spec for NamedSharding. Did you mean P()?");
}
nb::object idl = nb::object(mesh_.attr("_internal_device_list"));
if (idl.is_none()) {
internal_device_list_ = std::nullopt;
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 308
_version = 309

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down

0 comments on commit 0a5c792

Please sign in to comment.