diff --git a/Project.toml b/Project.toml index c85cb0d..68d4325 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,12 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.4.1" +version = "1.4.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -47,14 +46,13 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] [compat] AMDGPU = "0.9.6, 1" -Adapt = "4" +Adapt = "4.1" CUDA = "5.2" ChainRulesCore = "1.23" Compat = "4.15" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10, 11" -LinearAlgebra = "1.10" MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" diff --git a/ext/MLDataDevicesChainRulesCoreExt.jl b/ext/MLDataDevicesChainRulesCoreExt.jl index 6a770b8..518ff20 100644 --- a/ext/MLDataDevicesChainRulesCoreExt.jl +++ b/ext/MLDataDevicesChainRulesCoreExt.jl @@ -1,24 +1,27 @@ module MLDataDevicesChainRulesCoreExt using Adapt: Adapt -using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, @non_differentiable using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type @non_differentiable get_device(::Any) @non_differentiable get_device_type(::Any) -function ChainRulesCore.rrule( - ::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray) - ∇adapt_storage = let dev = get_device(x) - if dev === nothing || dev isa UnknownDevice +function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray) + dev = get_device(x) + y = Adapt.adapt_storage(to, x) + if dev === nothing || dev isa UnknownDevice + dev isa UnknownDevice && @warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1 - Δ -> (NoTangent(), NoTangent(), Δ) - else - Δ -> (NoTangent(), NoTangent(), dev(Δ)) + ∇adapt_storage_unknown = Δ -> (NoTangent(), NoTangent(), Δ) + return y, ∇adapt_storage_unknown + else + ∇adapt_storage = let dev = dev, x = x + Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(dev(Δ))) end + return Adapt.adapt_storage(to, x), ∇adapt_storage end - return Adapt.adapt_storage(to, x), ∇adapt_storage end end diff --git a/src/MLDataDevices.jl b/src/MLDataDevices.jl index c837887..108d8bf 100644 --- a/src/MLDataDevices.jl +++ b/src/MLDataDevices.jl @@ -5,7 +5,6 @@ using Functors: Functors, fleaves using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random using Compat: @compat -using LinearAlgebra: Transpose, Adjoint abstract type AbstractDevice <: Function end abstract type AbstractCPUDevice <: AbstractDevice end diff --git a/src/public.jl b/src/public.jl index 104a424..6440ddb 100644 --- a/src/public.jl +++ b/src/public.jl @@ -342,8 +342,10 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) ldev = Symbol(dev, :Device) @eval begin function (D::$(ldev))(x::AbstractArray{T}) where {T} - return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) : - map(D, x) + if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray + return Adapt.adapt(D, x) + end + return map(D, x) end (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) function (D::$(ldev))(x) @@ -373,14 +375,6 @@ for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice) end end -Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x -Adapt.adapt_storage(::XLADevice, x::AbstractRange) = x -# Prevent Ambiguity -for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, - CUDADevice{Nothing}, MetalDevice, oneAPIDevice) - @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) -end - """ isleaf(x) -> Bool @@ -399,4 +393,4 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct isleaf(x) = Functors.isleaf(x) isleaf(::AbstractArray{T}) where {T} = isbitstype(T) -isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false +isleaf(::Adapt.WrappedArray) = false diff --git a/test/amdgpu_tests.jl b/test/amdgpu_tests.jl index 41a8797..a771ada 100644 --- a/test/amdgpu_tests.jl +++ b/test/amdgpu_tests.jl @@ -53,7 +53,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/cuda_tests.jl b/test/cuda_tests.jl index 1f95831..2fce480 100644 --- a/test/cuda_tests.jl +++ b/test/cuda_tests.jl @@ -52,7 +52,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -82,7 +82,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/metal_tests.jl b/test/metal_tests.jl index aeb596a..2bc8845 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 9bec386..28275d3 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -50,17 +50,17 @@ end @testset "CRC Tests" begin dev = cpu_device() # Other devices don't work with FiniteDifferences.jl - test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true) + test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=true) gdev = gpu_device() if !(gdev isa MetalDevice) # On intel devices causes problems x = randn(10) - ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, gdev, x) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, gdev, x) @test ∂dev === nothing @test ∂x ≈ ones(10) x = randn(10) |> gdev - ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, cpu_device(), x) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, cpu_device(), x) @test ∂dev === nothing @test ∂x ≈ gdev(ones(10)) @test get_device(∂x) isa parameterless_type(typeof(gdev)) @@ -181,7 +181,6 @@ end end @testset "shared parameters" begin - # from x = rand(1) m = (; a=x, b=x') count = Ref(0) @@ -199,7 +198,7 @@ end y::Float64 end - for x in [1.0, 'a', BitsType(1, 2.0)] + @testset for x in [1.0, 'a', BitsType(1, 2.0)] @test MLDataDevices.isleaf([x]) @test !MLDataDevices.isleaf([x]') @test !MLDataDevices.isleaf(transpose([x])) @@ -207,3 +206,16 @@ end end end end + +@testset "Zygote.gradient(wrapped arrays)" begin + using Zygote + + x = rand(4, 4) + cdev = cpu_device() + + @test only(Zygote.gradient(x -> sum(abs2, cdev(x)), x')) isa Matrix{Float64} + + gdev = gpu_device() + + @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} +end diff --git a/test/oneapi_tests.jl b/test/oneapi_tests.jl index 8bb6026..2169869 100644 --- a/test/oneapi_tests.jl +++ b/test/oneapi_tests.jl @@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions @test ps_xpu.mixed[1] isa Float32 @test ps_xpu.mixed[2] isa Float64 @test ps_xpu.mixed[3] isa aType - @test ps_xpu.range isa aType + @test ps_xpu.range isa AbstractRange @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions @test ps_cpu.mixed[1] isa Float32 @test ps_cpu.mixed[2] isa Float64 @test ps_cpu.mixed[3] isa Array - @test ps_cpu.range isa Array + @test ps_cpu.range isa AbstractRange @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG