diff --git a/PWGJE/Core/JetTaggingUtilities.h b/PWGJE/Core/JetTaggingUtilities.h index 7ed9d0b2db7..6e87e5db48b 100644 --- a/PWGJE/Core/JetTaggingUtilities.h +++ b/PWGJE/Core/JetTaggingUtilities.h @@ -634,7 +634,7 @@ bool isTaggedJetSV(T const jet, U const& /*prongs*/, float const& prongChi2PCAMi * @return The number of vertices (clusters) in the jet. */ template -int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyTracks const&, AnyParticles const& particles, AnyOriginalParticles const&, std::unordered_map>& trkLabels, float vtxResParam = 0.01 /* 0.01cm = 100um */, float trackPtMin = 0.5) +int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyTracks const&, AnyParticles const& particles, AnyOriginalParticles const&, std::unordered_map>& trkLabels, bool searchUpToQuark, float vtxResParam = 0.01 /* 0.01cm = 100um */, float trackPtMin = 0.5) { const auto& tracks = jet.template tracks_as(); const int n_trks = tracks.size(); @@ -791,7 +791,7 @@ int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyT else { const auto &particle = constituent.template mcParticle_as(); - int orig = RecoDecay::getParticleOrigin(particles, particle, true); + int orig = RecoDecay::getParticleOrigin(particles, particle, searchUpToQuark); trkLabels["trkOrigin"].push_back((orig > 0) ? orig : (trkLabels["trkVtxIndex"][trkIdx] == 0) ? 3 : 4); } diff --git a/PWGJE/Tasks/bjetTreeCreator.cxx b/PWGJE/Tasks/bjetTreeCreator.cxx index a415c92ebfb..b1daabc029e 100644 --- a/PWGJE/Tasks/bjetTreeCreator.cxx +++ b/PWGJE/Tasks/bjetTreeCreator.cxx @@ -735,7 +735,7 @@ struct BJetTreeCreator { //+ TrackLabelMap trkLabels{{"trkVtxIndex", {}}, {"trkOrigin", {}}}; - int nVertices = jettaggingutilities::vertexClustering(collision.template mcCollision_as(), analysisJet, allTracks, MCParticles, origParticles, trkLabels, vtxRes, trackPtMin); + int nVertices = jettaggingutilities::vertexClustering(collision.template mcCollision_as(), analysisJet, allTracks, MCParticles, origParticles, trkLabels, true, vtxRes, trackPtMin); analyzeJetTrackInfoForGNN(collision, analysisJet, allTracks, origTracks, tracksIndices, jetFlavor, eventWeight, &trkLabels); registry.fill(HIST("h2_jetMass_jetpT"), analysisJet.pt(), analysisJet.mass(), eventWeight);