From 1666440b94b799436922c26497fbcc978f43f530 Mon Sep 17 00:00:00 2001 From: thorek1 Date: Fri, 6 Oct 2023 19:56:53 +0200 Subject: [PATCH] spawn approach - still fails --- src/MacroModelling.jl | 151 +++++++++++++++++++++++------------------- 1 file changed, 82 insertions(+), 69 deletions(-) diff --git a/src/MacroModelling.jl b/src/MacroModelling.jl index aed79f3d..c626e027 100644 --- a/src/MacroModelling.jl +++ b/src/MacroModelling.jl @@ -2927,86 +2927,99 @@ function write_functions_mapping!(๐“‚::โ„ณ, max_perturbation_order::Int) end end - first_order = [] - second_order = [] - third_order = [] - row1 = Int[] - row2 = Int[] - row3 = Int[] - column1 = Int[] - column2 = Int[] - column3 = Int[] - # column3ext = Int[] - i1 = 1 - i2 = 1 - i3 = 1 - - sparse_init() = [[], [], [], [], [], [], [], [], []] - Polyester.@batch per=thread threadlocal = sparse_init() for c1 in 1:length(vars) - var1 = vars[c1] - # for (c1,var1) in enumerate(vars) - for (r,eq) in enumerate(eqs) - if Symbol(var1) โˆˆ Symbol.(Symbolics.get_variables(eq)) - deriv_first = Symbolics.derivative(eq,var1) - # if deriv_first != 0 - # deriv_expr = Meta.parse(string(deriv_first.subs(SPyPyC.PI,SPyPyC.N(SPyPyC.PI)))) - # push!(first_order, :($(postwalk(x -> x isa Expr ? x.args[1] == :conjugate ? x.args[2] : x : x, deriv_expr)))) - push!(threadlocal[1], Symbolics.toexpr(deriv_first)) - push!(threadlocal[2],r) - push!(threadlocal[3],c1) - i1 += 1 - if max_perturbation_order >= 2 - for (c2,var2) in enumerate(vars) - # if Symbol(var2) โˆˆ Symbol.(Symbolics.get_variables(deriv_first)) - if (((c1 - 1) * length(vars) + c2) โˆˆ second_order_idxs) && (Symbol(var2) โˆˆ Symbol.(Symbolics.get_variables(deriv_first))) - deriv_second = Symbolics.derivative(deriv_first,var2) - # if deriv_second != 0 - # deriv_expr = Meta.parse(string(deriv_second.subs(SPyPyC.PI,SPyPyC.N(SPyPyC.PI)))) - # push!(second_order, :($(postwalk(x -> x isa Expr ? x.args[1] == :conjugate ? x.args[2] : x : x, deriv_expr)))) - push!(threadlocal[4],Symbolics.toexpr(deriv_second)) - push!(threadlocal[5],r) - # push!(column2,(c1 - 1) * length(vars) + c2) - push!(threadlocal[6], Int.(indexin([(c1 - 1) * length(vars) + c2], second_order_idxs))...) - i2 += 1 - if max_perturbation_order == 3 - for (c3,var3) in enumerate(vars) - # if Symbol(var3) โˆˆ Symbol.(Symbolics.get_variables(deriv_second)) - # push!(column3ext,(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3) - if (((c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3) โˆˆ third_order_idxs) && (Symbol(var3) โˆˆ Symbol.(Symbolics.get_variables(deriv_second))) - deriv_third = Symbolics.derivative(deriv_second,var3) - # if deriv_third != 0 - # deriv_expr = Meta.parse(string(deriv_third.subs(SPyPyC.PI,SPyPyC.N(SPyPyC.PI)))) - # push!(third_order, :($(postwalk(x -> x isa Expr ? x.args[1] == :conjugate ? x.args[2] : x : x, deriv_expr)))) - push!(threadlocal[7],Symbolics.toexpr(deriv_third)) - push!(threadlocal[8],r) - # push!(column3,(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3) - push!(threadlocal[9], Int.(indexin([(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3], third_order_idxs))...) - i3 += 1 + tasks_per_thread = 200 # customize this as needed. More tasks have more overhead, but better + # load balancing + + chunk_size = max(1, length(vars) รท (tasks_per_thread * Threads.nthreads())) + data_chunks = Iterators.partition(vars, chunk_size) # partition your data into chunks that + # individual tasks will deal with + #See also ChunkSplitters.jl and SplittablesBase.jl for partitioning data + + tasks = map(data_chunks) do chunk + # Each chunk of your data gets its own spawned task that does its own local, sequential work + # and then returns the result + Threads.@spawn begin + for var1 in chunk + c1 = Int(indexin(var1,vars)...) + + first_order = [] + second_order = [] + third_order = [] + + row1 = Int[] + row2 = Int[] + row3 = Int[] + + column1 = Int[] + column2 = Int[] + column3 = Int[] + + + for (r,eq) in enumerate(eqs) + if Symbol(var1) โˆˆ Symbol.(Symbolics.get_variables(eq)) + deriv_first = Symbolics.derivative(eq,var1) + push!(first_order, Symbolics.toexpr(deriv_first)) + push!(row1,r) + push!(column1,c1) + if max_perturbation_order >= 2 + for (c2,var2) in enumerate(vars) + # if Symbol(var2) โˆˆ Symbol.(Symbolics.get_variables(deriv_first)) + if (((c1 - 1) * length(vars) + c2) โˆˆ second_order_idxs) && (Symbol(var2) โˆˆ Symbol.(Symbolics.get_variables(deriv_first))) + deriv_second = Symbolics.derivative(deriv_first,var2) + # if deriv_second != 0 + # deriv_expr = Meta.parse(string(deriv_second.subs(SPyPyC.PI,SPyPyC.N(SPyPyC.PI)))) + # push!(second_order, :($(postwalk(x -> x isa Expr ? x.args[1] == :conjugate ? x.args[2] : x : x, deriv_expr)))) + push!(second_order,Symbolics.toexpr(deriv_second)) + push!(row2,r) + # push!(column2,(c1 - 1) * length(vars) + c2) + push!(column2, Int.(indexin([(c1 - 1) * length(vars) + c2], second_order_idxs))...) + if max_perturbation_order == 3 + for (c3,var3) in enumerate(vars) + # if Symbol(var3) โˆˆ Symbol.(Symbolics.get_variables(deriv_second)) + # push!(column3ext,(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3) + if (((c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3) โˆˆ third_order_idxs) && (Symbol(var3) โˆˆ Symbol.(Symbolics.get_variables(deriv_second))) + deriv_third = Symbolics.derivative(deriv_second,var3) + # if deriv_third != 0 + # deriv_expr = Meta.parse(string(deriv_third.subs(SPyPyC.PI,SPyPyC.N(SPyPyC.PI)))) + # push!(third_order, :($(postwalk(x -> x isa Expr ? x.args[1] == :conjugate ? x.args[2] : x : x, deriv_expr)))) + push!(third_order,Symbolics.toexpr(deriv_third)) + push!(row3,r) + # push!(column3,(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3) + push!(column3, Int.(indexin([(c1 - 1) * length(vars)^2 + (c2 - 1) * length(vars) + c3], third_order_idxs))...) + # end + end # end end - # end - end + end + # end end - # end + end end - end + # end end - # end + end + end + + return first_order, second_order, third_order, row1, row2, row3, column1, column2, column3 end end - first_order = threadlocal[1][1] - row1 = Int.(threadlocal[1][2]) - column1 = Int.(threadlocal[1][3]) + # println(tasks) + + states = fetch.(tasks) + + first_order = vcat(states[1]...) + second_order = vcat(states[2]...) + third_order = vcat(states[3]...) - second_order = threadlocal[1][4] - row2 = Int.(threadlocal[1][5]) - column2 = Int.(threadlocal[1][6]) + row1 = vcat(states[4]...) + row2 = vcat(states[5]...) + row3 = vcat(states[6]...) - third_order = threadlocal[1][7] - row3 = Int.(threadlocal[1][8]) - column3 = Int.(threadlocal[1][9]) + column1 = vcat(states[7]...) + column2 = vcat(states[8]...) + column3 = vcat(states[9]...) mod_func3 = :(function model_jacobian(X::Vector, params::Vector{Real}, Xฬ„::Vector)