Skip to content

Commit b06250e

Browse files
Merge pull request SciML#61 from ChrisRackauckas-Claude/interface-check-20251230-000251
Add type validation and interface compatibility tests
2 parents 4003555 + 1fc6813 commit b06250e

File tree

4 files changed

+169
-0
lines changed

4 files changed

+169
-0
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ MATLABDiffEq.jl is simply a solver on the DiffEq common interface, so for detail
2626
However, the only options implemented are those for error calculations
2727
(`timeseries_errors`), `saveat`, and tolerances.
2828

29+
### Type Requirements
30+
31+
Since this package sends data to MATLAB for computation, it only supports types
32+
that MATLAB can handle:
33+
34+
- **Supported types:** `Float64`, integers (`Int64`, etc.), and `Complex{Float64}`
35+
- **Not supported:** `BigFloat`, `Float32`, GPU arrays (`CuArray`, `JLArray`), or other custom array types
36+
37+
If you need arbitrary precision or GPU computing, use the native Julia solvers
38+
from [DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/) instead.
39+
2940
Note that the algorithms are defined to have the same name as the MATLAB algorithms,
3041
but are not exported. Thus to use `ode45`, you would specify the algorithm as
3142
`MATLABDiffEq.ode45()`.

src/MATLABDiffEq.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,39 @@ using Reexport
55
using MATLAB, ModelingToolkit
66
using PrecompileTools
77

8+
# MATLAB only supports Float64 arrays. Check if a type is MATLAB-compatible.
9+
_is_matlab_compatible_eltype(::Type{Float64}) = true
10+
_is_matlab_compatible_eltype(::Type{<:Integer}) = true # MATLAB can convert integers
11+
_is_matlab_compatible_eltype(::Type{<:Complex{Float64}}) = true
12+
_is_matlab_compatible_eltype(::Type) = false
13+
14+
function _check_matlab_compatible(u0, tspan)
15+
T = eltype(u0)
16+
if !_is_matlab_compatible_eltype(T)
17+
throw(ArgumentError(
18+
"MATLABDiffEq.jl requires Float64-compatible element types. " *
19+
"Got eltype(u0) = $T. MATLAB does not support arbitrary precision " *
20+
"(BigFloat) or GPU arrays (JLArrays, CuArrays). Please convert your " *
21+
"initial conditions to Float64: u0 = Float64.(u0)"
22+
))
23+
end
24+
tT = eltype(tspan)
25+
if !_is_matlab_compatible_eltype(tT)
26+
throw(ArgumentError(
27+
"MATLABDiffEq.jl requires Float64-compatible time span types. " *
28+
"Got eltype(tspan) = $tT. MATLAB does not support arbitrary precision " *
29+
"(BigFloat). Please use Float64 for tspan: tspan = Float64.(tspan)"
30+
))
31+
end
32+
# Check that the array type itself is a standard Julia array
33+
if !(u0 isa Array || u0 isa Number)
34+
@warn "MATLABDiffEq.jl works best with standard Julia Arrays. " *
35+
"Got $(typeof(u0)). The array will be converted to a standard Array " *
36+
"before being sent to MATLAB."
37+
end
38+
return nothing
39+
end
40+
841
# Handle ModelingToolkit API changes: states -> unknowns
942
if isdefined(ModelingToolkit, :unknowns)
1043
const mtk_states = ModelingToolkit.unknowns
@@ -35,6 +68,9 @@ function DiffEqBase.__solve(
3568
callback = nothing,
3669
kwargs...
3770
) where {uType, tupType, isinplace, AlgType <: MATLABAlgorithm}
71+
# Validate that input types are MATLAB-compatible
72+
_check_matlab_compatible(prob.u0, prob.tspan)
73+
3874
tType = eltype(tupType)
3975

4076
if prob.tspan[end] - prob.tspan[1] < tType(0)
@@ -179,6 +215,14 @@ end
179215

180216
# Also precompile with missing keys (common case)
181217
_ = buildDEStats(Dict{String, Any}())
218+
219+
# Precompile type compatibility checks
220+
_ = _is_matlab_compatible_eltype(Float64)
221+
_ = _is_matlab_compatible_eltype(Int64)
222+
_ = _is_matlab_compatible_eltype(Complex{Float64})
223+
_ = _is_matlab_compatible_eltype(BigFloat)
224+
_ = _check_matlab_compatible([1.0, 2.0], (0.0, 1.0))
225+
_ = _check_matlab_compatible(1.0, (0.0, 1.0))
182226
end
183227
end
184228

