Skip to content

Commit

Permalink
Merge pull request #180 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 1.9.1 release
  • Loading branch information
ablaom authored Aug 15, 2023
2 parents fe9492d + d33265f commit 776852d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "1.9.0"
version = "1.9.1"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
6 changes: 3 additions & 3 deletions src/parameter_inspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ function params(m, ::Val{true})
return NamedTuple{fields}(Tuple([params(getfield(m, field)) for field in fields]))
end

isamodel(::Any) = false
isamodel(::Model) = true
isnotaleaf(::Any) = false
isnotaleaf(m::Model) = length(propertynames(m)) > 0

"""
flat_params(m::Model)
Expand All @@ -53,7 +53,7 @@ not a hard requirement.
parallel = true,)
"""
flat_params(m; prefix="") = flat_params(m, Val(isamodel(m)); prefix=prefix)
flat_params(m; prefix="") = flat_params(m, Val(isnotaleaf(m)); prefix=prefix)
flat_params(m, ::Val{false}; prefix="") = NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,))
function flat_params(m, ::Val{true}; prefix="")
fields = propertynames(m)
Expand Down
16 changes: 9 additions & 7 deletions test/parameter_inspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ end
end

struct ChildModel <: Model
x::Int
y::String
r::Int
s
end

struct ParentModel <: Model
Expand All @@ -46,18 +46,20 @@ struct ParentModel <: Model
second_child::ChildModel
end

struct Missy <: Model end

@testset "flat_params method" begin

m = ParentModel(1, "parent", ChildModel(2, "child1"),
ChildModel(3, "child2"))
ChildModel(3, Missy()))

@test MLJModelInterface.flat_params(m) == (
x = 1,
y = "parent",
first_child__x = 2,
first_child__y = "child1",
second_child__x = 3,
second_child__y = "child2"
first_child__r = 2,
first_child__s = "child1",
second_child__r = 3,
second_child__s = Missy()
)
end
true

0 comments on commit 776852d

Please sign in to comment.