Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Aug 2, 2024
1 parent 83b6b7e commit 786b200
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,3 +628,74 @@ function Base.show(io::IO, l::GINConv)
print(io, ", $(l.ϵ)")
print(io, ")")
end

@concrete struct NNConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractExplicitLayer
aggr
in_dims::Int
out_dims::Int
use_bias::Bool
add_self_loops::Bool
use_edge_weight::Bool
init_weight
init_bias
σ
end

"""
function NNConv(ch::Pair{Int, Int}, σ = identity;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
add_self_loops::Bool = true,
use_edge_weight::Bool = false,
allow_fast_activation::Bool = true)
"""
# fix args order
function NNConv(ch::Pair{Int, Int}, nn, σ = identity;
aggr = +,
init_bias = zeros32,
use_bias::Bool = true,
init_weight = glorot_uniform,
add_self_loops::Bool = true,
use_edge_weight::Bool = false,
allow_fast_activation::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
return NNConv(nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ)
end

function (l::GCNConv)(g, x, edge_weight, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)

# what would be the order of args here?
m = (; nn, l.aggr, ps.weight, bias = _getbias(ps),
l.add_self_loops, l.use_edge_weight, l.σ)
y = GNNlib.nn_conv(m, g, x, edge_weight)
stnew = _getstate(nn)
return y, stnew
end

function LuxCore.initialparameters(rng::AbstractRNG, l::NNConv)
weight = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weight, bias)
else
return (; weight)
end
end

LuxCore.parameterlength(l::NNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims # nn wont affect this right?
LuxCore.outputsize(d::NNConv) = (d.out_dims,)


function Base.show(io::IO, l::GINConv)
print(io, "NNConv($(l.nn)")
print(io, ", $(l.ϵ)")
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
l.add_self_loops || print(io, ", add_self_loops=false")
!l.use_edge_weight || print(io, ", use_edge_weight=true")
print(io, ")")
end

0 comments on commit 786b200

Please sign in to comment.