Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 106 additions & 27 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.lda = is_A_transposed ? m : k;
}

if (is_fp8_dtype(ret.Atype)) {
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK(ret.lda % 16 == 0,
"Leading dimension requirement on A for FP8 GEMM. Caller must pad.");
}
// Note: lda%16 check removed — cublas_gemm handles alignment padding automatically
// for sequence packing with dynamic token counts.
} else if (nvfp4) {
// NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe.

Expand Down Expand Up @@ -206,12 +203,14 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.lda = k;

// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK((ret.lda % 16) == 0,
"Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
// lda%16 check removed — cublas_gemm handles padding
// NVTE_CHECK((ret.lda % 16) == 0,
// "Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
// Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement.
// Smallest supported CType is 2 bytes in this scaling mode.
NVTE_CHECK((m % 8) == 0,
"Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
// m%8 check removed — cublas_gemm handles padding
// NVTE_CHECK((m % 8) == 0,
// "Outer dimension requirement on A for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
} else {
NVTE_ERROR("A has unsupported scaling mode");
}
Expand Down Expand Up @@ -247,11 +246,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.ldb = is_B_transposed ? k : n;
}

if (is_fp8_dtype(ret.Atype)) {
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK(ret.ldb % 16 == 0,
"Leading dimension requirement on B for FP8 GEMM. Caller must pad.");
}
// ldb%16 check removed — cublas_gemm handles alignment padding automatically
} else if (nvfp4) {
if (is_B_transposed) {
NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode),
Expand Down Expand Up @@ -292,12 +287,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla

// Requirements from
// https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK((ret.ldb % 16) == 0,
"B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
// ldb%16 check removed — cublas_gemm handles padding
// NVTE_CHECK((ret.ldb % 16) == 0,
// "B tensor stride requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad.");
if (B.scaling_mode == NVTE_BLOCK_SCALING_1D) {
// Observed this requirement only present for B tensor is 1D quantized.
NVTE_CHECK((n % 8) == 0,
"Outer dimension requirement on B for NVTE_BLOCK_SCALING GEMM. Caller must pad.");
// n%8 check removed — cublas_gemm handles alignment padding
}
} else {
NVTE_ERROR("B has unsupported scaling mode");
Expand Down Expand Up @@ -325,24 +320,91 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const int B1 = inputB->flat_last_dim();

// GEMM dims in column-major order
const int m = transa == CUBLAS_OP_T ? A0 : A1;
const int m_real = transa == CUBLAS_OP_T ? A0 : A1;
const int n = transb == CUBLAS_OP_T ? B1 : B0;
const int k = transa == CUBLAS_OP_T ? A1 : A0;
NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
const int k_real = transa == CUBLAS_OP_T ? A1 : A0;
NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k_real,
"GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
")");
const int ldd = m;

