Skip to content

Commit

Permalink
update: bc
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Apr 28, 2024
1 parent 50c3371 commit ad0a355
Showing 1 changed file with 69 additions and 55 deletions.
124 changes: 69 additions & 55 deletions nx_arangodb/algorithms/centrality/betweenness.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from networkx.algorithms.centrality import betweenness as nxbc

from nx_arangodb.convert import _to_graph
from nx_arangodb.convert import _to_graph as _to_nx_arangodb_graph
from nx_arangodb.utils import networkx_algorithm

# import pylibcugraph as plc
# from nx_cugraph.utils import networkx_algorithm, _seed_to_int
try:
import pylibcugraph as plc
from nx_cugraph.convert import _to_graph as _to_nx_cugraph_graph
from nx_cugraph.utils import _seed_to_int

GPU_ENABLED = True
except ModuleNotFoundError:
GPU_ENABLED = False


__all__ = ["betweenness_centrality"]

# 1. If GPU is enabled, call nx-cugraph bc() after converting to a nx_cugraph graph (in-memory graph)
# 2. If GPU is not enabled, call networkx bc() after converting to a networkx graph (in-memory graph)
# 3. If GPU is not enabled, call networkx bc() **without** converting to a networkx graph (remote graph)


@networkx_algorithm(
is_incomplete=True,
Expand All @@ -19,59 +29,63 @@
def betweenness_centrality(
G, k=None, normalized=True, weight=None, endpoints=False, seed=None
):
# We're just calling the original function from networkx here
# to test things out for now. i.e no nx-cugraph stuff here

print("ANTHONY: Calling betweenness_centrality from nx_arangodb")
G = _to_graph(G)

##############################
betweenness = dict.fromkeys(G, 0.0) # b[v]=0 for v in G
if k is None:
nodes = G
# 1.
if GPU_ENABLED:
print("ANTHONY: GPU is enabled. Using nx-cugraph bc()")

if weight is not None:
raise NotImplementedError(
"Weighted implementation of betweenness centrality not currently supported"
)

seed = _seed_to_int(seed)
G = _to_nx_cugraph_graph(G, weight)
node_ids, values = plc.betweenness_centrality(
resource_handle=plc.ResourceHandle(),
graph=G._get_plc_graph(),
k=k,
random_state=seed,
normalized=normalized,
include_endpoints=endpoints,
do_expensive_check=False,
)

return G._nodearrays_to_dict(node_ids, values)

# 2.
else:
nodes = seed.sample(list(G.nodes()), k)
for s in nodes:
# single source shortest paths
if weight is None: # use BFS
S, P, sigma, _ = nxbc._single_source_shortest_path_basic(G, s)
else: # use Dijkstra's algorithm
S, P, sigma, _ = nxbc._single_source_dijkstra_path_basic(G, s, weight)
# accumulation
if endpoints:
betweenness, _ = nxbc._accumulate_endpoints(betweenness, S, P, sigma, s)
print("ANTHONY: GPU is disabled. Using nx bc()")

G = _to_nx_arangodb_graph(G)

betweenness = dict.fromkeys(G, 0.0) # b[v]=0 for v in G
if k is None:
nodes = G
else:
betweenness, _ = nxbc._accumulate_basic(betweenness, S, P, sigma, s)
# rescaling
betweenness = nxbc._rescale(
betweenness,
len(G),
normalized=normalized,
directed=G.is_directed(),
k=k,
endpoints=endpoints,
)

return betweenness
##############################

# if weight is not None:
# raise NotImplementedError(
# "Weighted implementation of betweenness centrality not currently supported"
# )

# seed = _seed_to_int(seed)

# G = _nx_arangodb_graph_to_nx_cugraph_graph(G, weight)

# node_ids, values = plc.betweenness_centrality(
# resource_handle=plc.ResourceHandle(),
# graph=G._get_plc_graph(),
# k=k,
# random_state=seed,
# normalized=normalized,
# include_endpoints=endpoints,
# do_expensive_check=False,
# )

# return G._nodearrays_to_dict(node_ids, values)
nodes = seed.sample(list(G.nodes()), k)
for s in nodes:
# single source shortest paths
if weight is None: # use BFS
S, P, sigma, _ = nxbc._single_source_shortest_path_basic(G, s)
else: # use Dijkstra's algorithm
S, P, sigma, _ = nxbc._single_source_dijkstra_path_basic(G, s, weight)
# accumulation
if endpoints:
betweenness, _ = nxbc._accumulate_endpoints(betweenness, S, P, sigma, s)
else:
betweenness, _ = nxbc._accumulate_basic(betweenness, S, P, sigma, s)

betweenness = nxbc._rescale(
betweenness,
len(G),
normalized=normalized,
directed=G.is_directed(),
k=k,
endpoints=endpoints,
)

return betweenness

# 3. TODO

0 comments on commit ad0a355

Please sign in to comment.