Skip to content

Commit

Permalink
Add an extra table and vertexClustering function for GNN b-jet tagging
Browse files Browse the repository at this point in the history
  • Loading branch information
chchoi committed Nov 19, 2024
1 parent 9d72c0c commit f0013a2
Show file tree
Hide file tree
Showing 2 changed files with 398 additions and 1 deletion.
176 changes: 176 additions & 0 deletions PWGJE/Core/JetTaggingUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,182 @@ bool isTaggedJetSV(T const jet, U const& /*prongs*/, float const& prongChi2PCAMi
return true;
}

/**
* Clusters jet constituent tracks into groups of tracks originating from same mcParticle position (trkVtxIndex), and finds each track origin (trkOrigin). (for GNN b-jet tagging)
* @param trkLabels Track labels for GNN vertex and track origin predictions. trkVtxIndex: The index value of each vertex (cluster) which is determined by the function. trkOrigin: The category of the track origin (0: not physical primary, 1: charm, 2: beauty, 3: primary vertex, 4: other secondary vertex).
* @param vtxResParam Vertex resolution parameter which determines the cluster size. (cm)
* @param trackPtMin Minimum value of track pT.
* @return The number of vertices (clusters) in the jet.
*/
template <typename AnyCollision, typename AnalysisJet, typename AnyTracks, typename AnyParticles, typename AnyOriginalParticles>
int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyTracks const&, AnyParticles const& particles, AnyOriginalParticles const&, std::unordered_map<std::string, std::vector<int>>& trkLabels, float vtxResParam = 0.01 /* 0.01cm = 100um */, float trackPtMin = 0.5)
{
const auto& tracks = jet.template tracks_as<AnyTracks>();
const int n_trks = tracks.size();

// trkVtxIndex

std::vector<int> tempTrkVtxIndex;

int i=0;
for (const auto& constituent : tracks) {
if (!constituent.has_mcParticle() || !constituent.template mcParticle_as<AnyParticles>().isPhysicalPrimary() || constituent.pt() < trackPtMin)
tempTrkVtxIndex.push_back(-1);
else
tempTrkVtxIndex.push_back(i++);
}
tempTrkVtxIndex.push_back(i); // temporary index for PV
if (n_trks < 1) { // the process should be done for n_trks == 1 as well
trkLabels["trkVtxIndex"] = tempTrkVtxIndex;
return n_trks;
}

int n_pos = n_trks + 1;
std::vector<float> dists(n_pos * (n_pos - 1) / 2);
auto trk_pair_idx = [n_pos](int ti, int tj) {
if (ti==tj || ti>=n_pos || tj>=n_pos || ti<0 || tj<0) {
LOGF(info,"Track pair index out of range");
return -1;
}
else
return (ti < tj) ? (ti * n_pos - (ti * (ti + 1)) / 2 + tj - ti - 1) : (tj * n_pos - (tj * (tj + 1)) / 2 + ti - tj - 1);
}; // index n_trks is for PV

for (int ti=0; ti<n_pos-1; ti++)
for (int tj=ti+1; tj<n_pos; tj++) {
std::array<float, 3> posi, posj;

if (tj < n_trks) {
if (tracks[tj].has_mcParticle()) {
const auto& pj = tracks[tj].template mcParticle_as<AnyParticles>().template mcParticle_as<AnyOriginalParticles>();
posj = std::array<float, 3>{pj.vx(), pj.vy(), pj.vz()};
}
else {
dists[trk_pair_idx(ti, tj)] = std::numeric_limits<float>::max();
continue;
}
}
else {
posj = std::array<float, 3>{collision.posX(), collision.posY(), collision.posZ()};
}

if (tracks[ti].has_mcParticle()) {
const auto& pi = tracks[ti].template mcParticle_as<AnyParticles>().template mcParticle_as<AnyOriginalParticles>();
posi = std::array<float, 3>{pi.vx(), pi.vy(), pi.vz()};
}
else {
dists[trk_pair_idx(ti, tj)] = std::numeric_limits<float>::max();
continue;
}

dists[trk_pair_idx(ti, tj)] = RecoDecay::distance(posi, posj);
}

int clusteri = -1, clusterj = -1;
float min_min_dist = -1.f; // If there is an not-merge-able min_dist pair, check the 2nd-min_dist pair.
while (true) {

float min_dist = -1.f; // Get min_dist pair
for (int ti=0; ti<n_pos-1; ti++)
for (int tj=ti+1; tj<n_pos; tj++)
if (tempTrkVtxIndex[ti] != tempTrkVtxIndex[tj] && tempTrkVtxIndex[ti]>=0 && tempTrkVtxIndex[tj]>=0) {
float dist = dists[trk_pair_idx(ti, tj)];
if ((dist < min_dist || min_dist < 0.f) && dist > min_min_dist) {
min_dist = dist;
clusteri = ti;
clusterj = tj;
}
}
if (clusteri < 0 || clusterj < 0)
break;

bool mrg = true; // Merge-ability check
for (int ti=0; ti<n_pos && mrg; ti++)
if (tempTrkVtxIndex[ti] == tempTrkVtxIndex[clusteri] && tempTrkVtxIndex[ti]>=0)
for (int tj=0; tj<n_pos && mrg; tj++)
if (tj != ti && tempTrkVtxIndex[tj] == tempTrkVtxIndex[clusterj] && tempTrkVtxIndex[tj]>=0)
if (dists[trk_pair_idx(ti, tj)] > vtxResParam) { // If there is more distant pair compared to vtx_res between two clusters, they cannot be merged.
mrg = false;
min_min_dist = min_dist;
}
if (min_dist > vtxResParam || min_dist < 0.f)
break;

if (mrg) { // Merge two clusters
int old_index = tempTrkVtxIndex[clusterj];
for (int t=0; t<n_pos; t++)
if (tempTrkVtxIndex[t] == old_index)
tempTrkVtxIndex[t] = tempTrkVtxIndex[clusteri];
}
}

int n_vertices = 0;

// Sort the indices from PV (as 0) to the most distant SV (as 1~).
int idxPV = tempTrkVtxIndex[n_trks];
for (int t=0; t<n_trks; t++)
if (tempTrkVtxIndex[t] == idxPV) {
tempTrkVtxIndex[t] = -2;
n_vertices = 1; // There is a track originating from PV
}

std::unordered_map<int, float> avgDistances;
std::unordered_map<int, int> count;
for (int t=0; t<n_trks; t++) {
if (tempTrkVtxIndex[t] >= 0) {
avgDistances[tempTrkVtxIndex[t]] += dists[trk_pair_idx(t, n_trks)];
count[tempTrkVtxIndex[t]]++;
}
}

trkLabels["trkVtxIndex"] = std::vector<int>(n_trks, -1);
if (count.size() != 0) { // If there is any SV cluster not only PV cluster
for (auto& [idx, avgDistance] : avgDistances)
avgDistance /= count[idx];

n_vertices += avgDistances.size();

std::vector<std::pair<int, float>> sortedIndices(avgDistances.begin(), avgDistances.end());
std::sort(sortedIndices.begin(), sortedIndices.end(), [](const auto& a, const auto& b) { return a.second < b.second; });
int rank = 1;
for (const auto& [idx, avgDistance] : sortedIndices) {
bool found = false;
for (int t=0; t<n_trks; t++)
if (tempTrkVtxIndex[t] == idx) {
trkLabels["trkVtxIndex"][t] = rank;
found = true;
}
rank += found;
}
}

for (int t=0; t<n_trks; t++)
if (tempTrkVtxIndex[t] == -2)
trkLabels["trkVtxIndex"][t] = 0;

// trkOrigin

int trkIdx = 0;
for (auto &constituent : jet.template tracks_as<AnyTracks>())
{
if (!constituent.has_mcParticle() || !constituent.template mcParticle_as<AnyParticles>().isPhysicalPrimary() || constituent.pt() < trackPtMin)
{
trkLabels["trkOrigin"].push_back(0);
}
else
{
const auto &particle = constituent.template mcParticle_as<AnyParticles>();
int orig = RecoDecay::getParticleOrigin(particles, particle, true);
trkLabels["trkOrigin"].push_back((orig > 0) ? orig :
(trkLabels["trkVtxIndex"][trkIdx] == 0) ? 3 : 4);
}

trkIdx++;
}

return n_vertices;
}

}; // namespace jettaggingutilities

#endif // PWGJE_CORE_JETTAGGINGUTILITIES_H_
Loading

0 comments on commit f0013a2

Please sign in to comment.