Skip to content

Commit

Permalink
Added function: uniform--
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelnehrer02 committed Dec 26, 2023
1 parent dfaa280 commit c6bec80
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/ActiveInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ include("maths.jl")
include("environment.jl")

# From functions.jl
export array_of_any, array_of_any_zeros, plot_beliefs, plot_gridworld, plot_likelihood, create_B_matrix, onehot, plot_point_on_grid, infer_states, get_expected_states, get_expected_observations, calculate_G, run_active_inference_loop, construct_policies, calculate_G_policies, compute_prob_actions,active_inference_with_planning, EpistChainEnv, GridWorldEnv
export array_of_any, array_of_any_zeros, array_of_any_uniform, plot_beliefs, plot_gridworld, plot_likelihood, create_B_matrix, onehot, plot_point_on_grid, infer_states, get_expected_states, get_expected_observations, calculate_G, run_active_inference_loop, construct_policies, calculate_G_policies, compute_prob_actions,active_inference_with_planning, EpistChainEnv, GridWorldEnv

# From maths.jl
module Maths
Expand Down
10 changes: 10 additions & 0 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ function array_of_any_zeros(shape_list)
return arr
end

"""Creates an array of "Any" as a uniform categorical distribution"""
function array_of_any_uniform(shape_list)
arr = Array{Any}(undef, length(shape_list))
for i in eachindex(shape_list)
shape = shape_list[i]
arr[i] = norm_dist(ones(shape))
end
return arr
end

"""Function for Plotting Beliefs"""
function plot_beliefs(belief_dist, title_str="")
if abs(sum(belief_dist) - 1.0) > 1e-6
Expand Down

0 comments on commit c6bec80

Please sign in to comment.