Skip to content

Commit

Permalink
Merge pull request #10 from ilabcode/fitting
Browse files Browse the repository at this point in the history
merge fitting branch to master
  • Loading branch information
samuelnehrer02 authored Oct 28, 2024
2 parents 5da4fe9 + 38201ca commit 400fb0c
Show file tree
Hide file tree
Showing 19 changed files with 537 additions and 200 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI_small.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ on:
push:
branches:
- dev
- fitting
tags: ['*']
pull_request:
branches:
- dev
- fitting
workflow_dispatch:
concurrency:
# Skip intermediate builds: always.
Expand Down
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
name = "ActiveInference"
uuid = "688b0e7a-0122-4325-8669-5ff08899a59e"
authors = ["Jonathan Ehrenreich Laursen", "Samuel William Nehrer"]
version = "0.0.4"
version = "0.1.0"

[deps]
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[compat]
julia = "1.10"
ActionModels = "0.5"
ActionModels = "0.6"
Distributions = "0.25"
IterTools = "1.10"
LinearAlgebra = "1"
LogExpFunctions = "0.3"
Random = "1"
ReverseDiff = "1.15"
julia = "1.10"
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

## Julia Package for Active Inference.
ActiveInference.jl is a new Julia package for the computational modeling of active inference. We provide the necessary infrastructure for defining active inference models, currently implemented as partially observable Markov decision processes. After defining the generative model, you can simulate actions using agent-based simulations. We also provide the functionality to fit experimental data to active inference models for parameter recovery.

