diff --git a/domain_tests/BUILD b/domain_tests/BUILD index 0dfd652b..e2b7215d 100644 --- a/domain_tests/BUILD +++ b/domain_tests/BUILD @@ -46,11 +46,13 @@ cc_test( deps = [ ":domain_testing", "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random", "@abseil-cpp//absl/status", "@com_google_fuzztest//fuzztest:domain", "@com_google_fuzztest//fuzztest:flatbuffers", "@com_google_fuzztest//fuzztest/internal:meta", + "@com_google_fuzztest//fuzztest/internal:serialization", "@com_google_fuzztest//fuzztest/internal:test_flatbuffers_cc_fbs", "@flatbuffers//:runtime_cc", "@googletest//:gtest_main", diff --git a/domain_tests/arbitrary_domains_flatbuffers_test.cc b/domain_tests/arbitrary_domains_flatbuffers_test.cc index 508ad269..e87ff736 100644 --- a/domain_tests/arbitrary_domains_flatbuffers_test.cc +++ b/domain_tests/arbitrary_domains_flatbuffers_test.cc @@ -19,11 +19,14 @@ #include #include #include +#include +#include #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/random/random.h" #include "absl/status/status.h" #include "flatbuffers/base.h" @@ -32,21 +35,74 @@ #include "flatbuffers/reflection_generated.h" #include "flatbuffers/string.h" #include "flatbuffers/vector.h" +#include "flatbuffers/verifier.h" #include "./fuzztest/domain.h" #include "./domain_tests/domain_testing.h" #include "./fuzztest/flatbuffers.h" #include "./fuzztest/internal/meta.h" +#include "./fuzztest/internal/serialization.h" +#include "./fuzztest/internal/test_flatbuffers_64bits_generated.h" #include "./fuzztest/internal/test_flatbuffers_generated.h" namespace fuzztest { +namespace internal { +template +void AbslStringify(Sink& sink, const Union& e) { + absl::Format(&sink, "%s", EnumNameUnion(e)); +} +template +void AbslStringify(Sink& sink, const ByteEnum& e) { + absl::Format(&sink, "%s", EnumNameByteEnum(e)); +} +template +void AbslStringify(Sink& sink, const ShortEnum& e) { + absl::Format(&sink, "%s", EnumNameShortEnum(e)); +} +template +void AbslStringify(Sink& sink, const IntEnum& e) { + absl::Format(&sink, "%s", EnumNameIntEnum(e)); +} +template +void AbslStringify(Sink& sink, const LongEnum& e) { + absl::Format(&sink, "%s", EnumNameLongEnum(e)); +} +template +void AbslStringify(Sink& sink, const UByteEnum& e) { + absl::Format(&sink, "%s", EnumNameUByteEnum(e)); +} +template +void AbslStringify(Sink& sink, const UShortEnum& e) { + absl::Format(&sink, "%s", EnumNameUShortEnum(e)); +} +template +void AbslStringify(Sink& sink, const UIntEnum& e) { + absl::Format(&sink, "%s", EnumNameUIntEnum(e)); +} +template +void AbslStringify(Sink& sink, const ULongEnum& e) { + absl::Format(&sink, "%s", EnumNameULongEnum(e)); +} +} // namespace internal namespace { +using ::fuzztest::internal::BoolStruct; using ::fuzztest::internal::BoolTable; +using ::fuzztest::internal::ByteEnum; +using ::fuzztest::internal::DefaultStruct; using ::fuzztest::internal::DefaultTable; +using ::fuzztest::internal::DefaultTable64; +using ::fuzztest::internal::IntEnum; +using ::fuzztest::internal::LongEnum; using ::fuzztest::internal::OptionalTable; using ::fuzztest::internal::RecursiveTable; using ::fuzztest::internal::RequiredTable; -using ::fuzztest::internal::UnsupportedTypesTable; +using ::fuzztest::internal::ShortEnum; +using ::fuzztest::internal::StringTable; +using ::fuzztest::internal::UByteEnum; +using ::fuzztest::internal::UIntEnum; +using ::fuzztest::internal::ULongEnum; +using ::fuzztest::internal::UnionTable; +using ::fuzztest::internal::UShortEnum; using ::testing::_; using ::testing::AllOf; using ::testing::Each; @@ -70,6 +126,40 @@ inline bool Eq(const T* lhs, const T* rhs) { return Eq(*lhs, *rhs); } +template <> +inline bool Eq(const BoolStruct& lhs, const BoolStruct& rhs) { + return Eq(lhs.b(), rhs.b()); +} + +template <> +inline bool Eq(const DefaultStruct& lhs, + const DefaultStruct& rhs) { + const bool eq_b = lhs.b() == rhs.b(); + const bool eq_i8 = lhs.i8() == rhs.i8(); + const bool eq_i16 = lhs.i16() == rhs.i16(); + const bool eq_i32 = lhs.i32() == rhs.i32(); + const bool eq_i64 = lhs.i64() == rhs.i64(); + const bool eq_u8 = lhs.u8() == rhs.u8(); + const bool eq_u16 = lhs.u16() == rhs.u16(); + const bool eq_u32 = lhs.u32() == rhs.u32(); + const bool eq_u64 = lhs.u64() == rhs.u64(); + const bool eq_f = lhs.f() == rhs.f(); + const bool eq_d = lhs.d() == rhs.d(); + const bool eq_ei8 = lhs.ei8() == rhs.ei8(); + const bool eq_ei16 = lhs.ei16() == rhs.ei16(); + const bool eq_ei32 = lhs.ei32() == rhs.ei32(); + const bool eq_ei64 = lhs.ei64() == rhs.ei64(); + const bool eq_eu8 = lhs.eu8() == rhs.eu8(); + const bool eq_eu16 = lhs.eu16() == rhs.eu16(); + const bool eq_eu32 = lhs.eu32() == rhs.eu32(); + const bool eq_eu64 = lhs.eu64() == rhs.eu64(); + const bool eq_s = Eq(lhs.s(), rhs.s()); + + return eq_b && eq_i8 && eq_i16 && eq_i32 && eq_i64 && eq_u8 && eq_u16 && + eq_u32 && eq_u64 && eq_f && eq_d && eq_ei8 && eq_ei16 && eq_ei32 && + eq_ei64 && eq_eu8 && eq_eu16 && eq_eu32 && eq_eu64 && eq_s; +} + template <> inline bool Eq(const flatbuffers::String& lhs, const flatbuffers::String& rhs) { @@ -82,6 +172,78 @@ inline bool Eq(const BoolTable& lhs, const BoolTable& rhs) { return lhs.b() == rhs.b(); } +template +inline bool Eq(const flatbuffers::Vector& lhs, + const flatbuffers::Vector& rhs) { + if (lhs.size() != rhs.size()) return false; + for (int i = 0; i < lhs.size(); ++i) { + if (!Eq(lhs.Get(i), rhs.Get(i))) return false; + } + return true; +} + +template <> +inline bool Eq(const StringTable& lhs, const StringTable& rhs) { + return Eq(lhs.str(), rhs.str()); +} + +template <> +inline bool Eq>( + const std::pair& lhs, + const std::pair& rhs) { + if (lhs.first == internal::Union_NONE && rhs.first == internal::Union_NONE) { + return true; + } + if (lhs.first != rhs.first) return false; + + switch (lhs.first) { + case internal::Union_BoolTable: + return Eq(static_cast(lhs.second), + static_cast(rhs.second)); + case internal::Union_StringTable: + return Eq(static_cast(lhs.second), + static_cast(rhs.second)); + case internal::Union_BoolStruct: + return Eq(static_cast(rhs.second), + static_cast(lhs.second)); + default: + CHECK(false) << "Unsupported union type"; + } +} + +template +inline bool Eq(const flatbuffers::Vector* lhs, + const flatbuffers::Vector* rhs) { + if (lhs == nullptr && rhs == nullptr) return true; + if (lhs == nullptr || rhs == nullptr) return false; + return Eq(*lhs, *rhs); +} + +template <> +inline bool +Eq*, + const flatbuffers::Vector<::flatbuffers::Offset>*>>( + const std::pair*, + const flatbuffers::Vector<::flatbuffers::Offset>*>& + lhs, + const std::pair*, + const flatbuffers::Vector<::flatbuffers::Offset>*>& + rhs) { + if (!Eq(lhs.first, rhs.first)) return false; + if (lhs.second == nullptr && rhs.second == nullptr) return true; + if (lhs.second == nullptr || rhs.second == nullptr) return false; + if (lhs.first->size() != lhs.second->size()) return false; + if (lhs.second->size() != rhs.second->size()) return false; + + for (int i = 0; i < lhs.second->size(); ++i) { + if (!Eq(std::pair(lhs.first->Get(i), lhs.second->Get(i)), + std::pair(rhs.first->Get(i), rhs.second->Get(i)))) { + return false; + } + } + return true; +} + template <> inline bool Eq(const DefaultTable& lhs, const DefaultTable& rhs) { const bool eq_b = lhs.b() == rhs.b(); @@ -105,14 +267,111 @@ inline bool Eq(const DefaultTable& lhs, const DefaultTable& rhs) { const bool eq_eu32 = lhs.eu32() == rhs.eu32(); const bool eq_eu64 = lhs.eu64() == rhs.eu64(); const bool eq_t = Eq(lhs.t(), rhs.t()); + const bool eq_u = Eq(std::pair(static_cast(lhs.u_type()), lhs.u()), + std::pair(static_cast(rhs.u_type()), rhs.u())); + const bool eq_s = Eq(lhs.s(), rhs.s()); + const bool eq_v_b = Eq(lhs.v_b(), rhs.v_b()); + const bool eq_v_i8 = Eq(lhs.v_i8(), rhs.v_i8()); + const bool eq_v_i16 = Eq(lhs.v_i16(), rhs.v_i16()); + const bool eq_v_i32 = Eq(lhs.v_i32(), rhs.v_i32()); + const bool eq_v_i64 = Eq(lhs.v_i64(), rhs.v_i64()); + const bool eq_v_u8 = Eq(lhs.v_u8(), rhs.v_u8()); + const bool eq_v_u16 = Eq(lhs.v_u16(), rhs.v_u16()); + const bool eq_v_u32 = Eq(lhs.v_u32(), rhs.v_u32()); + const bool eq_v_u64 = Eq(lhs.v_u64(), rhs.v_u64()); + const bool eq_v_f = Eq(lhs.v_f(), rhs.v_f()); + const bool eq_v_d = Eq(lhs.v_d(), rhs.v_d()); + const bool eq_v_str = Eq(lhs.v_str(), rhs.v_str()); + const bool eq_v_ei8 = Eq(lhs.v_ei8(), rhs.v_ei8()); + const bool eq_v_ei16 = Eq(lhs.v_ei16(), rhs.v_ei16()); + const bool eq_v_ei32 = Eq(lhs.v_ei32(), rhs.v_ei32()); + const bool eq_v_ei64 = Eq(lhs.v_ei64(), rhs.v_ei64()); + const bool eq_v_eu8 = Eq(lhs.v_eu8(), rhs.v_eu8()); + const bool eq_v_eu16 = Eq(lhs.v_eu16(), rhs.v_eu16()); + const bool eq_v_eu32 = Eq(lhs.v_eu32(), rhs.v_eu32()); + const bool eq_v_eu64 = Eq(lhs.v_eu64(), rhs.v_eu64()); + const bool eq_v_t = Eq(lhs.v_t(), rhs.v_t()); + const bool eq_v_u_type = Eq(lhs.v_u_type(), rhs.v_u_type()); + const bool eq_v_u = Eq(std::make_pair(lhs.v_u_type(), lhs.v_u()), + std::make_pair(rhs.v_u_type(), rhs.v_u())); + const bool eq_v_s = Eq(lhs.v_s(), rhs.v_s()); return eq_b && eq_i8 && eq_i16 && eq_i32 && eq_i64 && eq_u8 && eq_u16 && eq_u32 && eq_u64 && eq_f && eq_d && eq_str && eq_ei8 && eq_ei16 && - eq_ei32 && eq_ei64 && eq_eu8 && eq_eu16 && eq_eu32 && eq_eu64 && eq_t; + eq_ei32 && eq_ei64 && eq_eu8 && eq_eu16 && eq_eu32 && eq_eu64 && + eq_t && eq_u && eq_s && eq_v_b && eq_v_i8 && eq_v_i16 && eq_v_i32 && + eq_v_i64 && eq_v_u8 && eq_v_u16 && eq_v_u32 && eq_v_u64 && eq_v_f && + eq_v_d && eq_v_str && eq_v_ei8 && eq_v_ei16 && eq_v_ei32 && + eq_v_ei64 && eq_v_eu8 && eq_v_eu16 && eq_v_eu32 && eq_v_eu64 && + eq_v_t && eq_v_u_type && eq_v_u && eq_v_s; } const internal::DefaultTable* CreateDefaultTable( flatbuffers::FlatBufferBuilder& fbb) { auto bool_table_offset = internal::CreateBoolTable(fbb, true); + auto string_table_offset = + internal::CreateStringTableDirect(fbb, "foo bar baz"); + DefaultStruct s{ + true, // b + 1, // i8 + 2, // i16 + 3, // i32 + 4, // i64 + 5, // u8 + 6, // u16 + 7, // u32 + 8, // u64 + 9, // f + 10.0, // d + internal::ByteEnum_First, // ei8 + internal::ShortEnum_First, // ei16 + internal::IntEnum_First, // ei32 + internal::LongEnum_First, // ei64 + internal::UByteEnum_First, // eu8 + internal::UShortEnum_First, // eu16 + internal::UIntEnum_First, // eu32 + internal::ULongEnum_First, // eu64 + BoolStruct{true} // s + }; + std::vector v_b{true, false}; + std::vector v_i8{1, 2, 3}; + std::vector v_i16{1, 2, 3}; + std::vector v_i32{1, 2, 3}; + std::vector v_i64{1, 2, 3}; + std::vector v_u8{1, 2, 3}; + std::vector v_u16{1, 2, 3}; + std::vector v_u32{1, 2, 3}; + std::vector v_u64{1, 2, 3}; + std::vector v_f{1, 2, 3}; + std::vector v_d{1, 2, 3}; + std::vector> v_str{ + fbb.CreateString("foo"), fbb.CreateString("bar"), + fbb.CreateString("baz")}; + std::vector> v_ei8{ + internal::ByteEnum_First, internal::ByteEnum_Second}; + std::vector> v_ei16{ + internal::ShortEnum_First, internal::ShortEnum_Second}; + std::vector> v_ei32{internal::IntEnum_First, + internal::IntEnum_Second}; + std::vector> v_ei64{ + internal::LongEnum_First, internal::LongEnum_Second}; + std::vector> v_eu8{ + internal::UByteEnum_First, internal::UByteEnum_Second}; + std::vector> v_eu16{ + internal::UShortEnum_First, internal::UShortEnum_Second}; + std::vector> v_eu32{ + internal::UIntEnum_First, internal::UIntEnum_Second}; + std::vector> v_eu64{ + internal::ULongEnum_First, internal::ULongEnum_Second}; + std::vector> v_t{bool_table_offset}; + std::vector> v_u_type{ + internal::Union_BoolTable, + internal::Union_StringTable, + }; + std::vector> v_u{ + bool_table_offset.Union(), + string_table_offset.Union(), + }; + std::vector v_s{s}; auto table_offset = internal::CreateDefaultTableDirect(fbb, /*b=*/true, @@ -135,7 +394,34 @@ const internal::DefaultTable* CreateDefaultTable( /*eu16=*/internal::UShortEnum_Second, /*eu32=*/internal::UIntEnum_Second, /*eu64=*/internal::ULongEnum_Second, - /*t=*/bool_table_offset); + /*t=*/bool_table_offset, + /*u_type=*/internal::Union_BoolTable, + /*u=*/bool_table_offset.Union(), + /*s=*/&s, + /*v_b=*/&v_b, + /*v_i8=*/&v_i8, + /*v_i16=*/&v_i16, + /*v_i32=*/&v_i32, + /*v_i64=*/&v_i64, + /*v_u8=*/&v_u8, + /*v_u16=*/&v_u16, + /*v_u32=*/&v_u32, + /*v_u64=*/&v_u64, + /*v_f=*/&v_f, + /*v_d=*/&v_d, + /*v_str=*/&v_str, + /*v_ei8=*/&v_ei8, + /*v_ei16=*/&v_ei16, + /*v_ei32=*/&v_ei32, + /*v_ei64=*/&v_ei64, + /*v_eu8=*/&v_eu8, + /*v_eu16=*/&v_eu16, + /*v_eu32=*/&v_eu32, + /*v_eu64=*/&v_eu64, + /*v_t=*/&v_t, + /*v_u_type=*/&v_u_type, + /*v_u=*/&v_u, + /*v_s=*/&v_s); fbb.Finish(table_offset); return flatbuffers::GetRoot(fbb.GetBufferPointer()); } @@ -261,6 +547,7 @@ TEST(FlatbuffersTableDomainImplTest, DefaultTableValueRoundTrip) { EXPECT_EQ(new_table->u64(), 8); EXPECT_EQ(new_table->f(), 9.0); EXPECT_EQ(new_table->d(), 10.0); + ASSERT_THAT(new_table->str(), NotNull()); EXPECT_EQ(new_table->str()->str(), "foo bar baz"); EXPECT_EQ(new_table->ei8(), internal::ByteEnum_Second); EXPECT_EQ(new_table->ei16(), internal::ShortEnum_Second); @@ -272,6 +559,159 @@ TEST(FlatbuffersTableDomainImplTest, DefaultTableValueRoundTrip) { EXPECT_EQ(new_table->eu64(), internal::ULongEnum_Second); ASSERT_THAT(new_table->t(), NotNull()); EXPECT_EQ(new_table->t()->b(), true); + EXPECT_EQ(new_table->u_type(), internal::Union_BoolTable); + ASSERT_THAT(new_table->u(), NotNull()); + EXPECT_EQ(new_table->u_as_BoolTable()->b(), true); + ASSERT_THAT(new_table->s(), NotNull()); + EXPECT_EQ(new_table->s()->b(), true); + EXPECT_EQ(new_table->s()->i8(), 1); + EXPECT_EQ(new_table->s()->i16(), 2); + EXPECT_EQ(new_table->s()->i32(), 3); + EXPECT_EQ(new_table->s()->i64(), 4); + EXPECT_EQ(new_table->s()->u8(), 5); + EXPECT_EQ(new_table->s()->u16(), 6); + EXPECT_EQ(new_table->s()->u32(), 7); + EXPECT_EQ(new_table->s()->u64(), 8); + EXPECT_EQ(new_table->s()->f(), 9.0); + EXPECT_EQ(new_table->s()->d(), 10.0); + EXPECT_EQ(new_table->s()->ei8(), internal::ByteEnum_First); + EXPECT_EQ(new_table->s()->ei16(), internal::ShortEnum_First); + EXPECT_EQ(new_table->s()->ei32(), internal::IntEnum_First); + EXPECT_EQ(new_table->s()->ei64(), internal::LongEnum_First); + EXPECT_EQ(new_table->s()->eu8(), internal::UByteEnum_First); + EXPECT_EQ(new_table->s()->eu16(), internal::UShortEnum_First); + EXPECT_EQ(new_table->s()->eu32(), internal::UIntEnum_First); + EXPECT_EQ(new_table->s()->eu64(), internal::ULongEnum_First); + EXPECT_EQ(new_table->s()->s().b(), true); + ASSERT_THAT(new_table->v_b(), NotNull()); + EXPECT_EQ(new_table->v_b()->size(), 2); + EXPECT_EQ(new_table->v_b()->Get(0), true); + EXPECT_EQ(new_table->v_b()->Get(1), false); + ASSERT_THAT(new_table->v_i8(), NotNull()); + EXPECT_EQ(new_table->v_i8()->size(), 3); + EXPECT_EQ(new_table->v_i8()->Get(0), 1); + EXPECT_EQ(new_table->v_i8()->Get(1), 2); + EXPECT_EQ(new_table->v_i8()->Get(2), 3); + ASSERT_THAT(new_table->v_i16(), NotNull()); + EXPECT_EQ(new_table->v_i16()->size(), 3); + EXPECT_EQ(new_table->v_i16()->Get(0), 1); + EXPECT_EQ(new_table->v_i16()->Get(1), 2); + EXPECT_EQ(new_table->v_i16()->Get(2), 3); + ASSERT_THAT(new_table->v_i32(), NotNull()); + EXPECT_EQ(new_table->v_i32()->size(), 3); + EXPECT_EQ(new_table->v_i32()->Get(0), 1); + EXPECT_EQ(new_table->v_i32()->Get(1), 2); + EXPECT_EQ(new_table->v_i32()->Get(2), 3); + ASSERT_THAT(new_table->v_i64(), NotNull()); + EXPECT_EQ(new_table->v_i64()->size(), 3); + EXPECT_EQ(new_table->v_i64()->Get(0), 1); + EXPECT_EQ(new_table->v_i64()->Get(1), 2); + EXPECT_EQ(new_table->v_i64()->Get(2), 3); + ASSERT_THAT(new_table->v_u8(), NotNull()); + EXPECT_EQ(new_table->v_u8()->size(), 3); + EXPECT_EQ(new_table->v_u8()->Get(0), 1); + EXPECT_EQ(new_table->v_u8()->Get(1), 2); + EXPECT_EQ(new_table->v_u8()->Get(2), 3); + ASSERT_THAT(new_table->v_u16(), NotNull()); + EXPECT_EQ(new_table->v_u16()->size(), 3); + EXPECT_EQ(new_table->v_u16()->Get(0), 1); + EXPECT_EQ(new_table->v_u16()->Get(1), 2); + EXPECT_EQ(new_table->v_u16()->Get(2), 3); + ASSERT_THAT(new_table->v_u32(), NotNull()); + EXPECT_EQ(new_table->v_u32()->size(), 3); + EXPECT_EQ(new_table->v_u32()->Get(0), 1); + EXPECT_EQ(new_table->v_u32()->Get(1), 2); + EXPECT_EQ(new_table->v_u32()->Get(2), 3); + ASSERT_THAT(new_table->v_u64(), NotNull()); + EXPECT_EQ(new_table->v_u64()->size(), 3); + EXPECT_EQ(new_table->v_u64()->Get(0), 1); + EXPECT_EQ(new_table->v_u64()->Get(1), 2); + EXPECT_EQ(new_table->v_u64()->Get(2), 3); + ASSERT_THAT(new_table->v_f(), NotNull()); + EXPECT_EQ(new_table->v_f()->size(), 3); + EXPECT_EQ(new_table->v_f()->Get(0), 1); + EXPECT_EQ(new_table->v_f()->Get(1), 2); + EXPECT_EQ(new_table->v_f()->Get(2), 3); + ASSERT_THAT(new_table->v_d(), NotNull()); + EXPECT_EQ(new_table->v_d()->size(), 3); + EXPECT_EQ(new_table->v_d()->Get(0), 1); + EXPECT_EQ(new_table->v_d()->Get(1), 2); + EXPECT_EQ(new_table->v_d()->Get(2), 3); + EXPECT_EQ(new_table->v_str()->size(), 3); + EXPECT_EQ(new_table->v_str()->Get(0)->str(), "foo"); + EXPECT_EQ(new_table->v_str()->Get(1)->str(), "bar"); + EXPECT_EQ(new_table->v_str()->Get(2)->str(), "baz"); + ASSERT_THAT(new_table->v_ei8(), NotNull()); + EXPECT_EQ(new_table->v_ei8()->size(), 2); + EXPECT_EQ(new_table->v_ei8()->Get(0), internal::ByteEnum_First); + EXPECT_EQ(new_table->v_ei8()->Get(1), internal::ByteEnum_Second); + ASSERT_THAT(new_table->v_ei16(), NotNull()); + EXPECT_EQ(new_table->v_ei16()->size(), 2); + EXPECT_EQ(new_table->v_ei16()->Get(0), internal::ShortEnum_First); + EXPECT_EQ(new_table->v_ei16()->Get(1), internal::ShortEnum_Second); + ASSERT_THAT(new_table->v_ei32(), NotNull()); + EXPECT_EQ(new_table->v_ei32()->size(), 2); + EXPECT_EQ(new_table->v_ei32()->Get(0), internal::IntEnum_First); + EXPECT_EQ(new_table->v_ei32()->Get(1), internal::IntEnum_Second); + ASSERT_THAT(new_table->v_ei64(), NotNull()); + EXPECT_EQ(new_table->v_ei64()->size(), 2); + EXPECT_EQ(new_table->v_ei64()->Get(0), internal::LongEnum_First); + EXPECT_EQ(new_table->v_ei64()->Get(1), internal::LongEnum_Second); + ASSERT_THAT(new_table->v_eu8(), NotNull()); + EXPECT_EQ(new_table->v_eu8()->size(), 2); + EXPECT_EQ(new_table->v_eu8()->Get(0), internal::UByteEnum_First); + EXPECT_EQ(new_table->v_eu8()->Get(1), internal::UByteEnum_Second); + ASSERT_THAT(new_table->v_eu16(), NotNull()); + EXPECT_EQ(new_table->v_eu16()->size(), 2); + EXPECT_EQ(new_table->v_eu16()->Get(0), internal::UShortEnum_First); + EXPECT_EQ(new_table->v_eu16()->Get(1), internal::UShortEnum_Second); + ASSERT_THAT(new_table->v_eu32(), NotNull()); + EXPECT_EQ(new_table->v_eu32()->size(), 2); + EXPECT_EQ(new_table->v_eu32()->Get(0), internal::UIntEnum_First); + EXPECT_EQ(new_table->v_eu32()->Get(1), internal::UIntEnum_Second); + ASSERT_THAT(new_table->v_t(), NotNull()); + EXPECT_EQ(new_table->v_t()->size(), 1); + ASSERT_THAT(new_table->v_t()->Get(0), NotNull()); + EXPECT_EQ(new_table->v_t()->Get(0)->b(), true); + ASSERT_THAT(new_table->v_u_type(), NotNull()); + ASSERT_EQ(new_table->v_u_type()->size(), 2); + EXPECT_EQ(new_table->v_u_type()->Get(0), internal::Union_BoolTable); + EXPECT_EQ(new_table->v_u_type()->Get(1), internal::Union_StringTable); + ASSERT_THAT(new_table->v_u(), NotNull()); + EXPECT_EQ(new_table->v_u()->size(), 2); + auto v_u_0 = + static_cast(new_table->v_u()->Get(0)); + ASSERT_THAT(v_u_0, NotNull()); + EXPECT_EQ(v_u_0->b(), true); + auto v_u_1 = + static_cast(new_table->v_u()->Get(1)); + ASSERT_THAT(v_u_1, NotNull()); + ASSERT_THAT(v_u_1->str(), NotNull()); + EXPECT_EQ(v_u_1->str()->str(), "foo bar baz"); + ASSERT_THAT(new_table->v_s(), NotNull()); + ASSERT_EQ(new_table->v_s()->size(), 1); + auto v_s_0 = new_table->v_s()->Get(0); + ASSERT_THAT(v_s_0, NotNull()); + EXPECT_EQ(v_s_0->b(), true); + EXPECT_EQ(v_s_0->i8(), 1); + EXPECT_EQ(v_s_0->i16(), 2); + EXPECT_EQ(v_s_0->i32(), 3); + EXPECT_EQ(v_s_0->i64(), 4); + EXPECT_EQ(v_s_0->u8(), 5); + EXPECT_EQ(v_s_0->u16(), 6); + EXPECT_EQ(v_s_0->u32(), 7); + EXPECT_EQ(v_s_0->u64(), 8); + EXPECT_EQ(v_s_0->f(), 9.0); + EXPECT_EQ(v_s_0->d(), 10.0); + EXPECT_EQ(v_s_0->ei8(), internal::ByteEnum_First); + EXPECT_EQ(v_s_0->ei16(), internal::ShortEnum_First); + EXPECT_EQ(v_s_0->ei32(), internal::IntEnum_First); + EXPECT_EQ(v_s_0->ei64(), internal::LongEnum_First); + EXPECT_EQ(v_s_0->eu8(), internal::UByteEnum_First); + EXPECT_EQ(v_s_0->eu16(), internal::UShortEnum_First); + EXPECT_EQ(v_s_0->eu32(), internal::UIntEnum_First); + EXPECT_EQ(v_s_0->eu64(), internal::ULongEnum_First); + EXPECT_EQ(v_s_0->s().b(), true); } TEST(FlatbuffersTableDomainImplTest, InitGeneratesSeeds) { @@ -291,12 +731,22 @@ TEST(FlatbuffersTableDomainImplTest, InitGeneratesSeeds) { TEST(FlatbuffersTableDomainImplTest, CanMutateAnyTableField) { absl::flat_hash_map mutated_fields{ - {"b", false}, {"i8", false}, {"i16", false}, {"i32", false}, - {"i64", false}, {"u8", false}, {"u16", false}, {"u32", false}, - {"u64", false}, {"f", false}, {"d", false}, {"str", false}, - {"ei8", false}, {"ei16", false}, {"ei32", false}, {"ei64", false}, - {"eu8", false}, {"eu16", false}, {"eu32", false}, {"eu64", false}, - {"t", false}, + {"b", false}, {"i8", false}, {"i16", false}, + {"i32", false}, {"i64", false}, {"u8", false}, + {"u16", false}, {"u32", false}, {"u64", false}, + {"f", false}, {"d", false}, {"str", false}, + {"ei8", false}, {"ei16", false}, {"ei32", false}, + {"ei64", false}, {"eu8", false}, {"eu16", false}, + {"eu32", false}, {"eu64", false}, {"t", false}, + {"u_type", false}, {"u", false}, {"s", false}, + {"v_b", false}, {"v_i8", false}, {"v_i16", false}, + {"v_i32", false}, {"v_i64", false}, {"v_u8", false}, + {"v_u16", false}, {"v_u32", false}, {"v_u64", false}, + {"v_f", false}, {"v_d", false}, {"v_str", false}, + {"v_ei8", false}, {"v_ei16", false}, {"v_ei32", false}, + {"v_ei64", false}, {"v_eu8", false}, {"v_eu16", false}, + {"v_eu32", false}, {"v_eu64", false}, {"v_t", false}, + {"v_u_type", false}, {"v_u", false}, {"v_s", false}, }; auto domain = Arbitrary(); @@ -332,6 +782,36 @@ TEST(FlatbuffersTableDomainImplTest, CanMutateAnyTableField) { mutated_fields["eu32"] |= mut->eu32() != init->eu32(); mutated_fields["eu64"] |= mut->eu64() != init->eu64(); mutated_fields["t"] |= !Eq(mut->t(), init->t()); + mutated_fields["u_type"] |= !Eq(mut->u_type(), init->u_type()); + mutated_fields["u"] |= + !Eq(std::pair(static_cast(mut->u_type()), mut->u()), + std::pair(static_cast(init->u_type()), init->u())); + mutated_fields["s"] |= !Eq(mut->s(), init->s()); + mutated_fields["v_b"] |= !Eq(mut->v_b(), init->v_b()); + mutated_fields["v_i8"] |= !Eq(mut->v_i8(), init->v_i8()); + mutated_fields["v_i16"] |= !Eq(mut->v_i16(), init->v_i16()); + mutated_fields["v_i32"] |= !Eq(mut->v_i32(), init->v_i32()); + mutated_fields["v_i64"] |= !Eq(mut->v_i64(), init->v_i64()); + mutated_fields["v_u8"] |= !Eq(mut->v_u8(), init->v_u8()); + mutated_fields["v_u16"] |= !Eq(mut->v_u16(), init->v_u16()); + mutated_fields["v_u32"] |= !Eq(mut->v_u32(), init->v_u32()); + mutated_fields["v_u64"] |= !Eq(mut->v_u64(), init->v_u64()); + mutated_fields["v_f"] |= !Eq(mut->v_f(), init->v_f()); + mutated_fields["v_d"] |= !Eq(mut->v_d(), init->v_d()); + mutated_fields["v_str"] |= !Eq(mut->v_str(), init->v_str()); + mutated_fields["v_ei8"] |= !Eq(mut->v_ei8(), init->v_ei8()); + mutated_fields["v_ei16"] |= !Eq(mut->v_ei16(), init->v_ei16()); + mutated_fields["v_ei32"] |= !Eq(mut->v_ei32(), init->v_ei32()); + mutated_fields["v_ei64"] |= !Eq(mut->v_ei64(), init->v_ei64()); + mutated_fields["v_eu8"] |= !Eq(mut->v_eu8(), init->v_eu8()); + mutated_fields["v_eu16"] |= !Eq(mut->v_eu16(), init->v_eu16()); + mutated_fields["v_eu32"] |= !Eq(mut->v_eu32(), init->v_eu32()); + mutated_fields["v_eu64"] |= !Eq(mut->v_eu64(), init->v_eu64()); + mutated_fields["v_t"] |= !Eq(mut->v_str(), init->v_str()); + mutated_fields["v_u_type"] |= !Eq(mut->v_u_type(), init->v_u_type()); + mutated_fields["v_u"] |= !Eq(std::make_pair(mut->v_u_type(), mut->v_u()), + std::make_pair(init->v_u_type(), init->v_u())); + mutated_fields["v_s"] |= !Eq(mut->v_s(), init->v_s()); if (std::all_of(mutated_fields.begin(), mutated_fields.end(), [](const auto& p) { return p.second; })) { @@ -345,6 +825,32 @@ TEST(FlatbuffersTableDomainImplTest, CanMutateAnyTableField) { TEST(FlatbuffersTableDomainImplTest, OptionalTableEventuallyBecomeEmpty) { flatbuffers::FlatBufferBuilder fbb; auto bool_table_offset = internal::CreateBoolTable(fbb, true); + DefaultStruct s; + std::vector v_b{true, false}; + std::vector v_i8{}; + std::vector v_i16{}; + std::vector v_i32{}; + std::vector v_i64{}; + std::vector v_u8{}; + std::vector v_u16{}; + std::vector v_u32{}; + std::vector v_u64{}; + std::vector v_f{}; + std::vector v_d{}; + std::vector> v_str{ + fbb.CreateString(""), fbb.CreateString(""), fbb.CreateString("")}; + std::vector> v_ei8{}; + std::vector> v_ei16{}; + std::vector> v_ei32{}; + std::vector> v_ei64{}; + std::vector> v_eu8{}; + std::vector> v_eu16{}; + std::vector> v_eu32{}; + std::vector> v_eu64{}; + std::vector> v_t{}; + std::vector> v_u_type{}; + std::vector> v_u{}; + std::vector v_s{}; auto table_offset = internal::CreateOptionalTableDirect(fbb, true, // b @@ -367,7 +873,34 @@ TEST(FlatbuffersTableDomainImplTest, OptionalTableEventuallyBecomeEmpty) { internal::UShortEnum_Second, // eu16 internal::UIntEnum_Second, // eu32 internal::ULongEnum_Second, // eu64 - bool_table_offset // t + bool_table_offset, // t + internal::Union_BoolTable, // u_type + bool_table_offset.Union(), // u + &s, // s + &v_b, // v_b + &v_i8, // v_i8 + &v_i16, // v_i16 + &v_i32, // v_i32 + &v_i64, // v_i64 + &v_u8, // v_u8 + &v_u16, // v_u16 + &v_u32, // v_u32 + &v_u64, // v_u64 + &v_f, // v_f + &v_d, // v_d + &v_str, // v_str + &v_ei8, // v_ei8 + &v_ei16, // v_ei16 + &v_ei32, // v_ei32 + &v_ei64, // v_ei64 + &v_eu8, // v_eu8 + &v_eu16, // v_eu16 + &v_eu32, // v_eu32 + &v_eu64, // v_eu64 + &v_t, // v_t + &v_u_type, // v_u_type + &v_u, // v_u + &v_s // v_s ); fbb.Finish(table_offset); auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); @@ -377,12 +910,22 @@ TEST(FlatbuffersTableDomainImplTest, OptionalTableEventuallyBecomeEmpty) { absl::BitGen bitgen; absl::flat_hash_map null_fields{ - {"b", false}, {"i8", false}, {"i16", false}, {"i32", false}, - {"i64", false}, {"u8", false}, {"u16", false}, {"u32", false}, - {"u64", false}, {"f", false}, {"d", false}, {"str", false}, - {"ei8", false}, {"ei16", false}, {"ei32", false}, {"ei64", false}, - {"eu8", false}, {"eu16", false}, {"eu32", false}, {"eu64", false}, - {"t", false}, + {"b", false}, {"i8", false}, {"i16", false}, + {"i32", false}, {"i64", false}, {"u8", false}, + {"u16", false}, {"u32", false}, {"u64", false}, + {"f", false}, {"d", false}, {"str", false}, + {"ei8", false}, {"ei16", false}, {"ei32", false}, + {"ei64", false}, {"eu8", false}, {"eu16", false}, + {"eu32", false}, {"eu64", false}, {"t", false}, + {"u_type", false}, {"u", false}, {"s", false}, + {"v_b", false}, {"v_i8", false}, {"v_i16", false}, + {"v_i32", false}, {"v_i64", false}, {"v_u8", false}, + {"v_u16", false}, {"v_u32", false}, {"v_u64", false}, + {"v_f", false}, {"v_d", false}, {"v_str", false}, + {"v_ei8", false}, {"v_ei16", false}, {"v_ei32", false}, + {"v_ei64", false}, {"v_eu8", false}, {"v_eu16", false}, + {"v_eu32", false}, {"v_eu64", false}, {"v_t", false}, + {"v_u_type", false}, {"v_u", false}, {"v_s", false}, }; // Optional fields are mutated to null with probability 1/100. @@ -413,6 +956,33 @@ TEST(FlatbuffersTableDomainImplTest, OptionalTableEventuallyBecomeEmpty) { null_fields["eu32"] |= !v->eu32().has_value(); null_fields["eu64"] |= !v->eu64().has_value(); null_fields["t"] |= v->t() == nullptr; + null_fields["u_type"] |= v->u_type() == internal::Union_NONE; + null_fields["u"] |= v->u() == nullptr; + null_fields["s"] |= v->s() == nullptr; + null_fields["v_b"] |= v->v_b() == nullptr; + null_fields["v_i8"] |= v->v_i8() == nullptr; + null_fields["v_i16"] |= v->v_i16() == nullptr; + null_fields["v_i32"] |= v->v_i32() == nullptr; + null_fields["v_i64"] |= v->v_i64() == nullptr; + null_fields["v_u8"] |= v->v_u8() == nullptr; + null_fields["v_u16"] |= v->v_u16() == nullptr; + null_fields["v_u32"] |= v->v_u32() == nullptr; + null_fields["v_u64"] |= v->v_u64() == nullptr; + null_fields["v_f"] |= v->v_f() == nullptr; + null_fields["v_d"] |= v->v_d() == nullptr; + null_fields["v_str"] |= v->v_str() == nullptr; + null_fields["v_ei8"] |= v->v_ei8() == nullptr; + null_fields["v_ei16"] |= v->v_ei16() == nullptr; + null_fields["v_ei32"] |= v->v_ei32() == nullptr; + null_fields["v_ei64"] |= v->v_ei64() == nullptr; + null_fields["v_eu8"] |= v->v_eu8() == nullptr; + null_fields["v_eu16"] |= v->v_eu16() == nullptr; + null_fields["v_eu32"] |= v->v_eu32() == nullptr; + null_fields["v_eu64"] |= v->v_eu64() == nullptr; + null_fields["v_t"] |= v->v_t() == nullptr; + null_fields["v_u_type"] |= v->v_u_type() == nullptr; + null_fields["v_u"] |= v->v_u() == nullptr; + null_fields["v_s"] |= v->v_s() == nullptr; if (std::all_of(null_fields.begin(), null_fields.end(), [](const auto& p) { return p.second; })) { @@ -451,83 +1021,164 @@ TEST(FlatbuffersTableDomainImplTest, Printer) { printer.PrintCorpusValue(*corpus, &out, domain_implementor::PrintMode::kHumanReadable); - EXPECT_THAT(out, AllOf(HasSubstr("b: (true)"), // b - HasSubstr("i8: (1)"), // i8 - HasSubstr("i16: (2)"), // i16 - HasSubstr("i32: (3)"), // i32 - HasSubstr("i64: (4)"), // i64 - HasSubstr("u8: (5)"), // u8 - HasSubstr("u16: (6)"), // u16 - HasSubstr("u32: (7)"), // u32 - HasSubstr("u64: (8)"), // u64 - HasSubstr("f: (9.f)"), // f - HasSubstr("d: (10.)"), // d - HasSubstr("str: (\"foo bar baz\")"), // str - HasSubstr("ei8: (Second)"), // ei8 - HasSubstr("ei16: (Second)"), // ei16 - HasSubstr("ei32: (Second)"), // ei32 - HasSubstr("ei64: (Second)"), // ei64 - HasSubstr("eu8: (Second)"), // eu8 - HasSubstr("eu16: (Second)"), // eu16 - HasSubstr("eu32: (Second)"), // eu32 - HasSubstr("eu64: (Second)"), // eu64 - HasSubstr("t: ({b: (true)})") // t - )); -} - -TEST(FlatbuffersTableDomainImplTest, UnsupportedTypesRemainNull) { - absl::flat_hash_map null_fields{ - {"u", true}, {"s", true}, {"v_b", true}, {"v_i8", true}, - {"v_i16", true}, {"v_i32", true}, {"v_i64", true}, {"v_u8", true}, - {"v_u16", true}, {"v_u32", true}, {"v_u64", true}, {"v_f", true}, - {"v_d", true}, {"v_str", true}, {"v_ei8", true}, {"v_ei16", true}, - {"v_ei32", true}, {"v_ei64", true}, {"v_eu8", true}, {"v_eu16", true}, - {"v_eu32", true}, {"v_eu64", true}, {"v_t", true}, {"v_u", true}, - {"v_s", true}}; - - auto domain = Arbitrary(); + EXPECT_THAT( + out, + AllOf(HasSubstr("b: (true)"), // b + HasSubstr("i8: (1)"), // i8 + HasSubstr("i16: (2)"), // i16 + HasSubstr("i32: (3)"), // i32 + HasSubstr("i64: (4)"), // i64 + HasSubstr("u8: (5)"), // u8 + HasSubstr("u16: (6)"), // u16 + HasSubstr("u32: (7)"), // u32 + HasSubstr("u64: (8)"), // u64 + HasSubstr("f: (9.f)"), // f + HasSubstr("d: (10.)"), // d + HasSubstr("str: (\"foo bar baz\")"), // str + HasSubstr("ei8: (Second)"), // ei8 + HasSubstr("ei16: (Second)"), // ei16 + HasSubstr("ei32: (Second)"), // ei32 + HasSubstr("ei64: (Second)"), // ei64 + HasSubstr("eu8: (Second)"), // eu8 + HasSubstr("eu16: (Second)"), // eu16 + HasSubstr("eu32: (Second)"), // eu32 + HasSubstr("eu64: (Second)"), // eu64 + HasSubstr("t: ({b: (true)})"), // t + HasSubstr("u: (({b: (true)}))"), // u + HasSubstr( + "s: ({b: true, i8: 1, i16: 2, i32: 3, i64: 4, u8: 5, u16: 6, " + "u32: 7, u64: 8, f: 9.f, d: 10., ei8: First, ei16: First, " + "ei32: First, ei64: First, eu8: First, eu16: First, eu32: " + "First, eu64: First, s: {b: true}})"), // s + HasSubstr("v_b: ({true, false})"), // v_b + HasSubstr("v_i8: ({1, 2, 3})"), // v_i8 + HasSubstr("v_i16: ({1, 2, 3})"), // v_i16 + HasSubstr("v_i32: ({1, 2, 3})"), // v_i32 + HasSubstr("v_i64: ({1, 2, 3})"), // v_i64 + HasSubstr("v_u8: ({1, 2, 3})"), // v_u8 + HasSubstr("v_u16: ({1, 2, 3})"), // v_u16 + HasSubstr("v_u32: ({1, 2, 3})"), // v_u32 + HasSubstr("v_u64: ({1, 2, 3})"), // v_u64 + HasSubstr("v_f: ({1.f, 2.f, 3.f})"), // v_f + HasSubstr("v_d: ({1., 2., 3.})"), // v_d + HasSubstr("v_str: ({\"foo\", \"bar\", \"baz\"})"), // v_str + HasSubstr("v_ei8: ({First, Second})"), // v_ei8 + HasSubstr("v_ei16: ({First, Second})"), // v_ei16 + HasSubstr("v_ei32: ({First, Second})"), // v_ei32 + HasSubstr("v_ei64: ({First, Second})"), // v_ei64 + HasSubstr("v_eu8: ({First, Second})"), // v_eu8 + HasSubstr("v_eu16: ({First, Second})"), // v_eu16 + HasSubstr("v_eu32: ({First, Second})"), // v_eu32 + HasSubstr("v_eu64: ({First, Second})"), // v_eu64 + HasSubstr("v_t: ({{b: (true)}})"), // v_t + HasSubstr("v_u: ({({b: (true)}), " + "({str: (\"foo bar baz\")})})"), // v_u + HasSubstr( + "v_s: ({{b: true, i8: 1, i16: 2, i32: 3, i64: 4, u8: 5, u16: " + "6, u32: 7, u64: 8, f: 9.f, d: 10., ei8: First, ei16: First, " + "ei32: First, ei64: First, eu8: First, eu16: First, eu32: " + "First, eu64: First, s: {b: true}}})") // v_s + )); +} - absl::BitGen bitgen; - for (size_t i = 0; - i < IterationsToHitAll(null_fields.size(), 1.0 / null_fields.size()); - ++i) { - Value val(domain, bitgen); - val.Mutate(domain, bitgen, {}, false); - const auto& mut = val.user_value; +TEST(FlatbuffersTableDomainImplTest, MutateSelectedField) { + flatbuffers::FlatBufferBuilder fbb; + auto table = CreateDefaultTable(fbb); + auto domain = Arbitrary(); + absl::BitGen prng; - null_fields["u"] &= mut->u() == nullptr; - null_fields["s"] &= mut->s() == nullptr; - null_fields["v_b"] &= mut->v_b() == nullptr; - null_fields["v_i8"] &= mut->v_i8() == nullptr; - null_fields["v_i16"] &= mut->v_i16() == nullptr; - null_fields["v_i32"] &= mut->v_i32() == nullptr; - null_fields["v_i64"] &= mut->v_i64() == nullptr; - null_fields["v_u8"] &= mut->v_u8() == nullptr; - null_fields["v_u16"] &= mut->v_u16() == nullptr; - null_fields["v_u32"] &= mut->v_u32() == nullptr; - null_fields["v_u64"] &= mut->v_u64() == nullptr; - null_fields["v_f"] &= mut->v_f() == nullptr; - null_fields["v_d"] &= mut->v_d() == nullptr; - null_fields["v_str"] &= mut->v_str() == nullptr; - null_fields["v_ei8"] &= mut->v_ei8() == nullptr; - null_fields["v_ei16"] &= mut->v_ei16() == nullptr; - null_fields["v_ei32"] &= mut->v_ei32() == nullptr; - null_fields["v_ei64"] &= mut->v_ei64() == nullptr; - null_fields["v_eu8"] &= mut->v_eu8() == nullptr; - null_fields["v_eu16"] &= mut->v_eu16() == nullptr; - null_fields["v_eu32"] &= mut->v_eu32() == nullptr; - null_fields["v_eu64"] &= mut->v_eu64() == nullptr; - null_fields["v_t"] &= mut->v_t() == nullptr; - null_fields["v_u"] &= mut->v_u() == nullptr; - null_fields["v_s"] &= mut->v_s() == nullptr; - - if (std::any_of(null_fields.begin(), null_fields.end(), - [](const auto& p) { return !p.second; })) { - break; + { + // Mutate scalar field (b). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 1); + EXPECT_NE(domain.GetValue(*corpus)->b(), table->b()); + } + { + // Mutate nested table field (t.b). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 22); + EXPECT_NE(domain.GetValue(*corpus)->t()->b(), table->t()->b()); + } + { + // Mutate union type field (u_type). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 24); + EXPECT_NE(domain.GetValue(*corpus)->u_type(), table->u_type()); + } + { + // Mutate union field (u.*) content. + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 25); + switch (domain.GetValue(*corpus)->u_type()) { + case internal::Union_BoolTable: + EXPECT_NE(domain.GetValue(*corpus)->u_as_BoolTable()->b(), + table->u_as_BoolTable()->b()); + break; + case internal::Union_StringTable: + EXPECT_NE( + domain.GetValue(*corpus)->u_as_StringTable()->str()->string_view(), + table->u_as_StringTable()->str()->string_view()); + break; + default: + FAIL() << "Unexpected union type: " + << domain.GetValue(*corpus)->u_type(); + } + } + { + // Mutate vector of tables field (v_t). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 68); + auto user_value = domain.GetValue(*corpus); + EXPECT_FALSE(Eq(user_value->v_t(), table->v_t())); + } + { + // Mutate vector of tables field (v_t[0].b). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 69); + EXPECT_NE(domain.GetValue(*corpus)->v_t()->Get(0)->b(), + table->v_t()->Get(0)->b()); + } + { + // Mutate vector of unions field type (v_u_type[0]). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 71); + EXPECT_NE(domain.GetValue(*corpus)->v_u_type()->Get(0), + table->v_u_type()->Get(0)); + } + { + // Mutate vector of unions field (v_u[0].*). + auto corpus = domain.FromValue(table); + domain.MutateSelectedField(*corpus, prng, {}, false, 72); + EXPECT_EQ(domain.GetValue(*corpus)->v_u_type()->size(), + table->v_u_type()->size()); + EXPECT_EQ(domain.GetValue(*corpus)->v_u_type()->Get(0), + table->v_u_type()->Get(0)); + if (!domain.GetValue(*corpus)->v_u_type()->empty()) { + switch (domain.GetValue(*corpus)->v_u_type()->Get(0)) { + case internal::Union_BoolTable: + EXPECT_NE( + static_cast( + domain.GetValue(*corpus)->v_u()->Get(0)) + ->b(), + static_cast(table->v_u()->Get(0)) + ->b()); + break; + case internal::Union_StringTable: + EXPECT_NE( + static_cast( + domain.GetValue(*corpus)->v_u()->Get(0)) + ->str() + ->string_view(), + static_cast(table->v_u()->Get(0)) + ->str() + ->string_view()); + break; + default: + FAIL() << "Unexpected union type: " + << domain.GetValue(*corpus)->v_u_type()->Get(0); + } } } - - EXPECT_THAT(null_fields, Each(Pair(_, true))); } TEST(FlatbuffersTableDomainImplTest, MutateAlwaysChangesValues) { @@ -538,23 +1189,28 @@ TEST(FlatbuffersTableDomainImplTest, MutateAlwaysChangesValues) { schema->objects()->LookupByKey(DefaultTable::GetFullyQualifiedName()); absl::BitGen bitgen; - size_t iterations = IterationsToHitAll(object->fields()->size(), - 1.0 / object->fields()->size()); + const uint32_t field_count = object->fields()->size(); + int iterations = IterationsToHitAll(field_count, 1.0 / field_count); typename decltype(domain)::corpus_type corpus = domain.Init(bitgen); - for (size_t i = 0; i < iterations; ++i) { + for (int i = 0; i < iterations; ++i) { auto mutated_corpus = corpus; domain.Mutate(mutated_corpus, bitgen, {}, false); - EXPECT_FALSE(Eq(domain.GetValue(mutated_corpus), domain.GetValue(corpus))); + if (Eq(domain.GetValue(mutated_corpus), domain.GetValue(corpus))) { + auto printer = domain.GetPrinter(); + std::string corpus_str; + printer.PrintCorpusValue(corpus, &corpus_str, + domain_implementor::PrintMode::kHumanReadable); + std::string mutated_corpus_str; + printer.PrintCorpusValue(mutated_corpus, &mutated_corpus_str, + domain_implementor::PrintMode::kHumanReadable); + FAIL() << "Mutated corpus is equal to the original corpus." + << "\nOriginal: " << corpus_str + << "\nMutated: " << mutated_corpus_str; + } corpus = mutated_corpus; } } -TEST(FlatbuffersTableDomainImplTest, UnsupportedFieldsCountIsZero) { - auto domain = Arbitrary(); - auto corpus = domain.Init(absl::BitGen()); - EXPECT_EQ(domain.CountNumberOfFields(corpus), 0); -} - TEST(FlatbuffersTableDomainImplTest, CountNumberOfFieldsWithNull) { flatbuffers::FlatBufferBuilder fbb; auto table_offset = internal::CreateOptionalTableDirect(fbb); @@ -564,7 +1220,72 @@ TEST(FlatbuffersTableDomainImplTest, CountNumberOfFieldsWithNull) { auto domain = Arbitrary(); auto corpus = domain.FromValue(table); ASSERT_TRUE(corpus.has_value()); - EXPECT_EQ(domain.CountNumberOfFields(corpus.value()), 21); + EXPECT_EQ(domain.CountNumberOfFields(corpus.value()), 46); +} + +TEST(FlatbuffersUnionDomainImpl, ParseCorpusRejectsInvalidValues) { + auto domain = Arbitrary(); + { + flatbuffers::FlatBufferBuilder fbb; + internal::CreateUnionTable(fbb, internal::Union_BoolTable, 0); + fbb.Finish(internal::CreateUnionTable(fbb, internal::Union_BoolTable, 0)); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize()); + ASSERT_TRUE(verifier.VerifyBuffer()); + + auto corpus = domain.FromValue(table); + ASSERT_TRUE(corpus.has_value()); + EXPECT_FALSE(domain.ValidateCorpusValue(corpus.value()).ok()); + } + { + internal::IRObject ir_object; + auto& subs = ir_object.MutableSubs(); + subs.reserve(2); + + auto& u_obj = subs.emplace_back(); + auto& u_subs = u_obj.MutableSubs(); + u_subs.reserve(2); + u_subs.emplace_back(1); // id + auto& u_opt_value = u_subs.emplace_back(); // value + auto& u_opt_value_subs = u_opt_value.MutableSubs(); + u_opt_value_subs.reserve(2); + u_opt_value_subs.emplace_back(1); // has value + auto& u_inner_value = u_opt_value_subs.emplace_back(); + + u_inner_value.MutableSubs().reserve(2); + u_inner_value.MutableSubs().emplace_back(-1); // type (invalid) + u_inner_value.MutableSubs().emplace_back(); // value + + auto corpus = domain.ParseCorpus(ir_object); + ASSERT_FALSE(corpus.has_value()); + } + { + internal::IRObject ir_object; + auto& subs = ir_object.MutableSubs(); + subs.reserve(2); + + auto& u_obj = subs.emplace_back(); + auto& u_subs = u_obj.MutableSubs(); + u_subs.reserve(2); + u_subs.emplace_back(1); // id + auto& u_opt_value = u_subs.emplace_back(); // value + auto& u_opt_value_subs = u_opt_value.MutableSubs(); + u_opt_value_subs.reserve(2); + u_opt_value_subs.emplace_back(1); // has value + auto& u_inner_value = u_opt_value_subs.emplace_back(); + + u_inner_value.MutableSubs().reserve(2); + u_inner_value.MutableSubs().emplace_back( + internal::Union_BoolTable); // type + auto& bool_table = u_inner_value.MutableSubs().emplace_back(); // value + auto& bool_table_subs = bool_table.MutableSubs(); + bool_table_subs.reserve(2); + bool_table_subs.emplace_back(200); // id (invalid) + u_subs.emplace_back(); // value + + auto corpus = domain.ParseCorpus(ir_object); + ASSERT_FALSE(corpus.has_value()); + } } TEST(FlatbuffersTableDomainImplTest, RecursiveTable) { @@ -592,5 +1313,37 @@ TEST(FlatbuffersTableDomainImplTest, RecursiveTable) { ASSERT_THAT(new_table, IsNull()); } +TEST(FlatbuffersTableDomainImplTest, DefaultTable64ValueRoundTrip) { + flatbuffers::FlatBufferBuilder64 fbb; + auto str_offset = fbb.CreateString("foo bar baz"); + std::vector v_u8 = {1, 2, 3}; + auto v_u8_offset = fbb.CreateVector64(v_u8); + auto table_offset = + internal::CreateDefaultTable64(fbb, str_offset, v_u8_offset); + fbb.Finish(table_offset); + auto table = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + auto domain = Arbitrary(); + auto corpus = domain.FromValue(table); + ASSERT_TRUE(corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*corpus)); + + auto ir = domain.SerializeCorpus(corpus.value()); + + auto new_corpus = domain.ParseCorpus(ir); + ASSERT_TRUE(new_corpus.has_value()); + ASSERT_OK(domain.ValidateCorpusValue(*new_corpus)); + + auto new_table = domain.GetValue(*new_corpus); + ASSERT_THAT(new_table, NotNull()); + ASSERT_THAT(new_table->str(), NotNull()); + EXPECT_EQ(new_table->str()->str(), "foo bar baz"); + ASSERT_THAT(new_table->v_u8(), NotNull()); + ASSERT_EQ(new_table->v_u8()->size(), 3); + EXPECT_EQ(new_table->v_u8()->Get(0), 1); + EXPECT_EQ(new_table->v_u8()->Get(1), 2); + EXPECT_EQ(new_table->v_u8()->Get(2), 3); +} + } // namespace } // namespace fuzztest diff --git a/fuzztest/internal/BUILD b/fuzztest/internal/BUILD index f5155910..2bf470ff 100644 --- a/fuzztest/internal/BUILD +++ b/fuzztest/internal/BUILD @@ -616,8 +616,13 @@ cc_test( flatbuffer_library_public( name = "test_flatbuffers_fbs", - srcs = ["test_flatbuffers.fbs"], + srcs = [ + "test_flatbuffers.fbs", + "test_flatbuffers_64bits.fbs", + ], outs = [ + "test_flatbuffers_64bits_bfbs_generated.h", + "test_flatbuffers_64bits_generated.h", "test_flatbuffers_bfbs_generated.h", "test_flatbuffers_generated.h", ], diff --git a/fuzztest/internal/CMakeLists.txt b/fuzztest/internal/CMakeLists.txt index 0eb75aba..5c1173e6 100644 --- a/fuzztest/internal/CMakeLists.txt +++ b/fuzztest/internal/CMakeLists.txt @@ -574,6 +574,7 @@ if (FUZZTEST_BUILD_FLATBUFFERS) test_flatbuffers_headers SCHEMAS "test_flatbuffers.fbs" + "test_flatbuffers_64bits.fbs" FLAGS --bfbs-gen-embed --gen-name-strings TESTONLY diff --git a/fuzztest/internal/domains/BUILD b/fuzztest/internal/domains/BUILD index 545fef36..30790fc5 100644 --- a/fuzztest/internal/domains/BUILD +++ b/fuzztest/internal/domains/BUILD @@ -187,9 +187,9 @@ cc_library( hdrs = ["flatbuffers_domain_impl.h"], deps = [ ":core_domains_impl", - "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/random:bit_gen_ref", @@ -204,6 +204,7 @@ cc_library( "@com_google_fuzztest//fuzztest/internal:meta", "@com_google_fuzztest//fuzztest/internal:serialization", "@com_google_fuzztest//fuzztest/internal:status", + "@com_google_fuzztest//fuzztest/internal:type_support", "@flatbuffers//:runtime_cc", ], ) diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.cc b/fuzztest/internal/domains/flatbuffers_domain_impl.cc index 7e8e144e..866c4bdb 100644 --- a/fuzztest/internal/domains/flatbuffers_domain_impl.cc +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.cc @@ -11,276 +11,490 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include "./fuzztest/internal/domains/flatbuffers_domain_impl.h" #include #include #include +#include #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/random/bit_gen_ref.h" #include "absl/random/distributions.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "flatbuffers/base.h" #include "flatbuffers/flatbuffer_builder.h" -#include "flatbuffers/reflection.h" #include "flatbuffers/reflection_generated.h" +#include "flatbuffers/struct.h" +#include "flatbuffers/table.h" +#include "./common/logging.h" #include "./fuzztest/domain_core.h" #include "./fuzztest/internal/any.h" #include "./fuzztest/internal/domains/domain_base.h" -#include "./fuzztest/internal/domains/domain_type_erasure.h" +#include "./fuzztest/internal/meta.h" #include "./fuzztest/internal/serialization.h" namespace fuzztest::internal { -FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( - const reflection::Schema* absl_nonnull schema, - const reflection::Object* absl_nonnull table_object) - : schema_(schema), table_object_(table_object) {} +// Gets a domain for a specific struct type. +template <> +auto FlatbuffersUnionDomainImpl::GetDefaultDomainForType( + const reflection::EnumVal& enum_value) const { + const reflection::Object* object = + schema_->objects()->Get(enum_value.union_type()->index()); + return Domain( + FlatbuffersStructUntypedDomainImpl{schema_, object}); +} -FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( - const FlatbuffersTableUntypedDomainImpl& other) - : DomainBase(other), - schema_(other.schema_), - table_object_(other.table_object_) { - absl::MutexLock l_other(other.mutex_); - absl::MutexLock l_this(mutex_); - domains_ = other.domains_; +FlatbuffersUnionDomainImpl::FlatbuffersUnionDomainImpl( + const reflection::Schema* schema, const reflection::Enum* union_def) + : schema_(schema), union_def_(union_def), type_domain_(union_def) { + type_domain_.WithExcludedValues({0 /* NONE */}); } -FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( - const FlatbuffersTableUntypedDomainImpl& other) { - DomainBase::operator=(other); - schema_ = other.schema_; - table_object_ = other.table_object_; +FlatbuffersUnionDomainImpl::FlatbuffersUnionDomainImpl( + const FlatbuffersUnionDomainImpl& other) + : schema_(other.schema_), + union_def_(other.union_def_), + type_domain_(other.type_domain_) { + absl::MutexLock l(mutex_); absl::MutexLock l_other(other.mutex_); - absl::MutexLock l_this(mutex_); domains_ = other.domains_; - return *this; } -FlatbuffersTableUntypedDomainImpl::FlatbuffersTableUntypedDomainImpl( - FlatbuffersTableUntypedDomainImpl&& other) - : schema_(other.schema_), table_object_(other.table_object_) { +FlatbuffersUnionDomainImpl::FlatbuffersUnionDomainImpl( + FlatbuffersUnionDomainImpl&& other) + : schema_(other.schema_), + union_def_(other.union_def_), + type_domain_(std::move(other.type_domain_)) { + absl::MutexLock l(mutex_); absl::MutexLock l_other(other.mutex_); - absl::MutexLock l_this(mutex_); domains_ = std::move(other.domains_); - DomainBase::operator=(std::move(other)); } -FlatbuffersTableUntypedDomainImpl& FlatbuffersTableUntypedDomainImpl::operator=( - FlatbuffersTableUntypedDomainImpl&& other) { - schema_ = other.schema_; - table_object_ = other.table_object_; - absl::MutexLock l_other(other.mutex_); - absl::MutexLock l_this(mutex_); - domains_ = std::move(other.domains_); - DomainBase::operator=(std::move(other)); - return *this; +// Get a domain for a specific table type. +template <> +auto FlatbuffersUnionDomainImpl::GetDefaultDomainForType( + const reflection::EnumVal& enum_value) const { + const reflection::Object* object = + schema_->objects()->Get(enum_value.union_type()->index()); + return Domain( + FlatbuffersTableUntypedDomainImpl{schema_, object}); } -FlatbuffersTableUntypedDomainImpl::corpus_type -FlatbuffersTableUntypedDomainImpl::Init(absl::BitGenRef prng) { +FlatbuffersUnionDomainImpl::corpus_type FlatbuffersUnionDomainImpl::Init( + absl::BitGenRef prng) { if (auto seed = this->MaybeGetRandomSeed(prng)) { return *seed; } + + // Unions are encoded as the combination of two fields: an enum representing + // the union choice and the offset to the actual element. + // + // The following code follows that logic. corpus_type val; - for (const auto* field : *table_object_->fields()) { - VisitFlatbufferField(schema_, field, InitializeVisitor{*this, prng, val}); + + val.type = type_domain_.Init(prng); + auto type_value = type_domain_.GetValue(val.type); + + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return val; + } + + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto inner_val = + GetCachedDomain(*type_enumval).Init(prng); + val.value = std::move(inner_val); + } else { + auto inner_val = + GetCachedDomain(*type_enumval).Init(prng); + val.value = std::move(inner_val); } return val; } // Mutates the corpus value. -void FlatbuffersTableUntypedDomainImpl::Mutate( - corpus_type& val, absl::BitGenRef prng, +void FlatbuffersUnionDomainImpl::Mutate( + corpus_type& corpus_value, absl::BitGenRef prng, const domain_implementor::MutationMetadata& metadata, bool only_shrink) { - uint64_t field_count = 0; - for (const auto* field : *table_object_->fields()) { - VisitFlatbufferField(schema_, field, - CountNumberOfMutableFieldsVisitor{*this, field_count, - val, only_shrink}); + auto type_value = type_domain_.GetValue(corpus_value.type); + + // Mutate the type with probability 1%. + if (absl::Bernoulli(prng, 0.01)) { + // Mutate the type. + type_domain_.Mutate(corpus_value.type, prng, metadata, only_shrink); + type_value = type_domain_.GetValue(corpus_value.type); + + // If the union is set after type mutation, init the value corpus value. + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) return; + + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + corpus_value.value = + GetCachedDomain(*type_enumval).Init(prng); + } else { + corpus_value.value = + GetCachedDomain(*type_enumval).Init(prng); + } + return; } - if (field_count == 0) return; - auto selected_field_index = absl::Uniform(prng, 1ul, field_count + 1); - MutateSelectedField(val, prng, metadata, only_shrink, selected_field_index); + // Mutate the value if the union is set. + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) return; + + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = GetCachedDomain(*type_enumval); + domain.Mutate(corpus_value.value, prng, metadata, only_shrink); + } else { + GetCachedDomain(*type_enumval) + .Mutate(corpus_value.value, prng, metadata, only_shrink); + } } -uint64_t FlatbuffersTableUntypedDomainImpl::CountNumberOfFields( - corpus_type& val) { +uint64_t FlatbuffersUnionDomainImpl::CountNumberOfFields( + corpus_type& corpus_value) { uint64_t field_count = 0; - for (const auto* field : *table_object_->fields()) { - VisitFlatbufferField( - schema_, field, - CountNumberOfMutableFieldsVisitor{*this, field_count, val}); + + // If the union has only one type (besides NONE), the type is not counted + // as mutable field. + if (union_def_->values()->size() <= 2) { + return field_count; + } + + // The first field is the union type. + ++field_count; + + auto type_value = type_domain_.GetValue(corpus_value.type); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return field_count; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = GetCachedDomain(*type_enumval); + field_count += domain.CountNumberOfFields(corpus_value.value); + } else { + auto domain = GetCachedDomain(*type_enumval); + field_count += domain.CountNumberOfFields(corpus_value.value); } return field_count; } -uint64_t FlatbuffersTableUntypedDomainImpl::MutateSelectedField( - corpus_type& val, absl::BitGenRef prng, +uint64_t FlatbuffersUnionDomainImpl::MutateSelectedField( + corpus_type& corpus_value, absl::BitGenRef prng, const domain_implementor::MutationMetadata& metadata, bool only_shrink, uint64_t selected_field_index) { - uint64_t field_counter = 0; - uint64_t fields_count = CountNumberOfFields(val); - if (fields_count < selected_field_index) { - return fields_count; - } + uint64_t field_count = 0; - for (const auto* field : *table_object_->fields()) { - if (!IsSupportedField(field)) { - if (only_shrink && !val.contains(field->id())) continue; - } + // If the union has only one type (besides NONE), the type is not counted + // as mutable field. + if (union_def_->values()->size() <= 2) { + return field_count; + } - ++field_counter; - if (field_counter == selected_field_index) { - VisitFlatbufferField( - schema_, field, - MutateVisitor{*this, prng, metadata, only_shrink, val}); - return field_counter; + // The first field is the union type. + ++field_count; + if (selected_field_index == field_count) { + type_domain_.Mutate(corpus_value.type, prng, metadata, only_shrink); + auto type_value = type_domain_.GetValue(corpus_value.type); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) return selected_field_index; + + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + corpus_value.value = + GetCachedDomain(*type_enumval).Init(prng); + } else { + corpus_value.value = + GetCachedDomain(*type_enumval).Init(prng); } + return field_count; + } - if (field->type()->base_type() == reflection::BaseType::Obj) { - auto sub_object = schema_->objects()->Get(field->type()->index()); - if (!sub_object->is_struct()) { - field_counter += - GetCachedDomain(field).MutateSelectedField( - val[field->id()], prng, metadata, only_shrink, - selected_field_index - field_counter); - } - // TODO: Add support for structs. - } + auto type_value = type_domain_.GetValue(corpus_value.type); - if (field_counter >= selected_field_index) { - return field_counter; - } + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return 0; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = GetCachedDomain(*type_enumval); + field_count += domain.MutateSelectedField( + corpus_value.value, prng, metadata, only_shrink, + selected_field_index - field_count); + } else { + auto domain = GetCachedDomain(*type_enumval); + field_count += domain.MutateSelectedField( + corpus_value.value, prng, metadata, only_shrink, + selected_field_index - field_count); } - return field_counter; + return field_count; } -absl::Status FlatbuffersTableUntypedDomainImpl::ValidateCorpusValue( +absl::Status FlatbuffersUnionDomainImpl::ValidateCorpusValue( const corpus_type& corpus_value) const { - for (const auto* field : *table_object_->fields()) { - absl::Status result; - GenericDomainCorpusType field_corpus; - if (auto it = corpus_value.find(field->id()); it != corpus_value.end()) { - field_corpus = it->second; - } - if (!field_corpus.has_value()) continue; - VisitFlatbufferField(schema_, field, - ValidateVisitor{*this, field_corpus, result}); - if (!result.ok()) return result; + // Unions are encoded as the combination of two fields: an enum representing + // the union choice and the offset to the actual element. + // + // Both type and value should be validated. + // + // Start with the type validation. + auto type_value = type_domain_.GetValue(corpus_value.type); + + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid union type: ", type_value)); + } + + // Validate the value. + if (!corpus_value.value.has_value()) { + return absl::InvalidArgumentError("Union value is not set."); + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = GetCachedDomain(*type_enumval); + return domain.ValidateCorpusValue(corpus_value.value); + } else { + auto domain = GetCachedDomain(*type_enumval); + return domain.ValidateCorpusValue(corpus_value.value); } - return absl::OkStatus(); } -std::optional -FlatbuffersTableUntypedDomainImpl::FromValue(const value_type& value) const { - if (value == nullptr) { +// Converts the value to a corpus value. +std::optional +FlatbuffersUnionDomainImpl::FromValue(const value_type& value) const { + auto out = std::make_optional(); + auto type_corpus = type_domain_.FromValue(value.type); + if (type_corpus.has_value()) { + out->type = *type_corpus; + } + auto type_enumval = union_def_->values()->LookupByKey(value.type); + if (type_enumval == nullptr) { return std::nullopt; } - corpus_type ret; - for (const auto* field : *table_object_->fields()) { - VisitFlatbufferField(schema_, field, FromValueVisitor{*this, value, ret}); + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + std::optional inner_corpus; + if (object->is_struct()) { + auto domain = GetCachedDomain(*type_enumval); + inner_corpus = + domain.FromValue(static_cast(value.value)); + } else { + auto domain = GetCachedDomain(*type_enumval); + inner_corpus = + domain.FromValue(static_cast(value.value)); } - return ret; + if (inner_corpus.has_value()) { + out->value = std::move(inner_corpus.value()); + } + return out; } -std::optional -FlatbuffersTableUntypedDomainImpl::ParseCorpus(const IRObject& obj) const { +// Converts the IRObject to a corpus value. +std::optional +FlatbuffersUnionDomainImpl::ParseCorpus(const IRObject& obj) const { + // Follows the structure created by `SerializeCorpus` to deserialize the + // IRObject. corpus_type out; auto subs = obj.Subs(); if (!subs) { return std::nullopt; } - // Follows the structure created by `SerializeCorpus` to deserialize the - // IRObject. - // subs->size() represents the number of fields in the table. - out.reserve(subs->size()); - for (const auto& sub : *subs) { - auto pair_subs = sub.Subs(); - // Each field is represented by a pair of field id and the serialized - // corpus value. - if (!pair_subs.has_value() || pair_subs->size() != 2) { - return std::nullopt; - } + // We expect 2 fields: the type and the value. + if (subs->size() != 2) { + return std::nullopt; + } - // Deserialize the field id. - auto id = (*pair_subs)[0].GetScalar(); - if (!id.has_value()) { - return std::nullopt; - } + // Parse the type which is stored in the first field of the IRObject subs. + auto type_corpus = type_domain_.ParseCorpus((*subs)[0]); + if (!type_corpus.has_value()) { + return std::nullopt; + } + if (auto status = type_domain_.ValidateCorpusValue(*type_corpus); + !status.ok()) { + FUZZTEST_LOG(ERROR) << "Failed to validate type corpus: " + << status.message(); + return std::nullopt; + } + out.type = *type_corpus; + auto type_value = type_domain_.GetValue(out.type); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return std::nullopt; + } - // Get information about the field from reflection. - const reflection::Field* absl_nullable field = GetFieldById(*id); - if (field == nullptr) { - return std::nullopt; - } + // Parse the value. + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object == nullptr) { + return std::nullopt; + } + std::optional inner_corpus; + if (object->is_struct()) { + auto domain = GetCachedDomain(*type_enumval); + // The value is stored in the second field of the IRObject subs. + inner_corpus = domain.ParseCorpus((*subs)[1]); + } else { + auto domain = GetCachedDomain(*type_enumval); + // The value is stored in the second field of the IRObject subs. + inner_corpus = domain.ParseCorpus((*subs)[1]); + } - // Deserialize the field corpus value. - std::optional inner_parsed; - VisitFlatbufferField(schema_, field, - ParseVisitor{*this, (*pair_subs)[1], inner_parsed}); - if (!inner_parsed) { - return std::nullopt; - } - out[id.value()] = *std::move(inner_parsed); + if (inner_corpus.has_value()) { + out.value = std::move(inner_corpus.value()); } return out; } -IRObject FlatbuffersTableUntypedDomainImpl::SerializeCorpus( - const corpus_type& value) const { +// Converts the corpus value to an IRObject. +IRObject FlatbuffersUnionDomainImpl::SerializeCorpus( + const corpus_type& corpus_value) const { IRObject out; - auto& subs = out.MutableSubs(); - subs.reserve(value.size()); - - // Each field is represented by a pair of field id and the serialized - // corpus value. - for (const auto& [id, field_corpus] : value) { - // Get information about the field from reflection. - const reflection::Field* absl_nullable field = GetFieldById(id); - if (field == nullptr) { - continue; - } - IRObject& pair = subs.emplace_back(); - auto& pair_subs = pair.MutableSubs(); - pair_subs.reserve(2); + auto type_value = type_domain_.GetValue(corpus_value.type); - // Serialize the field id. - pair_subs.emplace_back(field->id()); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return out; + } - // Serialize the field corpus value. - VisitFlatbufferField( - schema_, field, - SerializeVisitor{*this, field_corpus, pair_subs.emplace_back()}); + auto& pair = out.MutableSubs(); + // We have 2 fields: the type and the value. + pair.reserve(2); + + // Serialize the type. + pair.push_back(type_domain_.SerializeCorpus(corpus_value.type)); + + // Serialize the value. + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = GetCachedDomain(*type_enumval); + pair.push_back(domain.SerializeCorpus(corpus_value.value)); + } else { + auto domain = GetCachedDomain(*type_enumval); + pair.push_back(domain.SerializeCorpus(corpus_value.value)); } return out; } -bool FlatbuffersTableUntypedDomainImpl::IsSupportedField( - const reflection::Field* absl_nonnull field) const { - auto base_type = field->type()->base_type(); - if (base_type == reflection::BaseType::UType) return false; - if (flatbuffers::IsScalar(base_type)) return true; - if (base_type == reflection::BaseType::String) return true; - if (base_type == reflection::BaseType::Obj) { - auto sub_object = schema_->objects()->Get(field->type()->index()); - return !sub_object->is_struct(); - }; - return false; +std::optional FlatbuffersUnionDomainImpl::BuildValue( + const corpus_type& corpus_value, + flatbuffers::FlatBufferBuilder64& builder) const { + // Get the object type. + auto type_value = type_domain_.GetValue(corpus_value.type); + auto type_enumval = union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr || !corpus_value.value.has_value()) { + return std::nullopt; + } + const reflection::Object* object = + schema_->objects()->Get(type_enumval->union_type()->index()); + if (object == nullptr) { + return std::nullopt; + } + if (object->is_struct()) { + FlatbuffersStructUntypedDomainImpl domain{schema_, object}; + return domain.BuildValue( + corpus_value.value + .GetAs>(), + builder); + } else { + FlatbuffersTableUntypedDomainImpl domain{schema_, object}; + return domain.BuildTable( + corpus_value.value + .GetAs>(), + builder); + } +} + +void FlatbuffersUnionDomainImpl::Printer::PrintCorpusValue( + const corpus_type& value, domain_implementor::RawSink out, + domain_implementor::PrintMode mode) const { + auto type_value = self.type_domain_.GetValue(value.type); + auto type_enumval = self.union_def_->values()->LookupByKey(type_value); + if (type_enumval == nullptr) { + return; + } + absl::Format(out, "<%s>(", type_enumval->name()->str()); + + const reflection::Object* object = + self.schema_->objects()->Get(type_enumval->union_type()->index()); + if (object->is_struct()) { + auto domain = self.GetCachedDomain(*type_enumval); + domain_implementor::PrintValue(domain, value.value, out, mode); + } else { + auto domain = self.GetCachedDomain(*type_enumval); + domain_implementor::PrintValue(domain, value.value, out, mode); + } + absl::Format(out, ")"); +} + +std::optional +FlatbuffersStructUntypedDomainImpl::FromValue(const value_type& value) const { + if (value == nullptr) { + return std::nullopt; + } + corpus_type val; + for (const auto& [_, field] : fields_by_id_) { + VisitFlatbufferField(schema_, field, FromValueVisitor{*this, value, val}); + } + return val; +} + +std::optional +FlatbuffersStructUntypedDomainImpl::BuildValue( + const corpus_type& value, flatbuffers::FlatBufferBuilder64& builder) const { + std::vector buf(object_->bytesize()); + BuildValue(value, buf.data()); + builder.StartStruct(object_->minalign()); + builder.PushBytes(buf.data(), buf.size()); + return builder.EndStruct(); +} + +void FlatbuffersStructUntypedDomainImpl::BuildValue(const corpus_type& value, + uint8_t* buf) const { + for (const auto& [_, field] : fields_by_id_) { + VisitFlatbufferField(schema_, field, BuildValueVisitor{*this, value, buf}); + } +} + +std::optional +FlatbuffersTableUntypedDomainImpl::FromValue(const value_type& value) const { + if (value == nullptr) { + return std::nullopt; + } + corpus_type ret; + for (const auto& [_, field] : fields_by_id_) { + VisitFlatbufferField(schema_, field, FromValueVisitor{*this, value, ret}); + } + return ret; } uint32_t FlatbuffersTableUntypedDomainImpl::BuildTable( - const corpus_type& value, flatbuffers::FlatBufferBuilder& builder) const { + const corpus_type& value, flatbuffers::FlatBufferBuilder64& builder) const { // Add all the fields to the builder. // Offsets is the map of field id to its offset in the table. - absl::flat_hash_map + absl::flat_hash_map offsets; // Some fields are stored inline in the flatbuffer table itself (a.k.a diff --git a/fuzztest/internal/domains/flatbuffers_domain_impl.h b/fuzztest/internal/domains/flatbuffers_domain_impl.h index 1482a598..61a80b93 100644 --- a/fuzztest/internal/domains/flatbuffers_domain_impl.h +++ b/fuzztest/internal/domains/flatbuffers_domain_impl.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -26,9 +27,9 @@ #include #include -#include "absl/algorithm/container.h" #include "absl/base/nullability.h" #include "absl/base/thread_annotations.h" +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/random/bit_gen_ref.h" @@ -37,21 +38,27 @@ #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "flatbuffers/base.h" +#include "flatbuffers/buffer.h" #include "flatbuffers/flatbuffer_builder.h" +#include "flatbuffers/reflection.h" #include "flatbuffers/reflection_generated.h" #include "flatbuffers/string.h" +#include "flatbuffers/struct.h" #include "flatbuffers/table.h" +#include "flatbuffers/vector.h" #include "flatbuffers/verifier.h" #include "./common/logging.h" #include "./fuzztest/domain_core.h" #include "./fuzztest/internal/any.h" #include "./fuzztest/internal/domains/arbitrary_impl.h" +#include "./fuzztest/internal/domains/container_of_impl.h" #include "./fuzztest/internal/domains/domain_base.h" #include "./fuzztest/internal/domains/domain_type_erasure.h" #include "./fuzztest/internal/domains/element_of_impl.h" #include "./fuzztest/internal/meta.h" #include "./fuzztest/internal/serialization.h" #include "./fuzztest/internal/status.h" +#include "./fuzztest/internal/type_support.h" namespace fuzztest::internal { @@ -68,81 +75,160 @@ struct FlatbuffersEnumTag { template struct is_flatbuffers_enum_tag : std::false_type {}; -template -struct is_flatbuffers_enum_tag> +template +struct is_flatbuffers_enum_tag> : std::true_type {}; template inline constexpr bool is_flatbuffers_enum_tag_v = is_flatbuffers_enum_tag::value; +// +// Flatbuffers vector detection. +// +template +struct FlatbuffersVectorTag { + using value_type = T; +}; + +template +struct is_flatbuffers_vector_tag : std::false_type {}; + +template +struct is_flatbuffers_vector_tag> : std::true_type {}; + +template +inline constexpr bool is_flatbuffers_vector_tag_v = + is_flatbuffers_vector_tag::value; + +template +struct FlatbuffersVector64Tag { + using value_type = T; +}; + +template +struct is_flatbuffers_vector64_tag : std::false_type {}; + +template +struct is_flatbuffers_vector64_tag> : std::true_type { +}; + +template +inline constexpr bool is_flatbuffers_vector64_tag_v = + is_flatbuffers_vector64_tag::value; + +template +inline constexpr bool is_any_flatbuffers_vector_tag_v = + is_flatbuffers_vector_tag_v || is_flatbuffers_vector64_tag_v; + +template +struct flatbuffers_vector_tag_offset; + +template +struct flatbuffers_vector_tag_offset> { + using type = flatbuffers::uoffset_t; +}; + +template +struct flatbuffers_vector_tag_offset> { + using type = flatbuffers::uoffset64_t; +}; + +template +using flatbuffers_vector_tag_offset_t = + typename flatbuffers_vector_tag_offset::type; + struct FlatbuffersArrayTag; + +// Flatbuffers container element type detection. +template +inline constexpr bool is_flatbuffers_container_of_v = []() constexpr { + if constexpr (is_flatbuffers_vector_tag_v || + is_flatbuffers_vector64_tag_v) { + return std::is_same_v; + } else { + return false; + } +}(); + struct FlatbuffersTableTag; struct FlatbuffersStructTag; struct FlatbuffersUnionTag; -struct FlatbuffersVectorTag; + +// Helper to wrap the visitor with the correct tag type. +template