Skip to content

Commit

Permalink
Added basic training components
Browse files Browse the repository at this point in the history
  • Loading branch information
simonge committed May 6, 2024
1 parent 03850ad commit 067f01c
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 0 deletions.
5 changes: 5 additions & 0 deletions benchmarks/LOWQ2/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
training:
# Training configuration goes here

analysis:
# Analysis configuration goes here
57 changes: 57 additions & 0 deletions benchmarks/LOWQ2/training/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Snakemake file for training a new neural network for LOW-Q2 tagger electron momentum reconstruction
from itertools import product

import os
import shutil
from snakemake.remote.S3 import RemoteProvider as S3RemoteProvider

S3 = S3RemoteProvider(
endpoint_url="https://eics3.sdcc.bnl.gov:9000",
access_key_id=os.environ["S3_ACCESS_KEY"],
secret_access_key=os.environ["S3_SECRET_KEY"],
)


SIM_DIRECTORY = "/scratch/EIC/G4out/S3/"
RECON_DIRECTORY = "/scratch/EIC/ReconOut/S3/"
MODEL_DIRECTORY = "/scratch/EIC/LowQ2Model/"
REMOTE_DIRECTORY = "eictest/EPIC/EVGEN/SIDIS/pythia6-eic/1.0.0/10x100/q2_0to1/"
FILE_BASE = "pythia_ep_noradcor_10x100_q2_0.000000001_1.0_run"
XML_FILE = "/home/simong/EIC/epic/epic_18x275.xml"

rule download_input:
input:
S3.remote(REMOTE_DIRECTORY+FILE_BASE+"{index}"+EVENT_EXTENSION),
output:
EVENTS_DIRECTORY+FILE_BASE+"{index}"+EVENT_EXTENSION,
run:
shutil.move(input[0], output[0])

rule run_reconstruction:
input:
SIM_DIRECTORY+FILE_BASE+"{index}"+SIM_EXTENSION,
"/home/simong/EIC/EICrecon/bin/eicrecon",
params:
XML=XML_FILE,
collections="LowQ2Tracks,ScatteredElectron",
output:
RECON_DIRECTORY+FILE_BASE+"{index}_reco.{tag}"+SIM_EXTENSION,
shell: """
/home/simong/EIC/EICrecon/bin/eicrecon {input[0]} -Pjana:nevents=400 -Pdd4hep:xml_files={params.XML} -Ppodio:output_include_collections={params.collections} -Ppodio:output_file={output[0]} -PLOWQ2:LowQ2Trajectories:electron_beamE=18
"""

rule low_q2_train_network:
params:
beam_energy="18",
type_name="LowQ2MomentumRegression",
method_name="DNN"
input:
train_data=RECON_DIRECTORY+FILE_BASE+"{index}_reco.{tag}"+SIM_EXTENSION
output:
root_output=MODEL_DIRECTORY+trainedData.root",
model_output=MODEL_DIRECTORY+"weights/"
shell:
"""
root -l -b -q 'TaggerRegressionEICrecon.C+("{input.train_data}", "{output.root_output}", MODEL_DIRECTORY, "{params.beam_energy}", "{params.type_name}", "{params.method_name}")'
"""

152 changes: 152 additions & 0 deletions benchmarks/LOWQ2/training/TaggerRegressionEICrecon.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#include <cstdlib>
#include <iostream>
#include <map>
#include <string>

#include "TChain.h"
#include "TFile.h"
#include "TTree.h"
#include "TString.h"
#include "TObjString.h"
#include "TSystem.h"
#include "TROOT.h"

#include "TMVA/MethodDNN.h"
#include "TMVA/Reader.h"
#include "TMVA/Tools.h"
#include "TMVA/Factory.h"
#include "TMVA/DataLoader.h"
#include "TMVA/TMVARegGui.h"

using namespace TMVA;

// The training currently requires files which only contained a single DIS scattered electron to have been simulated e.g. generated using GETaLM
// The scattered electron must be the 3rd particle in the file after the two beam particles
// At least one track reconstructed by EIC algorithms in the LOWQ2 tagger is needed.

