diff --git a/src/ActiveInference.jl b/src/ActiveInference.jl index a977006..2bd37e7 100644 --- a/src/ActiveInference.jl +++ b/src/ActiveInference.jl @@ -5,7 +5,7 @@ include("maths.jl") include("environment.jl") # From functions.jl -export array_of_any, array_of_any_zeros, plot_beliefs, plot_gridworld, plot_likelihood, create_B_matrix, onehot, plot_point_on_grid, infer_states, get_expected_states, get_expected_observations, calculate_G, run_active_inference_loop, construct_policies, calculate_G_policies, compute_prob_actions,active_inference_with_planning, EpistChainEnv, GridWorldEnv +export array_of_any, array_of_any_zeros, array_of_any_uniform, plot_beliefs, plot_gridworld, plot_likelihood, create_B_matrix, onehot, plot_point_on_grid, infer_states, get_expected_states, get_expected_observations, calculate_G, run_active_inference_loop, construct_policies, calculate_G_policies, compute_prob_actions,active_inference_with_planning, EpistChainEnv, GridWorldEnv # From maths.jl module Maths diff --git a/src/functions.jl b/src/functions.jl index d9b714a..2e457ad 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -16,6 +16,16 @@ function array_of_any_zeros(shape_list) return arr end +"""Creates an array of "Any" as a uniform categorical distribution""" +function array_of_any_uniform(shape_list) + arr = Array{Any}(undef, length(shape_list)) + for i in eachindex(shape_list) + shape = shape_list[i] + arr[i] = norm_dist(ones(shape)) + end + return arr +end + """Function for Plotting Beliefs""" function plot_beliefs(belief_dist, title_str="") if abs(sum(belief_dist) - 1.0) > 1e-6