Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions tests/cpp/operator/test_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,142 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si
num_tensors * col_numel);
}

void performTestGroupedSwizzleMXFP8Variable(const std::vector<std::pair<size_t, size_t>>& shapes) {
using namespace transformer_engine;
using namespace test;

int num_tensors = shapes.size();
std::vector<std::unique_ptr<Tensor>> input_tensors;
std::vector<std::unique_ptr<Tensor>> output_tensors;
std::vector<Tensor*> input_ptrs;
std::vector<Tensor*> output_ptrs;
input_tensors.reserve(num_tensors);
output_tensors.reserve(num_tensors);
input_ptrs.reserve(num_tensors);
output_ptrs.reserve(num_tensors);

constexpr size_t BLOCK_SIZE = 32;
for (int i = 0; i < num_tensors; ++i) {
const std::vector<size_t> shape{shapes[i].first, shapes[i].second};
auto input = std::make_unique<Tensor>("input_" + std::to_string(i), shape,
DType::kFloat8E4M3, true, true,
NVTE_MXFP8_1D_SCALING);
auto output = std::make_unique<Tensor>("output_" + std::to_string(i), shape,
DType::kFloat8E4M3, true, true,
NVTE_MXFP8_1D_SCALING);
fillUniform(input.get());
fillUniform(output.get());

// Zero padding
input->to_cpu();
const NVTEShape rs = input->rowwise_scale_inv_shape();
zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr<uint8_t>(),
rs.data[0], rs.data[1],
shapes[i].first, (shapes[i].second + BLOCK_SIZE - 1) / BLOCK_SIZE);
const NVTEShape cs = input->columnwise_scale_inv_shape();
zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr<uint8_t>(),
cs.data[0], cs.data[1],
(shapes[i].first + BLOCK_SIZE - 1) / BLOCK_SIZE, shapes[i].second);
input->from_cpu();

input_ptrs.push_back(input.get());
output_ptrs.push_back(output.get());
input_tensors.emplace_back(std::move(input));
output_tensors.emplace_back(std::move(output));
}

GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING);

const uint8_t input_swizzled = 0;
nvte_set_grouped_tensor_param(grouped_input.get_handle(),
kNVTEGroupedWithGEMMSwizzledScales,
&input_swizzled, sizeof(input_swizzled));
const uint8_t output_swizzled = 1;
nvte_set_grouped_tensor_param(grouped_output.get_handle(),
kNVTEGroupedWithGEMMSwizzledScales,
&output_swizzled, sizeof(output_swizzled));

nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(),
grouped_output.get_handle(),
0);

cudaDeviceSynchronize();
NVTE_CHECK_CUDA(cudaGetLastError());

// Verification
size_t row_offset = 0;
size_t col_offset = 0;
for (int i = 0; i < num_tensors; ++i) {
const NVTEShape row_shape = input_tensors[i]->rowwise_scale_inv_shape();
const NVTEShape col_shape = input_tensors[i]->columnwise_scale_inv_shape();
const size_t row_numel = row_shape.data[0] * row_shape.data[1];
const size_t col_numel = col_shape.data[0] * col_shape.data[1];

std::vector<uint8_t> output_row_host(row_numel);
std::vector<uint8_t> output_col_host(col_numel);
NVTE_CHECK_CUDA(cudaMemcpy(output_row_host.data(),
static_cast<uint8_t*>(grouped_output.scale_inv.get()) + row_offset,
row_numel, cudaMemcpyDeviceToHost));
NVTE_CHECK_CUDA(cudaMemcpy(output_col_host.data(),
static_cast<uint8_t*>(grouped_output.columnwise_scale_inv.get()) + col_offset,
col_numel, cudaMemcpyDeviceToHost));

std::vector<uint8_t> ref_row(row_numel);
std::vector<uint8_t> ref_col(col_numel);
compute_ref_swizzle<128, 4, true>(input_tensors[i]->rowwise_cpu_scale_inv_ptr<uint8_t>(),
ref_row.data(),
row_shape.data[0], row_shape.data[1]);
compute_ref_swizzle<128, 4, false>(
input_tensors[i]->columnwise_cpu_scale_inv_ptr<uint8_t>(),
ref_col.data(),
col_shape.data[1], col_shape.data[0]);

compareResults("grouped_swizzle_variable_rowwise_" + std::to_string(i),
output_row_host.data(), ref_row.data(), row_numel);
compareResults("grouped_swizzle_variable_colwise_" + std::to_string(i),
output_col_host.data(), ref_col.data(), col_numel);

row_offset += row_numel;
col_offset += col_numel;
}
}