void TaggerRegressionEICrecon(
TString inDataNames = "/scratch/EIC/ReconOut/qr_18x275_ab/qr_18x275_ab*_recon.edm4hep.root",
TString outDataName = "/scratch/EIC/LowQ2Model/trainedData.root",
TString dataFolderName = "/scratch/EIC/LowQ2Model/",
TString mcBeamEnergy = "18",
TString typeName = "LowQ2MomentumRegression",
TString methodName = "DNN",
TString inWeightName = "dataset/weights/LowQ2Reconstruction_DNN.weights.xml"
)
{

Bool_t loadWeights = 0;

//---------------------------------------------------------------
// This loads the library
TMVA::Tools::Instance();

ROOT::EnableImplicitMT(8);

// --------------------------------------------------------------------------------------------------
// Here the preparation phase begins
// Create a new root output file
TFile* outputFile = TFile::Open( outDataName, "RECREATE" );

// Create the factory object. Later you can choose the methods

TMVA::Factory *factory = new TMVA::Factory( typeName, outputFile,
"!V:!Silent:Color:DrawProgressBar:AnalysisType=Regression" );

;
TMVA::DataLoader *dataloader=new TMVA::DataLoader(dataFolderName);

// Input TrackParameters variables from EICrecon -
TString collectionName = "LowQ2Tracks[0]";
dataloader->AddVariable( collectionName+".loc.a", "fit_position_y", "units", 'F' );
dataloader->AddVariable( collectionName+".loc.b", "fit_position_z", "units", 'F' );
dataloader->AddVariable( "sin("+collectionName+".phi)*sin("+collectionName+".theta)", "fit_vector_x", "units", 'F' );
dataloader->AddVariable( "cos("+collectionName+".phi)*sin("+collectionName+".theta)", "fit_vector_y", "units", 'F' );

// Regression target particle 3-momentum, normalised to beam energy.
// Takes second particle, in the test data this is the scattered electron
// TODO add energy and array element information to be read directly from datafile - EMD4eic and EICrecon changes.
TString mcParticleName = "ScatteredElectron[0]";
dataloader->AddTarget( mcParticleName+".momentum.x/"+mcBeamEnergy );
dataloader->AddTarget( mcParticleName+".momentum.y/"+mcBeamEnergy );
dataloader->AddTarget( mcParticleName+".momentum.z/"+mcBeamEnergy );

std::cout << "--- TMVARegression : Using input files: " << inDataNames << std::endl;

// Register the regression tree
TChain* regChain = new TChain("events");
regChain->Add(inDataNames);
//regChain->SetEntries(8000); // Set smaller sample for tests

// global event weights per tree (see below for setting event-wise weights)
Double_t regWeight = 1.0;

// You can add an arbitrary number of regression trees
dataloader->AddRegressionTree( regChain, regWeight );

// This would set individual event weights (the variables defined in the
// expression need to exist in the original TTree)
// dataloader->SetWeightExpression( "1/(eE)", "Regression" ); // If MC event weights are kept use these
// Apply additional cuts on the data
TCut mycut = "@LowQ2Tracks.size()==1"; // Make sure there's one reconstructed track in event

dataloader->PrepareTrainingAndTestTree(mycut,"nTrain_Regression=0:nTest_Regression=0:SplitMode=Random:SplitSeed=1:NormMode=NumEvents:!V");

// TODO - Optimise layout and training more
TString layoutString("Layout=TANH|1024,TANH|128,TANH|64,TANH|32,LINEAR");

TString trainingStrategyString("TrainingStrategy=");
trainingStrategyString +="LearningRate=1e-4,Momentum=0,MaxEpochs=2000,ConvergenceSteps=200,BatchSize=64,TestRepetitions=1,Regularization=None,Optimizer=ADAM";

TString nnOptions("!H:V:ErrorStrategy=SUMOFSQUARES:WeightInitialization=XAVIERUNIFORM:RandomSeed=1234");

// Use GPU if possible on the machine
TString architectureString("Architecture=GPU");
// Transformation of data prior to training layers - decorrelate and normalise whole dataset
TString transformString("VarTransform=D,N");

nnOptions.Append(":");
nnOptions.Append(architectureString);
nnOptions.Append(":");
nnOptions.Append(transformString);
nnOptions.Append(":");
nnOptions.Append(layoutString);
nnOptions.Append(":");
nnOptions.Append(trainingStrategyString);

TMVA::MethodDNN* method = (MethodDNN*)factory->BookMethod(dataloader, TMVA::Types::kDL, methodName, nnOptions); // NN

// If loading previous model for further training
if(loadWeights){
TMVA::Reader *reader = new TMVA::Reader( "!Color:!Silent" );
reader->BookMVA( methodName, inWeightName );
TMVA::MethodDNN* kl = dynamic_cast<TMVA::MethodDNN*>(reader->FindMVA(methodName));
method = kl;
}


// --------------------------------------------------------------------------------------------------
// Now you can tell the factory to train, test, and evaluate the MVAs

// Train MVAs using the set of training events
factory->TrainAllMethods();

// Evaluate all MVAs using the set of test events
factory->TestAllMethods();

// Evaluate and compare performance of all configured MVAs
factory->EvaluateAllMethods();

// --------------------------------------------------------------

// Save the output
outputFile->Close();

std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
std::cout << "==> TMVARegression is done!" << std::endl;

delete factory;
delete dataloader;

}

0 comments on commit 067f01c

Please sign in to comment.