![Maze Animation](.github/animation_maze.gif)
* Example visualization of an agent navigating a maze, inspired by the one described in [Bruineberg et al., 2018](https://www.sciencedirect.com/science/article/pii/S0022519318303151?via%3Dihub).
Left: A synthetic agent wants to reach the end of the maze environment while avoiding dark-colored locations.
Right: The agent's noisy prior expectations about the state of the environment parameterized by Dirichlet distributions are updated dynamically as it moves through the maze.

## News
#### Version 0.1.0 - October 2024
- All parameters are now recoverable using ReverseDiff
- Now fully compatible with ActionModels.jl v0.6.6

#### Coming Soon
- **Documentation**: Full documentation will be available within the next few weeks, including examples and tutorials.

## Installation
Install ActiveInference.jl using the Julia package manager:
````@example Introduction
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModelsExtensions/get_history.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This module extends the "get_history" functionality of the ActionModels package to work specifically with instances of the AIF type.
This extends the "get_history" function of the ActionModels package to work specifically with instances of the AIF type.
get_history(aif::AIF, target_states::Vector{String})
Retrieves a history for multiple states of an AIF agent.
Expand Down
3 changes: 1 addition & 2 deletions src/ActionModelsExtensions/get_parameters.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
This module extends the "get_parameters" functionality of the ActionModels package to work specifically with instances of the AIF type.
This extends the "get_parameters" function of the ActionModels package to work specifically with instances of the AIF type.
get_parameters(aif::AIF, target_parameters::Vector{String})
Retrieves multiple target parameters from an AIF agent.
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModelsExtensions/get_states.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
This module extends the "get_states" functionality of the ActionModels package to work specifically with instances of the AIF type.
This extends the "get_states" function of the ActionModels package to work specifically with instances of the AIF type.
get_states(aif::AIF, target_states::Vector{String})
Retrieves multiple states from an AIF agent.
Expand Down
6 changes: 3 additions & 3 deletions src/ActionModelsExtensions/give_inputs.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This is an experimental module to extend the give_inputs! functionality of ActionsModels.jl to work with instances of the AIF type.
This is extends the give_inputs! function of ActionsModels.jl to work with instances of the AIF type.
single_input!(aif::AIF, obs)
Give a single observation to an AIF agent.
Expand All @@ -22,7 +22,7 @@ function ActionModels.single_input!(aif::AIF, obs::Vector)
# if there is only one factor
if num_factors == 1
# Sample action from the action distribution
action = rand(action_distributions[1])
action = rand(action_distributions)

# If the agent has not taken any actions yet
if isempty(aif.action)
Expand All @@ -45,7 +45,7 @@ function ActionModels.single_input!(aif::AIF, obs::Vector)
end
# If the agent has not taken any actions yet
if isempty(aif.action)
push!(aif.action, sampled_actions)
aif.action = sampled_actions
else
# Put the action in the last element of the action vector
aif.action[end] = sampled_actions
Expand Down
8 changes: 4 additions & 4 deletions src/ActionModelsExtensions/reset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ using ActionModels

function ActionModels.reset!(aif::AIF)
# Reset the agent's state fields to initial conditions
aif.qs_current = array_of_any_uniform([size(aif.B[f], 1) for f in eachindex(aif.B)])
aif.qs_current = create_matrix_templates([size(aif.B[f], 1) for f in eachindex(aif.B)])
aif.prior = aif.D
aif.Q_pi = ones(Real,length(aif.policies)) / length(aif.policies)
aif.G = zeros(Real,length(aif.policies))
aif.action = Real[]
aif.Q_pi = ones(length(aif.policies)) / length(aif.policies)
aif.G = zeros(length(aif.policies))
aif.action = Int[]

# Clear the history in the states dictionary
for key in keys(aif.states)
Expand Down
8 changes: 4 additions & 4 deletions src/ActionModelsExtensions/set_parameters.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""
This module extends the "set_parameters!" functionality of the ActionModels package to work with instances of the AIF type.
This extends the "set_parameters!" function of the ActionModels package to work with instances of the AIF type.
set_parameters!(aif::AIF, target_param::String, param_value::Any)
set_parameters!(aif::AIF, target_param::String, param_value::Real)
Set a single parameter in the AIF agent
set_parameters!(aif::AIF, parameters::Dict{String, Any})
set_parameters!(aif::AIF, parameters::Dict{String, Real})
Set multiple parameters in the AIF agent
"""

using ActionModels

# Setting a single parameter
function ActionModels.set_parameters!(aif::AIF, target_param::String, param_value::Any)
function ActionModels.set_parameters!(aif::AIF, target_param::String, param_value::Real)
# Update the parameters dictionary
aif.parameters[target_param] = param_value

Expand Down
4 changes: 2 additions & 2 deletions src/ActiveInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using LinearAlgebra
using IterTools
using Random
using Distributions
using LogExpFunctions
using ReverseDiff

include("utils/maths.jl")
include("pomdp/struct.jl")
Expand All @@ -31,9 +33,7 @@ export # utils/create_matrix_templates.jl
normalize_arrays,

# utils/utils.jl
array_of_any,
array_of_any_zeros,
array_of_any_uniform,
onehot,
get_model_dimensions,

Expand Down
18 changes: 14 additions & 4 deletions src/Environments/EpistChainEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ mutable struct EpistChainEnv
end

function step!(env::EpistChainEnv, action_label::String)
# Get current location
y, x = env.current_loc
next_y, next_x = y, x

# Update location based on action
if action_label == "DOWN"
next_y = y < env.len_y ? y + 1 : y
elseif action_label == "UP"
Expand All @@ -27,38 +29,45 @@ function step!(env::EpistChainEnv, action_label::String)
next_x = x > 1 ? x - 1 : x
elseif action_label == "RIGHT"
next_x = x < env.len_x ? x + 1 : x
elseif action_label == "STAY"
elseif action_label == "STAY"
# No change in location
end

# Set new location
env.current_loc = (next_y, next_x)

# Observations
loc_obs = env.current_loc
cue2_names = ["Null", "reward_on_top", "reward_on_bottom"]
cue2_loc_names = ["L1","L2","L3","L4"]
cue2_locs = [(1, 3), (2, 4), (4, 4), (5, 3)]

# Map cue2 location names to indices
cue2_loc_idx = Dict(cue2_loc_names[1] => 1, cue2_loc_names[2] => 2, cue2_loc_names[3] => 3, cue2_loc_names[4] => 4)

# Get cue2 location
cue2_loc = cue2_locs[cue2_loc_idx[env.cue2]]
#-------------------------------------------------------

# Determine cue1 observation
if env.current_loc == env.cue1_loc
cue1_obs = env.cue2
else
cue1_obs = "Null"
end

# Reward conditions and locations
reward_conditions = ["TOP", "BOTTOM"]
reward_locations = [(2,6), (4,6)]
rew_cond_idx = Dict(reward_conditions[1] => 1, reward_conditions[2] => 2)
#-------------------------------------------------------

# Determine cue2 observation
if env.current_loc == cue2_loc
cue2_obs = cue2_names[rew_cond_idx[env.reward_condition] + 1]
else
cue2_obs = "Null"
end

# Determine reward observation
if env.current_loc == reward_locations[1]
if env.reward_condition == "TOP"
reward_obs = "Cheese"
Expand All @@ -75,11 +84,12 @@ function step!(env::EpistChainEnv, action_label::String)
reward_obs = "Null"
end


# Return observations
return loc_obs, cue1_obs, cue2_obs, reward_obs
end

function reset_env!(env::EpistChainEnv)
# Reset environment to initial location
env.current_loc = env.init_loc
println("Re-initialized location to $(env.init_loc)")
return env.current_loc
Expand Down
4 changes: 2 additions & 2 deletions src/Environments/TMazeEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function reset_TMaze!(env::TMazeEnv; state=nothing)
env.reward_condition = onehot(env._reward_condition_idx, env.num_reward_conditions)

# Initialize the full state array
full_state = array_of_any(env.num_factors)
full_state = Vector{Any}(undef, env.num_factors)
full_state[env.location_factor_id] = loc_state
full_state[env.trial_factor_id] = env.reward_condition

Expand Down Expand Up @@ -182,7 +182,7 @@ end

function construct_state(env::TMazeEnv, state_tuple)
# Create an array of any
state = array_of_any(env.num_factors)
state = Vector{Any}(undef, env.num_factors)

# Populate the state array with one-hot encoded vectors
for (f, ns) in enumerate(env.num_states)
Expand Down
Loading

0 comments on commit 400fb0c

Please sign in to comment.