-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Collective GEMM custom op with nvte_cublas_gemm
(no comm. overlap)
#1307
base: main
Are you sure you want to change the base?
Conversation
nvte_cublas_gemm
nvte_cublas_gemm
(no comm. overlap)
Why? Normal JAX behavior is to do some gathering. |
It seems that currently the batch size is not handled in the C++ code. Since JAX is using row-major storage for tensor by default, probably the batch dimension should be combined with the |
bb2be56
to
fea0728
Compare
6444211
to
f440094
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@denera I have some questions about the PR.
|
||
# Validate operand layouts | ||
lhs_inner_dim, rhs_inner_dim = map( | ||
lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@denera should be ndims + inner_dim
when inner_dim is negative, right?
rhs_trans = contracting_dims[1] == rhs.ndim - 1 | ||
lhs = jnp.matrix_transpose(lhs) if lhs_trans and jax_dtype_is_fp8(lhs.dtype) else lhs | ||
rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs | ||
contracting_dims = (1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@denera is there a need to hard-code this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuBlasLt GEMM requires non-transposed LHS and transposed RHS for FP8 GEMM, but the batcher is not the right place to check/force that. Also, leaving contracting_dims=(1, 1)
out of the conditional for FP8 type is a mistake. Thanks for catching it!
grad=grad, | ||
accumulate=accumulate, | ||
use_split_accumulator=use_split_accumulator, | ||
)(lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives me an error.
Line: https://github.com/NVIDIA/TransformerEngine/pull/1307/files#diff-f5b74ca3c5a70acb3d764e9b8adea40b8bab554fe4d2362f3052b7b932c0464dR187-R194 returns a tuple.
TypeError: 'list' object is not callable
cc @denera
f057def
to
718c03d
Compare
09e2316
to
39bd494
Compare
self._amax_list[FP8MetaPackage.OUTPUT_IDX] = output_amax | ||
self._scale_list[FP8MetaPackage.OUTPUT_IDX] = output_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
For the delayed scaling FP8 recipe, the output amax
and scale
from GEMM are not used anywhere else afterward, so I think we don't need to output and store them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FP8 GEMM+RS overlap needs output amax/scale when the communication buffer type is FP8 -- i.e. the overlap algorithms/kernels communicate FP8 GEMM output and fuse BF16 upcast into the sum-reduce.
This PR does not implement TP overlap, but PR #1337 extends the same operations to support TP overlap, so I'm including the output amax/scale infrastructure here.
) | ||
return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias | ||
|
||
@pytest.mark.parametrize("m,n,k", GEMM_CASES) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to provide a list for test parameter b
(batch size)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this PR is not supposed to modify test_custom_call_compute.py
. These changes are erroneous and need to be removed. Thank you for catching it!
def test_gemm(self, b, m, n, k, use_bias, do_gelu): | ||
a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16) | ||
|
||
primitive_out = gemm(a, b, bias=bias if use_bias else None, layout="NT", do_gelu=do_gelu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need to provide or use layout
parameter here? On one hand, user or other functions in TE is unlikely to use this argument (I think C/C++ code would need it but not python code), on the other hand does it make dist-mem sharding complicated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes to this file are erroneous and I just pushed up a commit to restore the original.
All testing for the new collective GEMM custom op are written in test_distributed_gemm.py
instead.
Signed-off-by: Alp Dener <[email protected]> Added XLA FFI custom op for TE GEMM Signed-off-by: Alp Dener <[email protected]> finished GEMM custom op primitive and serial unit test Signed-off-by: Alp Dener <[email protected]> fixed GEMM custom op batcher Signed-off-by: Alp Dener <[email protected]> fixed output dtype error and contracting dimensions options Signed-off-by: Alp Dener <[email protected]> AG overlap working but executes scatter to match outer LHS dim Signed-off-by: Alp Dener <[email protected]> both all-gather and all-reduce are now working Signed-off-by: Alp Dener <[email protected]> code style Signed-off-by: Alp Dener <[email protected]> changed kwargs in abstract to be explicit Signed-off-by: Alp Dener <[email protected]> added fwd/bwd implementation for non-fp8 gemm Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
… passing test Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
…ide the custom op Signed-off-by: Alp Dener <[email protected]>
…xt-parallel LHS operands Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
… and TP-only meshes Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
07a2fb3
to
f68d71e
Compare
resources.update(dict(dp_resource="dp")) | ||
if parallel_dist == "FSDP_TP": | ||
fsdp = True | ||
mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=1, zp=NUM_DEVICES // 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mesh shape calculation is incorrect. Suggested revision:
if parallel_dist in ["DP_TP", "FSDP_TP"]:
batched = True
tp = NUM_DEVICES // 2
dp = NUM_DEVICES // tp
mesh_shape.update(dict(tp=tp, dp=dp))
resources.update(dict(dp_resource="dp"))
if parallel_dist == "FSDP_TP":
fsdp = True
dp = 1
zp = NUM_DEVICES // tp
mesh_shape.update(dict(tp=tp, dp=1, zp=zp))
resources.update(dict(fsdp_resource="zp"))
Description
Implements both old-style and new FFI-based XLA custom calls in C++, and the corresponding JAX primitive including custom partitioning rules.
Custom partitioning rules for a
LHS:([B,] M, K) x RHS:([B,] K, N) = OUT:([B,] M, N)
batched mat-mul operation where[B]
is the batch dimension:[B]
dimension for all operands.M
dimension.K
andN
dimensions.K
dimension of LHS to match the partitioning of theK
dimension of RHS.K
dimension is partitioned butM
dimension is not,jax.lax.psum
(all-reduce) the output over the TP mesh resource.M
andK
dimensions are partitioned,jax.lax.psum_scatter
(reduce-scatter) the output over the TP mesh resource.In practice, the RHS matrix (typically the weight tensor) should be allocated with transposed contracting dimensions
([B,] N, K)
for optimal GEMM heuristics in cuBlasLt. This layout is also mandatory for FP8 inputs.This PR does NOT update fused ops or Flax/Praxis modules to use the new GEMM custom op over the existing XLA pattern matching approach.
Type of change
Changes
nvte_cublas_gemm
.Checklist: