Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic strategy switching #35

Merged
merged 5 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.2.0"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FlameGraphs = "08572546-2f56-4bcf-ba4e-bab62c3a3f89"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand Down
58 changes: 23 additions & 35 deletions examples/jetreco.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,21 @@ function profile_code(jet_reconstruction, events, niters)
""",
)
end
"""
Top level call funtion for demonstrating the use of jet reconstruction

This uses the "generic_jet_reconstruct" wrapper, so algorithm swutching
happens inside the JetReconstruction package itself.

Some other ustilities are also supported here, such as profiling and
serialising the reconstructed jet outputs.
"""
function jet_process(
events::Vector{Vector{PseudoJet}};
ptmin::Float64 = 5.0,
distance::Float64 = 0.4,
power::Integer = -1,
strategy::JetRecoStrategy,
strategy::JetRecoStrategy.Strategy,
nsamples::Integer = 1,
gcoff::Bool = false,
profile::Bool = false,
Expand All @@ -69,22 +77,6 @@ function jet_process(
)
@info "Will process $(size(events)[1]) events"

# Strategy
if (strategy == N2Plain)
jet_reconstruction = plain_jet_reconstruct
elseif (strategy == N2Tiled || stragegy == Best)
jet_reconstruction = tiled_jet_reconstruct
else
throw(ErrorException("Strategy not yet implemented"))
end

# Build internal EDM structures for timing measurements, if needed
# For N2Tiled and N2Plain this is unnecessary as both these reconstruction
# methods can process PseudoJets directly
if (strategy == N2Tiled) || (strategy == N2Plain)
event_vector = events
end

# If we are dumping the results, setup the JSON structure
if !isnothing(dump)
jet_collection = FinalJets[]
Expand All @@ -93,21 +85,21 @@ function jet_process(
# Warmup code if we are doing a multi-sample timing run
if nsamples > 1 || profile
@info "Doing initial warm-up run"
for event in event_vector
finaljets, _ = jet_reconstruction(event, R = distance, p = power)
for event in events
finaljets, _ = generic_jet_reconstruct(event, R = distance, p = power, strategy = strategy)
final_jets(finaljets, ptmin)
end
end

if profile
profile_code(jet_reconstruction, event_vector, nsamples)
profile_code(generic_jet_reconstruct, events, nsamples)
return nothing
end

if alloc
println("Memory allocation statistics:")
@timev for event in event_vector
finaljets, _ = jet_reconstruction(event, R = distance, p = power)
@timev for event in events
finaljets, _ = generic_jet_reconstruct(event, R = distance, p = power, strategy = strategy)
final_jets(finaljets, ptmin)
end
return nothing
Expand All @@ -121,8 +113,8 @@ function jet_process(
for irun ∈ 1:nsamples
gcoff && GC.enable(false)
t_start = time_ns()
for (ievt, event) in enumerate(event_vector)
finaljets, _ = jet_reconstruction(event, R = distance, p = power, ptmin=ptmin)
for (ievt, event) in enumerate(events)
finaljets, _ = generic_jet_reconstruct(event, R = distance, p = power, ptmin=ptmin, strategy = strategy)
fj = final_jets(finaljets, ptmin)
# Only print the jet content once
if irun == 1
Expand Down Expand Up @@ -206,9 +198,9 @@ parse_command_line(args) = begin
default = -1

"--strategy"
help = "Strategy for the algorithm, valid values: Best, N2Plain, N2Tiled, N2TiledSoAGlobal, N2TiledSoATile"
arg_type = JetRecoStrategy
default = N2Plain
help = """Strategy for the algorithm, valid values: $(join(JetReconstruction.AllJetRecoStrategies, ", "))"""
arg_type = JetRecoStrategy.Strategy
default = JetRecoStrategy.Best

"--nsamples", "-m"
help = "Number of measurement points to acquire."
Expand Down Expand Up @@ -246,16 +238,12 @@ parse_command_line(args) = begin
end


function ArgParse.parse_item(::Type{JetRecoStrategy}, x::AbstractString)
if (x == "Best")
return JetRecoStrategy(0)
elseif (x == "N2Plain")
return JetRecoStrategy(1)
elseif (x == "N2Tiled")
return JetRecoStrategy(2)
else
function ArgParse.parse_item(::Type{JetRecoStrategy.Strategy}, x::AbstractString)
s = tryparse(JetRecoStrategy.Strategy, x)
if s === nothing
throw(ErrorException("Invalid value for strategy: $(x)"))
end
s
end

main() = begin
Expand Down
22 changes: 22 additions & 0 deletions src/GenericAlgo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# This is the generic reconstruction algorithm that will
# switch based on the strategy, or based on the event density
# if the "Best" strategy is to be employed

function generic_jet_reconstruct(particles; p = -1, R = 1.0, recombine = +, ptmin = 0.0, strategy = JetRecoStrategy.Best)
# Either map to the fixed algorithm corresponding to the strategy
# or to an optimal choice based on the density of initial particles

if strategy == JetRecoStrategy.Best
# The breakpoint of ~90 is determined empirically on e+e- -> H and 0.5TeV pp -> 5GeV jets
algorithm = length(particles) > 80 ? tiled_jet_reconstruct : plain_jet_reconstruct
elseif strategy == JetRecoStrategy.N2Plain
algorithm = plain_jet_reconstruct
elseif strategy == JetRecoStrategy.N2Tiled
algorithm = tiled_jet_reconstruct
else
throw(ErrorException("Invalid strategy: $(strategy)"))
end

# Now call the chosen algorithm, passing through the other parameters
algorithm(particles; p = p, R = R, recombine = recombine, ptmin = ptmin)
end
14 changes: 14 additions & 0 deletions src/JetRecoStrategies.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using EnumX

# Valid strategy enum (this is a scoped enum)
@enumx T=Strategy JetRecoStrategy Best N2Plain N2Tiled

# Map from string to an enum value (used for CLI parsing)
Base.tryparse(E::Type{<:Enum}, str::String) =
let insts = instances(E) ,
p = findfirst(==(Symbol(str)) ∘ Symbol, insts) ;
p !== nothing ? insts[p] : nothing
end

const AllJetRecoStrategies = [ String(Symbol(x)) for x in instances(JetRecoStrategy.Strategy) ]

16 changes: 9 additions & 7 deletions src/JetReconstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ py(p::LorentzVectorCyl) = LorentzVectorHEP.py(p)
pz(p::LorentzVectorCyl) = LorentzVectorHEP.pz(p)
energy(p::LorentzVectorCyl) = LorentzVectorHEP.energy(p)

# Philipp's pseudojet
# Philipp's pseudojet type
include("Pseudojet.jl")
export PseudoJet

# Simple HepMC3 reader
include("HepMC3.jl")

# Jet reconstruction strategies
include("JetRecoStrategies.jl")
export JetRecoStrategy

## N2Plain algorithm
# Algorithmic part for simple sequential implementation
include("PlainAlgo.jl")
Expand All @@ -37,11 +41,14 @@ export plain_jet_reconstruct
## N2Tiled algorithm
# Common pieces
include("TiledAlgoUtils.jl")

# Algorithmic part, tiled reconstruction strategy with linked list jet objects
include("TiledAlgoLL.jl")
export tiled_jet_reconstruct

## Generic algorithm, which can switch strategy dynamically
include("GenericAlgo.jl")
export generic_jet_reconstruct

# jet serialisation (saving to file)
include("Serialize.jl")
export savejets, loadjets!, loadjets
Expand All @@ -58,9 +65,4 @@ export jetsplot
include("JSONresults.jl")
export FinalJet, FinalJets, JSON3

# Strategy to be used
## Maybe an enum is not the best idea, use type dispatch instead?
@enum JetRecoStrategy Best N2Plain N2Tiled
export JetRecoStrategy, Best, N2Plain, N2Tiled

end
20 changes: 10 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,26 @@ function main()

# Test each stratgy...
for power in keys(algorithms)
do_test_compare_to_fastjet(N2Plain, fastjet_data[power], algname = algorithms[power], power = power)
do_test_compare_to_fastjet(N2Tiled, fastjet_data[power], algname = algorithms[power], power = power)
do_test_compare_to_fastjet(JetRecoStrategy.N2Plain, fastjet_data[power], algname = algorithms[power], power = power)
do_test_compare_to_fastjet(JetRecoStrategy.N2Tiled, fastjet_data[power], algname = algorithms[power], power = power)
end

# Compare inputing data in PseudoJet with using a LorentzVector
do_test_compare_types(N2Plain, algname = algorithms[-1], power = -1)
do_test_compare_types(N2Tiled, algname = algorithms[-1], power = -1)
do_test_compare_types(JetRecoStrategy.N2Plain, algname = algorithms[-1], power = -1)
do_test_compare_types(JetRecoStrategy.N2Tiled, algname = algorithms[-1], power = -1)
end

function do_test_compare_to_fastjet(strategy::JetRecoStrategy, fastjet_jets;
function do_test_compare_to_fastjet(strategy::JetRecoStrategy.Strategy, fastjet_jets;
algname = "Unknown",
ptmin::Float64 = 5.0,
distance::Float64 = 0.4,
power::Integer = -1)

# Strategy
if (strategy == N2Plain)
if (strategy == JetRecoStrategy.N2Plain)
jet_reconstruction = plain_jet_reconstruct
strategy_name = "N2Plain"
elseif (strategy == N2Tiled)
elseif (strategy == JetRecoStrategy.N2Tiled)
jet_reconstruction = tiled_jet_reconstruct
strategy_name = "N2Tiled"
else
Expand Down Expand Up @@ -104,17 +104,17 @@ function do_test_compare_to_fastjet(strategy::JetRecoStrategy, fastjet_jets;
end
end

function do_test_compare_types(strategy::JetRecoStrategy;
function do_test_compare_types(strategy::JetRecoStrategy.Strategy;
algname = "Unknown",
ptmin::Float64 = 5.0,
distance::Float64 = 0.4,
power::Integer = -1)

# Strategy
if (strategy == N2Plain)
if (strategy == JetRecoStrategy.N2Plain)
jet_reconstruction = plain_jet_reconstruct
strategy_name = "N2Plain"
elseif (strategy == N2Tiled)
elseif (strategy == JetRecoStrategy.N2Tiled)
jet_reconstruction = tiled_jet_reconstruct
strategy_name = "N2Tiled"
else
Expand Down
Loading