From e2cd5098985c2d966b86d0f507b36e6c98bfbacd Mon Sep 17 00:00:00 2001 From: Viktoria Zemliak Date: Sun, 19 May 2024 21:09:36 +0200 Subject: [PATCH 1/4] add multithreading to speed up --- src/pattern_detection.jl | 10 +++++----- src/runner.jl | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/pattern_detection.jl b/src/pattern_detection.jl index 5e276b14..994a633a 100644 --- a/src/pattern_detection.jl +++ b/src/pattern_detection.jl @@ -4,7 +4,7 @@ function slow_filter(img) end function fast_filter!(dat_filtered, dat, kernel) # - r = Images.ImageFiltering.ComputationalResources.CPU1(Images.ImageFiltering.FIR()) + #r = Images.ImageFiltering.ComputationalResources.CPU1(Images.ImageFiltering.FIR()) DSP.filt!(dat_filtered, kernel[1].data.parent, dat) return dat_filtered end @@ -42,9 +42,9 @@ function mult_chan_pattern_detector_probability(dat, stat_function, evts; n_perm d_perm = similar(dat, size(dat, 1), n_permutations) @debug "starting permutation loop" # We permute data for all events in advance - for ch = 1:size(dat, 1) - for perm = 1:n_permutations - + + Threads.@threads for perm = 1:n_permutations + for ch = 1:size(dat, 1) sortix = shuffle(1:size(dat_filtered, 1)) d_perm[ch, perm] = stat_function( fast_filter!(dat_filtered, @view(dat_padded[ch, sortix, :]), kernel), @@ -54,7 +54,7 @@ function mult_chan_pattern_detector_probability(dat, stat_function, evts; n_perm end mean_d_perm = mean(d_perm, dims = 2)[:, 1] - for n in names(evts) + Threads.@threads for n in names(evts) sortix = sortperm(evts[!, n]) col = fill(NaN, size(dat, 1)) for ch = 1:size(dat, 1) diff --git a/src/runner.jl b/src/runner.jl index 19999d11..48e49a38 100644 --- a/src/runner.jl +++ b/src/runner.jl @@ -1,3 +1,7 @@ +# FOR MULTITHREADING: +# run: >julia -t [n_threads] +# instead of [n_threads] write a desired number of threads (<= amount of CPU cores) + include("setup.jl") include("pattern_detection.jl") @@ -29,6 +33,13 @@ fid = h5open("data/mult.hdf5", "r") dat2 = read(fid["data"]["mult.hdf5"]) close(fid) +# Data for multiple channels (only fixations) +# 128 channels x 769 time x 2508 events + +fid = h5open("data_fixations/data_fixations.hdf5", "r") +dat_fix = read(fid["data"]["data_fixations.hdf5"]) +close(fid) + # PATTERN DECTECTION 1 # for single channel data @@ -65,6 +76,11 @@ evts_init = CSV.read("data/events_init.csv", DataFrame) evts_d = mult_chan_pattern_detector_probability(dat2[:, :, ix], Images.entropy, evts) end +# PATTERN DETECTION 4 (FOR FIXATIONS ONLY) +@time begin + evts_d = mult_chan_pattern_detector_probability(dat_fix, Images.entropy, evts) +end + begin f = Figure() ax = CairoMakie.Axis(f[1, 1], xlabel = "Channels", ylabel = "Sorting event variables") From eb9a3e55c64f5b81e56319da2871abba8e0b4b5c Mon Sep 17 00:00:00 2001 From: Viktoria Zemliak Date: Sun, 19 May 2024 22:49:34 +0200 Subject: [PATCH 2/4] remove inplace changes where multithreading --- src/pattern_detection.jl | 10 +++++----- src/runner.jl | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/pattern_detection.jl b/src/pattern_detection.jl index d2a90bfc..d53e9f51 100644 --- a/src/pattern_detection.jl +++ b/src/pattern_detection.jl @@ -4,10 +4,10 @@ function slow_filter(img) end -function fast_filter!(dat_filtered, kernel, dat) # +function fast_filter(kernel, dat) # #r = Images.ImageFiltering.ComputationalResources.CPU1(Images.ImageFiltering.FIR()) - DSP.filt!(dat_filtered, kernel[1].data.parent, dat) - return dat_filtered + filter_result = DSP.filt(kernel[1].data.parent, dat) + return filter_result end function single_chan_pattern_detector(dat, func, evts) @@ -48,7 +48,7 @@ function mult_chan_pattern_detector_probability(dat, stat_function, evts; n_perm for ch = 1:size(dat, 1) sortix = shuffle(1:size(dat_filtered, 1)) d_perm[ch, perm] = stat_function( - fast_filter!(dat_filtered, kernel, @view(dat_padded[ch, sortix, :])), + fast_filter(kernel, @view(dat_padded[ch, sortix, :])), ) @show ch, perm end @@ -59,7 +59,7 @@ function mult_chan_pattern_detector_probability(dat, stat_function, evts; n_perm sortix = sortperm(evts[!, n]) col = fill(NaN, size(dat, 1)) for ch = 1:size(dat, 1) - fast_filter!(dat_filtered, kernel, @view(dat_padded[ch, sortix, :])) + fast_filter(kernel, @view(dat_padded[ch, sortix, :])) d_emp = stat_function(dat_filtered) col[ch] = abs(d_emp - mean_d_perm[ch]) diff --git a/src/runner.jl b/src/runner.jl index 48e49a38..66d5fa60 100644 --- a/src/runner.jl +++ b/src/runner.jl @@ -36,7 +36,7 @@ close(fid) # Data for multiple channels (only fixations) # 128 channels x 769 time x 2508 events -fid = h5open("data_fixations/data_fixations.hdf5", "r") +fid = h5open("data/data_fixations.hdf5", "r") dat_fix = read(fid["data"]["data_fixations.hdf5"]) close(fid) @@ -77,6 +77,7 @@ evts_init = CSV.read("data/events_init.csv", DataFrame) end # PATTERN DETECTION 4 (FOR FIXATIONS ONLY) +# 10 cores: 50 s @time begin evts_d = mult_chan_pattern_detector_probability(dat_fix, Images.entropy, evts) end From d0b06a6461ea7db521473ed406fb96273e3cb1b3 Mon Sep 17 00:00:00 2001 From: Viktoria Zemliak Date: Sun, 19 May 2024 22:53:17 +0200 Subject: [PATCH 3/4] remove inplace change from one more function --- src/pattern_detection.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pattern_detection.jl b/src/pattern_detection.jl index d53e9f51..5cebeee5 100644 --- a/src/pattern_detection.jl +++ b/src/pattern_detection.jl @@ -59,8 +59,9 @@ function mult_chan_pattern_detector_probability(dat, stat_function, evts; n_perm sortix = sortperm(evts[!, n]) col = fill(NaN, size(dat, 1)) for ch = 1:size(dat, 1) - fast_filter(kernel, @view(dat_padded[ch, sortix, :])) - d_emp = stat_function(dat_filtered) + d_emp = stat_function( + fast_filter(kernel, @view(dat_padded[ch, sortix, :])) + ) col[ch] = abs(d_emp - mean_d_perm[ch]) print(ch, " ") From 5af4bb5036b10c02ba3cad0a8d7500d144c49b5b Mon Sep 17 00:00:00 2001 From: Viktoria Zemliak Date: Mon, 20 May 2024 00:17:53 +0200 Subject: [PATCH 4/4] remove unnecessary variable d_emp in mult_chan_pattern_detector_probability --- src/pattern_detection.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pattern_detection.jl b/src/pattern_detection.jl index 5cebeee5..01610821 100644 --- a/src/pattern_detection.jl +++ b/src/pattern_detection.jl @@ -59,11 +59,10 @@ function mult_chan_pattern_detector_probability(dat, stat_function, evts; n_perm sortix = sortperm(evts[!, n]) col = fill(NaN, size(dat, 1)) for ch = 1:size(dat, 1) - d_emp = stat_function( + col[ch] = abs(stat_function( fast_filter(kernel, @view(dat_padded[ch, sortix, :])) - ) + ) - mean_d_perm[ch]) - col[ch] = abs(d_emp - mean_d_perm[ch]) print(ch, " ") end println(n)