diff --git a/domain_tests/BUILD b/domain_tests/BUILD index 0dfd652b..e2b7215d 100644 --- a/domain_tests/BUILD +++ b/domain_tests/BUILD @@ -46,11 +46,13 @@ cc_test( deps = [ ":domain_testing", "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random", "@abseil-cpp//absl/status", "@com_google_fuzztest//fuzztest:domain", "@com_google_fuzztest//fuzztest:flatbuffers", "@com_google_fuzztest//fuzztest/internal:meta", + "@com_google_fuzztest//fuzztest/internal:serialization", "@com_google_fuzztest//fuzztest/internal:test_flatbuffers_cc_fbs", "@flatbuffers//:runtime_cc", "@googletest//:gtest_main", diff --git a/domain_tests/arbitrary_domains_flatbuffers_test.cc b/domain_tests/arbitrary_domains_flatbuffers_test.cc index 508ad269..5bb96570 100644 --- a/domain_tests/arbitrary_domains_flatbuffers_test.cc +++ b/domain_tests/arbitrary_domains_flatbuffers_test.cc @@ -19,11 +19,14 @@ #include #include #include +#include +#include #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/random/random.h" #include "absl/status/status.h" #include "flatbuffers/base.h" @@ -32,21 +35,73 @@ #include "flatbuffers/reflection_generated.h" #include "flatbuffers/string.h" #include "flatbuffers/vector.h" +#include "flatbuffers/verifier.h" #include "./fuzztest/domain.h" #include "./domain_tests/domain_testing.h" #include "./fuzztest/flatbuffers.h" #include "./fuzztest/internal/meta.h" +#include "./fuzztest/internal/serialization.h" +#include "./fuzztest/internal/test_flatbuffers_64bits_generated.h" #include "./fuzztest/internal/test_flatbuffers_generated.h" namespace fuzztest { +namespace internal { +template +void AbslStringify(Sink& sink, const Union& e) { + absl::Format(&sink, "%s", EnumNameUnion(e)); +} +template +void AbslStringify(Sink& sink, const ByteEnum& e) { + absl::Format(&sink, "%s", EnumNameByteEnum(e)); +} +template +void AbslStringify(Sink& sink, const ShortEnum& e) { + absl::Format(&sink, "%s", EnumNameShortEnum(e)); +} +template +void AbslStringify(Sink& sink, const IntEnum& e) { + absl::Format(&sink, "%s", EnumNameIntEnum(e)); +} +template +void AbslStringify(Sink& sink, const LongEnum& e) { + absl::Format(&sink, "%s", EnumNameLongEnum(e)); +} +template +void AbslStringify(Sink& sink, const UByteEnum& e) { + absl::Format(&sink, "%s", EnumNameUByteEnum(e)); +} +template +void AbslStringify(Sink& sink, const UShortEnum& e) { + absl::Format(&sink, "%s", EnumNameUShortEnum(e)); +} +template +void AbslStringify(Sink& sink, const UIntEnum& e) { + absl::Format(&sink, "%s", EnumNameUIntEnum(e)); +} +template +void AbslStringify(Sink& sink, const ULongEnum& e) { + absl::Format(&sink, "%s", EnumNameULongEnum(e)); +} +} // namespace internal namespace { using ::fuzztest::internal::BoolTable; +using ::fuzztest::internal::ByteEnum; using ::fuzztest::internal::DefaultTable; +using ::fuzztest::internal::DefaultTable64; +using ::fuzztest::internal::IntEnum; +using ::fuzztest::internal::LongEnum; using ::fuzztest::internal::OptionalTable; using ::fuzztest::internal::RecursiveTable; using ::fuzztest::internal::RequiredTable; +using ::fuzztest::internal::ShortEnum; +using ::fuzztest::internal::StringTable; +using ::fuzztest::internal::UByteEnum; +using ::fuzztest::internal::UIntEnum; +using ::fuzztest::internal::ULongEnum; +using ::fuzztest::internal::UnionTable; using ::fuzztest::internal::UnsupportedTypesTable; +using ::fuzztest::internal::UShortEnum; using ::testing::_; using ::testing::AllOf; using ::testing::Each; @@ -82,6 +137,75 @@ inline bool Eq(const BoolTable& lhs, const BoolTable& rhs) { return lhs.b() == rhs.b(); } +template +inline bool Eq(const flatbuffers::Vector& lhs, + const flatbuffers::Vector& rhs) { + if (lhs.size() != rhs.size()) return false; + for (int i = 0; i < lhs.size(); ++i) { + if (!Eq(lhs.Get(i), rhs.Get(i))) return false; + } + return true; +} + +template <> +inline bool Eq(const StringTable& lhs, const StringTable& rhs) { + return Eq(lhs.str(), rhs.str()); +} + +template <> +inline bool Eq>( + const std::pair& lhs, + const std::pair& rhs) { + if (lhs.first == internal::Union_NONE && rhs.first == internal::Union_NONE) { + return true; + } + if (lhs.first != rhs.first) return false; + + switch (lhs.first) { + case internal::Union_BoolTable: + return Eq(static_cast(lhs.second), + static_cast(rhs.second)); + case internal::Union_StringTable: + return Eq(static_cast(lhs.second), + static_cast(rhs.second)); + default: + CHECK(false) << "Unsupported union type"; + } +} + +template +inline bool Eq(const flatbuffers::Vector* lhs, + const flatbuffers::Vector* rhs) { + if (lhs == nullptr && rhs == nullptr) return true; + if (lhs == nullptr || rhs == nullptr) return false; + return Eq(*lhs, *rhs); +} + +template <> +inline bool +Eq*, + const flatbuffers::Vector<::flatbuffers::Offset>*>>( + const std::pair*, + const flatbuffers::Vector<::flatbuffers::Offset>*>& + lhs, + const std::pair*, + const flatbuffers::Vector<::flatbuffers::Offset>*>& + rhs) { + if (!Eq(lhs.first, rhs.first)) return false; + if (lhs.second == nullptr && rhs.second == nullptr) return true; + if (lhs.second == nullptr || rhs.second == nullptr) return false; + if (lhs.first->size() != lhs.second->size()) return false; + if (lhs.second->size() != rhs.second->size()) return false; + + for (int i = 0; i < lhs.second->size(); ++i) { + if (!Eq(std::pair(lhs.first->Get(i), lhs.second->Get(i)), + std::pair(rhs.first->Get(i), rhs.second->Get(i)))) { + return false; + } + } + return true; +} + template <> inline bool Eq(const DefaultTable& lhs, const DefaultTable& rhs) { const bool eq_b = lhs.b() == rhs.b(); @@ -105,14 +229,86 @@ inline bool Eq(const DefaultTable& lhs, const DefaultTable& rhs) { const bool eq_eu32 = lhs.eu32() == rhs.eu32(); const bool eq_eu64 = lhs.eu64() == rhs.eu64(); const bool eq_t = Eq(lhs.t(), rhs.t()); + const bool eq_u = Eq(std::pair(static_cast(lhs.u_type()), lhs.u()), + std::pair(static_cast(rhs.u_type()), rhs.u())); + const bool eq_v_b = Eq(lhs.v_b(), rhs.v_b()); + const bool eq_v_i8 = Eq(lhs.v_i8(), rhs.v_i8()); + const bool eq_v_i16 = Eq(lhs.v_i16(), rhs.v_i16()); + const bool eq_v_i32 = Eq(lhs.v_i32(), rhs.v_i32()); + const bool eq_v_i64 = Eq(lhs.v_i64(), rhs.v_i64()); + const bool eq_v_u8 = Eq(lhs.v_u8(), rhs.v_u8()); + const bool eq_v_u16 = Eq(lhs.v_u16(), rhs.v_u16()); + const bool eq_v_u32 = Eq(lhs.v_u32(), rhs.v_u32()); + const bool eq_v_u64 = Eq(lhs.v_u64(), rhs.v_u64()); + const bool eq_v_f = Eq(lhs.v_f(), rhs.v_f()); + const bool eq_v_d = Eq(lhs.v_d(), rhs.v_d()); + const bool eq_v_str = Eq(lhs.v_str(), rhs.v_str()); + const bool eq_v_ei8 = Eq(lhs.v_ei8(), rhs.v_ei8()); + const bool eq_v_ei16 = Eq(lhs.v_ei16(), rhs.v_ei16()); + const bool eq_v_ei32 = Eq(lhs.v_ei32(), rhs.v_ei32()); + const bool eq_v_ei64 = Eq(lhs.v_ei64(), rhs.v_ei64()); + const bool eq_v_eu8 = Eq(lhs.v_eu8(), rhs.v_eu8()); + const bool eq_v_eu16 = Eq(lhs.v_eu16(), rhs.v_eu16()); + const bool eq_v_eu32 = Eq(lhs.v_eu32(), rhs.v_eu32()); + const bool eq_v_eu64 = Eq(lhs.v_eu64(), rhs.v_eu64()); + const bool eq_v_t = Eq(lhs.v_t(), rhs.v_t()); + const bool eq_v_u_type = Eq(lhs.v_u_type(), rhs.v_u_type()); + const bool eq_v_u = Eq(std::make_pair(lhs.v_u_type(), lhs.v_u()), + std::make_pair(rhs.v_u_type(), rhs.v_u())); return eq_b && eq_i8 && eq_i16 && eq_i32 && eq_i64 && eq_u8 && eq_u16 && eq_u32 && eq_u64 && eq_f && eq_d && eq_str && eq_ei8 && eq_ei16 && - eq_ei32 && eq_ei64 && eq_eu8 && eq_eu16 && eq_eu32 && eq_eu64 && eq_t; + eq_ei32 && eq_ei64 && eq_eu8 && eq_eu16 && eq_eu32 && eq_eu64 && + eq_u && eq_t && eq_v_b && eq_v_i8 && eq_v_i16 && eq_v_i32 && + eq_v_i64 && eq_v_u8 && eq_v_u16 && eq_v_u32 && eq_v_u64 && eq_v_f && + eq_v_d && eq_v_str && eq_v_ei8 && eq_v_ei16 && eq_v_ei32 && + eq_v_ei64 && eq_v_eu8 && eq_v_eu16 && eq_v_eu32 && eq_v_eu64 && + eq_v_t && eq_v_u_type && eq_v_u; } const internal::DefaultTable* CreateDefaultTable( flatbuffers::FlatBufferBuilder& fbb) { auto bool_table_offset = internal::CreateBoolTable(fbb, true); + auto string_table_offset = + internal::CreateStringTableDirect(fbb, "foo bar baz"); + std::vector v_b{true, false}; + std::vector v_i8{1, 2, 3}; + std::vector v_i16{1, 2, 3}; + std::vector v_i32{1, 2, 3}; + std::vector v_i64{1, 2, 3}; + std::vector v_u8{1, 2, 3}; + std::vector v_u16{1, 2, 3}; + std::vector v_u32{1, 2, 3}; + std::vector v_u64{1, 2, 3}; + std::vector v_f{1, 2, 3}; + std::vector v_d{1, 2, 3}; + std::vector> v_str{ + fbb.CreateString("foo"), fbb.CreateString("bar"), + fbb.CreateString("baz")}; + std::vector> v_ei8{ + internal::ByteEnum_First, internal::ByteEnum_Second}; + std::vector> v_ei16{ + internal::ShortEnum_First, internal::ShortEnum_Second}; + std::vector> v_ei32{internal::IntEnum_First, + internal::IntEnum_Second}; + std::vector> v_ei64{ + internal::LongEnum_First, internal::LongEnum_Second}; + std::vector> v_eu8{ + internal::UByteEnum_First, internal::UByteEnum_Second}; + std::vector> v_eu16{ + internal::UShortEnum_First, internal::UShortEnum_Second}; + std::vector> v_eu32{ + internal::UIntEnum_First, internal::UIntEnum_Second}; + std::vector> v_eu64{ + internal::ULongEnum_First, internal::ULongEnum_Second}; + std::vector> v_t{bool_table_offset}; + std::vector> v_u_type{ + internal::Union_BoolTable, + internal::Union_StringTable, + }; + std::vector> v_u{ + bool_table_offset.Union(), + string_table_offset.Union(), + }; auto table_offset = internal::CreateDefaultTableDirect(fbb, /*b=*/true, @@ -135,7 +331,32 @@ const internal::DefaultTable* CreateDefaultTable( /*eu16=*/internal::UShortEnum_Second, /*eu32=*/internal::UIntEnum_Second, /*eu64=*/internal::ULongEnum_Second, - /*t=*/bool_table_offset); + /*t=*/bool_table_offset, + /*u_type=*/internal::Union_BoolTable, + /*u=*/bool_table_offset.Union(), + /*v_b=*/&v_b, + /*v_i8=*/&v_i8, + /*v_i16=*/&v_i16, + /*v_i32=*/&v_i32, + /*v_i64=*/&v_i64, + /*v_u8=*/&v_u8, + /*v_u16=*/&v_u16, + /*v_u32=*/&v_u32, + /*v_u64=*/&v_u64, + /*v_f=*/&v_f, + /*v_d=*/&v_d, + /*v_str=*/&v_str, + /*v_ei8=*/&v_ei8, + /*v_ei16=*/&v_ei16, + /*v_ei32=*/&v_ei32, + /*v_ei64=*/&v_ei64, + /*v_eu8=*/&v_eu8, + /*v_eu16=*/&v_eu16, + /*v_eu32=*/&v_eu32, + /*v_eu64=*/&v_eu64, + /*v_t=*/&v_t, + /*v_u_type=*/&v_u_type, + /*v_u=*/&v_u); fbb.Finish(table_offset); return flatbuffers::GetRoot(fbb.GetBufferPointer()); } @@ -261,6 +482,7 @@ TEST(FlatbuffersTableDomainImplTest, DefaultTableValueRoundTrip) { EXPECT_EQ(new_table->u64(), 8); EXPECT_EQ(new_table->f(), 9.0); EXPECT_EQ(new_table->d(), 10.0); + ASSERT_THAT(new_table->str(), NotNull()); EXPECT_EQ(new_table->str()->str(), "foo bar baz"); EXPECT_EQ(new_table->ei8(), internal::ByteEnum_Second); EXPECT_EQ(new_table->ei16(), internal::ShortEnum_Second); @@ -272,6 +494,114 @@ TEST(FlatbuffersTableDomainImplTest, DefaultTableValueRoundTrip) { EXPECT_EQ(new_table->eu64(), internal::ULongEnum_Second); ASSERT_THAT(new_table->t(), NotNull()); EXPECT_EQ(new_table->t()->b(), true); + EXPECT_EQ(new_table->u_type(), internal::Union_BoolTable); + ASSERT_THAT(new_table->u(), NotNull()); + EXPECT_EQ(new_table->u_as_BoolTable()->b(), true); + ASSERT_THAT(new_table->v_b(), NotNull()); + EXPECT_EQ(new_table->v_b()->size(), 2); + EXPECT_EQ(new_table->v_b()->Get(0), true); + EXPECT_EQ(new_table->v_b()->Get(1), false); + ASSERT_THAT(new_table->v_i8(), NotNull()); + EXPECT_EQ(new_table->v_i8()->size(), 3); + EXPECT_EQ(new_table->v_i8()->Get(0), 1); + EXPECT_EQ(new_table->v_i8()->Get(1), 2); + EXPECT_EQ(new_table->v_i8()->Get(2), 3); + ASSERT_THAT(new_table->v_i16(), NotNull()); + EXPECT_EQ(new_table->v_i16()->size(), 3); + EXPECT_EQ(new_table->v_i16()->Get(0), 1); + EXPECT_EQ(new_table->v_i16()->Get(1), 2); + EXPECT_EQ(new_table->v_i16()->Get(2), 3); + ASSERT_THAT(new_table->v_i32(), NotNull()); + EXPECT_EQ(new_table->v_i32()->size(), 3); + EXPECT_EQ(new_table->v_i32()->Get(0), 1); + EXPECT_EQ(new_table->v_i32()->Get(1), 2); + EXPECT_EQ(new_table->v_i32()->Get(2), 3); + ASSERT_THAT(new_table->v_i64(), NotNull()); + EXPECT_EQ(new_table->v_i64()->size(), 3); + EXPECT_EQ(new_table->v_i64()->Get(0), 1); + EXPECT_EQ(new_table->v_i64()->Get(1), 2); + EXPECT_EQ(new_table->v_i64()->Get(2), 3); + ASSERT_THAT(new_table->v_u8(), NotNull()); + EXPECT_EQ(new_table->v_u8()->size(), 3); + EXPECT_EQ(new_table->v_u8()->Get(0), 1); + EXPECT_EQ(new_table->v_u8()->Get(1), 2); + EXPECT_EQ(new_table->v_u8()->Get(2), 3); + ASSERT_THAT(new_table->v_u16(), NotNull()); + EXPECT_EQ(new_table->v_u16()->size(), 3); + EXPECT_EQ(new_table->v_u16()->Get(0), 1); + EXPECT_EQ(new_table->v_u16()->Get(1), 2); + EXPECT_EQ(new_table->v_u16()->Get(2), 3); + ASSERT_THAT(new_table->v_u32(), NotNull()); + EXPECT_EQ(new_table->v_u32()->size(), 3); + EXPECT_EQ(new_table->v_u32()->Get(0), 1); + EXPECT_EQ(new_table->v_u32()->Get(1), 2); + EXPECT_EQ(new_table->v_u32()->Get(2), 3); + ASSERT_THAT(new_table->v_u64(), NotNull()); + EXPECT_EQ(new_table->v_u64()->size(), 3); + EXPECT_EQ(new_table->v_u64()->Get(0), 1); + EXPECT_EQ(new_table->v_u64()->Get(1), 2); + EXPECT_EQ(new_table->v_u64()->Get(2), 3); + ASSERT_THAT(new_table->v_f(), NotNull()); + EXPECT_EQ(new_table->v_f()->size(), 3); + EXPECT_EQ(new_table->v_f()->Get(0), 1); + EXPECT_EQ(new_table->v_f()->Get(1), 2); + EXPECT_EQ(new_table->v_f()->Get(2), 3); + ASSERT_THAT(new_table->v_d(), NotNull()); + EXPECT_EQ(new_table->v_d()->size(), 3); + EXPECT_EQ(new_table->v_d()->Get(0), 1); + EXPECT_EQ(new_table->v_d()->Get(1), 2); + EXPECT_EQ(new_table->v_d()->Get(2), 3); + EXPECT_EQ(new_table->v_str()->size(), 3); + EXPECT_EQ(new_table->v_str()->Get(0)->str(), "foo"); + EXPECT_EQ(new_table->v_str()->Get(1)->str(), "bar"); + EXPECT_EQ(new_table->v_str()->Get(2)->str(), "baz"); + ASSERT_THAT(new_table->v_ei8(), NotNull()); + EXPECT_EQ(new_table->v_ei8()->size(), 2); + EXPECT_EQ(new_table->v_ei8()->Get(0), internal::ByteEnum_First); + EXPECT_EQ(new_table->v_ei8()->Get(1), internal::ByteEnum_Second); + ASSERT_THAT(new_table->v_ei16(), NotNull()); + EXPECT_EQ(new_table->v_ei16()->size(), 2); + EXPECT_EQ(new_table->v_ei16()->Get(0), internal::ShortEnum_First); + EXPECT_EQ(new_table->v_ei16()->Get(1), internal::ShortEnum_Second); + ASSERT_THAT(new_table->v_ei32(), NotNull()); + EXPECT_EQ(new_table->v_ei32()->size(), 2); + EXPECT_EQ(new_table->v_ei32()->Get(0), internal::IntEnum_First); + EXPECT_EQ(new_table->v_ei32()->Get(1), internal::IntEnum_Second); + ASSERT_THAT(new_table->v_ei64(), NotNull()); + EXPECT_EQ(new_table->v_ei64()->size(), 2); + EXPECT_EQ(new_table->v_ei64()->Get(0), internal::LongEnum_First); + EXPECT_EQ(new_table->v_ei64()->Get(1), internal::LongEnum_Second); + ASSERT_THAT(new_table->v_eu8(), NotNull()); + EXPECT_EQ(new_table->v_eu8()->size(), 2); + EXPECT_EQ(new_table->v_eu8()->Get(0), internal::UByteEnum_First); + EXPECT_EQ(new_table->v_eu8()->Get(1), internal::UByteEnum_Second); + ASSERT_THAT(new_table->v_eu16(), NotNull()); + EXPECT_EQ(new_table->v_eu16()->size(), 2); + EXPECT_EQ(new_table->v_eu16()->Get(0), internal::UShortEnum_First); + EXPECT_EQ(new_table->v_eu16()->Get(1), internal::UShortEnum_Second); + ASSERT_THAT(new_table->v_eu32(), NotNull()); + EXPECT_EQ(new_table->v_eu32()->size(), 2); + EXPECT_EQ(new_table->v_eu32()->Get(0), internal::UIntEnum_First); + EXPECT_EQ(new_table->v_eu32()->Get(1), internal::UIntEnum_Second); + ASSERT_THAT(new_table->v_t(), NotNull()); + EXPECT_EQ(new_table->v_t()->size(), 1); + ASSERT_THAT(new_table->v_t()->Get(0), NotNull()); + EXPECT_EQ(new_table->v_t()->Get(0)->b(), true); + ASSERT_THAT(new_table->v_u_type(), NotNull()); + ASSERT_EQ(new_table->v_u_type()->size(), 2); + EXPECT_EQ(new_table->v_u_type()->Get(0), internal::Union_BoolTable); + EXPECT_EQ(new_table->v_u_type()->Get(1), internal::Union_StringTable); + ASSERT_THAT(new_table->v_u(), NotNull()); + EXPECT_EQ(new_table->v_u()->size(), 2); + auto v_u_0 = + static_cast(new_table->v_u()->Get(0)); + ASSERT_THAT(v_u_0, NotNull()); + EXPECT_EQ(v_u_0->b(), true); + auto v_u_1 = + static_cast(new_table->v_u()->Get(1)); + ASSERT_THAT(v_u_1, NotNull()); + ASSERT_THAT(v_u_1->str(), NotNull()); + EXPECT_EQ(v_u_1->str()->str(), "foo bar baz"); } TEST(FlatbuffersTableDomainImplTest, InitGeneratesSeeds) { @@ -291,12 +621,22 @@ TEST(FlatbuffersTableDomainImplTest, InitGeneratesSeeds) { TEST(FlatbuffersTableDomainImplTest, CanMutateAnyTableField) { absl::flat_hash_map mutated_fields{ - {"b", false}, {"i8", false}, {"i16", false}, {"i32", false}, - {"i64", false}, {"u8", false}, {"u16", false}, {"u32", false}, - {"u64", false}, {"f", false}, {"d", false}, {"str", false}, - {"ei8", false}, {"ei16", false}, {"ei32", false}, {"ei64", false}, - {"eu8", false}, {"eu16", false}, {"eu32", false}, {"eu64", false}, - {"t", false}, + {"b", false}, {"i8", false}, {"i16", false}, + {"i32", false}, {"i64", false}, {"u8", false}, + {"u16", false}, {"u32", false}, {"u64", false}, + {"f", false}, {"d", false}, {"str", false}, + {"ei8", false}, {"ei16", false}, {"ei32", false}, + {"ei64", false}, {"eu8", false}, {"eu16", false}, + {"eu32", false}, {"eu64", false}, {"t", false}, + {"u_type", false}, {"u", false}, {"v_b", false}, + {"v_i8", false}, {"v_i16", false}, {"v_i32", false}, + {"v_i64", false}, {"v_u8", false}, {"v_u16", false}, + {"v_u32", false}, {"v_u64", false}, {"v_f", false}, + {"v_d", false}, {"v_str", false}, {"v_ei8", false}, + {"v_ei16", false}, {"v_ei32", false}, {"v_ei64", false}, + {"v_eu8", false}, {"v_eu16", false}, {"v_eu32", false}, + {"v_eu64", false}, {"v_t", false}, {"v_u_type", false}, + {"v_u", false}, }; auto domain = Arbitrary(); @@ -332,6 +672,34 @@ TEST(FlatbuffersTableDomainImplTest, CanMutateAnyTableField) { mutated_fields["eu32"] |= mut->eu32() != init->eu32(); mutated_fields["eu64"] |= mut->eu64() != init->eu64(); mutated_fields["t"] |= !Eq(mut->t(), init->t()); + mutated_fields["u_type"] |= !Eq(mut->u_type(), init->u_type()); + mutated_fields["u"] |= + !Eq(std::pair(static_cast(mut->u_type()), mut->u()), + std::pair(static_cast(init->u_type()), init->u())); + mutated_fields["v_b"] |= !Eq(mut->v_b(), init->v_b()); + mutated_fields["v_i8"] |= !Eq(mut->v_i8(), init->v_i8()); + mutated_fields["v_i16"] |= !Eq(mut->v_i16(), init->v_i16()); + mutated_fields["v_i32"] |= !Eq(mut->v_i32(), init->v_i32()); + mutated_fields["v_i64"] |= !Eq(mut->v_i64(), init->v_i64()); + mutated_fields["v_u8"] |= !Eq(mut->v_u8(), init->v_u8()); + mutated_fields["v_u16"] |= !Eq(mut->v_u16(), init->v_u16()); + mutated_fields["v_u32"] |= !Eq(mut->v_u32(), init->v_u32()); + mutated_fields["v_u64"] |= !Eq(mut->v_u64(), init->v_u64()); + mutated_fields["v_f"] |= !Eq(mut->v_f(), init->v_f()); + mutated_fields["v_d"] |= !Eq(mut->v_d(), init->v_d()); + mutated_fields["v_str"] |= !Eq(mut->v_str(), init->v_str()); + mutated_fields["v_ei8"] |= !Eq(mut->v_ei8(), init->v_ei8()); + mutated_fields["v_ei16"] |= !Eq(mut->v_ei16(), init->v_ei16()); + mutated_fields["v_ei32"] |= !Eq(mut->v_ei32(), init->v_ei32()); + mutated_fields["v_ei64"] |= !Eq(mut->v_ei64(), init->v_ei64()); + mutated_fields["v_eu8"] |= !Eq(mut->v_eu8(), init->v_eu8()); + mutated_fields["v_eu16"] |= !Eq(mut->v_eu16(), init->v_eu16()); + mutated_fields["v_eu32"] |= !Eq(mut->v_eu32(), init->v_eu32()); + mutated_fields["v_eu64"] |= !Eq(mut->v_eu64(), init->v_eu64()); + mutated_fields["v_t"] |= !Eq(mut->v_str(), init->v_str()); + mutated_fields["v_u_type"] |= !Eq(mut->v_u_type(), init->v_u_type()); + mutated_fields["v_u"] |= !Eq(std::make_pair(mut->v_u_type(), mut->v_u()), + std::make_pair(init->v_u_type(), init->v_u())); if (std::all_of(mutated_fields.begin(), mutated_fields.end(), [](const auto& p) { return p.second; })) { @@ -345,6 +713,30 @@ TEST(FlatbuffersTableDomainImplTest, CanMutateAnyTableField) { TEST(FlatbuffersTableDomainImplTest, OptionalTableEventuallyBecomeEmpty) { flatbuffers::FlatBufferBuilder fbb; auto bool_table_offset = internal::CreateBoolTable(fbb, true); + std::vector v_b{true, false}; + std::vector v_i8{}; + std::vector v_i16{}; + std::vector v_i32{}; + std::vector v_i64{}; + std::vector v_u8{}; + std::vector v_u16{}; + std::vector v_u32{}; + std::vector v_u64{}; + std::vector v_f{}; + std::vector v_d{}; + std::vector> v_str{ + fbb.CreateString(""), fbb.CreateString(""), fbb.CreateString("")}; + std::vector> v_ei8{}; + std::vector> v_ei16{}; + std::vector> v_ei32{}; + std::vector> v_ei64{}; + std::vector> v_eu8{}; + std::vector> v_eu16{}; + std::vector> v_eu32{}; + std::vector> v_eu64{}; + std::vector> v_t{}; + std::vector> v_u_type{}; + std::vector> v_u{}; auto table_offset = internal::CreateOptionalTableDirect(fbb, true, // b @@ -367,7 +759,32 @@ TEST(FlatbuffersTableDomainImplTest, OptionalTableEventuallyBecomeEmpty) { internal::UShortEnum_Second, // eu16 internal::UIntEnum_Second, // eu32 internal::ULongEnum_Second, // eu64 - bool_table_offset // t + bool_table_offset, // t + internal::Union_BoolTable, // u_type + bool_table_offset.Union(), // u + &v_b, // v_b + &v_i8, // v_i8 + &v_i16, // v_i16 + &v_i32, // v_i32 + &v_i64, // v_i64 + &v_u8, // v_u8 + &v_u16, // v_u16 + &v_u32, // v_u32 + &v_u64, // v_u64 + &v_f, // v_f + &v_d, // v_d + &v_str, // v_str + &v_ei8, // v_ei8 + &v_ei16, // v_ei16 + &v_ei32, // v_ei32 + &v_ei64, // v_ei64 + &v_eu8, // v_eu8 + &v_eu16, // v_eu16 + &v_eu32, // v_eu32 + &v_eu64, // v_eu64 + &v_t, // v_t + &v_u_type, // v_u_type + &v_u // v_u ); fbb.Finish(table_offset); auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); @@ -377,12 +794,22 @@ TEST(FlatbuffersTableDomainImplTest, OptionalTableEventuallyBecomeEmpty) { absl::BitGen bitgen; absl::flat_hash_map null_fields{ - {"b", false}, {"i8", false}, {"i16", false}, {"i32", false}, - {"i64", false}, {"u8", false}, {"u16", false}, {"u32", false}, - {"u64", false}, {"f", false}, {"d", false}, {"str", false}, - {"ei8", false}, {"ei16", false}, {"ei32", false}, {"ei64", false}, - {"eu8", false}, {"eu16", false}, {"eu32", false}, {"eu64", false}, - {"t", false}, + {"b", false}, {"i8", false}, {"i16", false}, + {"i32", false}, {"i64", false}, {"u8", false}, + {"u16", false}, {"u32", false}, {"u64", false}, + {"f", false}, {"d", false}, {"str", false}, + {"ei8", false}, {"ei16", false}, {"ei32", false}, + {"ei64", false}, {"eu8", false}, {"eu16", false}, + {"eu32", false}, {"eu64", false}, {"t", false}, + {"u_type", false}, {"u", false}, {"v_b", false}, + {"v_i8", false}, {"v_i16", false}, {"v_i32", false}, + {"v_i64", false}, {"v_u8", false}, {"v_u16", false}, + {"v_u32", false}, {"v_u64", false}, {"v_f", false}, + {"v_d", false}, {"v_str", false}, {"v_ei8", false}, + {"v_ei16", false}, {"v_ei32", false}, {"v_ei64", false}, + {"v_eu8", false}, {"v_eu16", false}, {"v_eu32", false}, + {"v_eu64", false}, {"v_t", false}, {"v_u_type", false}, + {"v_u", false}, }; // Optional fields are mutated to null with probability 1/100. @@ -413,6 +840,31 @@ TEST(FlatbuffersTableDomainImplTest, OptionalTableEventuallyBecomeEmpty) { null_fields["eu32"] |= !v->eu32().has_value(); null_fields["eu64"] |= !v->eu64().has_value(); null_fields["t"] |= v->t() == nullptr; + null_fields["u_type"] |= v->u_type() == internal::Union_NONE; + null_fields["u"] |= v->u() == nullptr; + null_fields["v_b"] |= v->v_b() == nullptr; + null_fields["v_i8"] |= v->v_i8() == nullptr; + null_fields["v_i16"] |= v->v_i16() == nullptr; + null_fields["v_i32"] |= v->v_i32() == nullptr; + null_fields["v_i64"] |= v->v_i64() == nullptr; + null_fields["v_u8"] |= v->v_u8() == nullptr; + null_fields["v_u16"] |= v->v_u16() == nullptr; + null_fields["v_u32"] |= v->v_u32() == nullptr; + null_fields["v_u64"] |= v->v_u64() == nullptr; + null_fields["v_f"] |= v->v_f() == nullptr; + null_fields["v_d"] |= v->v_d() == nullptr; + null_fields["v_str"] |= v->v_str() == nullptr; + null_fields["v_ei8"] |= v->v_ei8() == nullptr; + null_fields["v_ei16"] |= v->v_ei16() == nullptr; + null_fields["v_ei32"] |= v->v_ei32() == nullptr; + null_fields["v_ei64"] |= v->v_ei64() == nullptr; + null_fields["v_eu8"] |= v->v_eu8() == nullptr; + null_fields["v_eu16"] |= v->v_eu16() == nullptr; + null_fields["v_eu32"] |= v->v_eu32() == nullptr; + null_fields["v_eu64"] |= v->v_eu64() == nullptr; + null_fields["v_t"] |= v->v_t() == nullptr; + null_fields["v_u_type"] |= v->v_u_type() == nullptr; + null_fields["v_u"] |= v->v_u() == nullptr; if (std::all_of(null_fields.begin(), null_fields.end(), [](const auto& p) { return p.second; })) { @@ -451,39 +903,59 @@ TEST(FlatbuffersTableDomainImplTest, Printer) { printer.PrintCorpusValue(*corpus, &out, domain_implementor::PrintMode::kHumanReadable); - EXPECT_THAT(out, AllOf(HasSubstr("b: (true)"), // b - HasSubstr("i8: (1)"), // i8 - HasSubstr("i16: (2)"), // i16 - HasSubstr("i32: (3)"), // i32 - HasSubstr("i64: (4)"), // i64 - HasSubstr("u8: (5)"), // u8 - HasSubstr("u16: (6)"), // u16 - HasSubstr("u32: (7)"), // u32 - HasSubstr("u64: (8)"), // u64 - HasSubstr("f: (9.f)"), // f - HasSubstr("d: (10.)"), // d - HasSubstr("str: (\"foo bar baz\")"), // str - HasSubstr("ei8: (Second)"), // ei8 - HasSubstr("ei16: (Second)"), // ei16 - HasSubstr("ei32: (Second)"), // ei32 - HasSubstr("ei64: (Second)"), // ei64 - HasSubstr("eu8: (Second)"), // eu8 - HasSubstr("eu16: (Second)"), // eu16 - HasSubstr("eu32: (Second)"), // eu32 - HasSubstr("eu64: (Second)"), // eu64 - HasSubstr("t: ({b: (true)})") // t - )); + EXPECT_THAT( + out, + AllOf(HasSubstr("b: (true)"), // b + HasSubstr("i8: (1)"), // i8 + HasSubstr("i16: (2)"), // i16 + HasSubstr("i32: (3)"), // i32 + HasSubstr("i64: (4)"), // i64 + HasSubstr("u8: (5)"), // u8 + HasSubstr("u16: (6)"), // u16 + HasSubstr("u32: (7)"), // u32 + HasSubstr("u64: (8)"), // u64 + HasSubstr("f: (9.f)"), // f + HasSubstr("d: (10.)"), // d + HasSubstr("str: (\"foo bar baz\")"), // str + HasSubstr("ei8: (Second)"), // ei8 + HasSubstr("ei16: (Second)"), // ei16 + HasSubstr("ei32: (Second)"), // ei32 + HasSubstr("ei64: (Second)"), // ei64 + HasSubstr("eu8: (Second)"), // eu8 + HasSubstr("eu16: (Second)"), // eu16 + HasSubstr("eu32: (Second)"), // eu32 + HasSubstr("eu64: (Second)"), // eu64 + HasSubstr("t: ({b: (true)})"), // t + HasSubstr("u: (({b: (true)}))"), // u + HasSubstr("v_b: ({true, false})"), // v_b + HasSubstr("v_i8: ({1, 2, 3})"), // v_i8 + HasSubstr("v_i16: ({1, 2, 3})"), // v_i16 + HasSubstr("v_i32: ({1, 2, 3})"), // v_i32 + HasSubstr("v_i64: ({1, 2, 3})"), // v_i64 + HasSubstr("v_u8: ({1, 2, 3})"), // v_u8 + HasSubstr("v_u16: ({1, 2, 3})"), // v_u16 + HasSubstr("v_u32: ({1, 2, 3})"), // v_u32 + HasSubstr("v_u64: ({1, 2, 3})"), // v_u64 + HasSubstr("v_f: ({1.f, 2.f, 3.f})"), // v_f + HasSubstr("v_d: ({1., 2., 3.})"), // v_d + HasSubstr("v_str: ({\"foo\", \"bar\", \"baz\"})"), // v_str + HasSubstr("v_ei8: ({First, Second})"), // v_ei8 + HasSubstr("v_ei16: ({First, Second})"), // v_ei16 + HasSubstr("v_ei32: ({First, Second})"), // v_ei32 + HasSubstr("v_ei64: ({First, Second})"), // v_ei64 + HasSubstr("v_eu8: ({First, Second})"), // v_eu8 + HasSubstr("v_eu16: ({First, Second})"), // v_eu16 + HasSubstr("v_eu32: ({First, Second})"), // v_eu32 + HasSubstr("v_eu64: ({First, Second})"), // v_eu64 + HasSubstr("v_t: ({{b: (true)}})"), // v_t + HasSubstr("v_u: ({({b: (true)}), " + "({str: (\"foo bar baz\")})})") // v_u + )); } TEST(FlatbuffersTableDomainImplTest, UnsupportedTypesRemainNull) { - absl::flat_hash_map null_fields{ - {"u", true}, {"s", true}, {"v_b", true}, {"v_i8", true}, - {"v_i16", true}, {"v_i32", true}, {"v_i64", true}, {"v_u8", true}, - {"v_u16", true}, {"v_u32", true}, {"v_u64", true}, {"v_f", true}, - {"v_d", true}, {"v_str", true}, {"v_ei8", true}, {"v_ei16", true}, - {"v_ei32", true}, {"v_ei64", true}, {"v_eu8", true}, {"v_eu16", true}, - {"v_eu32", true}, {"v_eu64", true}, {"v_t", true}, {"v_u", true}, - {"v_s", true}}; + absl::flat_hash_map null_fields{{"s", true}, + {"v_s", true}}; auto domain = Arbitrary(); @@ -495,30 +967,7 @@ TEST(FlatbuffersTableDomainImplTest, UnsupportedTypesRemainNull) { val.Mutate(domain, bitgen, {}, false); const auto& mut = val.user_value; - null_fields["u"] &= mut->u() == nullptr; null_fields["s"] &= mut->s() == nullptr; - null_fields["v_b"] &= mut->v_b() == nullptr; - null_fields["v_i8"] &= mut->v_i8() == nullptr; - null_fields["v_i16"] &= mut->v_i16() == nullptr; - null_fields["v_i32"] &= mut->v_i32() == nullptr; - null_fields["v_i64"] &= mut->v_i64() == nullptr; - null_fields["v_u8"] &= mut->v_u8() == nullptr; - null_fields["v_u16"] &= mut->v_u16() == nullptr; - null_fields["v_u32"] &= mut->v_u32() == nullptr; - null_fields["v_u64"] &= mut->v_u64() == nullptr; - null_fields["v_f"] &= mut->v_f() == nullptr; - null_fields["v_d"] &= mut->v_d() == nullptr; - null_fields["v_str"] &= mut->v_str() == nullptr; - null_fields["v_ei8"] &= mut->v_ei8() == nullptr; - null_fields["v_ei16"] &= mut->v_ei16() == nullptr; - null_fields["v_ei32"] &= mut->v_ei32() == nullptr; - null_fields["v_ei64"] &= mut->v_ei64() == nullptr; - null_fields["v_eu8"] &= mut->v_eu8() == nullptr; - null_fields["v_eu16"] &= mut->v_eu16() == nullptr; - null_fields["v_eu32"] &= mut->v_eu32() == nullptr; - null_fields["v_eu64"] &= mut->v_eu64() == nullptr; - null_fields["v_t"] &= mut->v_t() == nullptr; - null_fields["v_u"] &= mut->v_u() == nullptr; null_fields["v_s"] &= mut->v_s() == nullptr; if (std::any_of(null_fields.begin(), null_fields.end(), @@ -530,6 +979,106 @@ TEST(FlatbuffersTableDomainImplTest, UnsupportedTypesRemainNull) { EXPECT_THAT(null_fields, Each(Pair(_, true))); } +TEST(FlatbuffersTableDomainImplTest, MutateSelectedField) { + flatbuffers::FlatBufferBuilder fbb; + auto table = CreateDefaultTable(fbb); + auto domain = Arbitrary(); + absl::BitGen prng; + + { + // Mutate scalar field (b). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 1); + EXPECT_NE(domain.GetValue(*corpus)->b(), table->b()); + } + { + // Mutate nested table field (t.b). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 22); + EXPECT_NE(domain.GetValue(*corpus)->t()->b(), table->t()->b()); + } + { + // Mutate union type field (u_type). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 24); + EXPECT_NE(domain.GetValue(*corpus)->u_type(), table->u_type()); + } + { + // Mutate union field (u.*) content. + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 25); + switch (domain.GetValue(*corpus)->u_type()) { + case internal::Union_BoolTable: + EXPECT_NE(domain.GetValue(*corpus)->u_as_BoolTable()->b(), + table->u_as_BoolTable()->b()); + break; + case internal::Union_StringTable: + EXPECT_NE( + domain.GetValue(*corpus)->u_as_StringTable()->str()->string_view(), + table->u_as_StringTable()->str()->string_view()); + break; + default: + FAIL() << "Unexpected union type: " + << domain.GetValue(*corpus)->u_type(); + } + } + { + // Mutate vector of tables field (v_t). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 46); + auto user_value = domain.GetValue(*corpus); + EXPECT_FALSE(Eq(user_value->v_t(), table->v_t())); + } + { + // Mutate vector of tables field (v_t[0].b). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 47); + EXPECT_NE(domain.GetValue(*corpus)->v_t()->Get(0)->b(), + table->v_t()->Get(0)->b()); + } + { + // Mutate vector of unions field type (v_u_type[0]). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 49); + EXPECT_NE(domain.GetValue(*corpus)->v_u_type()->Get(0), + table->v_u_type()->Get(0)); + } + { + // Mutate vector of unions field (v_u[0].*). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 50); + EXPECT_EQ(domain.GetValue(*corpus)->v_u_type()->size(), + table->v_u_type()->size()); + EXPECT_EQ(domain.GetValue(*corpus)->v_u_type()->Get(0), + table->v_u_type()->Get(0)); + if (!domain.GetValue(*corpus)->v_u_type()->empty()) { + switch (domain.GetValue(*corpus)->v_u_type()->Get(0)) { + case internal::Union_BoolTable: + EXPECT_NE( + static_cast( + domain.GetValue(*corpus)->v_u()->Get(0)) + ->b(), + static_cast(table->v_u()->Get(0)) + ->b()); + break; + case internal::Union_StringTable: + EXPECT_NE( + static_cast( + domain.GetValue(*corpus)->v_u()->Get(0)) + ->str() + ->string_view(), + static_cast(table->v_u()->Get(0)) + ->str() + ->string_view()); + break; + default: + FAIL() << "Unexpected union type: " + << domain.GetValue(*corpus)->v_u_type()->Get(0); + } + } + } +} + TEST(FlatbuffersTableDomainImplTest, MutateAlwaysChangesValues) { auto domain = Arbitrary(); const reflection::Schema* schema = @@ -538,13 +1087,24 @@ TEST(FlatbuffersTableDomainImplTest, MutateAlwaysChangesValues) { schema->objects()->LookupByKey(DefaultTable::GetFullyQualifiedName()); absl::BitGen bitgen; - size_t iterations = IterationsToHitAll(object->fields()->size(), - 1.0 / object->fields()->size()); + const uint32_t field_count = object->fields()->size(); + int iterations = IterationsToHitAll(field_count, 1.0 / field_count); typename decltype(domain)::corpus_type corpus = domain.Init(bitgen); - for (size_t i = 0; i < iterations; ++i) { + for (int i = 0; i < iterations; ++i) { auto mutated_corpus = corpus; domain.Mutate(mutated_corpus, bitgen, {}, false); - EXPECT_FALSE(Eq(domain.GetValue(mutated_corpus), domain.GetValue(corpus))); + if (Eq(domain.GetValue(mutated_corpus), domain.GetValue(corpus))) { + auto printer = domain.GetPrinter(); + std::string corpus_str; + printer.PrintCorpusValue(corpus, &corpus_str, + domain_implementor::PrintMode::kHumanReadable); + std::string mutated_corpus_str; + printer.PrintCorpusValue(mutated_corpus, &mutated_corpus_str, + domain_implementor::PrintMode::kHumanReadable); + FAIL() << "Mutated corpus is equal to the original corpus." + << "\nOriginal: " << corpus_str + << "\nMutated: " << mutated_corpus_str; + } corpus = mutated_corpus; } } @@ -564,7 +1124,72 @@ TEST(FlatbuffersTableDomainImplTest, CountNumberOfFieldsWithNull) { auto domain = Arbitrary(); auto corpus = domain.FromValue(table); ASSERT_TRUE(corpus.has_value()); - EXPECT_EQ(domain.CountNumberOfFields(corpus.value()), 21); + EXPECT_EQ(domain.CountNumberOfFields(corpus.value()), 44); +} + +TEST(FlatbuffersUnionDomainImpl, ParseCorpusRejectsInvalidValues) { + auto domain = Arbitrary(); + { + flatbuffers::FlatBufferBuilder fbb; + internal::CreateUnionTable(fbb, internal::Union_BoolTable, 0); + fbb.Finish(internal::CreateUnionTable(fbb, internal::Union_BoolTable, 0)); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize()); + ASSERT_TRUE(verifier.VerifyBuffer()); + + auto corpus = domain.FromValue(table); + ASSERT_TRUE(corpus.has_value()); + EXPECT_FALSE(domain.ValidateCorpusValue(corpus.value()).ok()); + } + { + internal::IRObject ir_object; + auto& subs = ir_object.MutableSubs(); + subs.reserve(2); + + auto& u_obj = subs.emplace_back(); + auto& u_subs = u_obj.MutableSubs(); + u_subs.reserve(2); + u_subs.emplace_back(1); // id + auto& u_opt_value = u_subs.emplace_back(); // value + auto& u_opt_value_subs = u_opt_value.MutableSubs(); + u_opt_value_subs.reserve(2); + u_opt_value_subs.emplace_back(1); // has value + auto& u_inner_value = u_opt_value_subs.emplace_back(); + + u_inner_value.MutableSubs().reserve(2); + u_inner_value.MutableSubs().emplace_back(-1); // type (invalid) + u_inner_value.MutableSubs().emplace_back(); // value + + auto corpus = domain.ParseCorpus(ir_object); + ASSERT_FALSE(corpus.has_value()); + } + { + internal::IRObject ir_object; + auto& subs = ir_object.MutableSubs(); + subs.reserve(2); + + auto& u_obj = subs.emplace_back(); + auto& u_subs = u_obj.MutableSubs(); + u_subs.reserve(2); + u_subs.emplace_back(1); // id + auto& u_opt_value = u_subs.emplace_back(); // value + auto& u_opt_value_subs = u_opt_value.MutableSubs(); + u_opt_value_subs.reserve(2); + u_opt_value_subs.emplace_back(1); // has value + auto& u_inner_value = u_opt_value_subs.emplace_back(); + + u_inner_value.MutableSubs().reserve(2); + u_inner_value.MutableSubs().emplace_back( + internal::Union_BoolTable); // type + auto& bool_table = u_inner_value.MutableSubs().emplace_back(); // value + auto& bool_table_subs = bool_table.MutableSubs(); + bool_table_subs.reserve(2); + bool_table_subs.emplace_back(200); // id (invalid) + u_subs.emplace_back(); // value + + auto corpus = domain.ParseCorpus(ir_object); + ASSERT_FALSE(corpus.has_value()); + } } TEST(FlatbuffersTableDomainImplTest, RecursiveTable) { @@ -592,5 +1217,37 @@ TEST(FlatbuffersTableDomainImplTest, RecursiveTable) { ASSERT_THAT(new_table, IsNull()); } +TEST(FlatbuffersTableDomainImplTest, DefaultTable64ValueRoundTrip) { + flatbuffers::FlatBufferBuilder64 fbb; + auto str_offset = fbb.CreateString("foo bar baz"); + std::vector v_u8 = {1, 2, 3}; + auto v_u8_offset = fbb.CreateVector64(v_u8); + auto table_offset = + internal::CreateDefaultTable64(fbb, str_offset, v_u8_offset); + fbb.Finish(table_offset); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + auto domain = Arbitrary(); + auto corpus = domain.FromValue(table); + ASSERT_TRUE(corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*corpus)); + + auto ir = domain.SerializeCorpus(corpus.value()); + + auto new_corpus = domain.ParseCorpus(ir); + ASSERT_TRUE(new_corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*new_corpus)); + + auto new_table = domain.GetValue(*new_corpus); + ASSERT_THAT(new_table, NotNull()); + ASSERT_THAT(new_table->str(), NotNull()); + EXPECT_EQ(new_table->str()->str(), "foo bar baz"); + ASSERT_THAT(new_table->v_u8(), NotNull()); + ASSERT_EQ(new_table->v_u8()->size(), 3); + EXPECT_EQ(new_table->v_u8()->Get(0), 1); + EXPECT_EQ(new_table->v_u8()->Get(1), 2); + EXPECT_EQ(new_table->v_u8()->Get(2), 3); +} + } // namespace } // namespace fuzztest diff --git a/fuzztest/internal/BUILD b/fuzztest/internal/BUILD index f5155910..2bf470ff 100644 --- a/fuzztest/internal/BUILD +++ b/fuzztest/internal/BUILD @@ -616,8 +616,13 @@ cc_test( flatbuffer_library_public( name = "test_flatbuffers_fbs", - srcs = ["test_flatbuffers.fbs"], + srcs = [ + "test_flatbuffers.fbs", + "test_flatbuffers_64bits.fbs", + ], outs = [ + "test_flatbuffers_64bits_bfbs_generated.h", + "test_flatbuffers_64bits_generated.h", "test_flatbuffers_bfbs_generated.h", "test_flatbuffers_generated.h", ], diff --git a/fuzztest/internal/CMakeLists.txt b/fuzztest/internal/CMakeLists.txt index 0eb75aba..5c1173e6 100644 --- a/fuzztest/internal/CMakeLists.txt +++ b/fuzztest/internal/CMakeLists.txt @@ -574,6 +574,7 @@ if (FUZZTEST_BUILD_FLATBUFFERS) test_flatbuffers_headers SCHEMAS "test_flatbuffers.fbs" + "test_flatbuffers_64bits.fbs" FLAGS --bfbs-gen-embed --gen-name-strings TESTONLY diff --git a/fuzztest/internal/domains/BUILD b/fuzztest/internal/domains/BUILD index 545fef36..30790fc5 100644 --- a/fuzztest/internal/domains/BUILD +++ b/fuzztest/internal/domains/BUILD @@ -187,9 +187,9 @@ cc_library( hdrs = ["flatbuffers_domain_impl.h"], deps = [ ":core_domains_impl", - "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/random:bit_gen_ref", @@ -204,6 +204,7 @@ cc_library( "@com_google_fuzztest//fuzztest/internal:meta", "@com_google_fuzztest//fuzztest/internal:serialization", "@com_google_fuzztest//fuzztest/internal:status", + "@com_google_fuzztest//fuzztest/internal:type_support", "@flatbuffers//:runtime_cc", ], ) diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.cc b/fuzztest/internal/domains/flatbuffers_domain_impl.cc index 7e8e144e..28458161 100644 --- a/fuzztest/internal/domains/flatbuffers_domain_impl.cc +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include "./fuzztest/internal/domains/flatbuffers_domain_impl.h" #include @@ -22,29 +23,420 @@ #include "absl/random/bit_gen_ref.h" #include "absl/random/distributions.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "flatbuffers/base.h" #include "flatbuffers/flatbuffer_builder.h" #include "flatbuffers/reflection.h" #include "flatbuffers/reflection_generated.h" +#include "flatbuffers/table.h" +#include "./common/logging.h" #include "./fuzztest/domain_core.h" #include "./fuzztest/internal/any.h" #include "./fuzztest/internal/domains/domain_base.h" #include "./fuzztest/internal/domains/domain_type_erasure.h" +#include "./fuzztest/internal/meta.h" #include "./fuzztest/internal/serialization.h" namespace fuzztest::internal { +FlatbuffersUnionDomainImpl::FlatbuffersUnionDomainImpl( + const reflection::Schema* schema, const reflection::Enum* union_def) + : schema_(schema), union_def_(union_def), type_domain_(union_def) { + type_domain_.WithExcludedValues({0 /* NONE */}); +} + +FlatbuffersUnionDomainImpl::FlatbuffersUnionDomainImpl( + const FlatbuffersUnionDomainImpl& other) + : schema_(other.schema_), + union_def_(other.union_def_), + type_domain_(other.type_domain_) { + absl::MutexLock l(mutex_); + absl::MutexLock l_other(other.mutex_); + domains_ = other.domains_; +} + +FlatbuffersUnionDomainImpl::FlatbuffersUnionDomainImpl( + FlatbuffersUnionDomainImpl&& other) + : schema_(other.schema_), + union_def_(other.union_def_), + type_domain_(std::move(other.type_domain_)) { + absl::MutexLock l(mutex_); + absl::MutexLock l_other(other.mutex_); + domains_ = std::move(other.domains_); +} + +// Get a domain for a specific table type. +template <> +auto FlatbuffersUnionDomainImpl::GetDefaultDomainForType( + const reflection::EnumVal& enum_value) const { + const reflection::Object* object = + schema_->objects()->Get(enum_value.union_type()->index()); + return Domain( + FlatbuffersTableUntypedDomainImpl{schema_, object}); +} + +FlatbuffersUnionDomainImpl::corpus_type FlatbuffersUnionDomainImpl::Init( + absl::BitGenRef prng) { + if (auto seed = this->MaybeGetRandomSeed(prng)) { + return *seed; + } + + // Unions are encoded as the combination of two fields: an enum representing + // the union choice and the offset to the actual element. + // + // The following code follows that logic. + corpus_type val; + + val.type = type_domain_.Init(prng); + auto type_value = type_domain_.GetValue(val.type); + + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return val; + } + + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + } else { + auto inner_val = + GetCachedDomain(*type_enumval).Init(prng); + val.value = std::move(inner_val); + } + return val; +} + +// Mutates the corpus value. +void FlatbuffersUnionDomainImpl::Mutate( + corpus_type& corpus_value, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, bool only_shrink) { + auto type_value = type_domain_.GetValue(corpus_value.type); + + // Mutate the type with probability 1%. + if (absl::Bernoulli(prng, 0.01)) { + // Mutate the type. + type_domain_.Mutate(corpus_value.type, prng, metadata, only_shrink); + type_value = type_domain_.GetValue(corpus_value.type); + + // If the union is set after type mutation, init the value corpus value. + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) return; + + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + return; + } else { + corpus_value.value = + GetCachedDomain(*type_enumval).Init(prng); + } + return; + } + + // Mutate the value if the union is set. + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) return; + + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + return; + } else { + GetCachedDomain(*type_enumval) + .Mutate(corpus_value.value, prng, metadata, only_shrink); + } +} + +uint64_t FlatbuffersUnionDomainImpl::CountNumberOfFields( + corpus_type& corpus_value) { + uint64_t field_count = 0; + + // If the union has only one type (besides NONE), the type is not counted + // as mutable field. + if (union_def_->values()->size() <= 2) { + return field_count; + } + + // The first field is the union type. + ++field_count; + + auto type_value = type_domain_.GetValue(corpus_value.type); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return field_count; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + } else { + auto domain = GetCachedDomain(*type_enumval); + field_count += domain.CountNumberOfFields(corpus_value.value); + } + return field_count; +} + +uint64_t FlatbuffersUnionDomainImpl::MutateSelectedField( + corpus_type& corpus_value, absl::BitGenRef prng, + const domain_implementor::MutationMetadata& metadata, bool only_shrink, + uint64_t selected_field_index) { + uint64_t field_count = 0; + + // If the union has only one type (besides NONE), the type is not counted + // as mutable field. + if (union_def_->values()->size() <= 2) { + return field_count; + } + + // The first field is the union type. + ++field_count; + if (selected_field_index == field_count) { + type_domain_.Mutate(corpus_value.type, prng, metadata, only_shrink); + auto type_value = type_domain_.GetValue(corpus_value.type); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) return selected_field_index; + + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + } else { + corpus_value.value = + GetCachedDomain(*type_enumval).Init(prng); + } + return field_count; + } + + auto type_value = type_domain_.GetValue(corpus_value.type); + + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return 0; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + } else { + auto domain = GetCachedDomain(*type_enumval); + field_count += domain.MutateSelectedField( + corpus_value.value, prng, metadata, only_shrink, + selected_field_index - field_count); + } + return field_count; +} + +absl::Status FlatbuffersUnionDomainImpl::ValidateCorpusValue( + const corpus_type& corpus_value) const { + // Unions are encoded as the combination of two fields: an enum representing + // the union choice and the offset to the actual element. + // + // Both type and value should be validated. + // + // Start with the type validation. + auto type_value = type_domain_.GetValue(corpus_value.type); + + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid union type: ", type_value)); + } + + // Validate the value. + if (!corpus_value.value.has_value()) { + return absl::InvalidArgumentError("Union value is not set."); + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + return absl::OkStatus(); + } else { + auto domain = GetCachedDomain(*type_enumval); + return domain.ValidateCorpusValue(corpus_value.value); + } +} + +// Converts the value to a corpus value. +std::optional +FlatbuffersUnionDomainImpl::FromValue(const value_type& value) const { + auto out = std::make_optional(); + auto type_corpus = type_domain_.FromValue(value.type); + if (type_corpus.has_value()) { + out->type = *type_corpus; + } + auto type_enumval = union_def_->values()->LookupByKey(value.type); + if (type_enumval == nullptr) { + return std::nullopt; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + std::optional inner_corpus; + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + } else { + auto domain = GetCachedDomain(*type_enumval); + inner_corpus = + domain.FromValue(static_cast(value.value)); + } + if (inner_corpus.has_value()) { + out->value = std::move(inner_corpus.value()); + } + return out; +} + +// Converts the IRObject to a corpus value. +std::optional +FlatbuffersUnionDomainImpl::ParseCorpus(const IRObject& obj) const { + // Follows the structure created by `SerializeCorpus` to deserialize the + // IRObject. + corpus_type out; + auto subs = obj.Subs(); + if (!subs) { + return std::nullopt; + } + + // We expect 2 fields: the type and the value. + if (subs->size() != 2) { + return std::nullopt; + } + + // Parse the type which is stored in the first field of the IRObject subs. + auto type_corpus = type_domain_.ParseCorpus((*subs)[0]); + if (!type_corpus.has_value()) { + return std::nullopt; + } + if (auto status = type_domain_.ValidateCorpusValue(*type_corpus); + !status.ok()) { + FUZZTEST_LOG(ERROR) << "Failed to validate type corpus: " + << status.message(); + return std::nullopt; + } + out.type = *type_corpus; + auto type_value = type_domain_.GetValue(out.type); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return std::nullopt; + } + + // Parse the value. + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object == nullptr) { + return std::nullopt; + } + std::optional inner_corpus; + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + } else { + auto domain = GetCachedDomain(*type_enumval); + // The value is stored in the second field of the IRObject subs. + inner_corpus = domain.ParseCorpus((*subs)[1]); + } + + if (inner_corpus.has_value()) { + out.value = std::move(inner_corpus.value()); + } + return out; +} + +// Converts the corpus value to an IRObject. +IRObject FlatbuffersUnionDomainImpl::SerializeCorpus( + const corpus_type& corpus_value) const { + IRObject out; + auto type_value = type_domain_.GetValue(corpus_value.type); + + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return out; + } + + auto& pair = out.MutableSubs(); + // We have 2 fields: the type and the value. + pair.reserve(2); + + // Serialize the type. + pair.push_back(type_domain_.SerializeCorpus(corpus_value.type)); + + // Serialize the value. + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + } else { + auto domain = GetCachedDomain(*type_enumval); + pair.push_back(domain.SerializeCorpus(corpus_value.value)); + } + return out; +} + +std::optional FlatbuffersUnionDomainImpl::BuildValue( + const corpus_type& corpus_value, + flatbuffers::FlatBufferBuilder64& builder) const { + // Get the object type. + auto type_value = type_domain_.GetValue(corpus_value.type); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr || !corpus_value.value.has_value()) { + return std::nullopt; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object == nullptr) { + return std::nullopt; + } + if (object->is_struct()) { + // TODO (b/405939014): Support structs. + return std::nullopt; + } else { + FlatbuffersTableUntypedDomainImpl domain{schema_, object}; + return domain.BuildTable( + corpus_value.value + .GetAs>(), + builder); + } +} + +void FlatbuffersUnionDomainImpl::Printer::PrintCorpusValue( + const corpus_type& value, domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + auto type_value = self.type_domain_.GetValue(value.type); + auto type_enumval = self.union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return; + } + absl::Format(out, "<%s>(", type_enumval->name()->str()); + + const reflection::Object* object = + self.schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + absl::Format(out, "UNSUPPORTED_UNION_TYPE"); + } else { + auto domain = self.GetCachedDomain(*type_enumval); + domain_implementor::PrintValue(domain, value.value, out, mode); + } + absl::Format(out, ")"); +} + FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( const reflection::Schema* absl_nonnull schema, const reflection::Object* absl_nonnull table_object) - : schema_(schema), table_object_(table_object) {} + : schema_(schema), table_object_(table_object) { + for (const auto& field : *table_object_->fields()) { + fields_by_id_[field->id()] = field; + } +} FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( const FlatbuffersTableUntypedDomainImpl& other) : DomainBase(other), schema_(other.schema_), - table_object_(other.table_object_) { + table_object_(other.table_object_), + fields_by_id_(other.fields_by_id_) { absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = other.domains_; @@ -55,6 +447,7 @@ FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( DomainBase::operator=(other); schema_ = other.schema_; table_object_ = other.table_object_; + fields_by_id_ = other.fields_by_id_; absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = other.domains_; @@ -63,7 +456,9 @@ FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( FlatbuffersTableUntypedDomainImpl&& other) - : schema_(other.schema_), table_object_(other.table_object_) { + : schema_(other.schema_), + table_object_(other.table_object_), + fields_by_id_(std::move(other.fields_by_id_)) { absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = std::move(other.domains_); @@ -74,6 +469,7 @@ FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( FlatbuffersTableUntypedDomainImpl&& other) { schema_ = other.schema_; table_object_ = other.table_object_; + fields_by_id_ = std::move(other.fields_by_id_); absl::MutexLock l_other(other.mutex_); absl::MutexLock l_this(mutex_); domains_ = std::move(other.domains_); @@ -87,7 +483,7 @@ FlatbuffersTableUntypedDomainImpl::Init(absl::BitGenRef prng) { return *seed; } corpus_type val; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField(schema_, field, InitializeVisitor{*this, prng, val}); } return val; @@ -98,7 +494,7 @@ void FlatbuffersTableUntypedDomainImpl::Mutate( corpus_type& val, absl::BitGenRef prng, const domain_implementor::MutationMetadata& metadata, bool only_shrink) { uint64_t field_count = 0; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField(schema_, field, CountNumberOfMutableFieldsVisitor{*this, field_count, val, only_shrink}); @@ -112,7 +508,7 @@ void FlatbuffersTableUntypedDomainImpl::Mutate( uint64_t FlatbuffersTableUntypedDomainImpl::CountNumberOfFields( corpus_type& val) { uint64_t field_count = 0; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField( schema_, field, CountNumberOfMutableFieldsVisitor{*this, field_count, val}); @@ -130,29 +526,11 @@ uint64_t FlatbuffersTableUntypedDomainImpl::MutateSelectedField( return fields_count; } - for (const auto* field : *table_object_->fields()) { - if (!IsSupportedField(field)) { - if (only_shrink && !val.contains(field->id())) continue; - } - - ++field_counter; - if (field_counter == selected_field_index) { - VisitFlatbufferField( - schema_, field, - MutateVisitor{*this, prng, metadata, only_shrink, val}); - return field_counter; - } - - if (field->type()->base_type() == reflection::BaseType::Obj) { - auto sub_object = schema_->objects()->Get(field->type()->index()); - if (!sub_object->is_struct()) { - field_counter += - GetCachedDomain(field).MutateSelectedField( - val[field->id()], prng, metadata, only_shrink, - selected_field_index - field_counter); - } - // TODO: Add support for structs. - } + for (const auto& [_, field] : fields_by_id_) { + VisitFlatbufferField( + schema_, field, + MutateSelectedFieldVisitor{*this, field_counter, val, prng, metadata, + only_shrink, selected_field_index}); if (field_counter >= selected_field_index) { return field_counter; @@ -163,7 +541,7 @@ uint64_t FlatbuffersTableUntypedDomainImpl::MutateSelectedField( absl::Status FlatbuffersTableUntypedDomainImpl::ValidateCorpusValue( const corpus_type& corpus_value) const { - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { absl::Status result; GenericDomainCorpusType field_corpus; if (auto it = corpus_value.find(field->id()); it != corpus_value.end()) { @@ -183,7 +561,7 @@ FlatbuffersTableUntypedDomainImpl::FromValue(const value_type& value) const { return std::nullopt; } corpus_type ret; - for (const auto* field : *table_object_->fields()) { + for (const auto& [_, field] : fields_by_id_) { VisitFlatbufferField(schema_, field, FromValueVisitor{*this, value, ret}); } return ret; @@ -265,6 +643,8 @@ IRObject FlatbuffersTableUntypedDomainImpl::SerializeCorpus( bool FlatbuffersTableUntypedDomainImpl::IsSupportedField( const reflection::Field* absl_nonnull field) const { auto base_type = field->type()->base_type(); + // Union types are supported via the FlatbuffersUnionDomainImpl, but not + // directly in the table domain. if (base_type == reflection::BaseType::UType) return false; if (flatbuffers::IsScalar(base_type)) return true; if (base_type == reflection::BaseType::String) return true; @@ -272,15 +652,27 @@ bool FlatbuffersTableUntypedDomainImpl::IsSupportedField( auto sub_object = schema_->objects()->Get(field->type()->index()); return !sub_object->is_struct(); }; + if (base_type == reflection::BaseType::Union) return true; + if (base_type == reflection::BaseType::Vector || + base_type == reflection::BaseType::Vector64) { + auto elem_type = field->type()->element(); + if (flatbuffers::IsScalar(elem_type)) return true; + if (elem_type == reflection::BaseType::String) return true; + if (elem_type == reflection::BaseType::Obj) { + auto sub_object = schema_->objects()->Get(field->type()->index()); + return !sub_object->is_struct(); + } + if (elem_type == reflection::BaseType::Union) return true; + } return false; } uint32_t FlatbuffersTableUntypedDomainImpl::BuildTable( - const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const { + const corpus_type& value, flatbuffers::FlatBufferBuilder64& builder) const { // Add all the fields to the builder. // Offsets is the map of field id to its offset in the table. - absl::flat_hash_map + absl::flat_hash_map offsets; // Some fields are stored inline in the flatbuffer table itself (a.k.a diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.h b/fuzztest/internal/domains/flatbuffers_domain_impl.h index 1482a598..a7fcc67a 100644 --- a/fuzztest/internal/domains/flatbuffers_domain_impl.h +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -26,9 +27,9 @@ #include #include -#include "absl/algorithm/container.h" #include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/random/bit_gen_ref.h" @@ -37,21 +38,26 @@ #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "flatbuffers/base.h" +#include "flatbuffers/buffer.h" #include "flatbuffers/flatbuffer_builder.h" +#include "flatbuffers/reflection.h" #include "flatbuffers/reflection_generated.h" #include "flatbuffers/string.h" #include "flatbuffers/table.h" +#include "flatbuffers/vector.h" #include "flatbuffers/verifier.h" #include "./common/logging.h" #include "./fuzztest/domain_core.h" #include "./fuzztest/internal/any.h" #include "./fuzztest/internal/domains/arbitrary_impl.h" +#include "./fuzztest/internal/domains/container_of_impl.h" #include "./fuzztest/internal/domains/domain_base.h" #include "./fuzztest/internal/domains/domain_type_erasure.h" #include "./fuzztest/internal/domains/element_of_impl.h" #include "./fuzztest/internal/meta.h" #include "./fuzztest/internal/serialization.h" #include "./fuzztest/internal/status.h" +#include "./fuzztest/internal/type_support.h" namespace fuzztest::internal { @@ -76,73 +82,152 @@ template inline constexpr bool is_flatbuffers_enum_tag_v = is_flatbuffers_enum_tag::value; +// +// Flatbuffers vector detection. +// +template +struct FlatbuffersVectorTag { + using value_type = T; +}; + +template +struct is_flatbuffers_vector_tag : std::false_type {}; + +template +struct is_flatbuffers_vector_tag> : std::true_type {}; + +template +inline constexpr bool is_flatbuffers_vector_tag_v = + is_flatbuffers_vector_tag::value; + +template +struct FlatbuffersVector64Tag { + using value_type = T; +}; + +template +struct is_flatbuffers_vector64_tag : std::false_type {}; + +template +struct is_flatbuffers_vector64_tag> : std::true_type { +}; + +template +inline constexpr bool is_flatbuffers_vector64_tag_v = + is_flatbuffers_vector64_tag::value; + +template +inline constexpr bool is_any_flatbuffers_vector_tag_v = + is_flatbuffers_vector_tag_v || is_flatbuffers_vector64_tag_v; + +template +struct flatbuffers_vector_tag_offset; + +template +struct flatbuffers_vector_tag_offset> { + using type = flatbuffers::uoffset_t; +}; + +template +struct flatbuffers_vector_tag_offset> { + using type = flatbuffers::uoffset64_t; +}; + +template +using flatbuffers_vector_tag_offset_t = + typename flatbuffers_vector_tag_offset::type; + struct FlatbuffersArrayTag; + +// Flatbuffers container element type detection. +template +inline constexpr bool is_flatbuffers_container_of_v = []() constexpr { + if constexpr (is_flatbuffers_vector_tag_v || + is_flatbuffers_vector64_tag_v) { + return std::is_same_v; + } else { + return false; + } +}(); + struct FlatbuffersTableTag; struct FlatbuffersStructTag; struct FlatbuffersUnionTag; -struct FlatbuffersVectorTag; + +// Helper to wrap the visitor with the correct tag type. +template