Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Aug 1, 2024
1 parent d2ab349 commit 96eb264
Showing 1 changed file with 39 additions and 51 deletions.
90 changes: 39 additions & 51 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,57 +149,6 @@ function remove_self_loops(g::GNNGraph{<:ADJMAT_T})
g.ndata, g.edata, g.gdata)
end

"""
remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer})
Remove specified edges from a GNNGraph.
# Arguments
- `g`: The input graph from which edges will be removed.
- `edges_to_remove`: Vector of edge indices to be removed.
# Returns
A new GNNGraph with the specified edges removed.
# Example
```julia
julia> using GraphNeuralNetworks
# Construct a GNNGraph
julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1])
GNNGraph:
num_nodes: 3
num_edges: 5
# Remove the second edge
julia> g_new = remove_edges(g, [2]);
julia> g_new
GNNGraph:
num_nodes: 3
num_edges: 4
```
"""
function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer})
s, t = edge_index(g)
w = get_edge_weight(g)
edata = g.edata

mask_to_keep = trues(length(s))

mask_to_keep[edges_to_remove] .= false

s = s[mask_to_keep]
t = t[mask_to_keep]
edata = getobs(edata, mask_to_keep)
w = isnothing(w) ? nothing : getobs(w, mask_to_keep)

return GNNGraph((s, t, w),
g.num_nodes, length(s), g.num_graphs,
g.graph_indicator,
g.ndata, edata, g.gdata)
end

"""
remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer})
remove_edges(g::GNNGraph, p::Float64=0.5)
Expand Down Expand Up @@ -275,6 +224,45 @@ function remove_edges(g::GNNGraph{<:COO_T}, p = 0.5)
g.ndata, edata, g.gdata)
end

"""
remove_multi_edges(g::GNNGraph; aggr=+)
Remove multiple edges (also called parallel edges or repeated edges) from graph `g`.
Possible edge features are aggregated according to `aggr`, that can take value
`+`,`min`, `max` or `mean`.
See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref).
"""
function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +)
s, t = edge_index(g)
w = get_edge_weight(g)
edata = g.edata
num_edges = g.num_edges
idxs, idxmax = edge_encoding(s, t, g.num_nodes)

perm = sortperm(idxs)
idxs = idxs[perm]
s, t = s[perm], t[perm]
edata = getobs(edata, perm)
w = isnothing(w) ? nothing : getobs(w, perm)
idxs = [-1; idxs]
mask = idxs[2:end] .> idxs[1:(end - 1)]
if !all(mask)
s, t = s[mask], t[mask]
idxs = similar(s, num_edges)
idxs .= 1:num_edges
idxs .= idxs .- cumsum(.!mask)
num_edges = length(s)
w = _scatter(aggr, w, idxs, num_edges)
edata = _scatter(aggr, edata, idxs, num_edges)
end

return GNNGraph((s, t, w),
g.num_nodes, num_edges, g.num_graphs,
g.graph_indicator,
g.ndata, edata, g.gdata)
end

"""
remove_nodes(g::GNNGraph, nodes_to_remove::AbstractVector)
Expand Down

0 comments on commit 96eb264

Please sign in to comment.