diff --git a/Project.toml b/Project.toml index 748fe8748..0510ea2ce 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,12 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415" +[weakdeps] +LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" + +[extensions] +SymbolicUtilsLabelledArraysExt = "LabelledArrays" + [compat] AbstractTrees = "0.4" Bijections = "0.1.2" @@ -51,6 +57,7 @@ julia = "1.3" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -59,4 +66,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["BenchmarkTools", "Documenter", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"] +test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "Test", "Zygote"] diff --git a/ext/SymbolicUtilsLabelledArraysExt.jl b/ext/SymbolicUtilsLabelledArraysExt.jl new file mode 100644 index 000000000..9004729bf --- /dev/null +++ b/ext/SymbolicUtilsLabelledArraysExt.jl @@ -0,0 +1,25 @@ +module SymbolicUtilsLabelledArraysExt + +using LabelledArrays +using LabelledArrays.StaticArrays +using SymbolicUtils + +@inline function SymbolicUtils.Code.create_array(A::Type{<:SLArray}, T, nd::Val, d::Val{dims}, elems...) where {dims} + a = SymbolicUtils.Code.create_array(SArray, T, nd, d, elems...) + if nfields(dims) === ndims(A) + similar_type(A, eltype(a), Size(dims))(a) + else + a + end +end + +@inline function SymbolicUtils.Code.create_array(A::Type{<:LArray}, T, nd::Val, d::Val{dims}, elems...) where {dims} + data = SymbolicUtils.Code.create_array(Array, T, nd, d, elems...) + if nfields(dims) === ndims(A) + LArray{eltype(data),nfields(dims),typeof(data),LabelledArrays.symnames(A)}(data) + else + data + end +end + +end diff --git a/src/code.jl b/src/code.jl index 9eaac38e6..4128a39fd 100644 --- a/src/code.jl +++ b/src/code.jl @@ -1,6 +1,6 @@ module Code -using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions +using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex, @@ -578,25 +578,6 @@ end MArray{Tuple{dims...}, T}(elems...) end -## LabelledArrays -@inline function create_array(A::Type{<:SLArray}, T, nd::Val, d::Val{dims}, elems...) where {dims} - a = create_array(SArray, T, nd, d, elems...) - if nfields(dims) === ndims(A) - similar_type(A, eltype(a), Size(dims))(a) - else - a - end -end - -@inline function create_array(A::Type{<:LArray}, T, nd::Val, d::Val{dims}, elems...) where {dims} - data = create_array(Array, T, nd, d, elems...) - if nfields(dims) === ndims(A) - LArray{eltype(data),nfields(dims),typeof(data),LabelledArrays.symnames(A)}(data) - else - data - end -end - ## We use a separate type for Sparse Arrays to sidestep the need for ## iszero to be defined on the expression type @matchable struct MakeSparseArray{S<:AbstractSparseArray}