Skip to content

Commit

Permalink
Add graphsummary and make it the defaul show for CompGraphs
Browse files Browse the repository at this point in the history
  • Loading branch information
DrChainsaw committed Jun 30, 2024
1 parent 894daa9 commit 2a921a5
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 35 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Expand Down
6 changes: 3 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ using Documenter, Literate, NaiveNASlib, NaiveNASlib.Advanced, NaiveNASlib.Exten

const nndir = joinpath(dirname(pathof(NaiveNASlib)), "..")

function literate_example(sourcefile; rootdir=nndir, sourcedir = "test/examples", destdir="docs/src/examples")
fullpath = Literate.markdown(joinpath(rootdir, sourcedir, sourcefile), joinpath(rootdir, destdir); flavor=Literate.DocumenterFlavor(), mdstrings=true, codefence="````julia" => "````")
function literate_example(sourcefile; rootdir=nndir, sourcedir = "test/examples", destdir="docs/src/examples", kwargs...)
fullpath = Literate.markdown(joinpath(rootdir, sourcedir, sourcefile), joinpath(rootdir, destdir); flavor=Literate.DocumenterFlavor(), mdstrings=true, kwargs...)
dirs = splitpath(fullpath)
srcind = findfirst(==("src"), dirs)
joinpath(dirs[srcind+1:end]...)
end

quicktutorial = literate_example("quicktutorial.jl")
advancedtutorial = literate_example("advancedtutorial.jl")
advancedtutorial = literate_example("advancedtutorial.jl"; codefence="````julia" => "````")

