-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfunctions.R
412 lines (319 loc) · 12.1 KB
/
functions.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
########################################################
# Introduction #
########################################################
# Authors: Paolo Lapo Cerni, Lorenzo Vigorelli, Arman Singh Bains
# Date: 28/09/2024
# Description: This script contains the functions used to implement the K2 algorithm.
# The K2 algorithm is a scoring-based algorithm used to learn the structure of a Bayesian network from data.
########################################################
# Load Required Libraries #
########################################################
library(dplyr)
library(readr)
library(bnlearn)
library(bnstruct)
library(Rgraphviz)
library(foreach)
library(doParallel)
library(ggplot2)
library(tidyr)
library(purrr)
########################################################
# K2 Scoring Function(s) #
########################################################
## Scoring function for exact results
scoring_function <- function(data, x_i, parents){
# Find the possible values of the attribute i
r_i <- data |> distinct(data[[x_i]]) |> nrow()
if (length(parents) == 0){
N <- nrow(data)
num1 <- data |>
group_by(data[[x_i]]) |>
count() |>
mutate(a = factorial(n)) |>
ungroup() |>
pull(a) |>
prod()
# To do consistently with the other case
} else {
alpha <- data |> group_by(data[c(x_i, parents)]) |> count()
N <- alpha |> group_by(alpha[parents]) |>
summarise(N = sum(n), .groups = "drop") |>
select(N)
num1 <- alpha |>
group_by(alpha[parents]) |>
summarise(alpha = prod(factorial(n)), .groups = "drop") |>
pull(alpha)
}
# Calculate the score
den <- sapply(N, function(x, r) factorial(x + r - 1), r=r_i)
num2 <- factorial(r_i - 1)
return(prod(num2 / den * num1))
}
## Log scoring function to avoid numerical issues
log_scoring_function <- function(data, x_i, parents){
# Find the possible values of the attribute i
r_i <- data |> distinct(data[[x_i]]) |> nrow()
if (length(parents) == 0){
N <- nrow(data)
num1 <- data |>
group_by(data[[x_i]]) |>
count() |>
mutate(a = lfactorial(n)) |>
ungroup() |>
pull(a) |>
sum()
# To do consistently with the other case
} else {
alpha <- data |> group_by(data[c(x_i, parents)]) |> count()
N <- alpha |> group_by(alpha[parents]) |>
summarise(N = sum(n), .groups = "drop") |>
select(N)
num1 <- alpha |>
group_by(alpha[parents]) |>
summarise(alpha = sum(lfactorial(n)), .groups = "drop") |>
pull(alpha)
}
# Calculate the score
den <- sapply(N, function(x, r) lfactorial(x + r - 1), r=r_i)
num2 <- lfactorial(r_i - 1)
return(sum(num1 - den + num2))
}
########################################################
# K2 Algorithm Implementation #
########################################################
K2_algorithm <- function(data, max_parents){
names <- colnames(data)
results <- c()
for (i in 1:ncol(data)){
x_i <- names[i]
parents <- c()
p_old <- log_scoring_function(data, x_i, parents)
proceed <- TRUE
while (proceed){
# Check if the maximum number of parents has been reached
if (length(parents) >= max_parents){
break
}
# Compute the predecessors
predecessors <- setdiff(names[0:(i-1)], parents)
if (length(predecessors) == 0){
break
}
# Try adding a new parent
scores <- sapply(predecessors, function(z) log_scoring_function(data, x_i, c(z, parents)))
p_new <- max(scores)
# If the score increases, add the parent
if (p_new > p_old){
p_old <- p_new
parents <- c(parents, names[which.max(scores)])
} else {
proceed <- FALSE
}
} # end while
results[[x_i]] <- parents
} # end for
return(list(names=names, parents_list=results))
}
########################################################
# K2 Pipeline Implementation #
########################################################
# Convert the parent-child relationships to a DAG
get_dag <- function(names, parents_list){
dag <- empty.graph(names)
# Add arcs based on the parent-child relationships
for (child in names) {
parents <- parents_list[[child]]
if (length(parents) > 0) {
for (parent in parents) {
dag <- set.arc(dag, from = parent, to = child)
}
}
}
return(dag)
}
K2_to_dag <- function(data, max_parents){
# Run the K2 algorithm
results <- K2_algorithm(data, max_parents)
names <- results$names
parents_list <- results$parents_list
# Convert the parent-child relationships to a DAG
dag <- get_dag(names, parents_list)
# Get the score of the DAG
score <- score(dag, data)
return(list(dag=dag, score=score))
}
K2_pipeline <- function(data, max_parents, max_iter, mode="local", n_cores=-1, return_history=FALSE){
# Check if the mode is valid
if (mode != "local" && mode != "parallel") {
stop("Invalid mode. Please use 'local' or 'parallel'.")
}
# Return history only works in local mode
if (return_history && mode == "parallel") {
stop("Return history only works in local mode.")
}
# If the data does not have column names, assign them
if (is.null(colnames(data))) {
colnames(data) <- paste0("X", 1:ncol(data))
}
# Randomly shuffle data rows
data <- data[sample(nrow(data)), ]
# Try different random orders of the columns
if (mode == "local") {
history <- array(data = NA, dim = c(max_iter))
# Initialize the best score and DAG
score_best <- -Inf
dag_best <- NULL
# Try different random orders of the columns
for (i in 1:max_iter){
if (i > 1) {
data <- data |> select(sample(colnames(data)))
}
result <- K2_to_dag(data, max_parents)
score <- result$score
# Update the best DAG
if (score > score_best){
score_best <- score
dag_best <- result$dag
}
# Save the history
history[i] <- score_best
}
if (return_history) {
return(list(dag=dag_best, score=score_best, history=history))
}
return(list(dag=dag_best, score=score_best))
} else if (mode == "parallel") {
# Setup parallel processing
if (n_cores == -1) {
n_cores <- detectCores()
}
cl <- makeCluster(n_cores)
registerDoParallel(cl)
# Try different random orders of the columns in parallel
results <- mclapply(1:max_iter, function(i) {
data_sampled <- data |> select(sample(colnames(data)))
result <- K2_to_dag(data_sampled, max_parents)
return(result)
}, mc.cores = n_cores)
# Initialize best score and DAG
score_best <- -Inf
dag_best <- NULL
# Find the best result
for (res in results) {
if (res$score > score_best) {
score_best <- res$score
dag_best <- res$dag
}
}
return(list(dag=dag_best, score=score_best))
}
return(NULL)
}
########################################################
# BNstruct function Function(s) #
########################################################
learning <- function (data, maxParent, algo = "k2", percentage = 1, plot = F) {
# Check if the percentage is within the valid range
if (percentage < 0 || percentage > 1) {
stop("Percentage must be between 0 and 1.")
}
# Subset the data based on the given percentage
# Shuffle the rows of the data randomly
num_rows <- nrow(data)
data <- data[1:as.integer(percentage * num_rows), ]
# Determine the starting value for the dataset
minValue <- min(data, na.rm = TRUE)
startsFrom <- ifelse(minValue == 0, 0, 1)
# Calculate the number of unique values for each column in the data
sizes <- sapply(data, function(x) length(unique(x)))
sizes <- as.numeric(sizes)
# Create a BNDataset object with the given data
dataset <- BNDataset(data = data,
discreteness = rep(TRUE, ncol(data)),
variables = colnames(data),
starts.from = startsFrom,
node.sizes = sizes)
# Learn the network structure using the specified algorithm
dag <- learn.network(algo = algo, x = dataset, max.parents = maxParent)
# Create an empty graph and set its adjacency matrix
net = empty.graph(names(data))
amat(net) <- dag(dag)
# Convert data columns to factors
for (i in 1:length(names(data))) {
name = names(data)[i]
data[, name] = as.factor(as.character(data[, name]))
}
# Calculate the score of the network
score <- score(net, data = data)
# Plot the DAG if the plot parameter is TRUE
if (plot) {
plot(dag)
}
# Return the DAG and its score as a list
return(list(dag = net, score = score))
}
########################################################
# Compute shd and plots Function(s) #
########################################################
# Function to compute the Structural Hamming Distance (SHD) between two networks
computeShdSingle <- function(theor, empir, plot=FALSE) {
# Convert the theoretical and empirical networks to adjacency matrices
DAG1 <- amat(theor)
DAG2 <- amat(empir)
# Compute the Structural Hamming Distance between the two adjacency matrices
shd <- shd(DAG1, DAG2)
# If plot is TRUE, visualize the comparison between the two networks
if (plot) {
graphviz.compare(theor, empir, shape="rectangle")
}
# Return the computed SHD
return(shd)
}
########################################################
# Compute shd multiples and plots Function(s) #
########################################################
# Function to compute the Structural Hamming Distance (SHD) between a theoretical model and a list of empirical models
computeShd <- function(theor, empirList, plot=FALSE) {
# Convert the theoretical model to an adjacency matrix
DAG1 <- amat(theor)
# Compute the SHD between the theoretical model and each empirical model in the list
shd_values <- sapply(empirList, function(empir) shd(DAG1, amat(empir)))
# If plot is TRUE, visualize the comparison between the models
if (plot) {
# Create a list of graphs to visualize, including the theoretical model and the empirical models
plot_list <- c(list(theor), empirList)
# Number of empirical models
num_empirical <- length(empirList)
# Main titles for the graphs
titles <- c("THEORETICAL MODEL", paste("EMPIRICAL MODEL", seq_len(num_empirical)))
# Subtitles for the graphs, including the SHD values
subtitles <- c(paste("SHD =", "0"), paste("SHD =", shd_values))
# Visualize the comparative graphs using graphviz.compare
do.call(graphviz.compare, c(list(plot_list[[1]]), plot_list[-1],
list(shape = "rectangle",
main = titles,
sub = subtitles,
diff.args = list(tp.lwd = 2, tp.col = "green", fn.col = "orange"))))
}
# Return the computed SHD values
return(shd_values)
}
########################################################
# Compute scores and DAGs #
########################################################
# Function to compute scores and DAGs for a given algorithm, data, and percentages
compute_scores_and_dags <- function(algo, data, percentages) {
# Apply the learning function to each percentage and store the results
results <- lapply(percentages, function(p) learning(data = data, algo = algo, maxParent = 3, percentage = p))
# Extract the scores from the results
scores <- sapply(results, function(res) res$score)
# Extract the DAGs from the results
dags <- lapply(results, function(res) res$dag)
# Normalize the scores
min_score <- min(scores)
max_score <- max(scores)
scores_normalized <- (scores - min_score) / (max_score - min_score)
# Return a list containing the normalized scores and the DAGs
return(list(scores = scores_normalized, dags = dags))
}