From dc2070527ae156d1e6d52fdab3277bf084fbfcac Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Fri, 30 Aug 2024 16:41:38 +0200 Subject: [PATCH 01/25] Add LogExpFunction softmax --- Project.toml | 4 +++- src/ActiveInference.jl | 1 + src/pomdp/POMDP.jl | 4 ++-- src/pomdp/inference.jl | 8 ++++---- src/utils/maths.jl | 22 ++-------------------- src/utils/utils.jl | 8 ++++---- 6 files changed, 16 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index 796a568..ac3eef5 100644 --- a/Project.toml +++ b/Project.toml @@ -8,12 +8,14 @@ ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -julia = "1.10" ActionModels = "0.5" Distributions = "0.25" IterTools = "1.10" LinearAlgebra = "1" +LogExpFunctions = "0.3" Random = "1" +julia = "1.10" diff --git a/src/ActiveInference.jl b/src/ActiveInference.jl index 5b4d185..c6b843a 100644 --- a/src/ActiveInference.jl +++ b/src/ActiveInference.jl @@ -5,6 +5,7 @@ using LinearAlgebra using IterTools using Random using Distributions +using LogExpFunctions include("utils/maths.jl") include("pomdp/struct.jl") diff --git a/src/pomdp/POMDP.jl b/src/pomdp/POMDP.jl index 9935074..07cab84 100644 --- a/src/pomdp/POMDP.jl +++ b/src/pomdp/POMDP.jl @@ -45,7 +45,7 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) ### Pass action marginals through softmax function to get action probabilities for factor in 1:n_factors - action_p[factor] = softmax(log_action_marginals[factor] * alpha) + action_p[factor] = softmax(log_action_marginals[factor] * alpha, dims=1) action_distribution[factor] = Distributions.Categorical(action_p[factor]) end @@ -75,7 +75,7 @@ function action_pomdp!(aif::AIF, obs::Vector{Int64}) ### Pass action marginals through softmax function to get action probabilities for factor in 1:n_factors - action_p[factor] = softmax(log_action_marginals[factor] * alpha) + action_p[factor] = softmax(log_action_marginals[factor] * alpha, dims=1) action_distribution[factor] = Distributions.Categorical(action_p[factor]) end diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 2822b74..6de9329 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -113,7 +113,7 @@ function fixed_point_iteration(A::Vector{Array{<:Real}}, obs::Vector{Vector{Real # Single factor condition if n_factors == 1 qL = dot_product(likelihood, qs[1]) - return [softmax(qL .+ prior[1])] + return [softmax(qL .+ prior[1], dims=1)] else # Run Iteration curr_iter = 0 @@ -129,7 +129,7 @@ function fixed_point_iteration(A::Vector{Array{<:Real}}, obs::Vector{Vector{Real for i in 1:size(qs[factor], 1) qL[i] = sum([LL_tensor[indices...] / qs[factor][i] for indices in Iterators.product([1:size(LL_tensor, dim) for dim in 1:n_factors]...) if indices[factor] == i]) end - qs[factor] = softmax(qL + prior[factor]) + qs[factor] = softmax(qL + prior[factor], dims=1) end # Recompute free energy @@ -238,7 +238,7 @@ function update_posterior_policies( end - q_pi = softmax(G * gamma + lnE) + q_pi = softmax(G * gamma + lnE, dims=1) return q_pi, G end @@ -371,7 +371,7 @@ function sample_action(q_pi, policies, num_controls; action_selection="stochasti selected_policy[factor_i] = select_highest(action_marginals[factor_i]) elseif action_selection == "stochastic" log_marginal_f = capped_log(action_marginals[factor_i]) - p_actions = softmax(log_marginal_f * alpha) + p_actions = softmax(log_marginal_f * alpha, dims=1) selected_policy[factor_i] = action_select(p_actions) end end diff --git a/src/utils/maths.jl b/src/utils/maths.jl index a1b16e7..1ef1e59 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -4,25 +4,6 @@ function normalize_distribution(distribution) end -"""Sampling Function""" -function sample_category(probabilities) - rand_num = rand() - cum_probabilities = cumsum(probabilities) - category = findfirst(x -> x > rand_num, cum_probabilities) - return category -end - -"""Softmax Function""" -function softmax(dist) - - output = dist .- maximum(dist, dims = 1) - output = exp.(output) - output = output ./ sum(output, dims = 1) - - return output -end - - """ capped_log(x::Real) @@ -103,7 +84,8 @@ function softmax_array(arr) # Iterate through each index in arr and apply softmax for idx in eachindex(arr) - output[idx] = softmax(arr[idx]) + output[idx] = softmax(arr[idx], dims=1) + end return output diff --git a/src/utils/utils.jl b/src/utils/utils.jl index b8422ac..8f14702 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -127,17 +127,17 @@ function check_probability_distribution(Array::Vector{<:Array{T, N}}) where {T<: end """ -Check if the array of real number arrays is a proper probability distribution. +Check if the vector of vectors is a proper probability distribution. # Arguments -- (Array::Vector{Array{T}}) where T<:Real +- (Array::Vector{Vector{T}}) where T<:Real Throws an error if the array is not a valid probability distribution: - The values must be non-negative. - The sum of the values must be approximately 1. """ -function check_probability_distribution(Array::Vector{Array{T}}) where T<:Real +function check_probability_distribution(Array::Vector{Vector{T}}) where T<:Real for vector in Array # Check for non-negativity if any(vector .< 0) @@ -145,7 +145,7 @@ function check_probability_distribution(Array::Vector{Array{T}}) where T<:Real end # Check for normalization - if !all(isapprox.(sum(vector), 1.0, rtol=1e-5, atol=1e-8)) + if !all(isapprox.(sum(vector, dims=1), 1.0, rtol=1e-5, atol=1e-8)) throw(ArgumentError("The array is not normalized.")) end end From 1e92872bd53dd839826603c847347d76f7cd876a Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Sat, 31 Aug 2024 16:09:42 +0200 Subject: [PATCH 02/25] type correction --- src/utils/utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 8f14702..11ce88a 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -100,17 +100,17 @@ function get_log_action_marginals(aif) end """ -Check if the array of arrays is a proper probability distribution. +Check if the vector of arrays is a proper probability distribution. # Arguments -- (Array::Vector{<:Array{T, N}}) where {T<:Real, N} +- (Array::Vector{<:Array{T}}) where T<:Real Throws an error if the array is not a valid probability distribution: - The values must be non-negative. - The sum of the values must be approximately 1. """ -function check_probability_distribution(Array::Vector{<:Array{T, N}}) where {T<:Real, N} +function check_probability_distribution(Array::Vector{<:Array{T}}) where T<:Real for tensor in Array # Check for non-negativity if any(tensor .< 0) From 911ceffefdcb6c2732f9a0db132b6225611bd687 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 2 Sep 2024 16:16:44 +0200 Subject: [PATCH 03/25] Fix verbose function --- src/pomdp/struct.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index 3db58e0..44e171c 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -281,8 +281,12 @@ function init_aif(A, B; C=nothing, D=nothing, E=nothing, pA=nothing, pB=nothing, end # Check if parameters are provided or use defaults - if verbose == true && isnothing(parameters) - @warn "No parameters provided, default parameters will be used." + if isnothing(parameters) + + if verbose == true + @warn "No parameters provided, default parameters will be used." + end + parameters = Dict("gamma" => 16.0, "alpha" => 16.0, "lr_pA" => 1.0, From 5754037db76b006167bbfe0433859635292bead9 Mon Sep 17 00:00:00 2001 From: Jonathan7773 Date: Thu, 5 Sep 2024 11:04:24 +0200 Subject: [PATCH 04/25] Fix to Verbose setttings part --- src/pomdp/struct.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index 44e171c..a9f0460 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -263,8 +263,12 @@ function init_aif(A, B; C=nothing, D=nothing, E=nothing, pA=nothing, pB=nothing, end # Check if settings are provided or use defaults - if verbose == true && isnothing(settings) - @warn "No settings provided, default settings will be used." + if isnothing(settings) + + if verbose == true + @warn "No settings provided, default settings will be used." + end + settings = Dict( "policy_len" => 1, "num_controls" => nothing, From 8fff040477af8ea3fe5e9e6f5331c942851bc091 Mon Sep 17 00:00:00 2001 From: Jonathan7773 Date: Thu, 5 Sep 2024 12:55:38 +0200 Subject: [PATCH 05/25] Redundant function array_of_any() fully removed --- src/ActiveInference.jl | 1 - src/Environments/TMazeEnv.jl | 4 ++-- src/pomdp/POMDP.jl | 4 ++-- src/pomdp/inference.jl | 6 +++--- src/utils/utils.jl | 7 +------ 5 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/ActiveInference.jl b/src/ActiveInference.jl index c6b843a..a87e2e2 100644 --- a/src/ActiveInference.jl +++ b/src/ActiveInference.jl @@ -32,7 +32,6 @@ export # utils/create_matrix_templates.jl normalize_arrays, # utils/utils.jl - array_of_any, array_of_any_zeros, array_of_any_uniform, onehot, diff --git a/src/Environments/TMazeEnv.jl b/src/Environments/TMazeEnv.jl index 7b97bed..bb63fbd 100644 --- a/src/Environments/TMazeEnv.jl +++ b/src/Environments/TMazeEnv.jl @@ -85,7 +85,7 @@ function reset_TMaze!(env::TMazeEnv; state=nothing) env.reward_condition = onehot(env._reward_condition_idx, env.num_reward_conditions) # Initialize the full state array - full_state = array_of_any(env.num_factors) + full_state = Vector{Any}(undef, env.num_factors) full_state[env.location_factor_id] = loc_state full_state[env.trial_factor_id] = env.reward_condition @@ -182,7 +182,7 @@ end function construct_state(env::TMazeEnv, state_tuple) # Create an array of any - state = array_of_any(env.num_factors) + state = Vector{Any}(undef, env.num_factors) # Populate the state array with one-hot encoded vectors for (f, ns) in enumerate(env.num_states) diff --git a/src/pomdp/POMDP.jl b/src/pomdp/POMDP.jl index 07cab84..39e36ff 100644 --- a/src/pomdp/POMDP.jl +++ b/src/pomdp/POMDP.jl @@ -14,7 +14,7 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) n_factors = length(aif.settings["num_controls"]) # Initialize empty arrays for action distribution per factor - action_p = array_of_any(n_factors) + action_p = Vector{Any}(undef, n_factors) action_distribution = Vector(undef, n_factors) #If there was a previous action @@ -59,7 +59,7 @@ function action_pomdp!(aif::AIF, obs::Vector{Int64}) n_factors = length(aif.settings["num_controls"]) # Initialize an empty arrays for action distribution per factor - action_p = array_of_any(n_factors) + action_p = Vector{Any}(undef, n_factors) action_distribution = Vector(undef, n_factors) ### Infer states & policies diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 6de9329..7ad9c74 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -248,7 +248,7 @@ function get_expected_obs(qs_pi, A::Vector{Array{<:Real}}) qo_pi = [] for t in 1:n_steps - qo_pi_t = array_of_any(length(A)) + qo_pi_t = Vector{Any}(undef, length(A)) qo_pi = push!(qo_pi, qo_pi_t) end @@ -305,7 +305,7 @@ function calc_pA_info_gain(pA, qo_pi, qs_pi) n_steps = length(qo_pi) num_modalities = length(pA) - wA = array_of_any(num_modalities) + wA = Vector{Any}(undef, num_modalities) for (modality, pA_m) in enumerate(pA) wA[modality] = spm_wnorm(pA[modality]) end @@ -327,7 +327,7 @@ function calc_pB_info_gain(pB, qs_pi, qs_prev, policy) n_steps = length(qs_pi) num_factors = length(pB) - wB = array_of_any(num_factors) + wB = Vector{Any}(undef, num_factors) for (factor, pB_f) in enumerate(pB) wB[factor] = spm_wnorm(pB_f) end diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 11ce88a..5c4d174 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,10 +1,5 @@ """ -------- Utility Functions -------- """ -""" Creates an array of "Any" with the desired number of sub-arrays""" -function array_of_any(num_arr::Int) - return Array{Any}(undef, num_arr) #saves it as {Any} e.g. can be any kind of data type. -end - """ Creates an array of "Any" with the desired number of sub-arrays filled with zeros""" function array_of_any_zeros(shape_list) arr = Array{Any}(undef, length(shape_list)) @@ -79,7 +74,7 @@ end function get_log_action_marginals(aif) num_factors = length(aif.num_controls) action_marginals = create_matrix_templates(aif.num_controls, "zeros") - log_action_marginals = array_of_any(num_factors) + log_action_marginals = Vector{Any}(undef, num_factors) q_pi = get_states(aif, "posterior_policies") policies = get_states(aif, "policies") From 596bf9c8b6b7a332c9dcb699a61b753ececf0292 Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Thu, 19 Sep 2024 17:15:11 +0200 Subject: [PATCH 06/25] Adjustments for TrackedReal types --- src/ActionModelsExtensions/set_parameters.jl | 2 +- src/pomdp/POMDP.jl | 54 +++++++++++++++++++- src/pomdp/inference.jl | 6 ++- src/utils/create_matrix_templates.jl | 8 +-- src/utils/maths.jl | 16 +++--- src/utils/utils.jl | 9 +++- 6 files changed, 78 insertions(+), 17 deletions(-) diff --git a/src/ActionModelsExtensions/set_parameters.jl b/src/ActionModelsExtensions/set_parameters.jl index cff5ef6..2a02310 100644 --- a/src/ActionModelsExtensions/set_parameters.jl +++ b/src/ActionModelsExtensions/set_parameters.jl @@ -12,7 +12,7 @@ Set multiple parameters in the AIF agent using ActionModels # Setting a single parameter -function ActionModels.set_parameters!(aif::AIF, target_param::String, param_value::Any) +function ActionModels.set_parameters!(aif::AIF, target_param::String, param_value::Real) # Update the parameters dictionary aif.parameters[target_param] = param_value diff --git a/src/pomdp/POMDP.jl b/src/pomdp/POMDP.jl index 39e36ff..59ec1ea 100644 --- a/src/pomdp/POMDP.jl +++ b/src/pomdp/POMDP.jl @@ -10,7 +10,59 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) aif = agent.substruct ### Get parameters - alpha = agent.parameters["alpha"] + alpha = agent.substruct.parameters["alpha"] + n_factors = length(aif.settings["num_controls"]) + + # Initialize empty arrays for action distribution per factor + action_p = Vector{Any}(undef, n_factors) + action_distribution = Vector(undef, n_factors) + + #If there was a previous action + if !ismissing(agent.states["action"]) + + #Extract it + previous_action = agent.states["action"] + + #If it is not a vector, make it one + if !(previous_action isa Vector) + previous_action = [previous_action] + end + + #Store the action in the AIF substruct + agent.substruct.action = previous_action + end + + ### Infer states & policies + + # Run state inference + infer_states!(aif, obs) + + # Run policy inference + infer_policies!(aif) + + ### Retrieve log marginal probabilities of actions + log_action_marginals = get_log_action_marginals(aif) + + ### Pass action marginals through softmax function to get action probabilities + for factor in 1:n_factors + action_p[factor] = softmax(log_action_marginals[factor] * alpha, dims=1) + action_distribution[factor] = Distributions.Categorical(action_p[factor]) + end + + return n_factors == 1 ? action_distribution[1] : action_distribution +end + +### Action Model where the observation is a tuple + +function action_pomdp!(agent::Agent, obs::Tuple{Vararg{Int}}) + + aif = agent.substruct + + # convert observation to vector + obs = collect(obs) + + ### Get parameters + alpha = agent.substruct.parameters["alpha"] n_factors = length(aif.settings["num_controls"]) # Initialize empty arrays for action distribution per factor diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 7ad9c74..b8c17f3 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -355,9 +355,13 @@ end """ Sample Action [Stochastic or Deterministic] """ function sample_action(q_pi, policies, num_controls; action_selection="stochastic", alpha=16.0) num_factors = length(num_controls) - action_marginals = create_matrix_templates(num_controls, "zeros") selected_policy = zeros(Real,num_factors) + eltype_q_pi = eltype(q_pi) + + # Initialize action_marginals with the correct element type + action_marginals = create_matrix_templates(num_controls, "zeros", eltype_q_pi) + for (pol_idx, policy) in enumerate(policies) for (factor_i, action_i) in enumerate(policy[1,:]) action_marginals[factor_i][action_i] += q_pi[pol_idx] diff --git a/src/utils/create_matrix_templates.jl b/src/utils/create_matrix_templates.jl index 5c0282b..256a189 100644 --- a/src/utils/create_matrix_templates.jl +++ b/src/utils/create_matrix_templates.jl @@ -119,19 +119,19 @@ Creates templates based on the specified shapes vector and template type. Templa """ -function create_matrix_templates(shapes::Vector{Int64}, template_type::String) +function create_matrix_templates(shapes::Vector{Int64}, template_type::String, eltype::Type=Float64) if template_type == "uniform" # Create arrays filled with ones and then normalize - return [normalize_distribution(ones(n)) for n in shapes] + return [normalize_distribution(ones(eltype, n)) for n in shapes] elseif template_type == "random" # Create arrays filled with random values - return [normalize_distribution(rand(n)) for n in shapes] + return [normalize_distribution(rand(eltype, n)) for n in shapes] elseif template_type == "zeros" # Create arrays filled with zeros - return [zeros(n) for n in shapes] + return [zeros(eltype, n) for n in shapes] else # Throw error for invalid template type diff --git a/src/utils/maths.jl b/src/utils/maths.jl index 1ef1e59..5a76fe8 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -25,10 +25,10 @@ end Return the natural logarithm of x, capped at the machine epsilon value of x. """ function capped_log(array::AbstractArray{T}) where T <: Real - # convert Reals to Floats - x_float = float.(array) - # return the capped log of each element in x - return log.(max.(x_float, eps(one(eltype(x_float))))) + epsilon = oftype(array[1], 1e-16) + + # Return the log of the array values capped at epsilon + return log.(max.(array, epsilon)) end ### This method will be deprecated once all types in the package have been made more strict. @@ -41,10 +41,10 @@ end Return the natural logarithm of x, capped at the machine epsilon value of x. """ function capped_log(array::Array{Any}) - # convert Reals to Floats - x_float = float.(array) - # return the capped log of each element in x - return log.(max.(x_float, eps(one(eltype(x_float))))) + epsilon = oftype(array[1], 1e-16) + + # Return the log of the array values capped at epsilon + return log.(max.(array, epsilon)) end """ Apply capped_log to array of arrays """ diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 5c4d174..9d56091 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -73,11 +73,16 @@ end """ Function to get log marginal probabilities of actions """ function get_log_action_marginals(aif) num_factors = length(aif.num_controls) - action_marginals = create_matrix_templates(aif.num_controls, "zeros") - log_action_marginals = Vector{Any}(undef, num_factors) q_pi = get_states(aif, "posterior_policies") policies = get_states(aif, "policies") + # Determine the element type from q_pi + eltype_q_pi = eltype(q_pi) + + # Initialize action_marginals with the correct element type + action_marginals = create_matrix_templates(aif.num_controls, "zeros", eltype_q_pi) + log_action_marginals = Vector{Any}(undef, num_factors) + for (pol_idx, policy) in enumerate(policies) for (factor_i, action_i) in enumerate(policy[1,:]) action_marginals[factor_i][action_i] += q_pi[pol_idx] From bc30f38bc67c19c517d6734d2fea4aeee602ead4 Mon Sep 17 00:00:00 2001 From: Jonathan7773 Date: Sat, 21 Sep 2024 10:43:53 +0200 Subject: [PATCH 07/25] Added modified get_expected_states function, and BMA function --- src/pomdp/inference.jl | 44 +++++++++++++++++++++++++++++++++++++++++- src/utils/maths.jl | 42 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 7ad9c74..836db8b 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -3,7 +3,7 @@ #### State Inference #### """ Get Expected States """ -function get_expected_states(qs, B, policy) +function get_expected_states(qs, B, policy::Matrix{Int64}) n_steps, n_factors = size(policy) # initializing posterior predictive density as a list of beliefs over time @@ -21,6 +21,48 @@ function get_expected_states(qs, B, policy) return qs_pi[2:end] end +""" + Multiple dispatch for all expected states given all policies + +Multiple dispatch for getting expected states for all policies based on the agents currently +inferred states and the transition matrices for each factor and action in the policy. + +qs: Vector{Any} \n +B: Vector{Array{<:Real}} \n +policy: Vector{Matrix{Int64}} + +""" +function get_expected_states(qs, B, policy::Vector{Matrix{Int64}}) + + # Extracting the number of steps (policy_length) and factors from the first policy + n_steps, n_factors = size(policy[1]) + + # Number of policies + n_policies = length(policy) + + # Preparing vessel for the expected states for all policies. Has number of undefined entries equal to the + # number of policies + qs_pi_all = Vector{Any}(undef, n_policies) + + # Looping through all policies + for (policy_idx, policy_x) in enumerate(policy) + + # initializing posterior predictive density as a list of beliefs over time + qs_pi = [deepcopy(qs) for _ in 1:n_steps+1] + + # expected states over time + for t in 1:n_steps + for control_factor in 1:n_factors + action = policy_x[t, control_factor] + + qs_pi[t+1][control_factor] = B[control_factor][:, :, action] * qs_pi[t][control_factor] + end + end + qs_pi_all[policy_idx] = qs_pi[2:end] + end + return qs_pi_all +end + """ process_observation(observation::Int, n_modalities::Int, n_observations::Vector{Int}) diff --git a/src/utils/maths.jl b/src/utils/maths.jl index 1ef1e59..5d2083a 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -206,4 +206,46 @@ function spm_wnorm(A) wA = norm .- avg return wA +end + +""" + Calculate Bayesian Model Average (BMA) + +Calculates the Bayesian Model Average (BMA) which is used for the State Action Prediction Error (SAPE). +It is a weighted average of the expected states for all policies weighted by the posterior over policies. +The `qs_pi_all` should be the collection of expected states given all policies. Can be retrieved with the +`get_expected_states` function. + +`qs_pi_all`: Vector{Any} \n +`q_pi`: Vector{Float64} + +""" +function bayesian_model_average(qs_pi_all, q_pi) + + # Extracting the number of factors, states, and timesteps (policy length) from the first policy + n_factors = length(qs_pi_all[1][1]) + n_states = [size(qs_f, 1) for qs_f in qs_pi_all[1][1]] + n_steps = length(qs_pi_all[1]) + + # Preparing vessel for the expected states for all policies. Has number of undefined entries equal to the number of + # n_steps with each entry having the entries equal to the number of factors + qs_bma = [Vector{Vector{Real}}(undef, n_factors) for _ in 1:n_steps] + + # Populating the entries with zeros for each state in each factor for each timestep in policy + for i in 1:n_steps + for f in 1:n_factors + qs_bma[i][f] = zeros(Real, n_states[f]) + end + end + + # Populating the entries with the expected states for all policies weighted by the posterior over policies + for i in 1:n_steps + for (pol_idx, policy_weight) in enumerate(q_pi) + for f in 1:n_factors + qs_bma[i][f] .+= policy_weight .* qs_pi_all[pol_idx][i][f] + end + end + end + + return qs_bma end \ No newline at end of file From b89f4927a30ea01cb6dad075485542d2349464c4 Mon Sep 17 00:00:00 2001 From: Jonathan7773 Date: Mon, 23 Sep 2024 12:59:13 +0200 Subject: [PATCH 08/25] Added SAPE calculation and supporting functions --- src/pomdp/inference.jl | 16 +++++++++++++++- src/pomdp/struct.jl | 4 +++- src/utils/maths.jl | 16 +++++++++++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 836db8b..8785394 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -435,4 +435,18 @@ function compute_accuracy_new(log_likelihood, qs) end return results -end \ No newline at end of file +end + +""" Calculate SAPE """ +function calc_SAPE(aif::AIF) + + qs_pi_all = get_expected_states(aif.qs_current, aif.B, aif.policies) + qs_bma = bayesian_model_average(qs_pi_all, aif.states["posterior_policies"][1]) + + if length(aif.states["bayesian_model_averages"]) != 0 + sape = kl_div(qs_bma, aif.states["bayesian_model_averages"][end]) + push!(aif.states["SAPE"], sape) + end + + push!(aif.states["bayesian_model_averages"], qs_bma) +end diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index a9f0460..fa7f500 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -113,7 +113,9 @@ function create_aif(A, B; "prior" => Vector{Any}[], "posterior_policies" => Vector{Any}[], "expected_free_energies" => Vector{Any}[], - "policies" => policies + "policies" => policies, + "bayesian_model_averages" => Vector{Vector{<:Real}}[], + "SAPE" => Vector{<:Real}[] ) # initialize parameters dictionary diff --git a/src/utils/maths.jl b/src/utils/maths.jl index 5d2083a..2413058 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -248,4 +248,18 @@ function bayesian_model_average(qs_pi_all, q_pi) end return qs_bma -end \ No newline at end of file +end + +function kl_div(P::Vector{Vector{Vector{Real}}}, Q::Vector{Vector{Vector{Real}}}) + eps_val=1e-16 + dkl = 0.0 + for j in 1:length(P) + for i in 1:length(P[j]) + dkl += dot(P[j][i], log.(P[j][i] .+ eps_val) .- log.(Q[j][i] .+ eps_val)) + end + end + return dkl +end + + + From 507e9673213c168e5c8ca27fb9e089ae1c68cb8b Mon Sep 17 00:00:00 2001 From: Jonathan7773 Date: Tue, 24 Sep 2024 14:59:41 +0200 Subject: [PATCH 09/25] Fixed calc_SAPE --- src/pomdp/inference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index ff19af2..5c57659 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -445,7 +445,7 @@ end function calc_SAPE(aif::AIF) qs_pi_all = get_expected_states(aif.qs_current, aif.B, aif.policies) - qs_bma = bayesian_model_average(qs_pi_all, aif.states["posterior_policies"][1]) + qs_bma = bayesian_model_average(qs_pi_all, aif.Q_pi) if length(aif.states["bayesian_model_averages"]) != 0 sape = kl_div(qs_bma, aif.states["bayesian_model_averages"][end]) From d81feac712d152746ec3ff1087124d4a114dea0c Mon Sep 17 00:00:00 2001 From: samuelnehrer02 Date: Tue, 24 Sep 2024 16:43:09 +0200 Subject: [PATCH 10/25] Update ActionModels Version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ac3eef5..8f06a1b 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -ActionModels = "0.5" +ActionModels = "0.6" Distributions = "0.25" IterTools = "1.10" LinearAlgebra = "1" From cc4f06ea5ab2c27be307224e6bf887542bbba395 Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Mon, 7 Oct 2024 22:04:10 +0200 Subject: [PATCH 11/25] Change types to Reals --- src/pomdp/inference.jl | 12 ++++++------ src/pomdp/struct.jl | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index b8c17f3..e538725 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -78,7 +78,7 @@ function process_observation(observation::Union{Array{Int}, Tuple{Vararg{Int}}}, end """ Update Posterior States """ -function update_posterior_states(A::Vector{Array{<:Real}}, obs::Vector{Int64}; prior::Union{Nothing, Vector{Any}}=nothing, num_iter::Int=num_iter, dF_tol::Float64=dF_tol, kwargs...) +function update_posterior_states(A::Vector{Array{Real}}, obs::Vector{Int64}; prior::Union{Nothing, Vector{Any}}=nothing, num_iter::Int=num_iter, dF_tol::Float64=dF_tol, kwargs...) num_obs, num_states, num_modalities, num_factors = get_model_dimensions(A) obs_processed = process_observation(obs, num_modalities, num_obs) @@ -87,7 +87,7 @@ end """ Run State Inference via Fixed-Point Iteration """ -function fixed_point_iteration(A::Vector{Array{<:Real}}, obs::Vector{Vector{Real}}, num_obs::Vector{Int64}, num_states::Vector{Int64}; prior::Union{Nothing, Vector{Any}}=nothing, num_iter::Int=num_iter, dF::Float64=1.0, dF_tol::Float64=dF_tol) +function fixed_point_iteration(A::Vector{Array{Real}}, obs::Vector{Vector{Real}}, num_obs::Vector{Int64}, num_states::Vector{Int64}; prior::Union{Nothing, Vector{Any}}=nothing, num_iter::Int=num_iter, dF::Float64=1.0, dF_tol::Float64=dF_tol) n_modalities = length(num_obs) n_factors = length(num_states) @@ -191,9 +191,9 @@ end """ Update Posterior over Policies """ function update_posterior_policies( qs::Vector{Any}, - A::Vector{Array{<:Real}}, - B::Vector{Array{<:Real}}, - C::Vector{Array{<:Real}}, + A::Vector{Array{Real}}, + B::Vector{Array{Real}}, + C::Vector{Array{Real}}, policies::Vector{Matrix{Int64}}, use_utility::Bool=true, use_states_info_gain::Bool=true, @@ -243,7 +243,7 @@ function update_posterior_policies( end """ Get Expected Observations """ -function get_expected_obs(qs_pi, A::Vector{Array{<:Real}}) +function get_expected_obs(qs_pi, A::Vector{Array{Real}}) n_steps = length(qs_pi) qo_pi = [] diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index a9f0460..26003fd 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -1,10 +1,10 @@ """ -------- AIF Mutable Struct -------- """ mutable struct AIF - A::Vector{Array{<:Real}} # A-matrix - B::Vector{Array{<:Real}} # B-matrix - C::Vector{Array{<:Real}} # C-vectors - D::Vector{Array{<:Real}} # D-vectors + A::Vector{Array{Real}} # A-matrix + B::Vector{Array{Real}} # B-matrix + C::Vector{Array{Real}} # C-vectors + D::Vector{Array{Real}} # D-vectors E::Union{Vector{<:Real}, Nothing} # E-vector (Habits) pA::Union{Vector{Array{<:Real}}, Nothing} # Dirichlet priors for A-matrix pB::Union{Vector{Array{<:Real}}, Nothing} # Dirichlet priors for B-matrix From 76e632d3010e93bf7afc60484a8c5ae2afeb9683 Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Mon, 7 Oct 2024 22:47:33 +0200 Subject: [PATCH 12/25] Further type changes to Real only --- src/pomdp/inference.jl | 2 +- src/pomdp/struct.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 3effb7d..e4f781e 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -242,7 +242,7 @@ function update_posterior_policies( use_param_info_gain::Bool=false, pA = nothing, pB = nothing, - E = nothing, + E::Union{Vector{Real}, Nothing} = nothing, gamma::Real=16.0 ) n_policies = length(policies) diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index 6462578..1ff2937 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -5,10 +5,10 @@ mutable struct AIF B::Vector{Array{Real}} # B-matrix C::Vector{Array{Real}} # C-vectors D::Vector{Array{Real}} # D-vectors - E::Union{Vector{<:Real}, Nothing} # E-vector (Habits) - pA::Union{Vector{Array{<:Real}}, Nothing} # Dirichlet priors for A-matrix - pB::Union{Vector{Array{<:Real}}, Nothing} # Dirichlet priors for B-matrix - pD::Union{Vector{Array{<:Real}}, Nothing} # Dirichlet priors for D-vector + E::Union{Vector{Real}, Nothing} # E-vector (Habits) + pA::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for A-matrix + pB::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for B-matrix + pD::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for D-vector lr_pA::Real # pA Learning Parameter fr_pA::Real # pA Forgetting Parameter, 1.0 for no forgetting lr_pB::Real # pB learning Parameter From bd62e696ae5e41d38e21a711bccebae8f1e6a3af Mon Sep 17 00:00:00 2001 From: samuelnehrer02 Date: Tue, 8 Oct 2024 13:20:23 +0200 Subject: [PATCH 13/25] Type Changes --- Project.toml | 2 +- src/pomdp/inference.jl | 22 ++++++++++++---------- src/pomdp/learning.jl | 4 ++-- src/pomdp/struct.jl | 11 ++++++----- src/utils/maths.jl | 11 ++++++++++- 5 files changed, 31 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index 8f06a1b..ac3eef5 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -ActionModels = "0.6" +ActionModels = "0.5" Distributions = "0.25" IterTools = "1.10" LinearAlgebra = "1" diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index e4f781e..5407f73 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -3,7 +3,7 @@ #### State Inference #### """ Get Expected States """ -function get_expected_states(qs, B, policy::Matrix{Int64}) +function get_expected_states(qs::Vector{Vector{Real}}, B, policy::Matrix{Int64}) n_steps, n_factors = size(policy) # initializing posterior predictive density as a list of beliefs over time @@ -27,12 +27,12 @@ end Multiple dispatch for getting expected states for all policies based on the agents currently inferred states and the transition matrices for each factor and action in the policy. -qs: Vector{Any} \n +qs::Vector{Vector{Real}} \n B: Vector{Array{<:Real}} \n policy: Vector{Matrix{Int64}} """ -function get_expected_states(qs, B, policy::Vector{Matrix{Int64}}) +function get_expected_states(qs::Vector{Vector{Real}}, B, policy::Vector{Matrix{Int64}}) # Extracting the number of steps (policy_length) and factors from the first policy n_steps, n_factors = size(policy[1]) @@ -191,7 +191,7 @@ end """ Calculate Accuracy Term """ -function compute_accuracy(log_likelihood, qs) +function compute_accuracy(log_likelihood, qs::Vector{Vector{Real}}) n_factors = length(qs) ndims_ll = ndims(log_likelihood) dims = (ndims_ll - n_factors + 1) : ndims_ll @@ -207,7 +207,7 @@ end """ Calculate Free Energy """ -function calc_free_energy(qs, prior, n_factors, likelihood=nothing) +function calc_free_energy(qs::Vector{Vector{Real}}, prior, n_factors, likelihood=nothing) # Initialize free energy free_energy = 0.0 @@ -232,7 +232,7 @@ end #### Policy Inference #### """ Update Posterior over Policies """ function update_posterior_policies( - qs::Vector{Any}, + qs::Vector{Vector{Real}}, A::Vector{Array{Real}}, B::Vector{Array{Real}}, C::Vector{Array{Real}}, @@ -247,10 +247,10 @@ function update_posterior_policies( ) n_policies = length(policies) G = zeros(Real,n_policies) - q_pi = zeros(Real,n_policies, 1) + q_pi = Vector{Real}(undef, n_policies) qs_pi = Vector{Real}[] qo_pi = Vector{Real}[] - + if isnothing(E) lnE = capped_log(ones(Real, n_policies) / n_policies) else @@ -280,7 +280,9 @@ function update_posterior_policies( end - q_pi = softmax(G * gamma + lnE, dims=1) + + q_pi .= softmax(G * gamma + lnE, dims=1) + return q_pi, G end @@ -425,7 +427,7 @@ function sample_action(q_pi, policies, num_controls; action_selection="stochasti end """ Edited Compute Accuracy [Still needs to be nested within Fixed-Point Iteration] """ -function compute_accuracy_new(log_likelihood, qs) +function compute_accuracy_new(log_likelihood, qs::Vector{Vector{Real}}) n_factors = length(qs) ndims_ll = ndims(log_likelihood) dims = (ndims_ll - n_factors + 1) : ndims_ll diff --git a/src/pomdp/learning.jl b/src/pomdp/learning.jl index e58e1e2..7509fe6 100644 --- a/src/pomdp/learning.jl +++ b/src/pomdp/learning.jl @@ -26,7 +26,7 @@ function update_obs_likelihood_dirichlet(pA, A, obs, qs; lr = 1.0, fr = 1.0, mod end """ Update state likelihood matrix """ -function update_state_likelihood_dirichlet(pB, B, actions, qs, qs_prev; lr = 1.0, fr = 1.0, factors = "all") +function update_state_likelihood_dirichlet(pB, B, actions, qs::Vector{Vector{Real}}, qs_prev; lr = 1.0, fr = 1.0, factors = "all") num_factors = length(pB) @@ -47,7 +47,7 @@ function update_state_likelihood_dirichlet(pB, B, actions, qs, qs_prev; lr = 1.0 end """ Update prior D matrix """ -function update_state_prior_dirichlet(pD, qs; lr = 1.0, fr = 1.0, factors = "all") +function update_state_prior_dirichlet(pD, qs::Vector{Vector{Real}}; lr = 1.0, fr = 1.0, factors = "all") num_factors = length(pD) diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index 1ff2937..aaef515 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -23,9 +23,9 @@ mutable struct AIF num_controls::Array{Int,1} # Number of actions per factor control_fac_idx::Array{Int,1} # Indices of controllable factors policy_len::Int # Policy length - qs_current::Array{Any,1} # Current beliefs about states + qs_current::Vector{Vector{Real}} # Current beliefs about states prior::Array{Any,1} # Prior beliefs about states - Q_pi::Array{Real,1} # Posterior beliefs over policies + Q_pi::Vector{Real} # Posterior beliefs over policies G::Array{Real,1} # Expected free energy of policy action::Vector{Any} # Last action use_utility::Bool # Utility Boolean Flag @@ -457,9 +457,10 @@ function infer_states!(aif::AIF, obs::Vector{Int64}) aif.qs_current = update_posterior_states(aif.A, obs, prior=aif.prior, num_iter=aif.FPI_num_iter, dF_tol=aif.FPI_dF_tol) # Push changes to agent's history - push!(aif.states["prior"], copy(aif.prior)) - push!(aif.states["posterior_states"], copy(aif.qs_current)) + push!(aif.states["prior"], aif.prior) + push!(aif.states["posterior_states"], aif.qs_current) + return aif.qs_current end """ Update the agents's beliefs over policies """ @@ -474,7 +475,7 @@ function infer_policies!(aif::AIF) push!(aif.states["posterior_policies"], copy(aif.Q_pi)) push!(aif.states["expected_free_energies"], copy(aif.G)) - return q_pi, G + return q_pi end """ Sample action from the beliefs over policies """ diff --git a/src/utils/maths.jl b/src/utils/maths.jl index 4181bf7..62dd8f2 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -26,11 +26,20 @@ Return the natural logarithm of x, capped at the machine epsilon value of x. """ function capped_log(array::AbstractArray{T}) where T <: Real epsilon = oftype(array[1], 1e-16) - + # Return the log of the array values capped at epsilon return log.(max.(array, epsilon)) end +function capped_log(array::Vector{Real}) + + epsilon = 1e-16 + # Return the log of the array values capped at epsilon + array .= log.(max.(array, epsilon)) + + return array +end + ### This method will be deprecated once all types in the package have been made more strict. """ capped_log(array::Array{Any}) From a45246f0bc5b2fea00082f6bfe9ec7f889093cc8 Mon Sep 17 00:00:00 2001 From: samuelnehrer02 Date: Tue, 8 Oct 2024 14:19:12 +0200 Subject: [PATCH 14/25] prior type change --- src/pomdp/inference.jl | 8 +++++--- src/pomdp/struct.jl | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 5407f73..2028382 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -120,7 +120,7 @@ function process_observation(observation::Union{Array{Int}, Tuple{Vararg{Int}}}, end """ Update Posterior States """ -function update_posterior_states(A::Vector{Array{Real}}, obs::Vector{Int64}; prior::Union{Nothing, Vector{Any}}=nothing, num_iter::Int=num_iter, dF_tol::Float64=dF_tol, kwargs...) +function update_posterior_states(A::Vector{Array{Real}}, obs::Vector{Int64}; prior::Union{Nothing, Vector{Vector{Real}}}=nothing, num_iter::Int=num_iter, dF_tol::Float64=dF_tol, kwargs...) num_obs, num_states, num_modalities, num_factors = get_model_dimensions(A) obs_processed = process_observation(obs, num_modalities, num_obs) @@ -129,7 +129,7 @@ end """ Run State Inference via Fixed-Point Iteration """ -function fixed_point_iteration(A::Vector{Array{Real}}, obs::Vector{Vector{Real}}, num_obs::Vector{Int64}, num_states::Vector{Int64}; prior::Union{Nothing, Vector{Any}}=nothing, num_iter::Int=num_iter, dF::Float64=1.0, dF_tol::Float64=dF_tol) +function fixed_point_iteration(A::Vector{Array{Real}}, obs::Vector{Vector{Real}}, num_obs::Vector{Int64}, num_states::Vector{Int64}; prior::Union{Nothing, Vector{Vector{Real}}}=nothing, num_iter::Int=num_iter, dF::Float64=1.0, dF_tol::Float64=dF_tol) n_modalities = length(num_obs) n_factors = length(num_states) @@ -147,6 +147,8 @@ function fixed_point_iteration(A::Vector{Array{Real}}, obs::Vector{Vector{Real}} prior = create_matrix_templates(num_states) end + # Create a copy of the prior to avoid modifying the original + prior = deepcopy(prior) prior = capped_log_array(prior) # Initialize free energy @@ -207,7 +209,7 @@ end """ Calculate Free Energy """ -function calc_free_energy(qs::Vector{Vector{Real}}, prior, n_factors, likelihood=nothing) +function calc_free_energy(qs::Vector{Vector{Real}}, prior::Vector{Vector{Real}}, n_factors, likelihood=nothing) # Initialize free energy free_energy = 0.0 diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index aaef515..7105e0a 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -24,7 +24,7 @@ mutable struct AIF control_fac_idx::Array{Int,1} # Indices of controllable factors policy_len::Int # Policy length qs_current::Vector{Vector{Real}} # Current beliefs about states - prior::Array{Any,1} # Prior beliefs about states + prior::Vector{Vector{Real}} # Prior beliefs about states Q_pi::Vector{Real} # Posterior beliefs over policies G::Array{Real,1} # Expected free energy of policy action::Vector{Any} # Last action @@ -460,7 +460,7 @@ function infer_states!(aif::AIF, obs::Vector{Int64}) push!(aif.states["prior"], aif.prior) push!(aif.states["posterior_states"], aif.qs_current) - return aif.qs_current + return nothing end """ Update the agents's beliefs over policies """ From 654ec5d4ea532d72a7445aadac6edbe145722c08 Mon Sep 17 00:00:00 2001 From: samuelnehrer02 Date: Tue, 8 Oct 2024 15:22:15 +0200 Subject: [PATCH 15/25] Further Typing Changes --- src/ActionModelsExtensions/give_inputs.jl | 2 +- src/ActionModelsExtensions/set_parameters.jl | 4 ++-- src/pomdp/inference.jl | 2 +- src/pomdp/struct.jl | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/ActionModelsExtensions/give_inputs.jl b/src/ActionModelsExtensions/give_inputs.jl index 2d3e58a..a91fda9 100644 --- a/src/ActionModelsExtensions/give_inputs.jl +++ b/src/ActionModelsExtensions/give_inputs.jl @@ -45,7 +45,7 @@ function ActionModels.single_input!(aif::AIF, obs::Vector) end # If the agent has not taken any actions yet if isempty(aif.action) - push!(aif.action, sampled_actions) + aif.action = sampled_actions else # Put the action in the last element of the action vector aif.action[end] = sampled_actions diff --git a/src/ActionModelsExtensions/set_parameters.jl b/src/ActionModelsExtensions/set_parameters.jl index 2a02310..567c3e3 100644 --- a/src/ActionModelsExtensions/set_parameters.jl +++ b/src/ActionModelsExtensions/set_parameters.jl @@ -1,10 +1,10 @@ """ This module extends the "set_parameters!" functionality of the ActionModels package to work with instances of the AIF type. - set_parameters!(aif::AIF, target_param::String, param_value::Any) + set_parameters!(aif::AIF, target_param::String, param_value::Real) Set a single parameter in the AIF agent - set_parameters!(aif::AIF, parameters::Dict{String, Any}) + set_parameters!(aif::AIF, parameters::Dict{String, Real}) Set multiple parameters in the AIF agent """ diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 2028382..dfb28f7 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -399,7 +399,7 @@ end ### Action Sampling ### """ Sample Action [Stochastic or Deterministic] """ -function sample_action(q_pi, policies, num_controls; action_selection="stochastic", alpha=16.0) +function sample_action(q_pi, policies::Vector{Matrix{Int64}}, num_controls; action_selection="stochastic", alpha=16.0) num_factors = length(num_controls) selected_policy = zeros(Real,num_factors) diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index 7105e0a..e5cc060 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -19,15 +19,15 @@ mutable struct AIF factors_to_learn::Union{String, Vector{Int64}} # Modalities can be either "all" or "# factor" gamma::Real # Gamma parameter alpha::Real # Alpha parameter - policies::Array # Inferred from the B matrix + policies::Vector{Matrix{Int64}} # Inferred from the B matrix num_controls::Array{Int,1} # Number of actions per factor control_fac_idx::Array{Int,1} # Indices of controllable factors policy_len::Int # Policy length qs_current::Vector{Vector{Real}} # Current beliefs about states prior::Vector{Vector{Real}} # Prior beliefs about states Q_pi::Vector{Real} # Posterior beliefs over policies - G::Array{Real,1} # Expected free energy of policy - action::Vector{Any} # Last action + G::Vector{Real} # Expected free energy of policy + action::Vector{Real} # Last action use_utility::Bool # Utility Boolean Flag use_states_info_gain::Bool # States Information Gain Boolean Flag use_param_info_gain::Bool # Include the novelty value in the learning parameters @@ -460,7 +460,7 @@ function infer_states!(aif::AIF, obs::Vector{Int64}) push!(aif.states["prior"], aif.prior) push!(aif.states["posterior_states"], aif.qs_current) - return nothing + return aif.qs_current end """ Update the agents's beliefs over policies """ From 941f4362838d7b4decd833117a12d40e03552d7a Mon Sep 17 00:00:00 2001 From: samuelnehrer02 Date: Wed, 9 Oct 2024 08:52:39 +0200 Subject: [PATCH 16/25] Type Debugging --- src/utils/maths.jl | 39 ++++++++++----------------------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/src/utils/maths.jl b/src/utils/maths.jl index 62dd8f2..ebb82cf 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -24,11 +24,13 @@ end Return the natural logarithm of x, capped at the machine epsilon value of x. """ -function capped_log(array::AbstractArray{T}) where T <: Real - epsilon = oftype(array[1], 1e-16) +function capped_log(array::Array{Real}) + epsilon = 1e-16 # Return the log of the array values capped at epsilon - return log.(max.(array, epsilon)) + array .= log.(max.(array, epsilon)) + + return array end function capped_log(array::Vector{Real}) @@ -37,23 +39,7 @@ function capped_log(array::Vector{Real}) # Return the log of the array values capped at epsilon array .= log.(max.(array, epsilon)) - return array -end - -### This method will be deprecated once all types in the package have been made more strict. -""" - capped_log(array::Array{Any}) - -# Arguments -- `array::Array{Any}`: An array of real numbers. - -Return the natural logarithm of x, capped at the machine epsilon value of x. -""" -function capped_log(array::Array{Any}) - epsilon = oftype(array[1], 1e-16) - - # Return the log of the array values capped at epsilon - return log.(max.(array, epsilon)) + return array end """ Apply capped_log to array of arrays """ @@ -88,16 +74,11 @@ function dot_likelihood(A, obs) end """ Softmax Function for array of arrays """ -function softmax_array(arr) - output = Array{Any}(undef, length(arr)) - - # Iterate through each index in arr and apply softmax - for idx in eachindex(arr) - output[idx] = softmax(arr[idx], dims=1) - - end +function softmax_array(array) + # Use map to apply softmax to each element of arr + array .= map(x -> softmax(x, dims=1), array) - return output + return array end From adb102cabb207fcb179a12943bedbbc0b8d59384 Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Wed, 9 Oct 2024 14:34:55 +0200 Subject: [PATCH 17/25] Changes to Reals and capped_log --- src/pomdp/inference.jl | 9 ++------- src/pomdp/struct.jl | 9 +++++++-- src/utils/maths.jl | 18 +++++++++++++----- src/utils/utils.jl | 2 +- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index dfb28f7..b397dc6 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -244,7 +244,7 @@ function update_posterior_policies( use_param_info_gain::Bool=false, pA = nothing, pB = nothing, - E::Union{Vector{Real}, Nothing} = nothing, + E::Vector{Real} = nothing, gamma::Real=16.0 ) n_policies = length(policies) @@ -252,12 +252,7 @@ function update_posterior_policies( q_pi = Vector{Real}(undef, n_policies) qs_pi = Vector{Real}[] qo_pi = Vector{Real}[] - - if isnothing(E) - lnE = capped_log(ones(Real, n_policies) / n_policies) - else - lnE = capped_log(E) - end + lnE = capped_log(E) for (idx, policy) in enumerate(policies) qs_pi = get_expected_states(qs, B, policy) diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index e5cc060..475611d 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -5,7 +5,7 @@ mutable struct AIF B::Vector{Array{Real}} # B-matrix C::Vector{Array{Real}} # C-vectors D::Vector{Array{Real}} # D-vectors - E::Union{Vector{Real}, Nothing} # E-vector (Habits) + E::Vector{Real} # E-vector (Habits) pA::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for A-matrix pB::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for B-matrix pD::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for D-vector @@ -95,8 +95,13 @@ function create_aif(A, B; policies = construct_policies(num_states, n_controls=num_controls, policy_length=policy_len, controllable_factors_indices=control_fac_idx) + # if E-vector is not provided + if isnothing(E) + E = ones(Real, length(policies)) / length(policies) + end + # Throw error if the E-vector does not match the length of policies - if !isnothing(E) && length(E) != length(policies) + if length(E) != length(policies) error("Length of E-vector must match the number of policies.") end diff --git a/src/utils/maths.jl b/src/utils/maths.jl index ebb82cf..7b9020b 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -1,6 +1,7 @@ """Normalizes a Categorical probability distribution""" function normalize_distribution(distribution) - return distribution ./ sum(distribution, dims=1) + distribution .= distribution ./ sum(distribution, dims=1) + return distribution end @@ -95,6 +96,9 @@ function outer_product(x, y=nothing; remove_singleton_dims=true, args...) end end + # To ensure the types of x and y match the input types. + T = promote_type(eltype(x), eltype(y)) + # If y is provided, perform the cross multiplication. if y !== nothing reshape_dims_x = tuple(size(x)..., ones(Real, ndims(y))...) @@ -104,6 +108,9 @@ function outer_product(x, y=nothing; remove_singleton_dims=true, args...) B = reshape(y, reshape_dims_y) z = A .* B + + # Type convert to the original type + z = convert(Array{T}, z) else z = x end @@ -113,7 +120,7 @@ function outer_product(x, y=nothing; remove_singleton_dims=true, args...) z = outer_product(z, arg; remove_singleton_dims=remove_singleton_dims) end - # remove singleton dimension if true-- + # Remove singleton dimensions if true if remove_singleton_dims z = dropdims(z, dims = tuple(findall(size(z) .== 1)...)) end @@ -156,7 +163,8 @@ end function calculate_bayesian_surprise(A, x) qx = outer_product(x) G = 0.0 - qo = Real[] + qo = Vector{Real}() + T = typeof(qo) idx = [collect(Tuple(indices)) for indices in findall(qx .> exp(-16))] index_vector = [] @@ -171,12 +179,12 @@ function calculate_bayesian_surprise(A, x) end po = vec(po) if isempty(qo) - resize!(qo, length(po)) - fill!(qo, 0.0) + qo = zeros(Real, length(po)) end qo += qx[i...] * po G += qx[i...] * dot(po, log.(po .+ exp(-16))) end + qo = convert(T, qo) G = G - dot(qo, capped_log(qo)) return G end diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 9d56091..27fd2df 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -51,7 +51,7 @@ end """ Selects the highest value from Array -- used for deterministic action sampling """ -function select_highest(options_array::Array{Float64}) +function select_highest(options_array::Vector{T}) where T <: Real options_with_idx = [(i, option) for (i, option) in enumerate(options_array)] max_value = maximum(value for (idx, value) in options_with_idx) same_prob = [idx for (idx, value) in options_with_idx if abs(value - max_value) <= 1e-8] From 88be257c46451f3af7d80a0157daef7979012c84 Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Thu, 10 Oct 2024 09:01:52 +0200 Subject: [PATCH 18/25] Debugging ReverseDiff --- src/pomdp/POMDP.jl | 15 ++++++--------- src/utils/maths.jl | 27 +++++++++++++++++++++------ 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/pomdp/POMDP.jl b/src/pomdp/POMDP.jl index 59ec1ea..c2c2c3d 100644 --- a/src/pomdp/POMDP.jl +++ b/src/pomdp/POMDP.jl @@ -7,15 +7,13 @@ This module contains models of Partially Observable Markov Decision Processes un function action_pomdp!(agent::Agent, obs::Vector{Int64}) - aif = agent.substruct - ### Get parameters alpha = agent.substruct.parameters["alpha"] - n_factors = length(aif.settings["num_controls"]) + n_factors = length(agent.substruct.settings["num_controls"]) # Initialize empty arrays for action distribution per factor - action_p = Vector{Any}(undef, n_factors) - action_distribution = Vector(undef, n_factors) + action_p = Vector{Vector{Real}}(undef, n_factors) + action_distribution = Vector{Distributions.Categorical}(undef, n_factors) #If there was a previous action if !ismissing(agent.states["action"]) @@ -35,14 +33,13 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) ### Infer states & policies # Run state inference - infer_states!(aif, obs) + infer_states!(agent.substruct, obs) # Run policy inference - infer_policies!(aif) + infer_policies!(agent.substruct) ### Retrieve log marginal probabilities of actions - log_action_marginals = get_log_action_marginals(aif) - + log_action_marginals = get_log_action_marginals(agent.substruct) ### Pass action marginals through softmax function to get action probabilities for factor in 1:n_factors action_p[factor] = softmax(log_action_marginals[factor] * alpha, dims=1) diff --git a/src/utils/maths.jl b/src/utils/maths.jl index 7b9020b..52a50b4 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -25,21 +25,31 @@ end Return the natural logarithm of x, capped at the machine epsilon value of x. """ -function capped_log(array::Array{Real}) - epsilon = 1e-16 +function capped_log(array::Array{Float64}) + + epsilon = oftype(array[1], 1e-16) # Return the log of the array values capped at epsilon - array .= log.(max.(array, epsilon)) + array = log.(max.(array, epsilon)) return array end -function capped_log(array::Vector{Real}) +function capped_log(array::Array{Real}) - epsilon = 1e-16 + epsilon = oftype(array[1], 1e-16) # Return the log of the array values capped at epsilon - array .= log.(max.(array, epsilon)) + array = log.(max.(array, epsilon)) + + return array +end +function capped_log(array::Vector{Real}) + epsilon = oftype(array[1], 1e-16) + + array = log.(max.(array, epsilon)) + # Return the log of the array values capped at epsilon + #@show typeof(array) return array end @@ -194,6 +204,11 @@ function normalize_arrays(array::Vector{<:Array{<:Real}}) return map(normalize_distribution, array) end +""" Normalizes multiple arrays """ +function normalize_arrays(array::Vector{Any}) + return map(normalize_distribution, array) +end + """ SPM_wnorm """ function spm_wnorm(A) EPS_VAL = 1e-16 From eea1e4d8d076df91b15d01acb47229617b285ca2 Mon Sep 17 00:00:00 2001 From: samuelnehrer02 Date: Thu, 10 Oct 2024 20:55:41 +0200 Subject: [PATCH 19/25] Further ReverseDiff Debugging --- Project.toml | 1 + src/pomdp/POMDP.jl | 5 +++-- src/pomdp/inference.jl | 2 +- src/utils/utils.jl | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index ac3eef5..0f9df16 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [compat] ActionModels = "0.5" diff --git a/src/pomdp/POMDP.jl b/src/pomdp/POMDP.jl index c2c2c3d..97a4d3f 100644 --- a/src/pomdp/POMDP.jl +++ b/src/pomdp/POMDP.jl @@ -12,7 +12,7 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) n_factors = length(agent.substruct.settings["num_controls"]) # Initialize empty arrays for action distribution per factor - action_p = Vector{Vector{Real}}(undef, n_factors) + action_p = Vector{Vector{Float64}}(undef, n_factors) action_distribution = Vector{Distributions.Categorical}(undef, n_factors) #If there was a previous action @@ -44,8 +44,9 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) for factor in 1:n_factors action_p[factor] = softmax(log_action_marginals[factor] * alpha, dims=1) action_distribution[factor] = Distributions.Categorical(action_p[factor]) + @show action_distribution[factor] end - + return n_factors == 1 ? action_distribution[1] : action_distribution end diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index b397dc6..2e37b37 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -209,7 +209,7 @@ end """ Calculate Free Energy """ -function calc_free_energy(qs::Vector{Vector{Real}}, prior::Vector{Vector{Real}}, n_factors, likelihood=nothing) +function calc_free_energy(qs::Vector{Vector{Real}}, prior, n_factors, likelihood=nothing) # Initialize free energy free_energy = 0.0 diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 27fd2df..55898d0 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -81,7 +81,7 @@ function get_log_action_marginals(aif) # Initialize action_marginals with the correct element type action_marginals = create_matrix_templates(aif.num_controls, "zeros", eltype_q_pi) - log_action_marginals = Vector{Any}(undef, num_factors) + log_action_marginals = Vector{Vector{Float64}}(undef, num_factors) for (pol_idx, policy) in enumerate(policies) for (factor_i, action_i) in enumerate(policy[1,:]) From 54d03d6aa5846024071789a9d1cc90e2e3735224 Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Mon, 14 Oct 2024 17:07:16 +0200 Subject: [PATCH 20/25] Learning Rate Fitting Experiments --- src/ActionModelsExtensions/reset.jl | 8 +++--- src/ActiveInference.jl | 1 + src/pomdp/POMDP.jl | 13 ++++++--- src/pomdp/inference.jl | 42 +++++++++++++++++------------ src/pomdp/learning.jl | 7 +++-- src/pomdp/struct.jl | 26 +++++++++--------- src/utils/maths.jl | 2 +- src/utils/utils.jl | 2 +- 8 files changed, 60 insertions(+), 41 deletions(-) diff --git a/src/ActionModelsExtensions/reset.jl b/src/ActionModelsExtensions/reset.jl index 30d6d44..8ece1d7 100644 --- a/src/ActionModelsExtensions/reset.jl +++ b/src/ActionModelsExtensions/reset.jl @@ -9,11 +9,11 @@ using ActionModels function ActionModels.reset!(aif::AIF) # Reset the agent's state fields to initial conditions - aif.qs_current = array_of_any_uniform([size(aif.B[f], 1) for f in eachindex(aif.B)]) + aif.qs_current = create_matrix_templates([size(aif.B[f], 1) for f in eachindex(aif.B)]) aif.prior = aif.D - aif.Q_pi = ones(Real,length(aif.policies)) / length(aif.policies) - aif.G = zeros(Real,length(aif.policies)) - aif.action = Real[] + aif.Q_pi = ones(length(aif.policies)) / length(aif.policies) + aif.G = zeros(length(aif.policies)) + aif.action = Int[] # Clear the history in the states dictionary for key in keys(aif.states) diff --git a/src/ActiveInference.jl b/src/ActiveInference.jl index a87e2e2..3026393 100644 --- a/src/ActiveInference.jl +++ b/src/ActiveInference.jl @@ -6,6 +6,7 @@ using IterTools using Random using Distributions using LogExpFunctions +using ReverseDiff include("utils/maths.jl") include("pomdp/struct.jl") diff --git a/src/pomdp/POMDP.jl b/src/pomdp/POMDP.jl index 97a4d3f..25d841e 100644 --- a/src/pomdp/POMDP.jl +++ b/src/pomdp/POMDP.jl @@ -12,7 +12,7 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) n_factors = length(agent.substruct.settings["num_controls"]) # Initialize empty arrays for action distribution per factor - action_p = Vector{Vector{Float64}}(undef, n_factors) + action_p = Vector{Any}(undef, n_factors) action_distribution = Vector{Distributions.Categorical}(undef, n_factors) #If there was a previous action @@ -35,16 +35,23 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) # Run state inference infer_states!(agent.substruct, obs) + if !ismissing(agent.states["action"]) + + #Get the posterior over states from the previous time step + states_posterior = get_history(agent.substruct)["posterior_states"][end-1] + + # Update Transition Matrix + update_B!(agent.substruct, states_posterior) + end + # Run policy inference infer_policies!(agent.substruct) - ### Retrieve log marginal probabilities of actions log_action_marginals = get_log_action_marginals(agent.substruct) ### Pass action marginals through softmax function to get action probabilities for factor in 1:n_factors action_p[factor] = softmax(log_action_marginals[factor] * alpha, dims=1) action_distribution[factor] = Distributions.Categorical(action_p[factor]) - @show action_distribution[factor] end return n_factors == 1 ? action_distribution[1] : action_distribution diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 2e37b37..288b854 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -3,7 +3,7 @@ #### State Inference #### """ Get Expected States """ -function get_expected_states(qs::Vector{Vector{Real}}, B, policy::Matrix{Int64}) +function get_expected_states(qs::Vector{Vector{T}} where T <: Real, B, policy::Matrix{Int64}) n_steps, n_factors = size(policy) # initializing posterior predictive density as a list of beliefs over time @@ -120,7 +120,11 @@ function process_observation(observation::Union{Array{Int}, Tuple{Vararg{Int}}}, end """ Update Posterior States """ -function update_posterior_states(A::Vector{Array{Real}}, obs::Vector{Int64}; prior::Union{Nothing, Vector{Vector{Real}}}=nothing, num_iter::Int=num_iter, dF_tol::Float64=dF_tol, kwargs...) +function update_posterior_states( + A::Vector{Array{T,N}} where {T <: Real, N}, + obs::Vector{Int64}; + prior::Union{Nothing, Vector{Vector{T}}} where T <: Real = nothing, + num_iter::Int=num_iter, dF_tol::Float64=dF_tol, kwargs...) num_obs, num_states, num_modalities, num_factors = get_model_dimensions(A) obs_processed = process_observation(obs, num_modalities, num_obs) @@ -129,7 +133,11 @@ end """ Run State Inference via Fixed-Point Iteration """ -function fixed_point_iteration(A::Vector{Array{Real}}, obs::Vector{Vector{Real}}, num_obs::Vector{Int64}, num_states::Vector{Int64}; prior::Union{Nothing, Vector{Vector{Real}}}=nothing, num_iter::Int=num_iter, dF::Float64=1.0, dF_tol::Float64=dF_tol) +function fixed_point_iteration( + A::Vector{Array{T,N}} where {T <: Real, N}, obs::Vector{Vector{Real}}, num_obs::Vector{Int64}, num_states::Vector{Int64}; + prior::Union{Nothing, Vector{Vector{T}}} where T <: Real = nothing, + num_iter::Int=num_iter, dF::Float64=1.0, dF_tol::Float64=dF_tol +) n_modalities = length(num_obs) n_factors = length(num_states) @@ -138,9 +146,9 @@ function fixed_point_iteration(A::Vector{Array{Real}}, obs::Vector{Vector{Real}} likelihood = capped_log(likelihood) # Initialize posterior and prior - qs = Vector{Vector{Real}}(undef, n_factors) + qs = Vector{Vector{Float64}}(undef, n_factors) for factor in 1:n_factors - qs[factor] = ones(Real,num_states[factor]) / num_states[factor] + qs[factor] = ones(num_states[factor]) / num_states[factor] end if prior === nothing @@ -209,7 +217,7 @@ end """ Calculate Free Energy """ -function calc_free_energy(qs::Vector{Vector{Real}}, prior, n_factors, likelihood=nothing) +function calc_free_energy(qs::Vector{Vector{T}} where T <: Real, prior, n_factors, likelihood=nothing) # Initialize free energy free_energy = 0.0 @@ -234,24 +242,24 @@ end #### Policy Inference #### """ Update Posterior over Policies """ function update_posterior_policies( - qs::Vector{Vector{Real}}, - A::Vector{Array{Real}}, - B::Vector{Array{Real}}, - C::Vector{Array{Real}}, + qs::Vector{Vector{T}} where T <: Real, + A::Vector{Array{T, N}} where {T <: Real, N}, + B::Vector{Array{T, N}} where {T <: Real, N}, + C::Vector{Array{T}} where T <: Real, policies::Vector{Matrix{Int64}}, use_utility::Bool=true, use_states_info_gain::Bool=true, use_param_info_gain::Bool=false, pA = nothing, pB = nothing, - E::Vector{Real} = nothing, + E::Vector{T} where T <: Real = nothing, gamma::Real=16.0 ) n_policies = length(policies) - G = zeros(Real,n_policies) - q_pi = Vector{Real}(undef, n_policies) - qs_pi = Vector{Real}[] - qo_pi = Vector{Real}[] + G = zeros(n_policies) + q_pi = Vector{Float64}(undef, n_policies) + qs_pi = Vector{Float64}[] + qo_pi = Vector{Float64}[] lnE = capped_log(E) for (idx, policy) in enumerate(policies) @@ -278,13 +286,13 @@ function update_posterior_policies( end - q_pi .= softmax(G * gamma + lnE, dims=1) + q_pi = softmax(G * gamma + lnE, dims=1) return q_pi, G end """ Get Expected Observations """ -function get_expected_obs(qs_pi, A::Vector{Array{Real}}) +function get_expected_obs(qs_pi, A::Vector{Array{T,N}} where {T <: Real, N}) n_steps = length(qs_pi) qo_pi = [] diff --git a/src/pomdp/learning.jl b/src/pomdp/learning.jl index 7509fe6..f29d785 100644 --- a/src/pomdp/learning.jl +++ b/src/pomdp/learning.jl @@ -26,13 +26,16 @@ function update_obs_likelihood_dirichlet(pA, A, obs, qs; lr = 1.0, fr = 1.0, mod end """ Update state likelihood matrix """ -function update_state_likelihood_dirichlet(pB, B, actions, qs::Vector{Vector{Real}}, qs_prev; lr = 1.0, fr = 1.0, factors = "all") +function update_state_likelihood_dirichlet(pB, B, actions, qs::Vector{Vector{T}} where T <: Real, qs_prev; lr = 1.0, fr = 1.0, factors = "all") + + if ReverseDiff.istracked(lr) + lr = ReverseDiff.value(lr) + end num_factors = length(pB) qB = deepcopy(pB) - if factors === "all" factors = collect(1:num_factors) end diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index 475611d..23f1784 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -1,13 +1,13 @@ """ -------- AIF Mutable Struct -------- """ mutable struct AIF - A::Vector{Array{Real}} # A-matrix - B::Vector{Array{Real}} # B-matrix + A::Vector{Array{T, N}} where {T <: Real, N} # A-matrix + B::Vector{Array{T, N}} where {T <: Real, N} # B-matrix C::Vector{Array{Real}} # C-vectors - D::Vector{Array{Real}} # D-vectors - E::Vector{Real} # E-vector (Habits) + D::Vector{Vector{Real}} # D-vectors + E::Vector{T} where T <: Real # E-vector (Habits) pA::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for A-matrix - pB::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for B-matrix + pB::Union{Vector{Array{T, N}}, Nothing} where {T <: Real, N} # Dirichlet priors for B-matrix pD::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for D-vector lr_pA::Real # pA Learning Parameter fr_pA::Real # pA Forgetting Parameter, 1.0 for no forgetting @@ -23,11 +23,11 @@ mutable struct AIF num_controls::Array{Int,1} # Number of actions per factor control_fac_idx::Array{Int,1} # Indices of controllable factors policy_len::Int # Policy length - qs_current::Vector{Vector{Real}} # Current beliefs about states - prior::Vector{Vector{Real}} # Prior beliefs about states - Q_pi::Vector{Real} # Posterior beliefs over policies - G::Vector{Real} # Expected free energy of policy - action::Vector{Real} # Last action + qs_current::Vector{Vector{T}} where T <: Real # Current beliefs about states + prior::Vector{Vector{T}} where T <: Real # Prior beliefs about states + Q_pi::Vector{T} where T <:Real # Posterior beliefs over policies + G::Vector{T} where T <:Real # Expected free energy of policy + action::Vector{Int} # Last action use_utility::Bool # Utility Boolean Flag use_states_info_gain::Bool # States Information Gain Boolean Flag use_param_info_gain::Bool # Include the novelty value in the learning parameters @@ -107,9 +107,9 @@ function create_aif(A, B; qs_current = create_matrix_templates(num_states) prior = D - Q_pi = ones(Real,length(policies)) / length(policies) - G = zeros(Real,length(policies)) - action = [] + Q_pi = ones(length(policies)) / length(policies) + G = zeros(length(policies)) + action = Int[] # initialize states dictionary states = Dict( diff --git a/src/utils/maths.jl b/src/utils/maths.jl index 52a50b4..9ae82c4 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -35,7 +35,7 @@ function capped_log(array::Array{Float64}) return array end -function capped_log(array::Array{Real}) +function capped_log(array::Array{T}) where T <: Real epsilon = oftype(array[1], 1e-16) # Return the log of the array values capped at epsilon diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 55898d0..27fd2df 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -81,7 +81,7 @@ function get_log_action_marginals(aif) # Initialize action_marginals with the correct element type action_marginals = create_matrix_templates(aif.num_controls, "zeros", eltype_q_pi) - log_action_marginals = Vector{Vector{Float64}}(undef, num_factors) + log_action_marginals = Vector{Any}(undef, num_factors) for (pol_idx, policy) in enumerate(policies) for (factor_i, action_i) in enumerate(policy[1,:]) From 922c8ac9b2abdd96c6f107b7095ac2f697f8ca1f Mon Sep 17 00:00:00 2001 From: samuelnehrer02 Date: Tue, 22 Oct 2024 21:41:37 +0200 Subject: [PATCH 21/25] ReverseDiff Bug - Patchy Solution --- Project.toml | 2 +- src/pomdp/POMDP.jl | 10 +++++--- src/pomdp/inference.jl | 53 ++++++++++++++++++++++++++++++++---------- src/pomdp/learning.jl | 12 ++++++++++ src/pomdp/struct.jl | 2 +- src/utils/maths.jl | 11 ++------- 6 files changed, 64 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 0f9df16..5b75a11 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [compat] -ActionModels = "0.5" +ActionModels = "0.6" Distributions = "0.25" IterTools = "1.10" LinearAlgebra = "1" diff --git a/src/pomdp/POMDP.jl b/src/pomdp/POMDP.jl index 25d841e..879d580 100644 --- a/src/pomdp/POMDP.jl +++ b/src/pomdp/POMDP.jl @@ -35,6 +35,10 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) # Run state inference infer_states!(agent.substruct, obs) + update_A!(agent.substruct, obs) + + + #= if !ismissing(agent.states["action"]) #Get the posterior over states from the previous time step @@ -43,7 +47,7 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) # Update Transition Matrix update_B!(agent.substruct, states_posterior) end - + =# # Run policy inference infer_policies!(agent.substruct) ### Retrieve log marginal probabilities of actions @@ -80,9 +84,9 @@ function action_pomdp!(agent::Agent, obs::Tuple{Vararg{Int}}) #Extract it previous_action = agent.states["action"] - #If it is not a vector, make it one + # If it is not a vector, make it one if !(previous_action isa Vector) - previous_action = [previous_action] + previous_action = collect(previous_action) end #Store the action in the AIF substruct diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 288b854..1fb1b9c 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -104,7 +104,7 @@ Process observation with multiple modalities and return them in a one-hot encode function process_observation(observation::Union{Array{Int}, Tuple{Vararg{Int}}}, n_modalities::Int, n_observations::Vector{Int}) # Initialize the processed_observation vector - processed_observation = Vector{Vector{Real}}(undef, n_modalities) + processed_observation = Vector{Vector{Float64}}(undef, n_modalities) # Check if the length of observation matches the number of modalities if length(observation) == n_modalities @@ -134,7 +134,7 @@ end """ Run State Inference via Fixed-Point Iteration """ function fixed_point_iteration( - A::Vector{Array{T,N}} where {T <: Real, N}, obs::Vector{Vector{Real}}, num_obs::Vector{Int64}, num_states::Vector{Int64}; + A::Vector{Array{T,N}} where {T <: Real, N}, obs::Vector{Vector{Float64}}, num_obs::Vector{Int64}, num_states::Vector{Int64}; prior::Union{Nothing, Vector{Vector{T}}} where T <: Real = nothing, num_iter::Int=num_iter, dF::Float64=1.0, dF_tol::Float64=dF_tol ) @@ -167,9 +167,8 @@ function fixed_point_iteration( qL = dot_product(likelihood, qs[1]) return [softmax(qL .+ prior[1], dims=1)] else - # Run Iteration - curr_iter = 0 - while curr_iter < num_iter && dF >= dF_tol + # Run for a given number of iterations + for iteration in 1:num_iter qs_all = qs[1] for factor in 2:n_factors qs_all = qs_all .* reshape(qs[factor], tuple(ones(Real, factor - 1)..., :, 1)) @@ -181,7 +180,14 @@ function fixed_point_iteration( for i in 1:size(qs[factor], 1) qL[i] = sum([LL_tensor[indices...] / qs[factor][i] for indices in Iterators.product([1:size(LL_tensor, dim) for dim in 1:n_factors]...) if indices[factor] == i]) end - qs[factor] = softmax(qL + prior[factor], dims=1) + # If qs is tracked by ReverseDiff, get the value + if ReverseDiff.istracked(softmax(qL .+ prior[factor], dims=1)) + qs[factor] = ReverseDiff.value(softmax(qL .+ prior[factor], dims=1)) + + # Otherwise, proceed as normal + else + qs[factor] = softmax(qL .+ prior[factor], dims=1) + end end # Recompute free energy @@ -190,8 +196,6 @@ function fixed_point_iteration( # Update stopping condition dF = abs(prev_vfe - vfe) prev_vfe = vfe - - curr_iter += 1 end return qs @@ -201,7 +205,7 @@ end """ Calculate Accuracy Term """ -function compute_accuracy(log_likelihood, qs::Vector{Vector{Real}}) +function compute_accuracy(log_likelihood, qs::Vector{Vector{T}} where T <: Real) n_factors = length(qs) ndims_ll = ndims(log_likelihood) dims = (ndims_ll - n_factors + 1) : ndims_ll @@ -266,18 +270,43 @@ function update_posterior_policies( qs_pi = get_expected_states(qs, B, policy) qo_pi = get_expected_obs(qs_pi, A) + # Calculate expected utility if use_utility - G[idx] += calc_expected_utility(qo_pi, C) + # If ReverseDiff is tracking the expected utility, get the value + if ReverseDiff.istracked(calc_expected_utility(qo_pi, C)) + G[idx] += ReverseDiff.value(calc_expected_utility(qo_pi, C)) + + # Otherwise calculate the expected utility and add it to the G vector + else + G[idx] += calc_expected_utility(qo_pi, C) + end end + # Calculate expected information gain of states if use_states_info_gain - G[idx] += calc_states_info_gain(A, qs_pi) + # If ReverseDiff is tracking the information gain, get the value + if ReverseDiff.istracked(calc_states_info_gain(A, qs_pi)) + G[idx] += ReverseDiff.value(calc_states_info_gain(A, qs_pi)) + + # Otherwise calculate it and add it to the G vector + else + G[idx] += calc_states_info_gain(A, qs_pi) + end end + # Calculate expected information gain of parameters (learning) if use_param_info_gain if pA !== nothing - G[idx] += calc_pA_info_gain(pA, qo_pi, qs_pi) + + # if ReverseDiff is tracking pA information gain, get the value + if ReverseDiff.istracked(calc_pA_info_gain(pA, qo_pi, qs_pi)) + G[idx] += ReverseDiff.value(calc_pA_info_gain(pA, qo_pi, qs_pi)) + # Otherwise calculate it and add it to the G vector + else + G[idx] += calc_pA_info_gain(pA, qo_pi, qs_pi) + end end + if pB !== nothing G[idx] += calc_pB_info_gain(pB, qs_pi, qs, policy) end diff --git a/src/pomdp/learning.jl b/src/pomdp/learning.jl index f29d785..181cafe 100644 --- a/src/pomdp/learning.jl +++ b/src/pomdp/learning.jl @@ -1,6 +1,15 @@ """ Update obs likelihood matrix """ function update_obs_likelihood_dirichlet(pA, A, obs, qs; lr = 1.0, fr = 1.0, modalities = "all") + # If reverse diff is tracking the learning rate, get the value + if ReverseDiff.istracked(lr) + lr = ReverseDiff.value(lr) + end + # If reverse diff is tracking the forgetting rate, get the value + if ReverseDiff.istracked(fr) + fr = ReverseDiff.value(fr) + end + # Extracting the number of modalities and observations from the dirichlet: pA num_modalities = length(pA) num_observations = [size(pA[modality + 1], 1) for modality in 0:(num_modalities - 1)] @@ -31,6 +40,9 @@ function update_state_likelihood_dirichlet(pB, B, actions, qs::Vector{Vector{T}} if ReverseDiff.istracked(lr) lr = ReverseDiff.value(lr) end + if ReverseDiff.istracked(fr) + fr = ReverseDiff.value(fr) + end num_factors = length(pB) diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index 23f1784..65db7a4 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -6,7 +6,7 @@ mutable struct AIF C::Vector{Array{Real}} # C-vectors D::Vector{Vector{Real}} # D-vectors E::Vector{T} where T <: Real # E-vector (Habits) - pA::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for A-matrix + pA::Union{Vector{Array{T, N}}, Nothing} where {T <: Real, N} # Dirichlet priors for A-matrix pB::Union{Vector{Array{T, N}}, Nothing} where {T <: Real, N} # Dirichlet priors for B-matrix pD::Union{Vector{Array{Real}}, Nothing} # Dirichlet priors for D-vector lr_pA::Real # pA Learning Parameter diff --git a/src/utils/maths.jl b/src/utils/maths.jl index 9ae82c4..7a374bc 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -106,9 +106,6 @@ function outer_product(x, y=nothing; remove_singleton_dims=true, args...) end end - # To ensure the types of x and y match the input types. - T = promote_type(eltype(x), eltype(y)) - # If y is provided, perform the cross multiplication. if y !== nothing reshape_dims_x = tuple(size(x)..., ones(Real, ndims(y))...) @@ -119,8 +116,6 @@ function outer_product(x, y=nothing; remove_singleton_dims=true, args...) z = A .* B - # Type convert to the original type - z = convert(Array{T}, z) else z = x end @@ -173,8 +168,7 @@ end function calculate_bayesian_surprise(A, x) qx = outer_product(x) G = 0.0 - qo = Vector{Real}() - T = typeof(qo) + qo = Vector{Float64}() idx = [collect(Tuple(indices)) for indices in findall(qx .> exp(-16))] index_vector = [] @@ -189,12 +183,11 @@ function calculate_bayesian_surprise(A, x) end po = vec(po) if isempty(qo) - qo = zeros(Real, length(po)) + qo = zeros(length(po)) end qo += qx[i...] * po G += qx[i...] * dot(po, log.(po .+ exp(-16))) end - qo = convert(T, qo) G = G - dot(qo, capped_log(qo)) return G end From f88e9f7b2712033cae0dd9a2d8dae6a7e261c573 Mon Sep 17 00:00:00 2001 From: samuelnehrer02 Date: Wed, 23 Oct 2024 19:52:56 +0200 Subject: [PATCH 22/25] minor adjustments --- src/pomdp/learning.jl | 2 +- src/utils/utils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pomdp/learning.jl b/src/pomdp/learning.jl index 181cafe..336f392 100644 --- a/src/pomdp/learning.jl +++ b/src/pomdp/learning.jl @@ -62,7 +62,7 @@ function update_state_likelihood_dirichlet(pB, B, actions, qs::Vector{Vector{T}} end """ Update prior D matrix """ -function update_state_prior_dirichlet(pD, qs::Vector{Vector{Real}}; lr = 1.0, fr = 1.0, factors = "all") +function update_state_prior_dirichlet(pD, qs::Vector{Vector{T}} where T <: Real; lr = 1.0, fr = 1.0, factors = "all") num_factors = length(pD) diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 27fd2df..cbdda25 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -21,7 +21,7 @@ end """ Creates a onehot encoded vector """ function onehot(index::Int, vector_length::Int) - vector = zeros(Real, vector_length) + vector = zeros(vector_length) vector[index] = 1.0 return vector end From d818e3a8225a4774b1712188ba2d656a1847369b Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Mon, 28 Oct 2024 12:28:44 +0100 Subject: [PATCH 23/25] -updates for version 0.1.0 --- .github/workflows/CI_small.yml | 2 + Project.toml | 3 +- src/ActionModelsExtensions/give_inputs.jl | 2 +- src/pomdp/POMDP.jl | 109 +++++++++++++++++----- src/pomdp/inference.jl | 33 +++++-- src/pomdp/struct.jl | 6 +- src/utils/maths.jl | 29 ++++-- 7 files changed, 140 insertions(+), 44 deletions(-) diff --git a/.github/workflows/CI_small.yml b/.github/workflows/CI_small.yml index 013b90b..9e687eb 100644 --- a/.github/workflows/CI_small.yml +++ b/.github/workflows/CI_small.yml @@ -3,10 +3,12 @@ on: push: branches: - dev + - fitting tags: ['*'] pull_request: branches: - dev + - fitting workflow_dispatch: concurrency: # Skip intermediate builds: always. diff --git a/Project.toml b/Project.toml index 5b75a11..3772e4e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ActiveInference" uuid = "688b0e7a-0122-4325-8669-5ff08899a59e" authors = ["Jonathan Ehrenreich Laursen", "Samuel William Nehrer"] -version = "0.0.4" +version = "0.1.0" [deps] ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" @@ -19,4 +19,5 @@ IterTools = "1.10" LinearAlgebra = "1" LogExpFunctions = "0.3" Random = "1" +ReverseDiff = "1.15" julia = "1.10" diff --git a/src/ActionModelsExtensions/give_inputs.jl b/src/ActionModelsExtensions/give_inputs.jl index a91fda9..08923a5 100644 --- a/src/ActionModelsExtensions/give_inputs.jl +++ b/src/ActionModelsExtensions/give_inputs.jl @@ -22,7 +22,7 @@ function ActionModels.single_input!(aif::AIF, obs::Vector) # if there is only one factor if num_factors == 1 # Sample action from the action distribution - action = rand(action_distributions[1]) + action = rand(action_distributions) # If the agent has not taken any actions yet if isempty(aif.action) diff --git a/src/pomdp/POMDP.jl b/src/pomdp/POMDP.jl index 879d580..31ff02b 100644 --- a/src/pomdp/POMDP.jl +++ b/src/pomdp/POMDP.jl @@ -1,6 +1,16 @@ """ -This module contains models of Partially Observable Markov Decision Processes under Active Inference - + action_pomdp!(agent, obs) +This function wraps the POMDP action-perception loop used for simulating and fitting the data. + +Arguments: +- `agent::Agent`: An instance of ActionModels `Agent` type, which contains AIF type object as a substruct. +- `obs::Vector{Int64}`: A vector of observations, where each observation is an integer. +- `obs::Tuple{Vararg{Int}}`: A tuple of observations, where each observation is an integer. +- `obs::Int64`: A single observation, which is an integer. +- `aif::AIF`: An instance of the `AIF` type, which contains the agent's state, parameters, and substructures. + +Outputs: +- Returns a `Distributions.Categorical` distribution or a vector of distributions, representing the probability distributions for actions per each state factor. """ ### Action Model: Returns probability distributions for actions per factor @@ -21,9 +31,9 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) #Extract it previous_action = agent.states["action"] - #If it is not a vector, make it one + # If it is not a vector, make it one if !(previous_action isa Vector) - previous_action = [previous_action] + previous_action = collect(previous_action) end #Store the action in the AIF substruct @@ -35,23 +45,34 @@ function action_pomdp!(agent::Agent, obs::Vector{Int64}) # Run state inference infer_states!(agent.substruct, obs) - update_A!(agent.substruct, obs) - + # If action is empty, update D vectors + if ismissing(agent.states["action"]) && agent.substruct.pD !== nothing + qs_t1 = get_history(agent.substruct)["posterior_states"][1] + update_D!(aif, qs_t1) + end - #= - if !ismissing(agent.states["action"]) + # If learning of the B matrix is enabled and agent has a previous action + if !ismissing(agent.states["action"]) && agent.substruct.pB !== nothing - #Get the posterior over states from the previous time step + # Get the posterior over states from the previous time step states_posterior = get_history(agent.substruct)["posterior_states"][end-1] # Update Transition Matrix update_B!(agent.substruct, states_posterior) end - =# + + # If learning of the A matrix is enabled + if agent.substruct.pA !== nothing + update_A!(agent.substruct, obs) + end + # Run policy inference infer_policies!(agent.substruct) + + ### Retrieve log marginal probabilities of actions log_action_marginals = get_log_action_marginals(agent.substruct) + ### Pass action marginals through softmax function to get action probabilities for factor in 1:n_factors action_p[factor] = softmax(log_action_marginals[factor] * alpha, dims=1) @@ -72,14 +93,14 @@ function action_pomdp!(agent::Agent, obs::Tuple{Vararg{Int}}) ### Get parameters alpha = agent.substruct.parameters["alpha"] - n_factors = length(aif.settings["num_controls"]) + n_factors = length(agent.substruct.settings["num_controls"]) # Initialize empty arrays for action distribution per factor action_p = Vector{Any}(undef, n_factors) - action_distribution = Vector(undef, n_factors) + action_distribution = Vector{Distributions.Categorical}(undef, n_factors) - #If there was a previous action - if !ismissing(agent.states["action"]) + #If there was a previous action + if !ismissing(agent.states["action"]) #Extract it previous_action = agent.states["action"] @@ -96,20 +117,42 @@ function action_pomdp!(agent::Agent, obs::Tuple{Vararg{Int}}) ### Infer states & policies # Run state inference - infer_states!(aif, obs) + infer_states!(agent.substruct, obs) + + # If action is empty and pD is not nothing, update D vectors + if ismissing(agent.states["action"]) && agent.substruct.pD !== nothing + qs_t1 = get_history(agent.substruct)["posterior_states"][1] + update_D!(aif, qs_t1) + end + + # If learning of the B matrix is enabled and agent has a previous action + if !ismissing(agent.states["action"]) && agent.substruct.pB !== nothing + + # Get the posterior over states from the previous time step + states_posterior = get_history(agent.substruct)["posterior_states"][end-1] + + # Update Transition Matrix + update_B!(agent.substruct, states_posterior) + end + + # If learning of the A matrix is enabled + if agent.substruct.pA !== nothing + update_A!(agent.substruct, obs) + end # Run policy inference - infer_policies!(aif) + infer_policies!(agent.substruct) - ### Retrieve log marginal probabilities of actions - log_action_marginals = get_log_action_marginals(aif) + ### Retrieve log marginal probabilities of actions + log_action_marginals = get_log_action_marginals(agent.substruct) + ### Pass action marginals through softmax function to get action probabilities for factor in 1:n_factors action_p[factor] = softmax(log_action_marginals[factor] * alpha, dims=1) action_distribution[factor] = Distributions.Categorical(action_p[factor]) end - + return n_factors == 1 ? action_distribution[1] : action_distribution end @@ -119,18 +162,40 @@ function action_pomdp!(aif::AIF, obs::Vector{Int64}) alpha = aif.parameters["alpha"] n_factors = length(aif.settings["num_controls"]) - # Initialize an empty arrays for action distribution per factor + # Initialize empty arrays for action distribution per factor action_p = Vector{Any}(undef, n_factors) - action_distribution = Vector(undef, n_factors) + action_distribution = Vector{Distributions.Categorical}(undef, n_factors) ### Infer states & policies # Run state inference infer_states!(aif, obs) + # If action is empty, update D vectors + if ismissing(get_states(aif)["action"]) && aif.pD !== nothing + qs_t1 = get_history(aif)["posterior_states"][1] + update_D!(aif, qs_t1) + end + + # If learning of the B matrix is enabled and agent has a previous action + if !ismissing(get_states(aif)["action"]) && aif.pB !== nothing + + # Get the posterior over states from the previous time step + states_posterior = get_history(aif)["posterior_states"][end-1] + + # Update Transition Matrix + update_B!(aif, states_posterior) + end + + # If learning of the A matrix is enabled + if aif.pA !== nothing + update_A!(aif, obs) + end + # Run policy inference infer_policies!(aif) + ### Retrieve log marginal probabilities of actions log_action_marginals = get_log_action_marginals(aif) @@ -140,7 +205,7 @@ function action_pomdp!(aif::AIF, obs::Vector{Int64}) action_distribution[factor] = Distributions.Categorical(action_p[factor]) end - return action_distribution + return n_factors == 1 ? action_distribution[1] : action_distribution end function action_pomdp!(agent::Agent, obs::Int64) diff --git a/src/pomdp/inference.jl b/src/pomdp/inference.jl index 1fb1b9c..4c099c4 100644 --- a/src/pomdp/inference.jl +++ b/src/pomdp/inference.jl @@ -138,6 +138,7 @@ function fixed_point_iteration( prior::Union{Nothing, Vector{Vector{T}}} where T <: Real = nothing, num_iter::Int=num_iter, dF::Float64=1.0, dF_tol::Float64=dF_tol ) + # Get model dimensions (NOTE Sam: We need to save model dimensions in the AIF struct in the future) n_modalities = length(num_obs) n_factors = length(num_states) @@ -151,6 +152,7 @@ function fixed_point_iteration( qs[factor] = ones(num_states[factor]) / num_states[factor] end + # If no prior is provided, create a default prior with uniform distribution if prior === nothing prior = create_matrix_templates(num_states) end @@ -166,26 +168,38 @@ function fixed_point_iteration( if n_factors == 1 qL = dot_product(likelihood, qs[1]) return [softmax(qL .+ prior[1], dims=1)] + + # If there are more factors else - # Run for a given number of iterations - for iteration in 1:num_iter + ### Fixed-Point Iteration ### + curr_iter = 0 + ### Sam NOTE: We need check if ReverseDiff might potantially have issues with this while loop ### + while curr_iter < num_iter && dF >= dF_tol qs_all = qs[1] + # Loop over each factor starting from the second one for factor in 2:n_factors + # Reshape and multiply qs_all with the current factor's qs qs_all = qs_all .* reshape(qs[factor], tuple(ones(Real, factor - 1)..., :, 1)) end + + # Compute the log-likelihood LL_tensor = likelihood .* qs_all + # Update each factor's qs for factor in 1:n_factors - qL = zeros(Real,size(qs[factor])) + # Initialize qL for the current factor + qL = zeros(Real, size(qs[factor])) + + # Compute qL for each state in the current factor for i in 1:size(qs[factor], 1) qL[i] = sum([LL_tensor[indices...] / qs[factor][i] for indices in Iterators.product([1:size(LL_tensor, dim) for dim in 1:n_factors]...) if indices[factor] == i]) end + # If qs is tracked by ReverseDiff, get the value if ReverseDiff.istracked(softmax(qL .+ prior[factor], dims=1)) qs[factor] = ReverseDiff.value(softmax(qL .+ prior[factor], dims=1)) - - # Otherwise, proceed as normal else + # Otherwise, proceed as normal qs[factor] = softmax(qL .+ prior[factor], dims=1) end end @@ -196,6 +210,9 @@ function fixed_point_iteration( # Update stopping condition dF = abs(prev_vfe - vfe) prev_vfe = vfe + + # Increment iteration + curr_iter += 1 end return qs @@ -477,14 +494,14 @@ function compute_accuracy_new(log_likelihood, qs::Vector{Vector{Real}}) return results end -""" Calculate SAPE """ -function calc_SAPE(aif::AIF) +""" Calculate State-Action Prediction Error """ +function calculate_SAPE(aif::AIF) qs_pi_all = get_expected_states(aif.qs_current, aif.B, aif.policies) qs_bma = bayesian_model_average(qs_pi_all, aif.Q_pi) if length(aif.states["bayesian_model_averages"]) != 0 - sape = kl_div(qs_bma, aif.states["bayesian_model_averages"][end]) + sape = kl_divergence(qs_bma, aif.states["bayesian_model_averages"][end]) push!(aif.states["SAPE"], sape) end diff --git a/src/pomdp/struct.jl b/src/pomdp/struct.jl index 65db7a4..adba01f 100644 --- a/src/pomdp/struct.jl +++ b/src/pomdp/struct.jl @@ -26,7 +26,7 @@ mutable struct AIF qs_current::Vector{Vector{T}} where T <: Real # Current beliefs about states prior::Vector{Vector{T}} where T <: Real # Prior beliefs about states Q_pi::Vector{T} where T <:Real # Posterior beliefs over policies - G::Vector{T} where T <:Real # Expected free energy of policy + G::Vector{T} where T <:Real # Expected free energy of policies action::Vector{Int} # Last action use_utility::Bool # Utility Boolean Flag use_states_info_gain::Bool # States Information Gain Boolean Flag @@ -56,8 +56,8 @@ function create_aif(A, B; fr_pD = 1.0, modalities_to_learn = "all", factors_to_learn = "all", - gamma=16.0, - alpha=16.0, + gamma=1.0, + alpha=1.0, policy_len=1, num_controls=nothing, control_fac_idx=nothing, diff --git a/src/utils/maths.jl b/src/utils/maths.jl index 7a374bc..215d523 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -18,12 +18,12 @@ function capped_log(x::Real) end """ - capped_log(x::AbstractArray{T}) where T <: Real - -# Arguments -- `array::AbstractArray{T}`: An array of real numbers. + capped_log(array::Array{Float64}) + capped_log(array::Array{T}) where T <: Real + capped_log(array::Vector{Real}) Return the natural logarithm of x, capped at the machine epsilon value of x. + """ function capped_log(array::Array{Float64}) @@ -49,7 +49,6 @@ function capped_log(array::Vector{Real}) array = log.(max.(array, epsilon)) # Return the log of the array values capped at epsilon - #@show typeof(array) return array end @@ -256,15 +255,27 @@ function bayesian_model_average(qs_pi_all, q_pi) return qs_bma end -function kl_div(P::Vector{Vector{Vector{Real}}}, Q::Vector{Vector{Vector{Real}}}) - eps_val=1e-16 - dkl = 0.0 +""" + kl_divergence(P::Vector{Vector{Vector{Float64}}}, Q::Vector{Vector{Vector{Float64}}}) + +# Arguments +- `P::Vector{Vector{Vector{Real}}}` +- `Q::Vector{Vector{Vector{Real}}}` + +Return the Kullback-Leibler (KL) divergence between two probability distributions. +""" +function kl_divergence(P::Vector{Vector{Vector{Real}}}, Q::Vector{Vector{Vector{Real}}}) + eps_val = 1e-16 # eps constant to avoid log(0) + dkl = 0.0 # Initialize KL divergence to zero + for j in 1:length(P) for i in 1:length(P[j]) + # Compute the dot product of P[j][i] and the difference of logs of P[j][i] and Q[j][i] dkl += dot(P[j][i], log.(P[j][i] .+ eps_val) .- log.(Q[j][i] .+ eps_val)) end end - return dkl + + return dkl # Return KL divergence end From 5bf0b8cd340e925f5e067b299b3089615d5ee925 Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Mon, 28 Oct 2024 17:01:01 +0100 Subject: [PATCH 24/25] cleanup and readme updates --- README.md | 8 ++++++++ src/ActionModelsExtensions/get_history.jl | 2 +- src/ActionModelsExtensions/get_parameters.jl | 3 +-- src/ActionModelsExtensions/get_states.jl | 2 +- src/ActionModelsExtensions/give_inputs.jl | 2 +- src/ActionModelsExtensions/set_parameters.jl | 2 +- src/ActiveInference.jl | 1 - src/Environments/EpistChainEnv.jl | 18 ++++++++++++++---- src/utils/maths.jl | 14 +++++++------- src/utils/utils.jl | 10 ---------- 10 files changed, 34 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 719ce86..1b5c1b5 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,14 @@ ActiveInference.jl is a new Julia package for the computational modeling of acti Left: A synthetic agent wants to reach the end of the maze environment while avoiding dark-colored locations. Right: The agent's noisy prior expectations about the state of the environment parameterized by Dirichlet distributions are updated dynamically as it moves through the maze. +## News +#### Version 0.1.0 - October 2024 +- All parameters are now recoverable using ReverseDiff +- Now fully compatible with ActionModels.jl v0.6.6 + +#### Coming Soon +- **Documentation**: Full documentation will be available within the next few weeks, including examples and tutorials. + ## Installation Install ActiveInference.jl using the Julia package manager: ````@example Introduction diff --git a/src/ActionModelsExtensions/get_history.jl b/src/ActionModelsExtensions/get_history.jl index 4d032fb..46518a3 100644 --- a/src/ActionModelsExtensions/get_history.jl +++ b/src/ActionModelsExtensions/get_history.jl @@ -1,5 +1,5 @@ """ -This module extends the "get_history" functionality of the ActionModels package to work specifically with instances of the AIF type. +This extends the "get_history" function of the ActionModels package to work specifically with instances of the AIF type. get_history(aif::AIF, target_states::Vector{String}) Retrieves a history for multiple states of an AIF agent. diff --git a/src/ActionModelsExtensions/get_parameters.jl b/src/ActionModelsExtensions/get_parameters.jl index 41e7c1f..9c7b0a0 100644 --- a/src/ActionModelsExtensions/get_parameters.jl +++ b/src/ActionModelsExtensions/get_parameters.jl @@ -1,6 +1,5 @@ """ -This module extends the "get_parameters" functionality of the ActionModels package to work specifically with instances of the AIF type. - +This extends the "get_parameters" function of the ActionModels package to work specifically with instances of the AIF type. get_parameters(aif::AIF, target_parameters::Vector{String}) Retrieves multiple target parameters from an AIF agent. diff --git a/src/ActionModelsExtensions/get_states.jl b/src/ActionModelsExtensions/get_states.jl index 5728fe8..28ed731 100644 --- a/src/ActionModelsExtensions/get_states.jl +++ b/src/ActionModelsExtensions/get_states.jl @@ -1,5 +1,5 @@ """ -This module extends the "get_states" functionality of the ActionModels package to work specifically with instances of the AIF type. +This extends the "get_states" function of the ActionModels package to work specifically with instances of the AIF type. get_states(aif::AIF, target_states::Vector{String}) Retrieves multiple states from an AIF agent. diff --git a/src/ActionModelsExtensions/give_inputs.jl b/src/ActionModelsExtensions/give_inputs.jl index 08923a5..77636b9 100644 --- a/src/ActionModelsExtensions/give_inputs.jl +++ b/src/ActionModelsExtensions/give_inputs.jl @@ -1,6 +1,6 @@ """ -This is an experimental module to extend the give_inputs! functionality of ActionsModels.jl to work with instances of the AIF type. +This is extends the give_inputs! function of ActionsModels.jl to work with instances of the AIF type. single_input!(aif::AIF, obs) Give a single observation to an AIF agent. diff --git a/src/ActionModelsExtensions/set_parameters.jl b/src/ActionModelsExtensions/set_parameters.jl index 567c3e3..45bec80 100644 --- a/src/ActionModelsExtensions/set_parameters.jl +++ b/src/ActionModelsExtensions/set_parameters.jl @@ -1,5 +1,5 @@ """ -This module extends the "set_parameters!" functionality of the ActionModels package to work with instances of the AIF type. +This extends the "set_parameters!" function of the ActionModels package to work with instances of the AIF type. set_parameters!(aif::AIF, target_param::String, param_value::Real) Set a single parameter in the AIF agent diff --git a/src/ActiveInference.jl b/src/ActiveInference.jl index 3026393..398fbca 100644 --- a/src/ActiveInference.jl +++ b/src/ActiveInference.jl @@ -34,7 +34,6 @@ export # utils/create_matrix_templates.jl # utils/utils.jl array_of_any_zeros, - array_of_any_uniform, onehot, get_model_dimensions, diff --git a/src/Environments/EpistChainEnv.jl b/src/Environments/EpistChainEnv.jl index f068323..81a9107 100644 --- a/src/Environments/EpistChainEnv.jl +++ b/src/Environments/EpistChainEnv.jl @@ -16,9 +16,11 @@ mutable struct EpistChainEnv end function step!(env::EpistChainEnv, action_label::String) + # Get current location y, x = env.current_loc next_y, next_x = y, x + # Update location based on action if action_label == "DOWN" next_y = y < env.len_y ? y + 1 : y elseif action_label == "UP" @@ -27,38 +29,45 @@ function step!(env::EpistChainEnv, action_label::String) next_x = x > 1 ? x - 1 : x elseif action_label == "RIGHT" next_x = x < env.len_x ? x + 1 : x - elseif action_label == "STAY" + elseif action_label == "STAY" + # No change in location end + # Set new location env.current_loc = (next_y, next_x) + # Observations loc_obs = env.current_loc cue2_names = ["Null", "reward_on_top", "reward_on_bottom"] cue2_loc_names = ["L1","L2","L3","L4"] cue2_locs = [(1, 3), (2, 4), (4, 4), (5, 3)] + # Map cue2 location names to indices cue2_loc_idx = Dict(cue2_loc_names[1] => 1, cue2_loc_names[2] => 2, cue2_loc_names[3] => 3, cue2_loc_names[4] => 4) + # Get cue2 location cue2_loc = cue2_locs[cue2_loc_idx[env.cue2]] - #------------------------------------------------------- + # Determine cue1 observation if env.current_loc == env.cue1_loc cue1_obs = env.cue2 else cue1_obs = "Null" end + # Reward conditions and locations reward_conditions = ["TOP", "BOTTOM"] reward_locations = [(2,6), (4,6)] rew_cond_idx = Dict(reward_conditions[1] => 1, reward_conditions[2] => 2) - #------------------------------------------------------- + # Determine cue2 observation if env.current_loc == cue2_loc cue2_obs = cue2_names[rew_cond_idx[env.reward_condition] + 1] else cue2_obs = "Null" end + # Determine reward observation if env.current_loc == reward_locations[1] if env.reward_condition == "TOP" reward_obs = "Cheese" @@ -75,11 +84,12 @@ function step!(env::EpistChainEnv, action_label::String) reward_obs = "Null" end - + # Return observations return loc_obs, cue1_obs, cue2_obs, reward_obs end function reset_env!(env::EpistChainEnv) + # Reset environment to initial location env.current_loc = env.init_loc println("Re-initialized location to $(env.init_loc)") return env.current_loc diff --git a/src/utils/maths.jl b/src/utils/maths.jl index 215d523..e287145 100644 --- a/src/utils/maths.jl +++ b/src/utils/maths.jl @@ -18,14 +18,8 @@ function capped_log(x::Real) end """ - capped_log(array::Array{Float64}) - capped_log(array::Array{T}) where T <: Real - capped_log(array::Vector{Real}) - -Return the natural logarithm of x, capped at the machine epsilon value of x. - + capped_log(array::Array{Float64}) """ - function capped_log(array::Array{Float64}) epsilon = oftype(array[1], 1e-16) @@ -35,6 +29,9 @@ function capped_log(array::Array{Float64}) return array end +""" + capped_log(array::Array{T}) where T <: Real +""" function capped_log(array::Array{T}) where T <: Real epsilon = oftype(array[1], 1e-16) @@ -44,6 +41,9 @@ function capped_log(array::Array{T}) where T <: Real return array end +""" + capped_log(array::Vector{Real}) +""" function capped_log(array::Vector{Real}) epsilon = oftype(array[1], 1e-16) diff --git a/src/utils/utils.jl b/src/utils/utils.jl index cbdda25..f9b0141 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -9,16 +9,6 @@ function array_of_any_zeros(shape_list) return arr end -""" Creates an array of "Any" as a uniform categorical distribution""" -function array_of_any_uniform(shape_list) - arr = Array{Any}(undef, length(shape_list)) - for i in eachindex(shape_list) - shape = shape_list[i] - arr[i] = normalize_distribution(ones(Real, shape)) - end - return arr -end - """ Creates a onehot encoded vector """ function onehot(index::Int, vector_length::Int) vector = zeros(vector_length) From 38201ca7fe3bb0ce8cac8fd7a7535bc6919eaf37 Mon Sep 17 00:00:00 2001 From: Samuel William Nehrer Date: Mon, 28 Oct 2024 17:05:13 +0100 Subject: [PATCH 25/25] Readme --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 1b5c1b5..99ef330 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,6 @@ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) -## Julia Package for Active Inference. ActiveInference.jl is a new Julia package for the computational modeling of active inference. We provide the necessary infrastructure for defining active inference models, currently implemented as partially observable Markov decision processes. After defining the generative model, you can simulate actions using agent-based simulations. We also provide the functionality to fit experimental data to active inference models for parameter recovery. ![Maze Animation](.github/animation_maze.gif)