diff --git a/test/mooncake/qr.jl b/test/mooncake/qr.jl index bbb9a8d1..c4f0df9e 100644 --- a/test/mooncake/qr.jl +++ b/test/mooncake/qr.jl @@ -20,4 +20,11 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.test_mooncake_qr(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end + if T ∈ BLASFloats && CUDA.functional() + TestSuite.test_mooncake_qr(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, CuVector{T}} + TestSuite.test_mooncake_qr(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end + end end