Skip to content

Commit

Permalink
spawn approach - still fails
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Oct 6, 2023
1 parent 054fd8b commit 1666440
Showing 1 changed file with 82 additions and 69 deletions.
151 changes: 82 additions & 69 deletions src/MacroModelling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2927,86 +2927,99 @@ function write_functions_mapping!(𝓂::ℳ, max_perturbation_order::Int)
end
end

first_order = []
second_order = []
third_order = []
row1 = Int[]
row2 = Int[]
row3 = Int[]
column1 = Int[]
column2 = Int[]
column3 = Int[]
# column3ext = Int[]
i1 = 1
i2 = 1
i3 = 1

sparse_init() = [[], [], [], [], [], [], [], [], []]
Polyester.@batch per=thread threadlocal = sparse_init() for c1 in 1:length(vars)
var1 = vars[c1]
# for (c1,var1) in enumerate(vars)
for (r,eq) in enumerate(eqs)
if Symbol(var1) Symbol.(Symbolics.get_variables(eq))
deriv_first = Symbolics.derivative(eq,var1)
# if deriv_first != 0
# deriv_expr = Meta.parse(string(deriv_first.subs(SPyPyC.PI,SPyPyC.N(SPyPyC.PI))))
# push!(first_order, :($(postwalk(x -> x isa Expr ? x.args[1] == :conjugate ? x.args[2] : x : x, deriv_expr))))
push!(threadlocal[1], Symbolics.toexpr(deriv_first))
push!(threadlocal[2],r)
push!(threadlocal[3],c1)
i1 += 1
if max_perturbation_order >= 2
for (c2,var2) in enumerate(vars)
# if Symbol(var2) ∈ Symbol.(Symbolics.get_variables(deriv_first))
if (((c1 - 1) * length(vars) + c2) second_order_idxs) && (Symbol(var2) Symbol.(Symbolics.get_variables(deriv_first)))
deriv_second = Symbolics.derivative(deriv_first,var2)
# if deriv_second != 0
# deriv_expr = Meta.parse(string(deriv_second.subs(SPyPyC.PI,SPyPyC.N(SPyPyC.PI))))
# push!(second_order, :($(postwalk(x -> x isa Expr ? x.args[1] == :conjugate ? x.args[2] : x : x, deriv_expr))))
push!(threadlocal[4],Symbolics.toexpr(deriv_second))
push!(threadlocal[5],r)
# push!(column2,(c1 - 1) * length(vars) + c2)
push!(threadlocal[6], Int.(indexin([(c1 - 1) * length(vars) + c2], second_order_idxs))...)
i2 += 1
if max_perturbation_order == 3
for (c3,var3) in enumerate(vars)
# if Symbol(var3) ∈ Symbol.(Symbolics.get_variables(deriv_second))
# push!(column3ext,(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3)
if (((c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3) third_order_idxs) && (Symbol(var3) Symbol.(Symbolics.get_variables(deriv_second)))
deriv_third = Symbolics.derivative(deriv_second,var3)
# if deriv_third != 0
# deriv_expr = Meta.parse(string(deriv_third.subs(SPyPyC.PI,SPyPyC.N(SPyPyC.PI))))
# push!(third_order, :($(postwalk(x -> x isa Expr ? x.args[1] == :conjugate ? x.args[2] : x : x, deriv_expr))))
push!(threadlocal[7],Symbolics.toexpr(deriv_third))
push!(threadlocal[8],r)
# push!(column3,(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3)
push!(threadlocal[9], Int.(indexin([(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3], third_order_idxs))...)
i3 += 1
tasks_per_thread = 200 # customize this as needed. More tasks have more overhead, but better
# load balancing

chunk_size = max(1, length(vars) ÷ (tasks_per_thread * Threads.nthreads()))
data_chunks = Iterators.partition(vars, chunk_size) # partition your data into chunks that
# individual tasks will deal with
#See also ChunkSplitters.jl and SplittablesBase.jl for partitioning data

tasks = map(data_chunks) do chunk
# Each chunk of your data gets its own spawned task that does its own local, sequential work
# and then returns the result
Threads.@spawn begin
for var1 in chunk
c1 = Int(indexin(var1,vars)...)

first_order = []
second_order = []
third_order = []

row1 = Int[]
row2 = Int[]
row3 = Int[]

column1 = Int[]
column2 = Int[]
column3 = Int[]


for (r,eq) in enumerate(eqs)
if Symbol(var1) Symbol.(Symbolics.get_variables(eq))
deriv_first = Symbolics.derivative(eq,var1)
push!(first_order, Symbolics.toexpr(deriv_first))
push!(row1,r)
push!(column1,c1)
if max_perturbation_order >= 2
for (c2,var2) in enumerate(vars)
# if Symbol(var2) ∈ Symbol.(Symbolics.get_variables(deriv_first))
if (((c1 - 1) * length(vars) + c2) second_order_idxs) && (Symbol(var2) Symbol.(Symbolics.get_variables(deriv_first)))
deriv_second = Symbolics.derivative(deriv_first,var2)
# if deriv_second != 0
# deriv_expr = Meta.parse(string(deriv_second.subs(SPyPyC.PI,SPyPyC.N(SPyPyC.PI))))
# push!(second_order, :($(postwalk(x -> x isa Expr ? x.args[1] == :conjugate ? x.args[2] : x : x, deriv_expr))))
push!(second_order,Symbolics.toexpr(deriv_second))
push!(row2,r)
# push!(column2,(c1 - 1) * length(vars) + c2)
push!(column2, Int.(indexin([(c1 - 1) * length(vars) + c2], second_order_idxs))...)
if max_perturbation_order == 3
for (c3,var3) in enumerate(vars)
# if Symbol(var3) ∈ Symbol.(Symbolics.get_variables(deriv_second))
# push!(column3ext,(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3)
if (((c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3) third_order_idxs) && (Symbol(var3) Symbol.(Symbolics.get_variables(deriv_second)))
deriv_third = Symbolics.derivative(deriv_second,var3)
# if deriv_third != 0
# deriv_expr = Meta.parse(string(deriv_third.subs(SPyPyC.PI,SPyPyC.N(SPyPyC.PI))))
# push!(third_order, :($(postwalk(x -> x isa Expr ? x.args[1] == :conjugate ? x.args[2] : x : x, deriv_expr))))
push!(third_order,Symbolics.toexpr(deriv_third))
push!(row3,r)
# push!(column3,(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3)
push!(column3, Int.(indexin([(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3], third_order_idxs))...)
# end
end
# end
end
# end
end
end
# end
end
# end
end
end
end
# end
end
# end
end

end

return first_order, second_order, third_order, row1, row2, row3, column1, column2, column3
end
end

first_order = threadlocal[1][1]
row1 = Int.(threadlocal[1][2])
column1 = Int.(threadlocal[1][3])
# println(tasks)

states = fetch.(tasks)

first_order = vcat(states[1]...)
second_order = vcat(states[2]...)
third_order = vcat(states[3]...)

second_order = threadlocal[1][4]
row2 = Int.(threadlocal[1][5])
column2 = Int.(threadlocal[1][6])
row1 = vcat(states[4]...)
row2 = vcat(states[5]...)
row3 = vcat(states[6]...)

third_order = threadlocal[1][7]
row3 = Int.(threadlocal[1][8])
column3 = Int.(threadlocal[1][9])
column1 = vcat(states[7]...)
column2 = vcat(states[8]...)
column3 = vcat(states[9]...)


mod_func3 = :(function model_jacobian(X::Vector, params::Vector{Real}, X̄::Vector)
Expand Down

0 comments on commit 1666440

Please sign in to comment.