Skip to content

Commit

Permalink
Add TemporalBrains dataset (JuliaML#222)
Browse files Browse the repository at this point in the history
* init

* Add create_dataset

* Export TemporalBrains

* add struct and constructor

* Optimized version

* Update docs

* Add tests

* Add spaces

Co-authored-by: Carlo Lucibello <[email protected]>

* Add link

* Improve docstring

* Add `TemporalBrains` to docs

* Improve

* Fix & for `julia 1.6`

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
aurorarossi and CarloLucibello authored Oct 11, 2023
1 parent 60a2f05 commit c4188ab
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/src/datasets/graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ Reddit
TUDataset
METRLA
PEMSBAY
TemporalBrains
```
3 changes: 3 additions & 0 deletions src/MLDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ include("datasets/graphs/metrla.jl")
export METRLA
include("datasets/graphs/pemsbay.jl")
export PEMSBAY
include("datasets/graphs/temporalbrains.jl")
export TemporalBrains

# Meshes

Expand All @@ -156,6 +158,7 @@ function __init__()
__init__tudataset()
__init__metrla()
__init__pemsbay()
__init__temporalbrains()

# misc
__init__iris()
Expand Down
81 changes: 81 additions & 0 deletions src/datasets/graphs/temporalbrains.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
function __init__temporalbrains()
DEPNAME = "TemporalBrains"
LINK = "http://www-sop.inria.fr/members/Aurora.Rossi/index.html"
register(ManualDataDep(DEPNAME,
"""
Dataset: $DEPNAME
Website : $LINK
"""))
end


function tb_datadir(dir = nothing)
dir = isnothing(dir) ? datadep"TemporalBrains" : dir
LINK = "http://www-sop.inria.fr/members/Aurora.Rossi/data/LabelledTBN.zip"
if length(readdir((dir))) == 0
DataDeps.fetch_default(LINK, dir)
currdir = pwd()
cd(dir) # Needed since `unpack` extracts in working dir
DataDeps.unpack(joinpath(dir, "LabelledTBN.zip"))
# conditions when unzipped folder is our required data dir
cd(currdir)
end
@assert isdir(dir)
return dir
end


function create_tbdataset(dir, thre)
name_filelabels = joinpath(dir, "LabelledTBN", "labels.txt")
filelabels = open(name_filelabels, "r")
temporalgraphs = Vector{MLDatasets.TemporalSnapshotsGraph}(undef, 1000)

for (i,line) in enumerate(eachline(filelabels))
id, gender, age = split(line)
name_network_file = joinpath(dir, "LabelledTBN", "networks", id * "_ws60_wo30_tuk0_pearson_schaefer_100.txt")

data = readdlm(name_network_file,',',Float32; skipstart = 1)

data_thre = view(data,view(data,:,4) .> thre,:)
data_thre_int = Int.(view(data_thre,:,1:3))

activation = [zeros(Float32, 102) for _ in 1:27]
for t in 1:27
for n in 1:102
rows = ((view(data_thre_int,:,1).==n) .& (view(data_thre_int,:,3).==t))
activation[t][n] = mean(view(data_thre,rows,4))
end
end

temporalgraphs[i] = TemporalSnapshotsGraph(num_nodes=ones(Int,27)*102, edge_index = (data_thre_int[:,1], data_thre_int[:,2], data_thre_int[:,3]), node_data= activation, graph_data= (g = gender, a = age))
end
return temporalgraphs
end

"""
TemporalBrains(; dir = nothing, threshold_value = 0.6)
The TemporalBrains dataset contains a collection of temporal brain networks (as `TemporalSnapshotsGraph`s) of 1000 subjects obtained from resting-state fMRI data from the [Human Connectome Project (HCP)](https://www.humanconnectome.org/study/hcp-young-adult/document/extensively-processed-fmri-data-documentation).
The number of nodes is fixed for each of the 27 snapshots at 102, while the edges change over time.
For each `Graph` snapshot, the feature of a node represents the average activation of the node during that snapshot and it is contained in `Graphs.node_data`.
Each `TemporalSnapshotsGraph` has a label representing their gender ("M" for male and "F" for female) and age range (22-25, 26-30, 31-35 and 36+) contained as a named tuple in `graph_data`.
The `threshold_value` is used to binarize the edge weights and is set to 0.6 by default.
"""
struct TemporalBrains <: AbstractDataset
graphs::Vector{MLDatasets.TemporalSnapshotsGraph}
end

function TemporalBrains(;threshold_value = 0.6, dir=nothing)
create_default_dir("TemporalBrains")
dir = tb_datadir(dir)
graphs = create_tbdataset(dir, threshold_value)
return TemporalBrains(graphs)
end

Base.length(d::TemporalBrains) = length(d.graphs)
Base.getindex(d::TemporalBrains, ::Colon) = d.graphs[1]
Base.getindex(d::TemporalBrains, i) = getindex(d.graphs, i)
13 changes: 13 additions & 0 deletions test/datasets/graphs_no_ci.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,4 +363,17 @@ end
@test g.num_nodes == 325
@test g.num_edges == 2694
@test all(g.node_data.features[1][:,:,1][2:end,1] == g.node_data.targets[1][:,:,1][1:end-1])
end

@testset "TemporalBrains" begin
data = TemporalBrains()
@test data isa AbstractDataset
@test length(data) == 1000
g = data[1]
@test g isa MLDatasets.TemporalSnapshotsGraph

@test g.num_nodes == [102 for _ in 1:27]
@test g.num_snapshots == 27
@test g.snapshots[1] isa MLDatasets.Graph
@test length(g.snapshots[1].node_data) == 102
end

0 comments on commit c4188ab

Please sign in to comment.