Skip to content

Commit

Permalink
Fitter interface (#226)
Browse files Browse the repository at this point in the history
* Add FitterInputHandler interface

* Example use of FitterInputHandler using fitcentroid

* Rename GetNHits to GetNPEs, add new GetNHits for event

---------

Co-authored-by: James Shen <[email protected]>
  • Loading branch information
JamesJieranShen and James Shen authored Jan 24, 2025
1 parent b1d0fa8 commit 820cb3c
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 14 deletions.
7 changes: 7 additions & 0 deletions ratdb/FITTER.ratdb
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
{
"name": "FIT_COMMON",
"index": "",
"mode": 0, // 0 = use DS::PMT; 1 = use DS::DigitPMT DigitTime/Charge; 2 = use WaveformAnalysisResult
"waveform_analyzer": "Lognormal", // only required if mode = 2
}

{
name: "Fitter",
index: "FitTensor",
Expand Down
3 changes: 3 additions & 0 deletions src/ds/include/RAT/DS/WaveformAnalysisResult.hh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class WaveformAnalysisResult : public TObject {
virtual Double_t getCharge(size_t idx) { return charges.at(idx); }
virtual Double_t getFOM(std::string key, size_t idx) { return figures_of_merit.at(key).at(idx); }
virtual int getNhits() { return times.size(); }

virtual const std::vector<Double_t>& getTimes() { return times; }
virtual const std::vector<Double_t>& getCharges() { return charges; }
ClassDef(WaveformAnalysisResult, 1);

protected:
Expand Down
8 changes: 5 additions & 3 deletions src/fit/include/RAT/FitCentroidProc.hh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef __RAT_FitCentroidProc__
#define __RAT_FitCentroidProc__

#include <RAT/FitterInputHandler.hh>
#include <RAT/Processor.hh>
#include <string>

Expand All @@ -13,7 +14,7 @@ class EV;

class FitCentroidProc : public Processor {
public:
FitCentroidProc();
FitCentroidProc() : Processor("fitcentroid"), inputHandler(){};
virtual ~FitCentroidProc() {}

/** param = "power", value = exponent to raise charge to when averaging
Expand All @@ -23,8 +24,9 @@ class FitCentroidProc : public Processor {
virtual Processor::Result Event(DS::Root *ds, DS::EV *ev);

protected:
double fPower;
double fRescale;
double fPower = 2.0;
double fRescale = 1.0;
FitterInputHandler inputHandler;
};

} // namespace RAT
Expand Down
223 changes: 223 additions & 0 deletions src/fit/include/RAT/FitterInputHandler.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
/** @class FitterInputHandler
* Interface for handling the input stream of a standard fitter.
*
* @author James Shen <[email protected]>
*
* By default, the handler gets configured to one stream of input. This "steram"
* can either be the DS::PMTs, DS::DigitPMTs, waveform analysis results, etc.
*
* */

#ifndef __RAT_FitterInputHandler__
#define __RAT_FitterInputHandler__

#include <RAT/DB.hh>
#include <RAT/DS/EV.hh>
#include <RAT/Log.hh>
#include <algorithm>
#include <string>

#include "RAT/DS/DigitPMT.hh"
#include "RAT/DS/PMT.hh"

namespace RAT {

class FitterInputHandler {
public:
enum class Mode {
kPMT = 0,
kDigitPMT = 1,
kWaveformAnalysis = 2,
};

Mode mode;
std::string wfm_ana_name;

/**
* Default constructor. Configures the input based on the FIT_COMMON entry.
* */
FitterInputHandler() : FitterInputHandler(""){};
FitterInputHandler(const std::string& index) { Configure(index); };

/**
* Configures the class based on FIT_COMMON[index]
* @param index ratdb index to configure the class with.
* */
void Configure(const std::string& index) {
DBLinkPtr tbl = DB::Get()->GetLink("FIT_COMMON", index);
mode = static_cast<Mode>(tbl->GetI("mode"));
if (mode == Mode::kWaveformAnalysis) wfm_ana_name = tbl->GetS("waveform_analyzer");
}

/**
* Register an event to the input handler, so that we know to return the hits from this event.
* @param _ev event to register.
* */
void RegisterEvent(DS::EV* _ev) {
ev = _ev;
hitPMTChannels.clear();
// NOTE: class implementation assumes GetAllPMTIDs and GetAllDigitPMTIDs returns _sorted_ results.
// This is true since ev->digitpmt and ev->pmt are both std::maps.
hitPMTChannels = mode == Mode::kPMT ? ev->GetAllPMTIDs() : ev->GetAllDigitPMTIDs();
}

/**
* @brief Get PMTIDs for all pmts in the event.
* PMT will not be in the list if it never created a hit on DS::PMT or if the digitized waveform never crossed
* threshold.
*
* @return vector of all PMTs in event.
*/
const std::vector<Int_t>& GetAllHitPMTIDs() {
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
return hitPMTChannels;
}

/**
* @brief Get number of hit channels in the event.
*
*/
size_t GetNHits() {
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
return hitPMTChannels.size();
}

/**
* @brief Get the charge of a pmt.
* In the case where a waveoform analyzer created multiple hits on the PMT (multi-PE), this method only returns
* information about this first hit. To get information about all the hits, use getCharges.
*
* @param id PMT ID.
* @return charge of the first hit.
*/
double GetCharge(Int_t id) {
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
if (!std::binary_search(hitPMTChannels.begin(), hitPMTChannels.end(), id))
Log::Die("FitterInputHandler: Trying to access a channel with no hit registered!");

switch (mode) {
case Mode::kPMT:
return ev->GetOrCreatePMT(id)->GetCharge();
case Mode::kDigitPMT:
return ev->GetOrCreateDigitPMT(id)->GetDigitizedCharge();
case Mode::kWaveformAnalysis: {
DS::DigitPMT* digitpmt = ev->GetOrCreateDigitPMT(id);
std::vector<std::string> fitterNames = digitpmt->GetFitterNames();
if (std::find(fitterNames.begin(), fitterNames.end(), wfm_ana_name) == fitterNames.end()) {
info << "FitResult not found for pmt id " << id << " " << wfm_ana_name << newline;
}
return digitpmt->GetOrCreateWaveformAnalysisResult(wfm_ana_name)->getCharge(0);
}
default:
Log::Die("INVALID TYPE! Should never reach here.");
}
}

/**
* @brief Get the charge of all hits registered on a PMT.
* To only get information about the first hit (assume SPE), see getCharge.
*
* @param id PMT ID.
* @return vector of the charges registered on all hits on the PMT.
*/
std::vector<double> GetCharges(Int_t id) {
if (mode != Mode::kWaveformAnalysis) return std::vector<double>{GetCharge(id)};
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
if (!std::binary_search(hitPMTChannels.begin(), hitPMTChannels.end(), id))
Log::Die("FitterInputHandler: Trying to access a channel with no hit registered!");
DS::DigitPMT* digitpmt = ev->GetOrCreateDigitPMT(id);
std::vector<std::string> fitterNames = digitpmt->GetFitterNames();
if (std::find(fitterNames.begin(), fitterNames.end(), wfm_ana_name) == fitterNames.end()) {
info << "FitResult not found for pmt id " << id << " " << wfm_ana_name << newline;
}
return digitpmt->GetOrCreateWaveformAnalysisResult(wfm_ana_name)->getCharges();
}

/**
* @brief Get the time of a pmt hit.
* In the case where a waveoform analyzer created multiple hits on the PMT (multi-PE), this method only returns
* information about this first hit. To get information about all the hits, use getTimes.
* @param id PMT ID.
*
* @return time of the first hit.
*/
double GetTime(Int_t id) {
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
if (!std::binary_search(hitPMTChannels.begin(), hitPMTChannels.end(), id))
Log::Die("FitterInputHandler: Trying to access a channel with no hit registered!");

switch (mode) {
case Mode::kPMT:
return ev->GetOrCreatePMT(id)->GetTime();
case Mode::kDigitPMT:
return ev->GetOrCreateDigitPMT(id)->GetDigitizedTime();
case Mode::kWaveformAnalysis: {
DS::DigitPMT* digitpmt = ev->GetOrCreateDigitPMT(id);
std::vector<std::string> fitterNames = digitpmt->GetFitterNames();
if (std::find(fitterNames.begin(), fitterNames.end(), wfm_ana_name) == fitterNames.end()) {
info << "FitResult not found for pmt id " << id << " " << wfm_ana_name << newline;
}
return digitpmt->GetOrCreateWaveformAnalysisResult(wfm_ana_name)->getTime(0);
}
default:
Log::Die("INVALID TYPE! Should never reach here.");
}
}

/**
* @brief Get the time of all hits registered on a PMT.
* To only get information about the first hit (assume SPE), see getTime.
*
* @param id PMT ID.
* @return vector of the times registered on all hits on the PMT.
*/
std::vector<double> GetTimes(Int_t id) {
if (mode != Mode::kWaveformAnalysis) return std::vector<double>{GetTime(id)};
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
if (!std::binary_search(hitPMTChannels.begin(), hitPMTChannels.end(), id))
Log::Die("FitterInputHandler: Trying to access a channel with no hit registered!");
DS::DigitPMT* digitpmt = ev->GetOrCreateDigitPMT(id);
std::vector<std::string> fitterNames = digitpmt->GetFitterNames();
if (std::find(fitterNames.begin(), fitterNames.end(), wfm_ana_name) == fitterNames.end()) {
info << "FitResult not found for pmt id " << id << " " << wfm_ana_name << newline;
}
return digitpmt->GetOrCreateWaveformAnalysisResult(wfm_ana_name)->getTimes();
}

/**
* @brief Return the (approximate) number of hits registered on a PMT.
* Behavior is different depending on the mode.
* If mode is set to kPMT, always return 1 since no information about nhit is given.
* If mode is set to kDigitPMT, return the number of times that the waveform crosses threshold.
* If mode is set to kWaveformAnalysis, return the number of hits created by the analyzer.
*
* @param id PMT ID.
*/
unsigned int GetNPEs(Int_t id) {
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
if (!std::binary_search(hitPMTChannels.begin(), hitPMTChannels.end(), id))
Log::Die("FitterInputHandler: Trying to access a channel with no hit registered!");

switch (mode) {
case Mode::kPMT:
return 1; // no nhit information
case Mode::kDigitPMT:
return ev->GetOrCreateDigitPMT(id)->GetNCrossings(); // approximate
case Mode::kWaveformAnalysis:
DS::DigitPMT* digitpmt = ev->GetOrCreateDigitPMT(id);
std::vector<std::string> fitterNames = digitpmt->GetFitterNames();
if (std::find(fitterNames.begin(), fitterNames.end(), wfm_ana_name) == fitterNames.end()) {
info << "FitResult not found for pmt id " << id << " " << wfm_ana_name << newline;
}
return digitpmt->GetOrCreateWaveformAnalysisResult(wfm_ana_name)->getNhits();
}
}

protected:
DS::EV* ev = nullptr;
std::vector<Int_t> hitPMTChannels;
};

} // namespace RAT

#endif
15 changes: 4 additions & 11 deletions src/fit/src/FitCentroidProc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@
#include <string>

namespace RAT {

FitCentroidProc::FitCentroidProc() : Processor("fitcentroid") {
fPower = 2.0;
fRescale = 1.0;
}

void FitCentroidProc::SetD(std::string param, double value) {
if (param == "power") {
fPower = value;
Expand All @@ -29,19 +23,18 @@ void FitCentroidProc::SetD(std::string param, double value) {
}

Processor::Result FitCentroidProc::Event(DS::Root *ds, DS::EV *ev) {
inputHandler.RegisterEvent(ev);
double totalQ = 0;
TVector3 centroid(0.0, 0.0, 0.0);

for (int i : ev->GetAllPMTIDs()) {
DS::PMT *pmt = ev->GetOrCreatePMT(i);

for (int pmtid : inputHandler.GetAllHitPMTIDs()) {
double Qpow = 0.0;
Qpow = pow(pmt->GetCharge(), fPower);
Qpow = pow(inputHandler.GetCharge(pmtid), fPower);
totalQ += Qpow;

DS::Run *run = DS::RunStore::Get()->GetRun(ds);
DS::PMTInfo *pmtinfo = run->GetPMTInfo();
TVector3 pmtpos = pmtinfo->GetPosition(pmt->GetID());
TVector3 pmtpos = pmtinfo->GetPosition(pmtid);

if (fRescale != 1.0) {
pmtpos.SetMag(pmtpos.Mag() * fRescale);
Expand Down

0 comments on commit 820cb3c

Please sign in to comment.