diff --git a/gunfolds/conversions.py b/gunfolds/conversions.py index 9edcd511..db34f3c0 100644 --- a/gunfolds/conversions.py +++ b/gunfolds/conversions.py @@ -845,3 +845,26 @@ def encode_list_sccs(glist, scc_members=None): # if there is an edge between SCCs in the produced graph and none in the measured for nonsingleton SCCs - no go s += ':- directed(X,Y,U), scc(X,K), scc(Y,L), K != L, sccsize(L,Z), Z > 1, not dag(K,L,N), u(U,N).' return s + +def Glag2CG(results): + """Converts lag graph format to gunfolds graph format, + and A and B matrices representing directed and bidirected edges weights. + + Args: + results (dict): A dictionary containing: + - 'graph': A 3D NumPy array of shape [N, N, 2] representing the graph structure. + - 'val_matrix': A NumPy array of shape [N, N, 2] storing edge weights. + + Returns: + tuple: (graph_dict, A_matrix, B_matrix) + """ + + graph_array = results['graph'] + bidirected_edges = np.where(graph_array == 'o-o', 1, 0).astype(int) + directed_edges = np.where(graph_array == '-->', 1, 0).astype(int) + + graph_dict = adjs2graph(np.transpose(directed_edges[:, :, 1]), np.transpose((bidirected_edges[:, :, 0]))) + A_matrix = results['val_matrix'][:, :, 1] + B_matrix = results['val_matrix'][:, :, 0] + + return graph_dict, A_matrix, B_matrix \ No newline at end of file