class SwizzleGroupedVariableTestSuite
: public ::testing::TestWithParam<std::vector<std::pair<size_t, size_t>>> {};

TEST_P(SwizzleGroupedVariableTestSuite, TestGroupedSwizzleMXFP8Variable) {
const auto shapes = GetParam();
performTestGroupedSwizzleMXFP8Variable(shapes);
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
SwizzleGroupedVariableTestSuite,
::testing::Values(
// Case 1: num_tensors = 1 (n+3 = 4, even). Check simple alignment.
std::vector<std::pair<size_t, size_t>>{{1024, 1024}},

// Case 2: num_tensors = 2 (n+3 = 5, odd). Forces padding logic to trigger.
std::vector<std::pair<size_t, size_t>>{{128, 128}, {256, 256}},

// Case 3: Mixed small/irregular shapes.
std::vector<std::pair<size_t, size_t>>{{200, 160}, {33, 64}, {1, 32}},

// Case 4: Large workload to verify persistent grid
std::vector<std::pair<size_t, size_t>>(10, {4096, 4096}),

// Case 5: Variable M, Uniform K (Semi-variable)
std::vector<std::pair<size_t, size_t>>{{128, 256}, {512, 256}, {64, 256}},

// Case 6: Uniform M, Variable K (Semi-variable)
std::vector<std::pair<size_t, size_t>>{{512, 128}, {512, 1024}, {512, 32}}
),
[](const testing::TestParamInfo<SwizzleGroupedVariableTestSuite::ParamType>& info) {
return "VariableShapes_" + std::to_string(info.index) + "_N" + std::to_string(info.param.size());
}
);

class SwizzleGroupedTestSuite
: public ::testing::TestWithParam<std::tuple<int, size_t, size_t>> {};

Expand Down
14 changes: 7 additions & 7 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const bool same_last = std::all_of(last_dims.begin(), last_dims.end(),
[&](int64_t v) { return v == last_dims[0]; });

std::vector<int64_t> offsets(num_tensors, 0);
std::vector<int64_t> offsets(num_tensors + 1, 0);
auto random_padding = [&]() -> int64_t {
// Random padding ensuring 16-byte alignment regardless of element size
// cuBLAS requires aligned pointers for vectorized loads
Expand All @@ -1118,12 +1118,11 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const bool need_offsets = !same_first || !same_last;
const bool use_random_padding = need_offsets && scaling_mode != NVTE_MXFP8_1D_SCALING;
if (need_offsets) {
offsets[0] = 0;
for (size_t i = 1; i < num_tensors; ++i) {
for (size_t i = 1; i < num_tensors + 1; ++i) {
offsets[i] = offsets[i - 1] + numel(i - 1) + (use_random_padding ? random_padding() : 0);
}
} else {
for (size_t i = 0; i < num_tensors; ++i) {
for (size_t i = 0; i < num_tensors + 1; ++i) {
offsets[i] = static_cast<int64_t>(i) * numel(0);
}
}
Expand Down Expand Up @@ -1211,10 +1210,11 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
}

if (!same_first || !same_last) {
grouped.offsets_dev = cuda_alloc<int64_t>(num_tensors * sizeof(int64_t));
size_t num_off = num_tensors + 1;
grouped.offsets_dev = cuda_alloc<int64_t>(num_off * sizeof(int64_t));
NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(),
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape off_shape = nvte_make_shape(&num_tensors, 1);
num_off * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape off_shape = nvte_make_shape(&num_off, 1);
NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor));
}
Expand Down
Loading
Loading