From fd5857337a99abefc0ca73259307cc9e35f43a0c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 1 Jun 2026 17:00:30 +0200 Subject: [PATCH] Handle mixed types for inner properly --- ext/VectorInterfaceEnzymeExt.jl | 14 ++++++++++++-- ext/VectorInterfaceMooncakeExt.jl | 13 +++++++++++-- test/enzyme.jl | 8 ++++---- test/mooncake.jl | 8 ++++---- 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/ext/VectorInterfaceEnzymeExt.jl b/ext/VectorInterfaceEnzymeExt.jl index d08aea9..60c9642 100644 --- a/ext/VectorInterfaceEnzymeExt.jl +++ b/ext/VectorInterfaceEnzymeExt.jl @@ -215,6 +215,16 @@ function EnzymeRules.forward( end end +function project_add!(C, A, α) + TC = Base.promote_op(+, scalartype(A), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(add!(zerovector(C, TC), A, α))) + else + add!(C, A, α) + end +end + + function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(inner)}, @@ -241,8 +251,8 @@ function EnzymeRules.reverse( ) ΔS = dret.val Aval, Bval = cache - !isa(A, Const) && add!(A.dval, Bval, conj(ΔS)) - !isa(B, Const) && add!(B.dval, Aval, ΔS) + !isa(A, Const) && project_add!(A.dval, Bval, conj(ΔS)) + !isa(B, Const) && project_add!(B.dval, Aval, ΔS) return (nothing, nothing) end diff --git a/ext/VectorInterfaceMooncakeExt.jl b/ext/VectorInterfaceMooncakeExt.jl index 124583f..0286c57 100644 --- a/ext/VectorInterfaceMooncakeExt.jl +++ b/ext/VectorInterfaceMooncakeExt.jl @@ -140,6 +140,15 @@ end # inner # ----- +function project_add!(C, A, α) + TC = Base.promote_op(+, scalartype(A), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(add!(zerovector(C, TC), A, α))) + else + add!(C, A, α) + end +end + @is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray} function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual) @@ -151,8 +160,8 @@ function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual) s = inner(A, B) function inner_pullback(Δs) - add!(ΔA, B, conj(Δs)) - add!(ΔB, A, Δs) + project_add!(ΔA, B, conj(Δs)) + project_add!(ΔB, A, Δs) return NoRData(), NoRData(), NoRData() end diff --git a/test/enzyme.jl b/test/enzyme.jl index 8d22bca..cb1d154 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -104,13 +104,13 @@ end end end -@testset "inner ($T)" for T in eltypes +@testset "inner ($Tx, $Ty)" for Tx in eltypes, Ty in eltypes n = 12 - atol = rtol = n * precision(T) + atol = rtol = n * max(precision(Tx), precision(Ty)) # Vector - x = randn(T, n) - y = randn(T, n) + x = randn(Tx, n) + y = randn(Ty, n) for RT in (Const, Active) test_reverse(inner, RT, (x, Duplicated), (y, Duplicated); atol, rtol) end diff --git a/test/mooncake.jl b/test/mooncake.jl index 02e3305..92468a3 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -73,13 +73,13 @@ end Mooncake.TestUtils.test_rule(rng, add!!, my, mx, α, β; atol, rtol, is_primitive = false) end -@testset "inner pullbacks ($T)" for T in eltypes +@testset "inner pullbacks ($Tx, $Ty)" for Tx in eltypes, Ty in eltypes n = 12 - atol = rtol = n * precision(T) + atol = rtol = n * max(precision(Tx), precision(Ty)) # Vector - x = randn(T, n) - y = randn(T, n) + x = randn(Tx, n) + y = randn(Ty, n) Mooncake.TestUtils.test_rule(rng, inner, x, y; atol, rtol, is_primitive = false) # MinimalMVec