diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 1ea82f19cd..8ba2cd4573 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -506,6 +506,142 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si num_tensors * col_numel); } +void performTestGroupedSwizzleMXFP8Variable(const std::vector>& shapes) { + using namespace transformer_engine; + using namespace test; + + int num_tensors = shapes.size(); + std::vector> input_tensors; + std::vector> output_tensors; + std::vector input_ptrs; + std::vector output_ptrs; + input_tensors.reserve(num_tensors); + output_tensors.reserve(num_tensors); + input_ptrs.reserve(num_tensors); + output_ptrs.reserve(num_tensors); + + constexpr size_t BLOCK_SIZE = 32; + for (int i = 0; i < num_tensors; ++i) { + const std::vector shape{shapes[i].first, shapes[i].second}; + auto input = std::make_unique("input_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + auto output = std::make_unique("output_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + fillUniform(input.get()); + fillUniform(output.get()); + + // Zero padding + input->to_cpu(); + const NVTEShape rs = input->rowwise_scale_inv_shape(); + zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), + rs.data[0], rs.data[1], + shapes[i].first, (shapes[i].second + BLOCK_SIZE - 1) / BLOCK_SIZE); + const NVTEShape cs = input->columnwise_scale_inv_shape(); + zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), + cs.data[0], cs.data[1], + (shapes[i].first + BLOCK_SIZE - 1) / BLOCK_SIZE, shapes[i].second); + input->from_cpu(); + + input_ptrs.push_back(input.get()); + output_ptrs.push_back(output.get()); + input_tensors.emplace_back(std::move(input)); + output_tensors.emplace_back(std::move(output)); + } + + GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING); + + const uint8_t input_swizzled = 0; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &input_swizzled, sizeof(input_swizzled)); + const uint8_t output_swizzled = 1; + nvte_set_grouped_tensor_param(grouped_output.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &output_swizzled, sizeof(output_swizzled)); + + nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(), + grouped_output.get_handle(), + 0); + + cudaDeviceSynchronize(); + NVTE_CHECK_CUDA(cudaGetLastError()); + + // Verification + size_t row_offset = 0; + size_t col_offset = 0; + for (int i = 0; i < num_tensors; ++i) { + const NVTEShape row_shape = input_tensors[i]->rowwise_scale_inv_shape(); + const NVTEShape col_shape = input_tensors[i]->columnwise_scale_inv_shape(); + const size_t row_numel = row_shape.data[0] * row_shape.data[1]; + const size_t col_numel = col_shape.data[0] * col_shape.data[1]; + + std::vector output_row_host(row_numel); + std::vector output_col_host(col_numel); + NVTE_CHECK_CUDA(cudaMemcpy(output_row_host.data(), + static_cast(grouped_output.scale_inv.get()) + row_offset, + row_numel, cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(output_col_host.data(), + static_cast(grouped_output.columnwise_scale_inv.get()) + col_offset, + col_numel, cudaMemcpyDeviceToHost)); + + std::vector ref_row(row_numel); + std::vector ref_col(col_numel); + compute_ref_swizzle<128, 4, true>(input_tensors[i]->rowwise_cpu_scale_inv_ptr(), + ref_row.data(), + row_shape.data[0], row_shape.data[1]); + compute_ref_swizzle<128, 4, false>( + input_tensors[i]->columnwise_cpu_scale_inv_ptr(), + ref_col.data(), + col_shape.data[1], col_shape.data[0]); + + compareResults("grouped_swizzle_variable_rowwise_" + std::to_string(i), + output_row_host.data(), ref_row.data(), row_numel); + compareResults("grouped_swizzle_variable_colwise_" + std::to_string(i), + output_col_host.data(), ref_col.data(), col_numel); + + row_offset += row_numel; + col_offset += col_numel; + } +} + +class SwizzleGroupedVariableTestSuite + : public ::testing::TestWithParam>> {}; + +TEST_P(SwizzleGroupedVariableTestSuite, TestGroupedSwizzleMXFP8Variable) { + const auto shapes = GetParam(); + performTestGroupedSwizzleMXFP8Variable(shapes); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleGroupedVariableTestSuite, + ::testing::Values( + // Case 1: num_tensors = 1 (n+3 = 4, even). Check simple alignment. + std::vector>{{1024, 1024}}, + + // Case 2: num_tensors = 2 (n+3 = 5, odd). Forces padding logic to trigger. + std::vector>{{128, 128}, {256, 256}}, + + // Case 3: Mixed small/irregular shapes. + std::vector>{{200, 160}, {33, 64}, {1, 32}}, + + // Case 4: Large workload to verify persistent grid + std::vector>(10, {4096, 4096}), + + // Case 5: Variable M, Uniform K (Semi-variable) + std::vector>{{128, 256}, {512, 256}, {64, 256}}, + + // Case 6: Uniform M, Variable K (Semi-variable) + std::vector>{{512, 128}, {512, 1024}, {512, 32}} + ), + [](const testing::TestParamInfo& info) { + return "VariableShapes_" + std::to_string(info.index) + "_N" + std::to_string(info.param.size()); + } +); + class SwizzleGroupedTestSuite : public ::testing::TestWithParam> {}; diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 5196684118..b8bc38935f 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1099,7 +1099,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, const bool same_last = std::all_of(last_dims.begin(), last_dims.end(), [&](int64_t v) { return v == last_dims[0]; }); - std::vector offsets(num_tensors, 0); + std::vector offsets(num_tensors + 1, 0); auto random_padding = [&]() -> int64_t { // Random padding ensuring 16-byte alignment regardless of element size // cuBLAS requires aligned pointers for vectorized loads @@ -1118,12 +1118,11 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, const bool need_offsets = !same_first || !same_last; const bool use_random_padding = need_offsets && scaling_mode != NVTE_MXFP8_1D_SCALING; if (need_offsets) { - offsets[0] = 0; - for (size_t i = 1; i < num_tensors; ++i) { + for (size_t i = 1; i < num_tensors + 1; ++i) { offsets[i] = offsets[i - 1] + numel(i - 1) + (use_random_padding ? random_padding() : 0); } } else { - for (size_t i = 0; i < num_tensors; ++i) { + for (size_t i = 0; i < num_tensors + 1; ++i) { offsets[i] = static_cast(i) * numel(0); } } @@ -1211,10 +1210,11 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, } if (!same_first || !same_last) { - grouped.offsets_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + size_t num_off = num_tensors + 1; + grouped.offsets_dev = cuda_alloc(num_off * sizeof(int64_t)); NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(), - num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); - NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); + num_off * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape off_shape = nvte_make_shape(&num_off, 1); NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor)); } diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index de4fdbb040..8619d91220 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -8,10 +8,13 @@ #include #include +#include #include #include +#include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/logging.h" #include "transformer_engine/transformer_engine.h" @@ -1882,6 +1885,151 @@ void nvte_multi_tensor_unswizzle_scaling_factors(const NVTETensor* inputs, NVTET namespace transformer_engine { +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + grouped_swizzle_scaling_variable_shape_kernel(const void* input, void* output, + const int64_t* m_array, const int64_t* k_array, + int num_tensors, bool rowwise, + size_t scale_elem_size, size_t common_m, + size_t common_k) { + extern __shared__ int s_metadata[]; + int* s_total_blocks = &s_metadata[0]; + + // Warp reduction to compute total workload + if (threadIdx.x < 32 && threadIdx.y == 0) { + int local_blocks = 0; + for (int i = threadIdx.x; i < num_tensors; i += 32) { + size_t m = rowwise ? (m_array ? m_array[i] : common_m) : (k_array ? k_array[i] : common_k); + size_t k = rowwise ? (k_array ? k_array[i] : common_k) : (m_array ? m_array[i] : common_m); + + size_t padded_m = round_up_to_multiple(m, 128); + size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / SF_TILE_DIM_M; + int num_tiles_k = padded_k / SF_TILE_DIM_K; + + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + + int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); + int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); + local_blocks += grid_dim_x * grid_dim_y; + } + + for (int offset = 16; offset > 0; offset /= 2) { + local_blocks += __shfl_down_sync(0xffffffff, local_blocks, offset); + } + if (threadIdx.x == 0) *s_total_blocks = local_blocks; + } + __syncthreads(); + + const int total_blocks = *s_total_blocks; + + // Persistent-grid loop + for (int linear_block_id = blockIdx.x; linear_block_id < total_blocks; + linear_block_id += gridDim.x) { + // Discover tensor_id and local_block_id via linear scan + int tensor_id = 0; + int current_block_base = 0; + size_t current_scale_base = 0; + int grid_dim_x = 0; + int grid_dim_y = 0; + size_t M = 0, K = 0; + int vec_load_size = 0; + + for (int i = 0; i < num_tensors; ++i) { + M = rowwise ? (m_array ? m_array[i] : common_m) : (k_array ? k_array[i] : common_k); + K = rowwise ? (k_array ? k_array[i] : common_k) : (m_array ? m_array[i] : common_m); + + size_t padded_m = round_up_to_multiple(M, 128); + size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / SF_TILE_DIM_M; + int num_tiles_k = padded_k / SF_TILE_DIM_K; + + vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + + grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); + grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); + int blocks_i = grid_dim_x * grid_dim_y; + + if (linear_block_id < current_block_base + blocks_i) { + tensor_id = i; + break; + } + current_block_base += blocks_i; + current_scale_base += padded_m * padded_k * scale_elem_size; + } + + int local_block_id = linear_block_id - current_block_base; + int block_x = local_block_id % grid_dim_x; + int block_y = local_block_id / grid_dim_x; + + const uint8_t* input_base = reinterpret_cast(input) + current_scale_base; + uint8_t* output_base = reinterpret_cast(output) + current_scale_base; + + size_t padded_m = round_up_to_multiple(M, 128); + size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); + int original_M = static_cast(M); + int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); + + if (rowwise) { + if (vec_load_size == 4) { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); + } else if (vec_load_size == 2) { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); + } else { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); + } + } else { + if (vec_load_size == 4) { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); + } else if (vec_load_size == 2) { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); + } else { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); + } + } + } +} + +template +int grouped_swizzle_variable_max_active_blocks_per_sm(int device_id) { + static std::vector cache(cuda::num_devices(), -1); + static std::vector flags(cuda::num_devices()); + NVTE_CHECK(0 <= device_id && device_id < cuda::num_devices(), "invalid CUDA device ID"); + + auto init = [&]() { + constexpr int metadata_shmem = sizeof(int); // s_total_blocks + constexpr int dynamic_smem_size = + TB_DIM * 4 * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t) + metadata_shmem; + int max_active_blocks_per_sm; + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_per_sm, + grouped_swizzle_scaling_variable_shape_kernel, + TB_DIM * TB_DIM, dynamic_smem_size)); + NVTE_CHECK(max_active_blocks_per_sm > 0, "Occupancy query returned 0 blocks per SM."); + cache[device_id] = max_active_blocks_per_sm; + }; + std::call_once(flags[device_id], init); + return cache[device_id]; +} + void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, cudaStream_t stream) { // Check scaling mode @@ -1903,138 +2051,190 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* return; } - // Only support uniform shapes for graph-safe grouped swizzle - NVTE_CHECK(input->all_same_shape(), "Grouped swizzle requires uniform tensor shapes."); - NVTE_CHECK(input->all_same_last_dim() && input->all_same_first_dim(), - "Grouped swizzle requires uniform tensor shapes."); - - // Assumption is that all the tensors share the same shapes and are contgiuous. - // And so we dont need to pass array of input/output pointers(due to conttiguity) - // as well as array of shapes(due to uniform shapes). - const size_t first_dim = input->get_common_first_dim(); - const size_t last_dim = input->get_common_last_dim(); - - constexpr int SF_TILE_DIM_M = 128; - constexpr int SF_TILE_DIM_K = 4; - const dim3 block_size(TB_DIM, TB_DIM); + const int64_t* m_array = reinterpret_cast(input->first_dims.dptr); + const int64_t* k_array = reinterpret_cast(input->last_dims.dptr); + const bool is_variable_shape = !input->all_same_shape(); + + if (!is_variable_shape) { + // Fallback to uniform shape implementation + // Assumption is that all the tensors share the same shapes and are contgiuous. + // And so we dont need to pass array of input/output pointers(due to conttiguity) + // as well as array of shapes(due to uniform shapes). + const size_t first_dim = input->get_common_first_dim(); + const size_t last_dim = input->get_common_last_dim(); + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + const dim3 block_size(TB_DIM, TB_DIM); + + auto launch_grouped_swizzle = [&](bool rowwise) { + const size_t m = rowwise ? first_dim : last_dim; + const size_t k = rowwise ? last_dim : first_dim; + const size_t padded_m = round_up_to_multiple(m, 128); + const size_t padded_k = + round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); + const size_t scale_elems = padded_m * padded_k; + + const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) + : typeToSize(input->columnwise_scale_inv.dtype); + const size_t scale_stride_bytes = scale_elems * scale_elem_size; + + if (rowwise) { + NVTE_CHECK(input->scale_inv.numel() == input->num_tensors * scale_elems, + "Grouped input scale_inv size does not match expected packed size."); + NVTE_CHECK(output->scale_inv.numel() == output->num_tensors * scale_elems, + "Grouped output scale_inv size does not match expected packed size."); + } else { + NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, + "Grouped input columnwise_scale_inv size does not match expected packed size."); + NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, + "Grouped output columnwise_scale_inv size does not match expected packed size."); + } - auto launch_grouped_swizzle = [&](bool rowwise) { - const size_t m = rowwise ? first_dim : last_dim; - const size_t k = rowwise ? last_dim : first_dim; - const size_t padded_m = round_up_to_multiple(m, 128); - const size_t padded_k = - round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); - const size_t scale_elems = padded_m * padded_k; + const int num_tiles_m = padded_m / SF_TILE_DIM_M; + const int num_tiles_k = padded_k / SF_TILE_DIM_K; + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + const int n_tiles_in_tb = TB_DIM * vec_load_size; - const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) - : typeToSize(input->columnwise_scale_inv.dtype); - const size_t scale_stride_bytes = scale_elems * scale_elem_size; + dim3 num_blocks; + if (rowwise) { + num_blocks = dim3(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m, input->num_tensors); + } else { + num_blocks = + dim3(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size), input->num_tensors); + } + const int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + const int original_M = static_cast(rowwise ? first_dim : last_dim); + const int original_K = static_cast(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE))); + const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; + void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; + + if (rowwise) { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + } + } else { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + } + } + NVTE_CHECK_CUDA(cudaGetLastError()); + }; - if (rowwise) { - NVTE_CHECK(input->scale_inv.numel() == input->num_tensors * scale_elems, - "Grouped input scale_inv size does not match expected packed size."); - NVTE_CHECK(output->scale_inv.numel() == output->num_tensors * scale_elems, - "Grouped output scale_inv size does not match expected packed size."); - } else { - NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, - "Grouped input columnwise_scale_inv size does not match expected packed size."); - NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, - "Grouped output columnwise_scale_inv size does not match expected packed size."); + if (has_rowwise_scale_inv) { + launch_grouped_swizzle(true); } - - const int num_tiles_m = padded_m / SF_TILE_DIM_M; - const int num_tiles_k = padded_k / SF_TILE_DIM_K; - int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); - if (vec_load_size == 3) vec_load_size = 1; - const int n_tiles_in_tb = TB_DIM * vec_load_size; - - dim3 num_blocks; - if (rowwise) { - num_blocks = dim3(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m, input->num_tensors); - } else { - num_blocks = - dim3(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size), input->num_tensors); + if (has_columnwise_scale_inv) { + launch_grouped_swizzle(false); } - const int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - - const int original_M = static_cast(rowwise ? first_dim : last_dim); - const int original_K = static_cast(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE))); - const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; - void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; - - if (rowwise) { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - } - } else { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - } + } else { + // Variable shape implementation using Device-Side Block Scheduler + size_t num_tensors = input->num_tensors; + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + const dim3 block_size(TB_DIM, TB_DIM); + const int max_slm_size = TB_DIM * 4 * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int metadata_shmem = sizeof(int); // s_total_blocks + const int dynamic_smem_size = max_slm_size + metadata_shmem; + + size_t common_m = input->all_same_first_dim() ? input->get_common_first_dim() : 0; + size_t common_k = input->all_same_last_dim() ? input->get_common_last_dim() : 0; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_scaling_variable_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_size)); + + const int device_id = cuda::current_device(); + const int num_SMs = cuda::sm_count(device_id); + const int max_active_blocks_per_sm = + grouped_swizzle_variable_max_active_blocks_per_sm(device_id); + const int persistent_blocks = num_SMs * max_active_blocks_per_sm; + const dim3 num_blocks(persistent_blocks); + + auto launch_grouped_swizzle_variable = [&](bool rowwise) { + const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) + : typeToSize(input->columnwise_scale_inv.dtype); + + const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; + void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; + + grouped_swizzle_scaling_variable_shape_kernel + <<>>( + input_ptr, output_ptr, m_array, k_array, num_tensors, rowwise, scale_elem_size, + common_m, common_k); + + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + if (has_rowwise_scale_inv) { + launch_grouped_swizzle_variable(true); + } + if (has_columnwise_scale_inv) { + launch_grouped_swizzle_variable(false); } - NVTE_CHECK_CUDA(cudaGetLastError()); - }; - - if (has_rowwise_scale_inv) { - launch_grouped_swizzle(true); - } - if (has_columnwise_scale_inv) { - launch_grouped_swizzle(false); } } diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index cbaabaad17..7a198ca70b 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -379,13 +379,6 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW if (!swizzle_rowwise && !swizzle_columnwise) { return std::nullopt; } - const auto first_dims = input.get_first_dims(); - const auto last_dims = input.get_last_dims(); - if (first_dims.data_ptr != nullptr || last_dims.data_ptr != nullptr) { - NVTE_ERROR( - "Grouped GEMM swizzle requires uniform shapes for now (first_dims/last_dims must be " - "absent)."); - } std::optional rowwise_scales_pyt; std::optional columnwise_scales_pyt; @@ -427,6 +420,7 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW } swizzle_output.set_with_gemm_swizzled_scales(true); + NVTE_SCOPED_GIL_RELEASE({ nvte_swizzle_grouped_scaling_factors(swizzle_input.data(), swizzle_output.data(), at::cuda::getCurrentCUDAStream());