Skip to content

Commit

Permalink
added sageconv lux (#500)
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky authored Oct 3, 2024
1 parent a034753 commit 3fe2c76
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
2 changes: 1 addition & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export AGNNConv,
MEGNetConv,
NNConv,
ResGatedGraphConv,
# SAGEConv,
SAGEConv,
SGConv
# TAGConv,
# TransformerConv
Expand Down
48 changes: 48 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,51 @@ function Base.show(io::IO, l::ResGatedGraphConv)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

@concrete struct SAGEConv <: GNNLayer
in_dims::Int
out_dims::Int
use_bias::Bool
init_weight
init_bias
σ
aggr
end

function SAGEConv(ch::Pair{Int, Int}, σ = identity;
aggr = mean,
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true)
in_dims, out_dims = ch
σ = NNlib.fast_act(σ)
return SAGEConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::SAGEConv)
weight = l.init_weight(rng, l.out_dims, 2 * l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weight, bias)
else
return (; weight)
end
end

LuxCore.parameterlength(l::SAGEConv) = l.use_bias ? l.out_dims * 2 * l.in_dims + l.out_dims :
l.out_dims * 2 * l.in_dims
LuxCore.outputsize(d::SAGEConv) = (d.out_dims,)

function Base.show(io::IO, l::SAGEConv)
print(io, "SAGEConv(", l.in_dims, " => ", l.out_dims)
(l.σ == identity) || print(io, ", ", l.σ)
(l.aggr == mean) || print(io, ", aggr=", l.aggr)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

function (l::SAGEConv)(g, x, ps, st)
m = (; ps.weight, bias = _getbias(ps),
l.σ, l.aggr)
return GNNlib.sage_conv(m, g, x), st
end
5 changes: 5 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,9 @@
l = ResGatedGraphConv(in_dims => out_dims, tanh)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end

@testset "SAGEConv" begin
l = SAGEConv(in_dims => out_dims, tanh)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
end

0 comments on commit 3fe2c76

Please sign in to comment.