Skip to content

Commit

Permalink
v2.0.3 - Improvements to binned classification
Browse files Browse the repository at this point in the history
- Added MultiClass support for binned classification.
- Added `ANNZ::deriveHisClsPrb()` - change how probabilities are
calculated in binned classification.
- Added `ANNZ_PDF_max` - the most likely value of a PDF (the peak of
the PDF), to the outputs of regression.
- Various bug fixes.
  • Loading branch information
IftachSadeh committed Feb 25, 2015
1 parent 3e4db48 commit e39c31c
Show file tree
Hide file tree
Showing 19 changed files with 648 additions and 174 deletions.
13 changes: 11 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
# Changelog

## Master version
<!-- ## Master version -->

- Added `ANNZ_PDF_max`, the most likely value of a PDF (the peak of the PDF), to the outputs regression.
## ANNZ 2.0.3 (20/2/2015)

- **Added *MultiClass* support to binned classification:** The new option is controlled by setting the `doMultiCls` flag. In this mode, multiple background samples can be trained simultaneously against the signal. In the context of binned classification, this means that each classification bin acts as an independent sample during the training.

- **Added the function, `ANNZ::deriveHisClsPrb()`:** Modified binned classification, such that all classification probabilities are calculated by hand, instead of using the `CreateMVAPdfs` option of `TMVA::Factory`. By default, the new calculation takes into account the relative size of the signal in each classification bin, compared to the number of objects in the entire training sample. The latter feature may be turned off, by setting:
```python
glob.annz["useBinClsPrior"] = False
```

- Added `ANNZ_PDF_max`, the most likely value of a PDF (the peak of the PDF), to the outputs of regression.

- Fixed compatibility issues with ROOT v6.02.

Expand Down
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# ANNZ 2.0.2
# ANNZ 2.0.3

## Introduction
ANNZ uses both regression and classification techniques for estimation of single-value photo-z (or any regression problem) solutions and PDFs. In addition it is suitable for classification problems, such as star/galaxy classification.
Expand Down Expand Up @@ -265,6 +265,23 @@ The `scripts/generalSettings.py` script includes the following two functions:

The syntax for defining MLM options is explained in the [TMVA wiki](http://tmva.sourceforge.net/optionRef.html) and in the [TMVA manuall](http://tmva.sourceforge.net/docu/TMVAUsersGuide.pdf) (in the chapter, *The TMVA Methods*). It may be specified by the user with the `glob.annz["userMLMopts"]` variable. The only requirement not defined nominally in TMVA is `ANNZ_MLM`. This is an internal variable in ANNZ which specifies the type of MLM requested by the user.

The following is the list of all available MLM algorithms:

- **`CUTS`:** Rectangular cut optimization.
- **`Likelihood`:** Projective likelihood estimator (PDE approach).
- **`PDERS`:** Multidimensional likelihood estimator (PDE range-search approach).
- **`PDEFoam`:** Likelihood estimator using self-adapting phase-space binning.
- **`KNN`:** k-Nearest Neighbors.
- **`HMatrix`:** H-Matrix discriminant.
- **`Fisher`:** Fisher discriminants (linear discriminant analysis).
- **`LD`:** Linear discriminant analysis.
- **`FDA`:** Function discriminant analysis.
- **`ANN`:** Artificial neural networks (nonlinear discriminant analysis) using an MLP neural network (recommended type of neural network). Also available are the Clermont-Ferrand neural network (**`CFMlpANN`**) and the original ROOT implementation of a neural network (**`TMlpANN`**).
- **`SVM`:** Support vector machine.
- **`BDT`:** Boosted decision and regression trees.
- **`RuleFit`:** Predictive learning via rule ensembles.


Here are a couple of examples:

- For instance we can define a BDT with 110 decision trees, using the AdaBoost (adaptive boost) algorithm:
Expand Down Expand Up @@ -311,7 +328,7 @@ We define the following name-tags:

5. **ANNZ PDF 0 (`PDF_0_*`)**: The full PDF solution.

6. **`ANNZ_MLM_avg_1`, `ANNZ_PDF_max_1` and `PDF_1_*`**: The corresponding estimators for the second PDF.
6. **`ANNZ_MLM_avg_1`, `ANNZ_PDF_avg_1`, `ANNZ_PDF_max_1` and `PDF_1_*`**: The corresponding estimators for the second PDF.

#### Optimization

Expand Down Expand Up @@ -450,6 +467,7 @@ A few notes:
```
will insure that any MLM which has scatter higher than `0.04` will not be included in the PDF.

- **`doMultiCls`:** Using the *MultiClass* option of binned classification, multiple background samples can be trained simultaneously against the signal. This means that each classification bin acts as an independent sample during the training. The MultiClass option is only compatible with four MLM algorithms: `BDT`, `ANN`, `FDA` and `PDEFoam`. For `BDT`, only the gradient boosted decision trees are available. That is, one may set `:BoostType=Grad`, but not `:BoostType=Bagging` or `:BoostType=AdaBoost`, as part of the `userMLMopts` option.

- By default, a progress bar is drawn during training. If one is writing the output to a log file, the progress bar is important to avoid, as it will cause the size of the log file to become very large. One can either add `--isBatch` while running the example scripts, or set in `generalSettings.py` (or elsewhere),
```python
Expand Down
24 changes: 22 additions & 2 deletions examples/scripts/annz_binCls_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
# --------------------------------------------------------------------------------------------------
clsBinType = 2
if clsBinType == 0:
# derive the binning scheme on the fly, such that 60 bins, each no wider than 0.05, are chosen
# derive the binning scheme on the fly, such that 60 bins, each no wider than 0.02, are chosen
glob.annz["binCls_nBins"] = 60
glob.annz["binCls_maxBinW"] = 0.02
elif clsBinType == 1:
Expand Down Expand Up @@ -204,6 +204,26 @@
# --------------------------------------------------------------------------------------------------
glob.annz["rndOptTypes"] = "BDT"

# --------------------------------------------------------------------------------------------------
# doMultiCls:
# --------------------------------------------------------------------------------------------------
# - Using the MultiClass option of binned classification, multiple background samples can be trained
# simultaneously against the signal. This means that each classification bin acts as an independent sample during
# the training. The MultiClass option is only compatible with four MLM algorithms: BDT, ANN, FDA and PDEFoam.
# For BDT, only the gradient boosted decision trees are available. That is, one may set ":BoostType=Grad",
# but not ":BoostType=Bagging" or ":BoostType=AdaBoost", as part of the userMLMopts option.
# - examples:
# - glob.annz["userMLMopts_0"] = "ANNZ_MLM=FDA:Formula=(0)+(1)*x0+(2)*x1+(3)*x2+(4)*x3:" \
# +"ParRanges=(-1,1);(-10,10);(-10,10);(-10,10);(-10,10):" \
# +"FitMethod=GA:PopSize=300:Cycles=3:Steps=20:Trim=True:SaveBestGen=1"
# - glob.annz["userMLMopts_1"] = "ANNZ_MLM=PDEFoam:nActiveCells=500:nSampl=2000:nBin=5:Nmin=100:Kernel=None:Compress=T"
# - Using the MultiClass option, the binCls_bckShiftMin,binCls_bckShiftMax,binCls_bckSubsetRange
# options are ignored.
# - Using the MultiClass option, training is much slower, it is therefore recommended to set a low
# value (<3) of binCls_nTries.
# --------------------------------------------------------------------------------------------------
glob.annz["doMultiCls"] = False

# --------------------------------------------------------------------------------------------------
# - binCls_bckShiftMin,binCls_bckShiftMax (Optional training setting):
# --------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -472,7 +492,7 @@
# evalDirPostfix - if not empty, this string will be added to the name of the evaluation directory
# (can be used to prevent multiple evaluation of different input files from overwriting each other)
glob.annz["evalDirPostfix"] = ""

# run ANNZ with the current settings
runANNZ()

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/annz_binCls_quick.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
# evalDirPostfix - if not empty, this string will be added to the name of the evaluation directory
# (can be used to prevent multiple evaluation of different input files from overwriting each other)
glob.annz["evalDirPostfix"] = ""

# run ANNZ with the current settings
runANNZ()

Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/annz_rndReg_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
# zTrg - the name of the target variable of the regression
# minValZ,maxValZ - the minimal and maximal values of the target variable (zTrg)
glob.annz["zTrg"] = "Z"
glob.annz["minValZ"] = 0.05
glob.annz["minValZ"] = 0.0
glob.annz["maxValZ"] = 0.8

# set the number of near-neighbours used to compute the KNN error estimator
# (should be around 100 - set here to a very low number just to speed the example up... !)
glob.annz["nErrKNN"] = 20
glob.annz["nErrKNN"] = 50

# --------------------------------------------------------------------------------------------------
# pre-processing of the input dataset
Expand Down Expand Up @@ -304,7 +304,7 @@
hasUserPdfBins = False
if hasUserPdfBins:
# use a pre-defined set of PDF bins
glob.annz["userPdfBins"] = "0.05;0.1;0.2;0.24;0.3;0.52;0.6;0.7;0.8"
glob.annz["userPdfBins"] = "0.0;0.1;0.2;0.24;0.3;0.52;0.6;0.7;0.8"
else:
# nPDFbins - number of PDF bins (equal width bins between minValZ and maxValZ)
glob.annz["nPDFbins"] = 90
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/annz_rndReg_quick.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
# zTrg - the name of the target variable of the regression
# minValZ,maxValZ - the minimal and maximal values of the target variable (zTrg)
glob.annz["zTrg"] = "Z"
glob.annz["minValZ"] = 0.05
glob.annz["minValZ"] = 0.0
glob.annz["maxValZ"] = 0.8

# set the number of near-neighbours used to compute the KNN error estimator
# (should be around 100 - set here to a very low number just to speed the example up... !)
glob.annz["nErrKNN"] = 20
glob.annz["nErrKNN"] = 50

# --------------------------------------------------------------------------------------------------
# pre-processing of the input dataset
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/annz_singleReg_quick.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
# zTrg - the name of the target variable of the regression
# minValZ,maxValZ - the minimal and maximal values of the target variable (zTrg)
glob.annz["zTrg"] = "Z"
glob.annz["minValZ"] = 0.05
glob.annz["minValZ"] = 0.0
glob.annz["maxValZ"] = 0.8

# set the number of near-neighbours used to compute the KNN error estimator
# (should be around 100 - set here to a very low number just to speed the example up... !)
glob.annz["nErrKNN"] = 20
glob.annz["nErrKNN"] = 50

# --------------------------------------------------------------------------------------------------
# pre-processing of the input dataset
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/generalSettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def genRndOpts(aSeed):

nTreeFact = 3 if(rndAr[1] < 0.2) else 1
nTreesAdd = int(floor(rndAr[2]*300/10.) * 10) * nTreeFact
nTrees = ":NTrees="+str(int(50+max(0,min(nTreesAdd,800))))
nTrees = ":NTrees="+str(int(250+max(0,min(nTreesAdd,800))))

boostType = ":BoostType="
if (rndAr[3] < 0.4): boostType += "Bagging";
Expand Down
5 changes: 4 additions & 1 deletion include/ANNZ.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class ANNZ : public BaseClass {
void prepFactory(int nMLMnow = -1, TMVA::Factory * factory = NULL);
void doFactoryTrain(TMVA::Factory * factory);
void clearReaders(Log::LOGtypes logLevel = Log::DEBUG_1);
void loadReaders(map <TString,bool> & mlmSkipNow);
void loadReaders(map <TString,bool> & mlmSkipNow, bool needMcPRB = true);
double getReader(VarMaps * var = NULL, ANNZ_readType readType = ANNZ_readType::NUN, bool forceUpdate = false, int nMLMnow = -1);
void setupTypesTMVA();
TMVA::Types::EMVA getTypeMLMbyName(TString typeName);
Expand All @@ -133,6 +133,7 @@ class ANNZ : public BaseClass {
void makeTreeRegClsAllMLM();
void makeTreeRegClsOneMLM(int nMLMnow = -1);
double getSeparation(TH1 * hisSig, TH1 * hisBck);
void deriveHisClsPrb(int nMLMnow = -1);
TChain * mergeTreeFriends(TChain * aChain = NULL, TChain * aChainFriend = NULL, vector<TString> * chainFriendFileNameV = NULL,
vector <TString> * acceptV = NULL, vector <TString> * rejectV = NULL, TCut aCut = "");
void verifyIndicesMLM(TChain * aChain = NULL);
Expand Down Expand Up @@ -173,9 +174,11 @@ class ANNZ : public BaseClass {
vector < pair<TString,Float_t> > readerInptV;
vector < vector<int> > readerInptIndexV;
vector < TMVA::Reader* > regReaders;
vector < TMVA::Types::EAnalysisType > anlysTypes;
vector < TMVA::Types::EMVA > typeMLM, allANNZtypes;
map < TMVA::Types::EMVA,TString > typeToNameMLM;
map < TString,TMVA::Types::EMVA > nameToTypeMLM;
vector < TH1* > hisClsPrbV;

};
#endif // #define ANNZ_h
6 changes: 5 additions & 1 deletion src/ANNZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ ANNZ::~ANNZ() {
inNamesVar.clear(); inNamesErr.clear(); userWgtsM.clear(); mlmTagErrKNN.clear();
zClos_binE.clear(); zClos_binC.clear(); zBinCls_binE.clear(); zBinCls_binC.clear();
typeMLM.clear(); allANNZtypes.clear(); typeToNameMLM.clear(); nameToTypeMLM.clear();
bestMLMname.clear();
bestMLMname.clear(); anlysTypes.clear(); readerInptV.clear();

for(int nHisNow=0; nHisNow<(int)regReaders.size(); nHisNow++) DELNULL(regReaders[nHisNow]);
for(int nHisNow=0; nHisNow<(int)hisClsPrbV.size(); nHisNow++) DELNULL(hisClsPrbV[nHisNow]);
regReaders.clear(); hisClsPrbV.clear();

return;
}
Expand Down
69 changes: 55 additions & 14 deletions src/ANNZ_TMVA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ void ANNZ::clearReaders(Log::LOGtypes logLevel) {

DELNULL_(LOCATION,regReaders[nMLMnow],(TString)"regReaders["+utils->intToStr(nMLMnow)+"]",verb);
}
regReaders.clear(); readerInptV.clear(); readerInptIndexV.clear();

for(int nMLMnow=0; nMLMnow<(int)hisClsPrbV.size(); nMLMnow++) DELNULL(hisClsPrbV[nMLMnow]);

regReaders.clear(); readerInptV.clear(); readerInptIndexV.clear(); anlysTypes.clear(); hisClsPrbV.clear();

return;
}
Expand All @@ -105,18 +108,21 @@ void ANNZ::clearReaders(Log::LOGtypes logLevel) {
* so that the readers may be evaluated.
*
* @param mlmSkipNow - Map for determining which MLM is accepted.
* @param needMcPRB - wether or not to load the multiclass probability pdf
*/
// ===========================================================================================================
void ANNZ::loadReaders(map <TString,bool> & mlmSkipNow) {
// ======================================================
void ANNZ::loadReaders(map <TString,bool> & mlmSkipNow, bool needMcPRB) {
// ======================================================================
aLOG(Log::DEBUG_1) <<coutWhiteOnBlack<<coutYellow<<" - starting ANNZ::loadReaders() ... "<<coutDef<<endl;

int nMLMs = glob->GetOptI("nMLMs");
int nMLMs = glob->GetOptI("nMLMs");
bool isBinCls = glob->GetOptB("doBinnedCls");

// cleanup containers before initializing new readers
clearReaders(Log::DEBUG_2);

readerInptV.clear(); regReaders.resize(nMLMs,NULL); readerInptIndexV.resize(nMLMs);
anlysTypes.resize(nMLMs,TMVA::Types::kNoAnalysisType); hisClsPrbV.resize(glob->GetOptI("nMLMs"),NULL);

// -----------------------------------------------------------------------------------------------------------
// initialize readerInptV and add all required variables by input variables (formulae) - using
Expand Down Expand Up @@ -180,13 +186,31 @@ void ANNZ::loadReaders(map <TString,bool> & mlmSkipNow) {
bool foundReader = (dynamic_cast<TMVA::MethodBase*>(regReaders[nMLMnow]->FindMVA(MLMname)));

if(foundReader) {
TString methodName = (dynamic_cast<TMVA::MethodBase*>(regReaders[nMLMnow] ->FindMVA(MLMname)))->GetMethodTypeName();
TMVA::Types::EMVA methodType = (dynamic_cast<TMVA::MethodBase*>(regReaders[nMLMnow] ->FindMVA(MLMname)))->GetMethodType();
TString methodName = (dynamic_cast<TMVA::MethodBase*>(regReaders[nMLMnow]->FindMVA(MLMname)))->GetMethodTypeName();
TMVA::Types::EMVA methodType = (dynamic_cast<TMVA::MethodBase*>(regReaders[nMLMnow]->FindMVA(MLMname)))->GetMethodType();
anlysTypes[nMLMnow] = (dynamic_cast<TMVA::MethodBase*>(regReaders[nMLMnow]->FindMVA(MLMname)))->GetAnalysisType();

VERIFY(LOCATION,(TString)"Found inconsistent settings (configSave_type = \""+typeToNameMLM[typeMLM[nMLMnow]]
+"\" from the settings file, but "+MLMname+" is of type \""+typeToNameMLM[methodType]+"\""
,(typeMLM[nMLMnow] == methodType));

// load the classification response histogram for Multiclass readers
if(needMcPRB && (isBinCls || anlysTypes[nMLMnow] == TMVA::Types::kMulticlass)) {
TString hisClsPrbFileName = getKeyWord(MLMname,"postTrain","hisClsPrbFile");
TString hisName = getKeyWord(MLMname,"postTrain","hisClsPrbHis");

TFile * hisClsPrbFile = new TFile(hisClsPrbFileName,"READ");

hisClsPrbV[nMLMnow] = dynamic_cast<TH1*>(hisClsPrbFile->Get(hisName));
VERIFY(LOCATION,(TString)"Could not find hisClsPrbV[nMLMnow = "+utils->intToStr(nMLMnow)+"] in "
+hisClsPrbFileName+" ?!",(dynamic_cast<TH1*>(hisClsPrbV[nMLMnow])));

hisClsPrbV[nMLMnow] = (TH1*)hisClsPrbV[nMLMnow]->Clone((TString)hisName+"_cln");
hisClsPrbV[nMLMnow]->SetDirectory(0);

hisClsPrbFile->Close(); DELNULL(hisClsPrbFile);
}

if(nReadIn < 5) aLOG(Log::DEBUG) <<coutYellow<<" - Found "<<methodName<<" Reader("<<coutRed<<nMLMnow<<coutYellow<<") ... "<<coutDef<<endl;
if(nReadIn == 5) aLOG(Log::DEBUG) <<coutYellow<<" - Suppressing further messages ... "<<coutDef<<endl;
nReadIn++;
Expand Down Expand Up @@ -219,14 +243,31 @@ double ANNZ::getReader(VarMaps * var, ANNZ_readType readType, bool forceUpdate,
VERIFY(LOCATION,(TString)"Memory leak for regReaders[nMLMnow = "+utils->intToStr(nMLMnow)+"] ?! ",(dynamic_cast<TMVA::Reader*>(regReaders[nMLMnow])));
VERIFY(LOCATION,(TString)"unknown readType (\""+utils->intToStr((int)readType)+"\") ...",(nMLMnow < glob->GetOptI("nMLMs")));

TString MLMname = getTagName(nMLMnow);

var->updateReaderFormulae(readerInptV,forceUpdate);

if (readType == ANNZ_readType::REG) return (regReaders[nMLMnow]->EvaluateRegression(MLMname))[0];
else if(readType == ANNZ_readType::PRB) return max(min(regReaders[nMLMnow]->GetProba(MLMname),1.),0.);
else if(readType == ANNZ_readType::CLS) return regReaders[nMLMnow]->EvaluateMVA(MLMname);
else VERIFY(LOCATION,(TString)"un-supported readType (\""+utils->intToStr((int)readType)+"\") ...",false);

TString MLMname = getTagName(nMLMnow);
bool isBinCls = glob->GetOptB("doBinnedCls");
bool isMC = (anlysTypes[nMLMnow] == TMVA::Types::kMulticlass);
double readVal = 0;

if(isMC || isBinCls) {
double clsVal = isMC ? (regReaders[nMLMnow]->EvaluateMulticlass(MLMname))[0] : regReaders[nMLMnow]->EvaluateMVA(MLMname);

if (readType == ANNZ_readType::PRB) {
VERIFY(LOCATION,(TString)"Memory leak for hisClsPrbV[nMLMnow = "+utils->intToStr(nMLMnow)+"] ?! ",(dynamic_cast<TH1*>(hisClsPrbV[nMLMnow])));
readVal = max(min( hisClsPrbV[nMLMnow]->GetBinContent( hisClsPrbV[nMLMnow]->GetXaxis()->FindBin(clsVal) ) ,1.),0.);
}
else if(readType == ANNZ_readType::CLS) readVal = clsVal;
else VERIFY(LOCATION,(TString)"un-supported readType (\""+utils->intToStr((int)readType)+"\") ...",false);
}
else {
if (readType == ANNZ_readType::REG) readVal = (regReaders[nMLMnow]->EvaluateRegression(MLMname))[0];
else if(readType == ANNZ_readType::PRB) readVal = max(min(regReaders[nMLMnow]->GetProba(MLMname),1.),0.);
else if(readType == ANNZ_readType::CLS) readVal = regReaders[nMLMnow]->EvaluateMVA(MLMname);
else VERIFY(LOCATION,(TString)"un-supported readType (\""+utils->intToStr((int)readType)+"\") ...",false);
}

return readVal;
}

// ===========================================================================================================
Expand Down Expand Up @@ -320,7 +361,7 @@ bool ANNZ::verifyXML(TString outXmlFileName) {

DELNULL(testFile);
}

return isGoodXML;
}

Loading

0 comments on commit e39c31c

Please sign in to comment.