makedocs( sitename="NaiveNASlib",
root = joinpath(nndir, "docs"),
Expand Down
1 change: 1 addition & 0 deletions docs/src/reference/simple/graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ outputs(::CompGraph)
vertices
nvertices
findvertices
graphsummary
```
3 changes: 2 additions & 1 deletion src/NaiveNASlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ using JuMP: @variable, @constraint, @objective, @expression, MOI, MOI.INFEASIBLE
import HiGHS
import Functors
using Functors: @functor, functor
import PrettyTables

# Computation graph
export CompGraph, nvertices, vertices, findvertices, inputs, outputs, name
export CompGraph, nvertices, vertices, findvertices, inputs, outputs, name, graphsummary

# Vertex size operations
export nin, nout, Δnin!, Δnout!, Δsize!, relaxed
Expand Down
2 changes: 1 addition & 1 deletion src/api/Extend.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ using Reexport: @reexport
@reexport using ..NaiveNASlib: AbstractAlignSizeStrategy
@reexport using ..NaiveNASlib: AbstractConnectStrategy

@reexport using ..NaiveNASlib: base, parselect, vertex
@reexport using ..NaiveNASlib: base, parselect, vertex, op

end
21 changes: 18 additions & 3 deletions src/api/vertex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ julia> name(v)
"""
invariantvertex(args...; traitdecoration=identity) = vertex(traitdecoration(SizeInvariant()), args...)

struct Concat{D}
dims::D
end
(c::Concat)(x...) = cat(x...; dims=c.dims)
Base.show(io::IO, c::Concat) = print(io, "cat(x..., dims=", c.dims, ')')

"""
conc(v::AbstractVertex, vs::AbstractVertex...; dims, traitdecoration=identity, outwrap=identity)
conc(vname::AbstractString, v::AbstractVertex, vs::AbstractVertex...; dims, traitdecoration=identity, outwrap=identity)
Expand Down Expand Up @@ -201,10 +207,10 @@ julia> v([1], [2, 3], [4, 5, 6])
```
"""
function conc(v::AbstractVertex, vs::AbstractVertex...; dims, traitdecoration=identity, outwrap=identity)
vertex(traitdecoration(SizeStack()), outwrap((x...) -> cat(x..., dims=dims)), v, vs...)
vertex(traitdecoration(SizeStack()), outwrap(Concat(dims)), v, vs...)
end
function conc(vname::AbstractString, v::AbstractVertex, vs::AbstractVertex...; dims, traitdecoration=identity, outwrap=identity)
vertex(traitdecoration(SizeStack()), vname, outwrap((x...) -> cat(x..., dims=dims)), v, vs...)
vertex(traitdecoration(SizeStack()), vname, outwrap(Concat(dims)), v, vs...)
end


Expand Down Expand Up @@ -238,10 +244,19 @@ Shortcut for [`VertexConf(;outwrap=o)`](@ref).
outwrapconf(o) = VertexConf(outwrap=o)
VertexConf(;traitdecoration = identity, outwrap = identity)= VertexConf(traitdecoration, outwrap)

struct ElementWiseOp{F}
op::F
end
(e::ElementWiseOp)(x...) = e.op.(x...)
function Base.show(io::IO, e::ElementWiseOp)
show(io, e.op)
print(io, " (element wise)")
end

# Common wiring for all elementwise operations
function elemwise(op, conf::VertexConf, vs::AbstractVertex...)
all(vi -> nout(vi) == nout(vs[1]), vs) || throw(DimensionMismatch("nout of all vertices input to elementwise vertex must be equal! Got $(nout.(vs))"))
invariantvertex(conf.outwrap((x...) -> op.(x...)), vs...; conf.traitdecoration)
invariantvertex(conf.outwrap(ElementWiseOp(op)), vs...; conf.traitdecoration)
end


Expand Down
2 changes: 1 addition & 1 deletion src/compgraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ julia> vertices(graph)
vertices(g::CompGraph{<:Any, <:Tuple}) = unique(mapfoldl(ancestors, vcat, outputs(g)))
vertices(g::CompGraph{<:Any, <:AbstractVertex}) = ancestors(g.outputs)

## Non-public stuff to compute the CompGraph in a Zygote (and hopefully generally reverse-AD friendly) manner
## Non-public stuff to compute the CompGraph in a Zygote (and hopefully generally reverse-AD) friendly manner

compute_graph(memo, v::AbstractVertex) = last(output_with_memo(memo, v))
compute_graph(memo, vs::Tuple) = last(_calc_outs(memo, vs))
Expand Down
2 changes: 2 additions & 0 deletions src/mutation/vertex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Return the vertex wrapped in `v` (if any).
"""
function base(::AbstractVertex) end

op(v::AbstractVertex) = op(base(v))

"""
OutputsVertex
Expand Down
111 changes: 111 additions & 0 deletions src/prettyprint.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,114 @@
Base.show(g::CompGraph, args...; kwargs...) = show(stdout, g, args...; kwargs...)
function Base.show(io::IO, g::CompGraph, args...; kwargs...)
# Don't print the summary table if we are printing some iterable (as indicated by presence of :SHOWN_SET)
haskey(io, :SHOWN_SET) && return print(io, "CompGraph(", nvertices(g)," vertices)")
graphsummary(io, g, args...; title="CompGraph with graphsummary:", kwargs...)
end

"""
graphsummary([io], graph, extracolumns...; [inputhl], [outputhl], kwargs...)
Prints a summary table of `graph` to `io` using `PrettyTables.pretty_table`.
Extra columns can be added to the table by providing any number of `extracolumns` which can be one of the following:
* a function (or any callable object) which takes a vertex as input and returns the column content
* a `Pair` where the first element is the column name and the other element is what previous bullet describes
The keyword arguments `inputhl` (default `crayon"fg:black bg:249"`) and `outputhl` (default `inputhl`) can be used
to set the highlighting of the inputs and outputs to `graph` respectively. If set to `nothing` no special highlighting
will be used.
All other keyword arguments are forwarded to `PrettyTables.pretty_table`. Note that this allows for overriding the
default formatting, alignment and highlighting.
!!! warning "API Stability"
While this function is part of the public API for natural reasons, the exact shape of its output shall not be considered stable.
`Base.show` for `CompGraph`s just forwards all arguments and keyword arguments to this method. This might change in the future.
### Examples
```jldoctest
julia> using NaiveNASlib
julia> g = let
v1 = "v1" >> inputvertex("in1", 1) + inputvertex("in2", 1)
v2 = invariantvertex("v2", sin, v1)
v3 = conc("v3", v1, v2; dims=1)
CompGraph(inputs(v1), v3)
end;
julia> graphsummary(g)
┌────────────────┬───────────┬────────────────┬───────────────────┐
│ Graph Position │ Vertex Nr │ Input Vertices │ Op │
├────────────────┼───────────┼────────────────┼───────────────────┤
│ Input │ 1 │ │ │
│ Input │ 2 │ │ │
│ Hidden │ 3 │ 1,2 │ + (element wise) │
│ Hidden │ 4 │ 3 │ sin │
│ Output │ 5 │ 3,4 │ cat(x..., dims=1) │
└────────────────┴───────────┴────────────────┴───────────────────┘
julia> graphsummary(g, name, "input sizes" => nin, "output sizes" => nout)
┌────────────────┬───────────┬────────────────┬───────────────────┬──────┬─────────────┬──────────────┐
│ Graph Position │ Vertex Nr │ Input Vertices │ Op │ Name │ input sizes │ output sizes │
├────────────────┼───────────┼────────────────┼───────────────────┼──────┼─────────────┼──────────────┤
│ Input │ 1 │ │ │ in1 │ │ 1 │
│ Input │ 2 │ │ │ in2 │ │ 1 │
│ Hidden │ 3 │ 1,2 │ + (element wise) │ v1 │ 1,1 │ 1 │
│ Hidden │ 4 │ 3 │ sin │ v2 │ 1 │ 1 │
│ Output │ 5 │ 3,4 │ cat(x..., dims=1) │ v3 │ 1,1 │ 2 │
└────────────────┴───────────┴────────────────┴───────────────────┴──────┴─────────────┴──────────────┘
```
"""
graphsummary(g::CompGraph, extracolumns...; kwargs...) = graphsummary(stdout, g, extracolumns...; kwargs...)
function graphsummary(io, g::CompGraph, extracolumns...;
inputhl=PrettyTables.crayon"fg:black bg:249",
outputhl=inputhl,
kwargs...)
t = summarytable(g, extracolumns...)

# Default formatting
arraytostr = (x, args...) -> x isa AbstractVector ? join(x, ",") : isnothing(x) ? "" : x
rowhighligts = PrettyTables.Highlighter(Returns(true), function(h, x, i, j)
!isnothing(inputhl) && i <= length(inputs(g)) && return inputhl
!isnothing(outputhl) && i > length(t[1]) - length(outputs(g)) && return outputhl
length(t[1]) > 7 && iseven(i - isnothing(inputhl) * length(inputs(g))) && return PrettyTables.crayon"fg:white bold bg:dark_gray"
PrettyTables.crayon"default"
end)

PrettyTables.pretty_table(io, t;
show_subheader=false,
formatters=arraytostr,
highlighters = rowhighligts,
alignment = :l,
kwargs...)
end

function summarytable(g::CompGraph, extracols...)
vs = vertices(g)

inds = sort(collect(eachindex(vs)); by = function(i)
vs[i] in inputs(g) && return i - length(vs)
vs[i] in outputs(g) && return i + length(vs)
i
end)

vs_roworder = vs[inds]

NamedTuple((
Symbol("Graph Position") => map(v -> v in inputs(g) ? :Input : v in outputs(g) ? :Output : :Hidden, vs_roworder),
Symbol("Vertex Nr") => inds,
Symbol("Input Vertices") => map(v -> something.(indexin(inputs(v), vs), -1), vs_roworder),
:Op => op.(vs_roworder),
map(c -> _createextracol(c, vs_roworder), extracols)...
))
end

_createextracol(f, vs) = Symbol(uppercasefirst(string(f))) => f.(vs)
_createextracol(p::Pair, vs) = Symbol(first(p)) => last(p).(vs)

## Other stuff related to printing long arrays of numbers assuming patterns which often happen
## when mutating, typically long streaks of -1 and ascending integers.

compressed_string(x) = string(x)
struct RangeState
Expand Down
6 changes: 5 additions & 1 deletion src/vertex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,8 @@ Will return a generic string describing `v` if no name has been given to `v`.
Note that names in a graph don't have to be unique.
"""
name(v::AbstractVertex) = string(nameof(typeof(v)))
name(v::InputVertex) = v.name
name(v::InputVertex) = v.name

op(::InputVertex) = nothing
op(v::CompVertex) = op(v.computation)
op(f) = f
Loading

0 comments on commit 2a921a5

Please sign in to comment.