test/interface_tests.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Interface compatibility tests for MATLABDiffEq.jl
2+
# These tests verify type checking and interface compliance
3+
4+
using Test
5+
6+
@testset "Interface Compatibility" begin
7+
@testset "MATLAB type compatibility checks" begin
8+
# Test _is_matlab_compatible_eltype function
9+
@test MATLABDiffEq._is_matlab_compatible_eltype(Float64) == true
10+
@test MATLABDiffEq._is_matlab_compatible_eltype(Float32) == false
11+
@test MATLABDiffEq._is_matlab_compatible_eltype(Int64) == true
12+
@test MATLABDiffEq._is_matlab_compatible_eltype(Int32) == true
13+
@test MATLABDiffEq._is_matlab_compatible_eltype(Complex{Float64}) == true
14+
@test MATLABDiffEq._is_matlab_compatible_eltype(Complex{Float32}) == false
15+
@test MATLABDiffEq._is_matlab_compatible_eltype(BigFloat) == false
16+
@test MATLABDiffEq._is_matlab_compatible_eltype(BigInt) == false
17+
@test MATLABDiffEq._is_matlab_compatible_eltype(Rational{Int}) == false
18+
end
19+
20+
@testset "_check_matlab_compatible validation" begin
21+
# Valid Float64 inputs should pass
22+
@test MATLABDiffEq._check_matlab_compatible([1.0, 2.0], (0.0, 1.0)) === nothing
23+
@test MATLABDiffEq._check_matlab_compatible(1.0, (0.0, 1.0)) === nothing
24+
@test MATLABDiffEq._check_matlab_compatible([1, 2, 3], (0, 10)) === nothing # Integers OK
25+
26+
# Complex Float64 should pass
27+
@test MATLABDiffEq._check_matlab_compatible([1.0 + 2.0im], (0.0, 1.0)) === nothing
28+
29+
# BigFloat u0 should throw ArgumentError
30+
@test_throws ArgumentError MATLABDiffEq._check_matlab_compatible(
31+
BigFloat[1.0, 2.0], (0.0, 1.0)
32+
)
33+
34+
# BigFloat tspan should throw ArgumentError
35+
@test_throws ArgumentError MATLABDiffEq._check_matlab_compatible(
36+
[1.0, 2.0], (BigFloat(0.0), BigFloat(1.0))
37+
)
38+
39+
# Float32 should throw ArgumentError (MATLAB expects Float64)
40+
@test_throws ArgumentError MATLABDiffEq._check_matlab_compatible(
41+
Float32[1.0, 2.0], (0.0, 1.0)
42+
)
43+
end
44+
45+
@testset "Error messages are helpful" begin
46+
# Test that error messages contain useful information
47+
try
48+
MATLABDiffEq._check_matlab_compatible(BigFloat[1.0], (0.0, 1.0))
49+
@test false # Should not reach here
50+
catch e
51+
@test e isa ArgumentError
52+
@test occursin("BigFloat", e.msg)
53+
@test occursin("Float64", e.msg)
54+
@test occursin("MATLABDiffEq", e.msg)
55+
end
56+
57+
try
58+
MATLABDiffEq._check_matlab_compatible([1.0], (BigFloat(0.0), BigFloat(1.0)))
59+
@test false # Should not reach here
60+
catch e
61+
@test e isa ArgumentError
62+
@test occursin("tspan", lowercase(e.msg))
63+
end
64+
end
65+
66+
@testset "buildDEStats is type-generic" begin
67+
# Test that buildDEStats works with different Dict types
68+
stats1 = Dict{String, Any}("nfevals" => 100, "nsteps" => 50)
69+
result1 = MATLABDiffEq.buildDEStats(stats1)
70+
@test result1.nf == 100
71+
@test result1.naccept == 50
72+
73+
# Test with empty dict
74+
stats2 = Dict{String, Any}()
75+
result2 = MATLABDiffEq.buildDEStats(stats2)
76+
@test result2.nf == 0
77+
@test result2.naccept == 0
78+
79+
# Test with all fields
80+
stats3 = Dict{String, Any}(
81+
"nfevals" => 200,
82+
"nfailed" => 10,
83+
"nsteps" => 190,
84+
"nsolves" => 100,
85+
"npds" => 20,
86+
"ndecomps" => 15
87+
)
88+
result3 = MATLABDiffEq.buildDEStats(stats3)
89+
@test result3.nf == 200
90+
@test result3.nreject == 10
91+
@test result3.naccept == 190
92+
@test result3.nsolve == 100
93+
@test result3.njacs == 20
94+
@test result3.nw == 15
95+
end
96+
97+
@testset "Algorithm structs instantiation" begin
98+
# Test that all algorithm structs can be instantiated
99+
@test MATLABDiffEq.ode23() isa MATLABDiffEq.MATLABAlgorithm
100+
@test MATLABDiffEq.ode45() isa MATLABDiffEq.MATLABAlgorithm
101+
@test MATLABDiffEq.ode113() isa MATLABDiffEq.MATLABAlgorithm
102+
@test MATLABDiffEq.ode23s() isa MATLABDiffEq.MATLABAlgorithm
103+
@test MATLABDiffEq.ode23t() isa MATLABDiffEq.MATLABAlgorithm
104+
@test MATLABDiffEq.ode23tb() isa MATLABDiffEq.MATLABAlgorithm
105+
@test MATLABDiffEq.ode15s() isa MATLABDiffEq.MATLABAlgorithm
106+
@test MATLABDiffEq.ode15i() isa MATLABDiffEq.MATLABAlgorithm
107+
end
108+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
using DiffEqBase, MATLABDiffEq, ParameterizedFunctions, Test
22

3+
# Interface tests - these test type validation without needing MATLAB runtime
4+
include("interface_tests.jl")
5+
6+
# The following tests require MATLAB runtime to be available
7+
# They test the actual ODE solving functionality
8+
39
f = @ode_def_bare LotkaVolterra begin
410
dx = a * x - b * x * y
511
dy = -c * y + d * x * y

0 commit comments

Comments
 (0)