// Return immediately if GEMM is trivial
if (m <= 0 || n <= 0) {
if (m_real <= 0 || n <= 0) {
return;
}
NVTE_CHECK(k > 0);
NVTE_CHECK(k_real > 0);

// FP8 alignment: cuBLAS requires m%16==0, k%16==0 for FP8 GEMM.
// With sequence packing, token dims (m or k) may be unaligned.
// Pad to multiples of 16 BEFORE CanonicalizeGemmInput.
const bool is_fp8_a =
is_fp8_dtype(inputA->data.dtype) ||
(inputA->has_columnwise_data() && is_fp8_dtype(inputA->columnwise_data.dtype));
const bool is_fp8_b =
is_fp8_dtype(inputB->data.dtype) ||
(inputB->has_columnwise_data() && is_fp8_dtype(inputB->columnwise_data.dtype));
const bool need_fp8_pad = is_fp8_a || is_fp8_b;
const int m = need_fp8_pad ? ((m_real + 15) / 16) * 16 : m_real;
const int k = need_fp8_pad ? ((k_real + 15) / 16) * 16 : k_real;
const int ldd = m;

void *_pad_D = nullptr;
if (m != m_real && outputD->data.dptr) {
// Output needs padded buffer (m_padded rows instead of m_real)
const size_t d_elem = typeToSize(outputD->data.dtype);
cudaMallocAsync(&_pad_D, (size_t)m * n * d_elem, stream);
cudaMemsetAsync(_pad_D, 0, (size_t)m * n * d_elem, stream);
Comment on lines +354 to +355
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 cudaMallocAsync failures silently fall back to the unpadded buffer

All three cudaMallocAsync calls (for _pad_D, _pad_A, _pad_B) are unchecked. If a call fails—stream-ordered pool exhausted, allocation limit hit, etc.—the pointer stays nullptr and the code silently falls back to the original (unpadded) buffer while ldd = m (padded) has already been set. For _pad_D, this means cuBLAS writes column j at offset j × m_padded × sizeof(T) into a buffer sized for m_real rows, which is exactly the out-of-bounds corruption the PR was written to fix. Wrap all three allocations with NVTE_CHECK_CUDA:

NVTE_CHECK_CUDA(cudaMallocAsync(&_pad_D, (size_t)m * n * d_elem, stream));

Same pattern for _pad_A and _pad_B.

}

GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

// Safe k-padding: if k was padded, the extra rows in A and B (beyond k_real)
// contain garbage data which would corrupt ALL output values via dot product.
// (This happens in wgrad where k = token_count = unaligned.)
// Allocate padded copies with zeros for the extra rows.
// After CanonicalizeGemmInput on Hopper FP8 (TN layout):
// A: col-major [lda, m], lda = k (padded)
// B: col-major [ldb, n], ldb = k (padded)
// Original data has k_real contiguous elements per column.
void *_pad_A = nullptr;
void *_pad_B = nullptr;
if (k != k_real && param.A && param.B) {
const size_t a_elem = typeToSize(param.Atype);
const size_t b_elem = typeToSize(param.Btype);
// For TN: A is [k, m] col-major, B is [k, n] col-major
// For NN: A is [m, k] col-major (lda=m), B is [k, n] col-major (ldb=k)
// Determine number of columns for each matrix
const int a_cols = (param.transA == CUBLAS_OP_T) ? m : k;
const int b_cols = (param.transB == CUBLAS_OP_N) ? n : k;
// Leading dimension tells us row stride
const int a_lda = param.lda;
const int b_ldb = param.ldb;
// Original leading dimension before k-padding
// For TN: original lda was k_real (before we passed k_padded to Canonicalize)
// For NN: lda = m (not affected by k), ldb was k_real
const int a_orig_ld = (param.transA == CUBLAS_OP_T) ? k_real : a_lda;
const int b_orig_ld = (param.transB == CUBLAS_OP_N) ? k_real : b_ldb;

// Only pad A if its leading dimension involves k
if (a_lda != a_orig_ld) {
cudaMallocAsync(&_pad_A, (size_t)a_lda * a_cols * a_elem, stream);
cudaMemsetAsync(_pad_A, 0, (size_t)a_lda * a_cols * a_elem, stream);
cudaMemcpy2DAsync(_pad_A, (size_t)a_lda * a_elem, param.A, (size_t)a_orig_ld * a_elem,
(size_t)a_orig_ld * a_elem, a_cols, cudaMemcpyDeviceToDevice, stream);
param.A = _pad_A;
}

const GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);
// Only pad B if its leading dimension involves k
if (b_ldb != b_orig_ld) {
cudaMallocAsync(&_pad_B, (size_t)b_ldb * b_cols * b_elem, stream);
cudaMemsetAsync(_pad_B, 0, (size_t)b_ldb * b_cols * b_elem, stream);
cudaMemcpy2DAsync(_pad_B, (size_t)b_ldb * b_elem, param.B, (size_t)b_orig_ld * b_elem,
(size_t)b_orig_ld * b_elem, b_cols, cudaMemcpyDeviceToDevice, stream);
param.B = _pad_B;
}
}

void *C = outputD->data.dptr;
void *D = outputD->data.dptr;
void *C = _pad_D ? _pad_D : outputD->data.dptr;
void *D = _pad_D ? _pad_D : outputD->data.dptr;
void *D_scale = outputD->scale.dptr;
void *D_amax = outputD->amax.dptr;
void *bias_ptr = inputBias->data.dptr;
Expand Down Expand Up @@ -795,6 +857,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));

// FP8 alignment cleanup: copy padded output back, free padded buffers.
// Using stream-ordered cudaFreeAsync — no CPU-GPU sync, no pipeline bubbles,
// no competition with PyTorch's caching allocator.
if (_pad_D) {
const size_t d_elem = typeToSize(outputD->data.dtype);
// Column-major: output is [m, n], copy m_real rows from each column
cudaMemcpy2DAsync(outputD->data.dptr, (size_t)m_real * d_elem, _pad_D, (size_t)m * d_elem,
(size_t)m_real * d_elem, n, cudaMemcpyDeviceToDevice, stream);
cudaFreeAsync(_pad_D, stream);
}
if (_pad_A) {
cudaFreeAsync(_pad_A, stream);
}
if (_pad_B) {
cudaFreeAsync(_pad_B, stream);
}
}

} // namespace transformer_engine
Expand Down
14 changes: 5 additions & 9 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,16 +477,12 @@ def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool:


def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
"""Assert that tensor or tensors dimensions are supported for FP8 TN GEMM."""
"""Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.

for tensor in tensors:
if math.prod(tensor.shape[:-1]) % 8 != 0 or tensor.shape[-1] % 16 != 0:
raise ValueError(
"FP8 execution requires the product of all dimensions except the last to be"
" divisible by 8 and the last dimension to be divisible by 16, but got tensor"
f" with dims={list(tensor.size())} (product of leading dims ="
f" {math.prod(tensor.shape[:-1])}, last dim = {tensor.shape[-1]})"
)
NOTE: Relaxed — C++ cublas_gemm now handles alignment padding internally
for sequence packing with dynamic token counts.
"""
pass


def is_bf16_compatible() -> bool:
Expand Down