Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fitter interface #226

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ratdb/FITTER.ratdb
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
{
"name": "FIT_COMMON",
"index": "",
"mode": 0, // 0 = use DS::PMT; 1 = use DS::DigitPMT DigitTime/Charge; 2 = use WaveformAnalysisResult
"waveform_analyzer": "Lognormal", // only required if mode = 2
}

{
name: "Fitter",
index: "FitTensor",
Expand Down
3 changes: 3 additions & 0 deletions src/ds/include/RAT/DS/WaveformAnalysisResult.hh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class WaveformAnalysisResult : public TObject {
virtual Double_t getCharge(size_t idx) { return charges.at(idx); }
virtual Double_t getFOM(std::string key, size_t idx) { return figures_of_merit.at(key).at(idx); }
virtual int getNhits() { return times.size(); }

virtual const std::vector<Double_t>& getTimes() { return times; }
virtual const std::vector<Double_t>& getCharges() { return charges; }
ClassDef(WaveformAnalysisResult, 1);

protected:
Expand Down
8 changes: 5 additions & 3 deletions src/fit/include/RAT/FitCentroidProc.hh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef __RAT_FitCentroidProc__
#define __RAT_FitCentroidProc__

#include <RAT/FitterInputHandler.hh>
#include <RAT/Processor.hh>
#include <string>

Expand All @@ -13,7 +14,7 @@ class EV;

class FitCentroidProc : public Processor {
public:
FitCentroidProc();
FitCentroidProc() : Processor("fitcentroid"), inputHandler(){};
virtual ~FitCentroidProc() {}

/** param = "power", value = exponent to raise charge to when averaging
Expand All @@ -23,8 +24,9 @@ class FitCentroidProc : public Processor {
virtual Processor::Result Event(DS::Root *ds, DS::EV *ev);

protected:
double fPower;
double fRescale;
double fPower = 2.0;
double fRescale = 1.0;
FitterInputHandler inputHandler;
};

} // namespace RAT
Expand Down
214 changes: 214 additions & 0 deletions src/fit/include/RAT/FitterInputHandler.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/** @class FitterInputHandler
* Interface for handling the input stream of a standard fitter.
*
* @author James Shen <[email protected]>
*
* By default, the handler gets configured to one stream of input. This "steram"
* can either be the DS::PMTs, DS::DigitPMTs, waveform analysis results, etc.
*
* */

#ifndef __RAT_FitterInputHandler__
#define __RAT_FitterInputHandler__

#include <RAT/DB.hh>
#include <RAT/DS/EV.hh>
#include <RAT/Log.hh>
#include <algorithm>
#include <string>

#include "RAT/DS/DigitPMT.hh"
#include "RAT/DS/PMT.hh"

namespace RAT {

class FitterInputHandler {
public:
enum class Mode {
kPMT = 0,
kDigitPMT = 1,
kWaveformAnalysis = 2,
};

Mode mode;
std::string wfm_ana_name;

/**
* Default constructor. Configures the input based on the FIT_COMMON entry.
* */
FitterInputHandler() : FitterInputHandler(""){};
FitterInputHandler(const std::string& index) { Configure(index); };

/**
* Configures the class based on FIT_COMMON[index]
* @param index ratdb index to configure the class with.
* */
void Configure(const std::string& index) {
DBLinkPtr tbl = DB::Get()->GetLink("FIT_COMMON", index);
mode = static_cast<Mode>(tbl->GetI("mode"));
if (mode == Mode::kWaveformAnalysis) wfm_ana_name = tbl->GetS("waveform_analyzer");
}

/**
* Register an event to the input handler, so that we know to return the hits from this event.
* @param _ev event to register.
* */
void RegisterEvent(DS::EV* _ev) {
ev = _ev;
hitPMTChannels.clear();
// NOTE: class implementation assumes GetAllPMTIDs and GetAllDigitPMTIDs returns _sorted_ results.
// This is true since ev->digitpmt and ev->pmt are both std::maps.
hitPMTChannels = mode == Mode::kPMT ? ev->GetAllPMTIDs() : ev->GetAllDigitPMTIDs();
}

/**
* @brief Get PMTIDs for all pmts in the event.
* PMT will not be in the list if it never created a hit on DS::PMT or if the digitized waveform never crossed
* threshold.
*
* @return vector of all PMTs in event.
*/
const std::vector<Int_t>& GetAllHitPMTIDs() {
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
return hitPMTChannels;
}

/**
* @brief Get the charge of a pmt.
* In the case where a waveoform analyzer created multiple hits on the PMT (multi-PE), this method only returns
* information about this first hit. To get information about all the hits, use getCharges.
*
* @param id PMT ID.
* @return charge of the first hit.
*/
double GetCharge(Int_t id) {
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
if (!std::binary_search(hitPMTChannels.begin(), hitPMTChannels.end(), id))
Log::Die("FitterInputHandler: Trying to access a channel with no hit registered!");

switch (mode) {
case Mode::kPMT:
return ev->GetOrCreatePMT(id)->GetCharge();
case Mode::kDigitPMT:
return ev->GetOrCreateDigitPMT(id)->GetDigitizedCharge();
case Mode::kWaveformAnalysis: {
DS::DigitPMT* digitpmt = ev->GetOrCreateDigitPMT(id);
std::vector<std::string> fitterNames = digitpmt->GetFitterNames();
if (std::find(fitterNames.begin(), fitterNames.end(), wfm_ana_name) == fitterNames.end()) {
info << "FitResult not found for pmt id " << id << " " << wfm_ana_name << newline;
}
return digitpmt->GetOrCreateWaveformAnalysisResult(wfm_ana_name)->getCharge(0);
}
default:
Log::Die("INVALID TYPE! Should never reach here.");
}
}

/**
* @brief Get the charge of all hits registered on a PMT.
* To only get information about the first hit (assume SPE), see getCharge.
*
* @param id PMT ID.
* @return vector of the charges registered on all hits on the PMT.
*/
std::vector<double> GetCharges(Int_t id) {
if (mode != Mode::kWaveformAnalysis) return std::vector<double>{GetCharge(id)};
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
if (!std::binary_search(hitPMTChannels.begin(), hitPMTChannels.end(), id))
Log::Die("FitterInputHandler: Trying to access a channel with no hit registered!");
DS::DigitPMT* digitpmt = ev->GetOrCreateDigitPMT(id);
std::vector<std::string> fitterNames = digitpmt->GetFitterNames();
if (std::find(fitterNames.begin(), fitterNames.end(), wfm_ana_name) == fitterNames.end()) {
info << "FitResult not found for pmt id " << id << " " << wfm_ana_name << newline;
}
return digitpmt->GetOrCreateWaveformAnalysisResult(wfm_ana_name)->getCharges();
}

/**
* @brief Get the time of a pmt hit.
* In the case where a waveoform analyzer created multiple hits on the PMT (multi-PE), this method only returns
* information about this first hit. To get information about all the hits, use getTimes.
* @param id PMT ID.
*
* @return time of the first hit.
*/
double GetTime(Int_t id) {
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
if (!std::binary_search(hitPMTChannels.begin(), hitPMTChannels.end(), id))
Log::Die("FitterInputHandler: Trying to access a channel with no hit registered!");

switch (mode) {
case Mode::kPMT:
return ev->GetOrCreatePMT(id)->GetTime();
case Mode::kDigitPMT:
return ev->GetOrCreateDigitPMT(id)->GetDigitizedTime();
case Mode::kWaveformAnalysis: {
DS::DigitPMT* digitpmt = ev->GetOrCreateDigitPMT(id);
std::vector<std::string> fitterNames = digitpmt->GetFitterNames();
if (std::find(fitterNames.begin(), fitterNames.end(), wfm_ana_name) == fitterNames.end()) {
info << "FitResult not found for pmt id " << id << " " << wfm_ana_name << newline;
}
return digitpmt->GetOrCreateWaveformAnalysisResult(wfm_ana_name)->getTime(0);
}
default:
Log::Die("INVALID TYPE! Should never reach here.");
}
}

/**
* @brief Get the time of all hits registered on a PMT.
* To only get information about the first hit (assume SPE), see getTime.
*
* @param id PMT ID.
* @return vector of the times registered on all hits on the PMT.
*/
std::vector<double> GetTimes(Int_t id) {
if (mode != Mode::kWaveformAnalysis) return std::vector<double>{GetTime(id)};
if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
if (!std::binary_search(hitPMTChannels.begin(), hitPMTChannels.end(), id))
Log::Die("FitterInputHandler: Trying to access a channel with no hit registered!");
DS::DigitPMT* digitpmt = ev->GetOrCreateDigitPMT(id);
std::vector<std::string> fitterNames = digitpmt->GetFitterNames();
if (std::find(fitterNames.begin(), fitterNames.end(), wfm_ana_name) == fitterNames.end()) {
info << "FitResult not found for pmt id " << id << " " << wfm_ana_name << newline;
}
return digitpmt->GetOrCreateWaveformAnalysisResult(wfm_ana_name)->getTimes();
}

/**
* @brief Return the (approximate) number of hits registered on a PMT.
* Behavior is different depending on the mode.
* If mode is set to kPMT, always return 1 since no information about nhit is given.
* If mode is set to kDigitPMT, return the number of times that the waveform crosses threshold.
* If mode is set to kWaveformAnalysis, return the number of hits created by the analyzer.
*
* @param id PMT ID.
*/
unsigned int GetNhits(Int_t id) {
Copy link
Contributor

@tannerbk tannerbk Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusingly named. "Nhits" refers to the number of PMT hits in an event, not the number of "hits" on a PMT. This could be more aptly named "GetNPEs".

Additionally, you could add a method for nhits that returns the length of the appropriate hit PMT list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both implemented -- the old GetNHits is now called GetNPEs. There's now a new GetNHits method.

if (!ev) Log::Die("FitterInputHandler: Trying to acccess event info without registering the event.");
if (!std::binary_search(hitPMTChannels.begin(), hitPMTChannels.end(), id))
Log::Die("FitterInputHandler: Trying to access a channel with no hit registered!");

switch (mode) {
case Mode::kPMT:
return 1; // no nhit information
case Mode::kDigitPMT:
return ev->GetOrCreateDigitPMT(id)->GetNCrossings(); // approximate
case Mode::kWaveformAnalysis:
DS::DigitPMT* digitpmt = ev->GetOrCreateDigitPMT(id);
std::vector<std::string> fitterNames = digitpmt->GetFitterNames();
if (std::find(fitterNames.begin(), fitterNames.end(), wfm_ana_name) == fitterNames.end()) {
info << "FitResult not found for pmt id " << id << " " << wfm_ana_name << newline;
}
return digitpmt->GetOrCreateWaveformAnalysisResult(wfm_ana_name)->getNhits();
}
}

protected:
DS::EV* ev = nullptr;
std::vector<Int_t> hitPMTChannels;
};

} // namespace RAT

#endif
15 changes: 4 additions & 11 deletions src/fit/src/FitCentroidProc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@
#include <string>

namespace RAT {

FitCentroidProc::FitCentroidProc() : Processor("fitcentroid") {
fPower = 2.0;
fRescale = 1.0;
}

void FitCentroidProc::SetD(std::string param, double value) {
if (param == "power") {
fPower = value;
Expand All @@ -29,19 +23,18 @@ void FitCentroidProc::SetD(std::string param, double value) {
}

Processor::Result FitCentroidProc::Event(DS::Root *ds, DS::EV *ev) {
inputHandler.RegisterEvent(ev);
double totalQ = 0;
TVector3 centroid(0.0, 0.0, 0.0);

for (int i : ev->GetAllPMTIDs()) {
DS::PMT *pmt = ev->GetOrCreatePMT(i);

for (int pmtid : inputHandler.GetAllHitPMTIDs()) {
double Qpow = 0.0;
Qpow = pow(pmt->GetCharge(), fPower);
Qpow = pow(inputHandler.GetCharge(pmtid), fPower);
totalQ += Qpow;

DS::Run *run = DS::RunStore::Get()->GetRun(ds);
DS::PMTInfo *pmtinfo = run->GetPMTInfo();
TVector3 pmtpos = pmtinfo->GetPosition(pmt->GetID());
TVector3 pmtpos = pmtinfo->GetPosition(pmtid);

if (fRescale != 1.0) {
pmtpos.SetMag(pmtpos.Mag() * fRescale);
Expand Down
Loading