diff --git a/ext/VectorInterfaceEnzymeExt.jl b/ext/VectorInterfaceEnzymeExt.jl index 60c9642..a61c775 100644 --- a/ext/VectorInterfaceEnzymeExt.jl +++ b/ext/VectorInterfaceEnzymeExt.jl @@ -7,6 +7,15 @@ using Enzyme using Enzyme.EnzymeCore using Enzyme.EnzymeCore: EnzymeRules +function project_add!(C, A, α) + TC = Base.promote_op(+, scalartype(A), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(add!(zerovector(C, TC), A, α))) + else + add!(C, A, α) + end +end + """ project_scalar(x::Number, dx::Number) @@ -104,7 +113,7 @@ function EnzymeRules.reverse( α::Annotation{<:Number}, ) where {RT} Aval, αval = cache - !isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(αval)) + !isa(A, Const) && !isa(C, Const) && project_add!(A.dval, C.dval, conj(αval)) Δα = if !isa(α, Const) && !isa(C, Const) project_scalar(α.val, inner(Aval, C.dval)) elseif !isa(α, Const) @@ -186,7 +195,7 @@ function EnzymeRules.reverse( else nothing end - !isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(αval)) + !isa(A, Const) && !isa(C, Const) && project_add!(A.dval, C.dval, conj(αval)) !isa(C, Const) && scale!(C.dval, conj(βval)) return (nothing, nothing, Δα, Δβ) end @@ -215,15 +224,6 @@ function EnzymeRules.forward( end end -function project_add!(C, A, α) - TC = Base.promote_op(+, scalartype(A), scalartype(α)) - return if !(TC <: Real) && scalartype(C) <: Real - add!(C, real(add!(zerovector(C, TC), A, α))) - else - add!(C, A, α) - end -end - function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, diff --git a/ext/VectorInterfaceMooncakeExt.jl b/ext/VectorInterfaceMooncakeExt.jl index 0286c57..00ad349 100644 --- a/ext/VectorInterfaceMooncakeExt.jl +++ b/ext/VectorInterfaceMooncakeExt.jl @@ -17,6 +17,15 @@ For example, we might compute a complex `dx` but only require the real part. project_scalar(x::Number, dx::Number) = oftype(x, dx) project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) +function project_add!(C, A, α) + TC = Base.promote_op(+, scalartype(A), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(add!(zerovector(C, TC), A, α))) + else + add!(C, A, α) + end +end + _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{T}) where {T <: Number} = Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData @@ -72,7 +81,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, A_ΔA::CoDual function scale_pullback(::NoRData) copy!(C, C_cache) - add!(ΔA, ΔC, conj(α)) + project_add!(ΔA, ΔC, conj(α)) Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData() zerovector!(ΔC) return NoRData(), NoRData(), NoRData(), Δαr @@ -114,7 +123,7 @@ function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual, A_ΔA::CoDual, Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData() Δβr = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData() - add!(ΔA, ΔC, conj(α)) + project_add!(ΔA, ΔC, conj(α)) scale!(ΔC, conj(β)) return NoRData(), NoRData(), NoRData(), Δαr, Δβr @@ -140,15 +149,6 @@ end # inner # ----- -function project_add!(C, A, α) - TC = Base.promote_op(+, scalartype(A), scalartype(α)) - return if !(TC <: Real) && scalartype(C) <: Real - add!(C, real(add!(zerovector(C, TC), A, α))) - else - add!(C, A, α) - end -end - @is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray} function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual) diff --git a/test/enzyme.jl b/test/enzyme.jl index cb1d154..fbc9f5a 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -61,13 +61,14 @@ eltypes = (Float64, ComplexF64) end end -@testset "add ($T)" for T in eltypes +@testset "add ($Tx, $Ty)" for Ty in eltypes, Tx in eltypes n = 12 - atol = rtol = n * precision(T) + atol = rtol = n * max(precision(Tx), precision(Ty)) + T = Base.promote_op(+, Tx, Ty) # Vector - x = randn(T, n) - y = randn(T, n) + x = randn(Tx, n) + y = randn(Ty, n) α = randn(T) β = randn(T) for Tα in (Const, Active), Tβ in (Const, Active) diff --git a/test/mooncake.jl b/test/mooncake.jl index 92468a3..5ffa7a7 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -48,13 +48,13 @@ eltypes = (Float32, Float64, ComplexF64) Mooncake.TestUtils.test_rule(rng, scale!!, my, mx, α; atol, rtol, is_primitive = false) end -@testset "add pullbacks ($T)" for T in eltypes +@testset "add ($Tx, $Ty)" for Ty in eltypes, Tx in eltypes n = 12 - atol = rtol = n * precision(T) - + atol = rtol = n * max(precision(Tx), precision(Ty)) + T = Base.promote_op(+, Tx, Ty) # Vector - x = randn(T, n) - y = randn(T, n) + x = randn(Tx, n) + y = randn(Ty, n) α = randn(T) β = randn(T) Mooncake.TestUtils.test_rule(rng, add, y, x, α, β; atol, rtol, is_primitive = false)