Skip to content

Commit

Permalink
create GNNLux.jl package (#460)
Browse files Browse the repository at this point in the history
* create GNNLux

* create GNNLux.jl

* fix ci
  • Loading branch information
CarloLucibello authored Jul 26, 2024
1 parent cafc1bc commit 79515e9
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 3 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/test_GNNLux.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: GNNLux
on:
pull_request:
branches:
- master
push:
branches:
- master
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- '1.10' # Replace this with the minimum Julia version that your package supports.
# - '1' # '1' will automatically expand to the latest stable 1.x release of Julia.
# - 'pre'
os:
- ubuntu-latest
arch:
- x64

steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- name: Install Julia dependencies and run tests
shell: julia --project=monorepo {0}
run: |
using Pkg
# dev mono repo versions
pkg"registry up"
Pkg.update()
pkg"dev ./GNNGraphs ./GNNlib ./GNNLux"
Pkg.test("GNNLux"; coverage=true)
- uses: julia-actions/julia-processcoverage@v1
with:
# directories: ./GNNLux/src, ./GNNLux/ext
directories: ./GNNLux/src
- uses: codecov/codecov-action@v4
with:
files: lcov.info
21 changes: 21 additions & 0 deletions GNNLux/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Carlo Lucibello <[email protected]> and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
35 changes: 35 additions & 0 deletions GNNLux/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name = "GNNLux"
uuid = "e8545f4d-a905-48ac-a8c4-ca114b98986d"
authors = ["Carlo Lucibello and contributors"]
version = "0.1.0"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ConcreteStructs = "0.2.3"
Lux = "0.5.61"
LuxCore = "0.1.20"
NNlib = "0.9.21"
Reexport = "1.2"
julia = "1.10"

[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"]
2 changes: 2 additions & 0 deletions GNNLux/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# GNNLux.jl

15 changes: 15 additions & 0 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib
using LuxCore: LuxCore, AbstractExplicitLayer
using Lux: glorot_uniform, zeros32
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
@reexport using GNNGraphs

include("layers/conv.jl")
export GraphConv

end #module

93 changes: 93 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

@doc raw"""
GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).
Performs:
```math
\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j
```
where the aggregation type is selected by `aggr`.
# Arguments
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `σ`: Activation function.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `bias`: Add learnable bias.
- `init`: Weights' initializer.
# Examples
```julia
# create data
s = [1,1,2,3]
t = [2,3,1,1]
in_channel = 3
out_channel = 5
g = GNNGraph(s, t)
x = randn(Float32, 3, g.num_nodes)
# create layer
l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean)
# forward pass
y = l(g, x)
```
"""
@concrete struct GraphConv <: AbstractExplicitLayer
in_dims::Int
out_dims::Int
use_bias::Bool
init_weight::Function
init_bias::Function
σ
aggr
end


function GraphConv(ch::Pair{Int, Int}, σ = identity;
aggr = +,
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
allow_fast_activation::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
return GraphConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv)
weight1 = l.init_weight(rng, l.out_dims, l.in_dims)
weight2 = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
else
bias = false
end
return (; weight1, weight2, bias)
end

function LuxCore.parameterlength(l::GraphConv)
if l.use_bias
return 2 * l.in_dims * l.out_dims + l.out_dims
else
return 2 * l.in_dims * l.out_dims
end
end

LuxCore.statelength(d::GraphConv) = 0
LuxCore.outputsize(d::GraphConv) = (d.out_dims,)

function Base.show(io::IO, l::GraphConv)
print(io, "GraphConv(", l.in_dims, " => ", l.out_dims)
(l.σ == identity) || print(io, ", ", l.σ)
(l.aggr == +) || print(io, ", aggr=", l.aggr)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

(l::GraphConv)(g::GNNGraph, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st
19 changes: 19 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@testitem "layers/conv" setup=[SharedTestSetup] begin
rng = StableRNG(1234)
g = rand_graph(10, 30, seed=1234)
x = randn(rng, Float32, 3, 10)

@testset "GraphConv" begin
l = GraphConv(3 => 5, relu)
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3
end
end
10 changes: 10 additions & 0 deletions GNNLux/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Test
using Lux
using GNNLux
using Random, Statistics

using ReTestItems
# using Pkg, Preferences, Test
# using InteractiveUtils, Hwloc

runtests(GNNLux)
23 changes: 23 additions & 0 deletions GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
@testsetup module SharedTestSetup

import Reexport: @reexport

@reexport using Lux, Functors
@reexport using ComponentArrays, LuxCore, LuxTestUtils, Random, StableRNGs, Test,
Zygote, Statistics
@reexport using LuxTestUtils: @jet, @test_gradients, check_approx

# Some Helper Functions
function get_default_rng(mode::String)
dev = mode == "cpu" ? LuxCPUDevice() :
mode == "cuda" ? LuxCUDADevice() : mode == "amdgpu" ? LuxAMDGPUDevice() : nothing
rng = default_device_rng(dev)
return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng)
end

export get_default_rng

# export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, get_default_rng,
# StableRNG, maybe_rewrite_to_crosscor

end
9 changes: 6 additions & 3 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,15 @@ function cheb_conv(c, g::GNNGraph, X::AbstractMatrix{T}) where {T}
return Y .+ c.bias
end

function graph_conv(l, g::AbstractGNNGraph, x)
function graph_conv(l, g::AbstractGNNGraph, x, ps)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)
m = propagate(copy_xj, g, l.aggr, xj = xj)
x = l.σ.(l.weight1 * xi .+ l.weight2 * m .+ l.bias)
return x
x = ps.weight1 * xi .+ ps.weight2 * m
if l.use_bias
x = x .+ ps.bias
end
return l.σ.(x)
end

function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing)
Expand Down

0 comments on commit 79515e9

Please sign in to comment.