From c4188aba6fcd8765e30b4550cd72fce9f8ac3828 Mon Sep 17 00:00:00 2001 From: Aurora Rossi <65721467+aurorarossi@users.noreply.github.com> Date: Wed, 11 Oct 2023 11:07:20 +0200 Subject: [PATCH] Add TemporalBrains dataset (#222) * init * Add create_dataset * Export TemporalBrains * add struct and constructor * Optimized version * Update docs * Add tests * Add spaces Co-authored-by: Carlo Lucibello * Add link * Improve docstring * Add `TemporalBrains` to docs * Improve * Fix & for `julia 1.6` --------- Co-authored-by: Carlo Lucibello --- docs/src/datasets/graphs.md | 1 + src/MLDatasets.jl | 3 + src/datasets/graphs/temporalbrains.jl | 81 +++++++++++++++++++++++++++ test/datasets/graphs_no_ci.jl | 13 +++++ 4 files changed, 98 insertions(+) create mode 100644 src/datasets/graphs/temporalbrains.jl diff --git a/docs/src/datasets/graphs.md b/docs/src/datasets/graphs.md index 47f732d7..d235c6f6 100644 --- a/docs/src/datasets/graphs.md +++ b/docs/src/datasets/graphs.md @@ -32,4 +32,5 @@ Reddit TUDataset METRLA PEMSBAY +TemporalBrains ``` diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index 6b1e751a..412b79c6 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -135,6 +135,8 @@ include("datasets/graphs/metrla.jl") export METRLA include("datasets/graphs/pemsbay.jl") export PEMSBAY +include("datasets/graphs/temporalbrains.jl") +export TemporalBrains # Meshes @@ -156,6 +158,7 @@ function __init__() __init__tudataset() __init__metrla() __init__pemsbay() + __init__temporalbrains() # misc __init__iris() diff --git a/src/datasets/graphs/temporalbrains.jl b/src/datasets/graphs/temporalbrains.jl new file mode 100644 index 00000000..0e9ce533 --- /dev/null +++ b/src/datasets/graphs/temporalbrains.jl @@ -0,0 +1,81 @@ +function __init__temporalbrains() + DEPNAME = "TemporalBrains" + LINK = "http://www-sop.inria.fr/members/Aurora.Rossi/index.html" + register(ManualDataDep(DEPNAME, + """ + Dataset: $DEPNAME + Website : $LINK + """)) +end + + +function tb_datadir(dir = nothing) + dir = isnothing(dir) ? datadep"TemporalBrains" : dir + LINK = "http://www-sop.inria.fr/members/Aurora.Rossi/data/LabelledTBN.zip" + if length(readdir((dir))) == 0 + DataDeps.fetch_default(LINK, dir) + currdir = pwd() + cd(dir) # Needed since `unpack` extracts in working dir + DataDeps.unpack(joinpath(dir, "LabelledTBN.zip")) + # conditions when unzipped folder is our required data dir + cd(currdir) + end + @assert isdir(dir) + return dir +end + + +function create_tbdataset(dir, thre) + name_filelabels = joinpath(dir, "LabelledTBN", "labels.txt") + filelabels = open(name_filelabels, "r") + temporalgraphs = Vector{MLDatasets.TemporalSnapshotsGraph}(undef, 1000) + + for (i,line) in enumerate(eachline(filelabels)) + id, gender, age = split(line) + name_network_file = joinpath(dir, "LabelledTBN", "networks", id * "_ws60_wo30_tuk0_pearson_schaefer_100.txt") + + data = readdlm(name_network_file,',',Float32; skipstart = 1) + + data_thre = view(data,view(data,:,4) .> thre,:) + data_thre_int = Int.(view(data_thre,:,1:3)) + + activation = [zeros(Float32, 102) for _ in 1:27] + for t in 1:27 + for n in 1:102 + rows = ((view(data_thre_int,:,1).==n) .& (view(data_thre_int,:,3).==t)) + activation[t][n] = mean(view(data_thre,rows,4)) + end + end + + temporalgraphs[i] = TemporalSnapshotsGraph(num_nodes=ones(Int,27)*102, edge_index = (data_thre_int[:,1], data_thre_int[:,2], data_thre_int[:,3]), node_data= activation, graph_data= (g = gender, a = age)) + end + return temporalgraphs +end + +""" + TemporalBrains(; dir = nothing, threshold_value = 0.6) + +The TemporalBrains dataset contains a collection of temporal brain networks (as `TemporalSnapshotsGraph`s) of 1000 subjects obtained from resting-state fMRI data from the [Human Connectome Project (HCP)](https://www.humanconnectome.org/study/hcp-young-adult/document/extensively-processed-fmri-data-documentation). + +The number of nodes is fixed for each of the 27 snapshots at 102, while the edges change over time. + +For each `Graph` snapshot, the feature of a node represents the average activation of the node during that snapshot and it is contained in `Graphs.node_data`. + +Each `TemporalSnapshotsGraph` has a label representing their gender ("M" for male and "F" for female) and age range (22-25, 26-30, 31-35 and 36+) contained as a named tuple in `graph_data`. + +The `threshold_value` is used to binarize the edge weights and is set to 0.6 by default. +""" +struct TemporalBrains <: AbstractDataset + graphs::Vector{MLDatasets.TemporalSnapshotsGraph} +end + +function TemporalBrains(;threshold_value = 0.6, dir=nothing) + create_default_dir("TemporalBrains") + dir = tb_datadir(dir) + graphs = create_tbdataset(dir, threshold_value) + return TemporalBrains(graphs) +end + +Base.length(d::TemporalBrains) = length(d.graphs) +Base.getindex(d::TemporalBrains, ::Colon) = d.graphs[1] +Base.getindex(d::TemporalBrains, i) = getindex(d.graphs, i) diff --git a/test/datasets/graphs_no_ci.jl b/test/datasets/graphs_no_ci.jl index 5b98f69c..17da44f8 100644 --- a/test/datasets/graphs_no_ci.jl +++ b/test/datasets/graphs_no_ci.jl @@ -363,4 +363,17 @@ end @test g.num_nodes == 325 @test g.num_edges == 2694 @test all(g.node_data.features[1][:,:,1][2:end,1] == g.node_data.targets[1][:,:,1][1:end-1]) +end + +@testset "TemporalBrains" begin + data = TemporalBrains() + @test data isa AbstractDataset + @test length(data) == 1000 + g = data[1] + @test g isa MLDatasets.TemporalSnapshotsGraph + + @test g.num_nodes == [102 for _ in 1:27] + @test g.num_snapshots == 27 + @test g.snapshots[1] isa MLDatasets.Graph + @test length(g.snapshots[1].node_data) == 102 end \ No newline at end of file