From 37f5c4f5d0cbdd3084c13581ce7d785e1b40dc8d Mon Sep 17 00:00:00 2001 From: IftachSadeh Date: Tue, 15 Mar 2016 14:05:38 +0100 Subject: [PATCH] ANNZ 2.1.2 - Improved selection criteria for `ANNZ_best` in randomized regression. The optimization is now based on `glob.annz["optimCondReg"]="sig68"` or `"bias"` (The `"fracSig68"` option is deprecated.) - **Significant speed improvement** for KNN weights and `inTrainFlag` calculations in `CatFormat::addWgtKNNtoTree()`. - Modified `CatFormat::addWgtKNNtoTree()` and `CatFormat::inputToSplitTree_wgtKNN()` so that both training and testing objects are used together as the reference dataset, when deriving KNN weights. This new option is on by default, and may be turned off by setting: ```python glob.annz["trainTestTogether_wgtKNN"] = False ``` - For developers: internal interface change (not backward compatible) - What used to be `CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString outTreeName)` has been changed to `CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TChain * aChainEvl, TString outTreeName)`. - Cancelled the `splitTypeValid` option, which was not very useful and confusing for users. From now on, input datasets may only be divided into two subsets, one for training and one for testing. The user may define the training/testing samples in one of two ways (see `scripts/annz_rndReg_advanced.py` for details): 1. Automatic splitting: ```python glob.annz["splitType"] = "random" glob.annz["inAsciiFiles"] = "boss_dr10_0.csv;boss_dr10_1.csv" ``` Set a list of input files in `inAsciiFiles`, and use `splitType` to specify the method for splitting the sample. Allowed values for the latter are `serial`, `blocks` or `random`. 2. Splitting by file: ```python glob.annz["splitType"] = "byInFiles" glob.annz["splitTypeTrain"] = "boss_dr10_0.csv" glob.annz["splitTypeTest"] = "boss_dr10_1.csv;boss_dr10_2.csv" ``` Set a list of input files for training in `splitTypeTrain`, and a list of input files for testing in `splitTypeTest`. - Added plotting for the evaluation mode of regression (single regression, randomized regression and binned classification). If the regression target is detected as part of the evaluated dataset, the nominal performance plots are created. For instance, for the `scripts/annz_rndReg_quick.py` script, the plots will be created in `output/test_randReg_quick/regres/eval/plots/`. - Fixed bug in plotting routine from `ANNZ::doMetricPlots()`, when adding user-defined cuts for variables not already present in the input trees. - Simplified the interface for string variables in cut and weight expressions. - For example, given a set of input parameters, ```python glob.annz["inAsciiVars"] = "D:MAG_AUTO_G;D:MAG_AUTO_R;D:MAG_AUTO_I;D:Z_SPEC;C:FIELD" ``` one can now use cuts and weights of the form: ```python glob.annz["userCuts_train"] = " (FIELD == \"FIELD_0\") || (FIELD == \"FIELD_1\")" glob.annz["userCuts_valid"] = " (FIELD == \"FIELD_1\") || (FIELD == \"FIELD_2\")" glob.annz["userWeights_train"] = "1.0*(FIELD == \"FIELD_0\") + 2.0*(FIELD == \"FIELD_1\")" glob.annz["userWeights_valid"] = "1.0*(FIELD == \"FIELD_1\") + 0.1*(FIELD == \"FIELD_2\")" ``` Here, training is only done using `FIELD_0` and `FIELD_1`; validation is weighted such that galaxies from `FIELD_1` have ten times the weight compared to galaxies from `FIELD_2` etc. - The same rules also apply for the weight and cut options for the KNN re-weighting method: `cutInp_wgtKNN`, `cutRef_wgtKNN`, `weightRef_wgtKNN` and `weightInp_wgtKNN`, and for the corresponding variables for the evaluation compatibility test: `cutInp_inTrain`, `cutRef_inTrain`, `weightRef_inTrain` and `weightInp_inTrain`. (Examples for the re-weighting and for the compatibility test using these variables are given in `scripts/annz_rndReg_advanced.py`.) - `ANNZ_PDF_max_0` no longer calculated by default. This may be turned back on by setting ```python glob.annz["addMaxPDF"] = True ``` - Other minor modifications and bug fixes. --- CHANGELOG.md | 55 ++- README.md | 61 +++- examples/scripts/annz_binCls_advanced.py | 26 +- examples/scripts/annz_binCls_quick.py | 3 +- examples/scripts/annz_rndCls_advanced.py | 23 +- examples/scripts/annz_rndReg_advanced.py | 50 ++- examples/scripts/annz_rndReg_quick.py | 7 +- examples/scripts/annz_rndReg_weights.py | 34 +- examples/scripts/annz_singleReg_quick.py | 5 +- examples/scripts/generalSettings.py | 7 +- examples/scripts/helperFuncs.py | 64 ++-- include/ANNZ.hpp | 5 +- include/ANNZLinkDef.hpp | 10 - include/BaseClassLinkDef.hpp | 10 - include/CatFormat.hpp | 2 +- include/CatFormatLinkDef.hpp | 10 - include/OptMaps.hpp | 2 + include/OptMapsLinkDef.hpp | 10 - include/OutMngrLinkDef.hpp | 10 - include/Utils.hpp | 19 +- include/UtilsLinkDef.hpp | 10 - include/VarMaps.hpp | 20 +- include/VarMapsLinkDef.hpp | 10 - src/ANNZ_err.cpp | 11 +- src/ANNZ_loopCls.cpp | 18 +- src/ANNZ_loopReg.cpp | 359 ++++++++++--------- src/ANNZ_loopRegCls.cpp | 126 +++++-- src/ANNZ_train.cpp | 140 ++++---- src/ANNZ_utils.cpp | 173 +++++++--- src/CatFormat_asciiToTree.cpp | 170 ++++----- src/CatFormat_wgtKNN.cpp | 421 +++++++++++++++-------- src/OutMngr_draw.cpp | 8 + src/Utils.cpp | 28 +- src/VarMaps.cpp | 51 ++- src/myANNZ.cpp | 97 ++++-- 35 files changed, 1228 insertions(+), 827 deletions(-) delete mode 100644 include/ANNZLinkDef.hpp delete mode 100644 include/BaseClassLinkDef.hpp delete mode 100644 include/CatFormatLinkDef.hpp delete mode 100644 include/OptMapsLinkDef.hpp delete mode 100644 include/OutMngrLinkDef.hpp delete mode 100644 include/UtilsLinkDef.hpp delete mode 100644 include/VarMapsLinkDef.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 6426833..5bfeef5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,62 @@ # Changelog -## Master version + +## ANNZ 2.1.2 (15/3/2016) + +- Improved selection criteria for `ANNZ_best` in randomized regression. The optimization is now based on `glob.annz["optimCondReg"]="sig68"` or `"bias"` (The `"fracSig68"` option is deprecated.) + +- **Significant speed improvement** for KNN weights and `inTrainFlag` calculations in `CatFormat::addWgtKNNtoTree()`. + +- Modified `CatFormat::addWgtKNNtoTree()` and `CatFormat::inputToSplitTree_wgtKNN()` so that both training and testing objects are used together as the reference dataset, when deriving KNN weights. This new option is on by default, and may be turned off by setting: + ```python + glob.annz["trainTestTogether_wgtKNN"] = False + ``` + - For developers: internal interface change (not backward compatible) - What used to be `CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString outTreeName)` has been changed to `CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TChain * aChainEvl, TString outTreeName)`. + +- Cancelled the `splitTypeValid` option, which was not very useful and confusing for users. From now on, input datasets may only be divided into two subsets, one for training and one for testing. The user may define the training/testing samples in one of two ways (see `scripts/annz_rndReg_advanced.py` for details): + + 1. Automatic splitting: + ```python + glob.annz["splitType"] = "random" + glob.annz["inAsciiFiles"] = "boss_dr10_0.csv;boss_dr10_1.csv" + ``` + Set a list of input files in `inAsciiFiles`, and use `splitType` to specify the method for splitting the sample. Allowed values for the latter are `serial`, `blocks` or `random`. + + 2. Splitting by file: + ```python + glob.annz["splitType"] = "byInFiles" + glob.annz["splitTypeTrain"] = "boss_dr10_0.csv" + glob.annz["splitTypeTest"] = "boss_dr10_1.csv;boss_dr10_2.csv" + ``` + Set a list of input files for training in `splitTypeTrain`, and a list of input files for testing in `splitTypeTest`. + +- Added plotting for the evaluation mode of regression (single regression, randomized regression and binned classification). If the regression target is detected as part of the evaluated dataset, the nominal performance plots are created. For instance, for the `scripts/annz_rndReg_quick.py` script, the plots will be created in `output/test_randReg_quick/regres/eval/plots/`. - Fixed bug in plotting routine from `ANNZ::doMetricPlots()`, when adding user-defined cuts for variables not already present in the input trees. +- Simplified the interface for string variables in cut and weight expressions. + - For example, given a set of input parameters, + ```python + glob.annz["inAsciiVars"] = "D:MAG_AUTO_G;D:MAG_AUTO_R;D:MAG_AUTO_I;D:Z_SPEC;C:FIELD" + ``` + one can now use cuts and weights of the form: + ```python + glob.annz["userCuts_train"] = " (FIELD == \"FIELD_0\") || (FIELD == \"FIELD_1\")" + glob.annz["userCuts_valid"] = " (FIELD == \"FIELD_1\") || (FIELD == \"FIELD_2\")" + glob.annz["userWeights_train"] = "1.0*(FIELD == \"FIELD_0\") + 2.0*(FIELD == \"FIELD_1\")" + glob.annz["userWeights_valid"] = "1.0*(FIELD == \"FIELD_1\") + 0.1*(FIELD == \"FIELD_2\")" + ``` + Here, training is only done using `FIELD_0` and `FIELD_1`; validation is weighted such that galaxies from `FIELD_1` have ten times the weight compared to galaxies from `FIELD_2` etc. + + - The same rules also apply for the weight and cut options for the KNN re-weighting method: `cutInp_wgtKNN`, `cutRef_wgtKNN`, `weightRef_wgtKNN` and `weightInp_wgtKNN`, and for the corresponding variables for the evaluation compatibility test: `cutInp_inTrain`, `cutRef_inTrain`, `weightRef_inTrain` and `weightInp_inTrain`. (Examples for the re-weighting and for the compatibility test using these variables are given in `scripts/annz_rndReg_advanced.py`.) + +- `ANNZ_PDF_max_0` no longer calculated by default. This may be turned back on by setting +```python +glob.annz["addMaxPDF"] = True +``` + +- Other minor modifications and bug fixes. + ## ANNZ 2.1.1 (15/1/2016) - Fixed bug in generating a name for an internal `TF1` function in `ANNZ::setupKdTreeKNN()`. diff --git a/README.md b/README.md index 215d5df..b47a05f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# ANNZ 2.1.1 +# ANNZ 2.1.2 ## Introduction ANNZ uses both regression and classification techniques for estimation of single-value photo-z (or any regression problem) solutions and PDFs. In addition it is suitable for classification problems, such as star/galaxy classification. @@ -145,7 +145,7 @@ python scripts/annz_singleReg_quick.py --singleRegression --evaluate python scripts/annz_singleReg_quick.py --singleRegression --optimize ``` - 4. **evaluate**: Evaluate an input dataset using the trained MLM. + 4. **evaluate**: Evaluate an input dataset using the trained MLM. If the regression target is detected as part of the evaluated dataset, performance plots are created as well. ```bash python scripts/annz_singleReg_quick.py --singleRegression --evaluate ``` @@ -167,7 +167,7 @@ python scripts/annz_singleReg_quick.py --singleRegression --evaluate python scripts/annz_rndReg_quick.py --randomRegression --optimize ``` - 4. **evaluate**: Evaluate an input dataset using the derived estimators: the *best* MLM, the PDF(s) and the average of the weighted and the un-weighted PDF solutions. + 4. **evaluate**: Evaluate an input dataset using the derived estimators: the *best* MLM, the PDF(s) and the average of the weighted PDF solutions. If the regression target is detected as part of the evaluated dataset, performance plots are created as well. ```bash python scripts/annz_rndReg_quick.py --randomRegression --evaluate ``` @@ -190,7 +190,7 @@ python scripts/annz_singleReg_quick.py --singleRegression --evaluate python scripts/annz_binCls_quick.py --binnedClassification --verify ``` - 4. **evaluate**: Evaluate an input dataset using the derived estimators: the PDF and the average of the PDF solution. + 4. **evaluate**: Evaluate an input dataset using the derived estimators: the PDF and the average of the PDF solution. If the regression target is detected as part of the evaluated dataset, performance plots are created as well. ```bash python scripts/annz_binCls_quick.py --binnedClassification --evaluate ``` @@ -296,6 +296,30 @@ Here are a couple of examples: See the advanced scripts for additional details. +#### Definition of input samples + +- Machine learning methods require two input samples for the training process. The first of these is the *training sample*, which is used explicitly for the training process - we'll refer to this sample as `S-1`. The second is sometimes called the *validation sample* and sometimes the *testing sample*. It is used for evaluating the result of the trained MLM in each step of the training - we'll refer to this sample is `S-2`. The `S-2` sample should have the same properties as `S-1`, but should be an independent collection of objects. + +- The terminology between *validation* and *testing* is not always consistent. Some authors refer to a third independent sample of objects which is used to check the performance of the trained MLM after training is complete; this third sample may be referred to as a *validation* or as a *testing* sample, thus creating potential confusion with regards to `S-2`. For the purposes of using this software package, the distinction between *validation* and *testing* is not relevant. We define only two samples, `S-1` for training, and `S-2` for checking the performance during each step of the training process. The user may use the evaluation stage (after optimization/verification is complete) in order to derive the solution for any other sample of objects. + +- The user may define the training/testing samples in one of two ways (see `scripts/annz_rndReg_advanced.py` for details): + + 1. Automatic splitting: + ```python + glob.annz["splitType"] = "random" + glob.annz["inAsciiFiles"] = "boss_dr10_0.csv;boss_dr10_1.csv" + ``` + Set a list of input files in `inAsciiFiles`, and use `splitType` to specify the method for splitting the sample. Allowed values for the latter are `serial`, `blocks` or `random`. + + 2. Splitting by file: + ```python + glob.annz["splitType"] = "byInFiles" + glob.annz["splitTypeTrain"] = "boss_dr10_0.csv" + glob.annz["splitTypeTest"] = "boss_dr10_1.csv;boss_dr10_2.csv" + ``` + Set a list of input files for training in `splitTypeTrain`, and a list of input files for testing in `splitTypeTest`. + + #### Definition of signal and background objects in single/randomized classification Signal and background objects may be defined by using at least one pair of variables, either `userCuts_sig` and `userCuts_bck` or `inpFiles_sig` and `inpFiles_bck`. Alternatively, it is also possible to use three of the latter or all four in tandem. For example: @@ -400,9 +424,12 @@ glob.annz["addInTrainFlag"] = True - This output indicates if the an evaluated object is "compatible" with corresponding objects from the training dataset. The compatibility is estimated by comparing the density of objects in the training dataset in the vicinity of the evaluated object. If the evaluated object belongs to an area of parameter-space which is not represented in the training dataset, we will get `inTrainFlag = 0`. In this case, the output of the training is probably unreliable. - - The calculation is performed using a KNN approach, similar to the algorithm used for the `useWgtKNN` calculation. It is possible to generate either binary flags (i.e., `inTrainFlag = 0` or `1`) or to get a floating-point value between zero and one. The binary decision is based on the `maxRelRatioInRef_inTrain` parameter; the latter represents a threshold for the relative density of objects from the training sample in the area of the evaluated object. (If `maxRelRatioInRef_inTrain > 0`, then `inTrainFlag` is binary.) + - The calculation is performed using a KNN approach, similar to the algorithm used for the `useWgtKNN` calculation. It is possible to generate either binary flags (i.e., `inTrainFlag = 0` or `1`) or to get a floating-point value with in the range, `[0,1]`. The binary decision is based on the `maxRelRatioInRef_inTrain` parameter; the latter represents a threshold for the relative density of objects from the training sample in the area of the evaluated object. By default, `maxRelRatioInRef_inTrain = -1`. In this case, the user is expected to study the distribution of the `inTrainFlag` parameter, and decide on an appropriate cut value as a quality criteria. If the user sets ` 0 < maxRelRatioInRef_inTrain < 1`, this is equivalent to choosing the cut parameter in advance. In this case the quality flag is binary (i.e., all objects for which `inTrainFlag < maxRelRatioInRef_inTrain` will get a flag rounded down to `0`, the rest will get `1` ). + + - It is recommended to first generate a floating-point estimate of `inTrainFlag` and to study the distribution. One could e.g., decide to discard objects with a low value of `inTrainFlag`, based on how the bias or scatter increase as `inTrainFlag` decreases. Them, once a proper cut value for `maxRelRatioInRef_inTrain` is determined, `inTrainFlag` may be set to produce a binary decision. - - It is recommended to first generate a floating-point estimate of `inTrainFlag` and to study the distribution. For production, once a proper cut value for `maxRelRatioInRef_inTrain` is determined, `inTrainFlag` should be set to produce a binary decision. +If the regression target is detected as part of the evaluated dataset, performance plots are created as well, found e.g., at `output/test_randReg_quick/regres/eval/plots/`. This is useful if, for instance, one wants to easily check the performance on a dataset which was not used for training/testing. +The plots will also include the dependence of the performance on any added variable in the evaluated dataset. For instance, if one sets `glob.annz["addOutputVars"] = "MAG_Z"`, then the dependence of the bias, scatter etc. on `MAG_Z` wil be plotted. This is particularly useful for assessing the dependence of the performance on the `inTrainFlag` parameter, in order to decide on a cut value for the latter. ### Single regression @@ -452,8 +479,6 @@ A few notes: - The cut variables, `userCuts_train` and `userCuts_valid`, are binary, in the sense that they define the conditions for an object to be accepted for training or validation. On the other hand, the weight variables, `userWeights_train` and `userWeights_valid`, can serve as either cuts or weights; a zero-weight is equivalent to a cut, as it effectively excludes an objects. Therefore it is possible to compose weight expressions which have "boolean" components (such as `(MAGERR_I > 1)`) which are numerically equivalent to `0` or `1`. In principle, we can therefore do everything with `userWeights_train` and `userWeights_valid`. However, due to performance considerations, it is recommended to use `userCuts_train` and `userCuts_valid` for well defined rejection criteria; then use the weight variables for everything else. - - If as part of the ascii or ROOT output an object has a weight which is zero, then the corresponding estimator for that object should not be used! - - The plots provided following optimization or verification all take into account the respective weights of the different estimators. ## General comments @@ -469,7 +494,6 @@ A few notes: ``` where non-detection of magnitudes (usually indicated by setting a magnitude to `100`) are mapped to the magnitude limits in the different bands. This avoids training/evaluating with nonsensical numerical values; it also does not require any special pre-processing of the input dataset, as the conditions are set on the fly during the training and evaluation stages. - - The training phase may be run multiple times, in order to make sure that all MLMs have completed training successfully. By default, if a trained MLM is detected in the output directory, then ANNZ does not overwrite it. In order to force re-training, one may delete the training directory for a particular MLM (for instance, using `scripts/annz_rndReg_quick.py`, this might be `output/test_randReg_quick/regres/train/ANNZ_3`). Alternatively, it's possible to force retraining by setting in the relevant python script the flag, ```python glob.annz["overwriteExistingTrain"] = True @@ -485,7 +509,7 @@ A few notes: - It is possible to use ANNZ to generate object weights, based on a reference dataset. The weights are generated as part of the `--genInputTrees` phase using the `useWgtKNN` option, and are then used for training and optimization; they are also calculated during evaluation, and added as part of the per-object weight which is included in the output of the evaluation. This feature is useful, if e.g., the target dataset for evaluation has a different distribution of input parameters, compared to the training dataset. For instance, for photo-z derivation, it is possible for the spectroscopic training sample to have a different color distribution, compared to the target photometric sample. The derived weights in this case are calculated as the ratio between the number of objects in a given color-box in the reference sample, compared to the training sample. The procedure is implemented in `CatFormat::addWgtKNNtoTree()` (in `src/CatFormat_wgtKNN.cpp`), where a more detailed explanation is also given. See `scripts/annz_rndReg_advanced.py` for a use-example. - - Using the script, `scripts/annz_rndReg_weights.py`, it is possible to generate the weights based on the KNN method (`useWgtKNN`), and/or the `inTrainFlag` quality-flag, without training/evaluating any MLMs. The former are stored to e.g., `output/test_randReg_weights/rootIn/ANNZ_KNN_wANNZ_tree_valid_0000.csv`, and the latter to `output/test_randReg_weights/inTrainFlag/inTrainFlagANNZ_tree_wgtTree_0000.csv`. + - Using the script, `scripts/annz_rndReg_weights.py`, it is possible to generate the weights based on the KNN method (`useWgtKNN`), and/or the `inTrainFlag` quality-flag, without training/evaluating any MLMs. That is, no machine-learning or photo-z training is needed. Instead, this script may be used to simple generate training weights, or conversely derive the `inTrainFlag` quality-flag, for a given evaluated sample, with respect to a specific reference sample. The former are stored to e.g., `output/test_randReg_weights/rootIn/ANNZ_KNN_wANNZ_tree_valid_0000.csv`, and the latter to `output/test_randReg_weights/inTrainFlag/inTrainFlagANNZ_tree_wgtTree_0000.csv`. - The KNN error, weight and quality-flag calculations are nominally performed for rescaled variable distributions; each input variable is mapped by a linear transformation to the range `[-1,1]`, so that the distance in the input parameter space is not biased by the scale (units) of the different parameters. It is possible to prevent the rescalling by setting the following flags to `False`: `doWidthRescale_errKNN`, `doWidthRescale_wgtKNN` and `doWidthRescale_inTrain`. These respectively relate to the KNN error calculation, the reference dataset reweighting, and the training quality-flag. @@ -493,12 +517,25 @@ A few notes: - It is possible to train/optimize MLMs using specific cuts and/or weights, based on any mathematical expression which uses the variables defined in the input dataset (not limited to the variables used for the training). The relevant variables are `userCuts_train`, `userCuts_valid`, `userWeights_train` and `userWeights_valid`. See the advanced scripts for use-examples. - - The syntax for math expressions is defined using the ROOT conventions (see e.g., [TMath](https://root.cern.ch/root/html/TMath.html) and [TFormula](https://root.cern.ch/root/html/TFormula.html)). Acceptable expressions may for instance include the following ridiculous choice: + - The syntax for math expressions is defined using the ROOT conventions (see e.g., [TMath](https://root.cern.ch/root/html524/TMath.html) and [TFormula](https://root.cern.ch/root/html/TFormula.html)). Acceptable expressions may for instance include the following ridiculous choice: ```python glob.annz["userCuts_train"] = "(MAG_R > 22)/MAG_R + (MAG_R <= 22)*1" glob.annz["userCuts_valid"] = "pow(MAG_G,3) + exp(MAG_R)*MAG_I/20. + abs(sin(MAG_Z))" ``` + - Note that training variables (defined in `inputVariables`) can only include expressions containing integer or floating-point variables. However, for the cut and weight variables, it is possible to also use string variables. + + - For instance, let's assume that `inAsciiVars` included the variable `FIELD`, which gives the name of the field for each galaxy in the training and validation datasets. Then, one may e.g., set cuts and weights of the form: + ```python + glob.annz["userCuts_train"] = " (FIELD == \"FIELD_0\") || (FIELD == \"FIELD_1\")" + glob.annz["userCuts_valid"] = " (FIELD == \"FIELD_1\") || (FIELD == \"FIELD_2\")" + glob.annz["userWeights_train"] = "1.0*(FIELD == \"FIELD_0\") + 2.0*(FIELD == \"FIELD_1\")" + glob.annz["userWeights_valid"] = "1.0*(FIELD == \"FIELD_1\") + 0.1*(FIELD == \"FIELD_2\")" + ``` + Here, training is only done using `FIELD_0` and `FIELD_1`; validation is weighted such that galaxies from `FIELD_1` have 10 times the weight compared to galaxies from `FIELD_2` etc. + + - The same rules also apply for the weight and cut options for the KNN re-weighting method: `cutInp_wgtKNN`, `cutRef_wgtKNN`, `weightRef_wgtKNN` and `weightInp_wgtKNN`, and for the corresponding variables for the evaluation compatibility test (the `inTrainFlag` parameter): `cutInp_inTrain`, `cutRef_inTrain`, `weightRef_inTrain` and `weightInp_inTrain`. (Examples for the re-weighting and for the compatibility test using these variables are given in `scripts/annz_rndReg_advanced.py`.) + - By default, the output of evaluation is written to a subdirectory named `eval` in the output directory. An output file may e.g., be `./output/test_randReg_quick/regres/eval/ANNZ_randomReg_0000.csv`. It is possible to set the the `evalDirPostfix` variable in order to change this. For instance, setting ```python glob.annz["evalDirPostfix"] = "cat0" @@ -532,7 +569,7 @@ A few notes: glob.annz["isBatch"] = True ``` - - It is possible to use root input files instead of ascii inputs. In this case, use the `splitTypeTrain`, `splitTypeTest`, `splitTypeValid`, `inAsciiFiles` and `inAsciiFiles_wgtKNN` variables in the same way as for ascii inputs; in addition, specify the name of the tree inside the root files. The latter is done using the variable `inTreeName` (for the nominal set) or `inTreeName_wgtKNN` (for the `inAsciiFiles_wgtKNN` variable). An example is given in `scripts/annz_rndReg_advanced.py`. + - It is possible to use root input files instead of ascii inputs. In this case, use the `splitTypeTrain`, `splitTypeTest`, `inAsciiFiles` and `inAsciiFiles_wgtKNN` variables in the same way as for ascii inputs; in addition, specify the name of the tree inside the root files. The latter is done using the variable `inTreeName` (for the nominal set) or `inTreeName_wgtKNN` (for the `inAsciiFiles_wgtKNN` variable). An example is given in `scripts/annz_rndReg_advanced.py`. - The output of ANNZ includes escape sequences for color. To avoid these, set ```python diff --git a/examples/scripts/annz_binCls_advanced.py b/examples/scripts/annz_binCls_advanced.py index 2224a47..72dbd36 100644 --- a/examples/scripts/annz_binCls_advanced.py +++ b/examples/scripts/annz_binCls_advanced.py @@ -58,34 +58,24 @@ glob.annz["inAsciiVars"] = "F:MAG_U;F:MAGERR_U;F:MAG_G;F:MAGERR_G;F:MAG_R;F:MAGERR_R;F:MAG_I;F:MAGERR_I;F:MAG_Z;F:MAGERR_Z;D:Z" # -------------------------------------------------------------------------------------------------- - # - For training and testing/validation the input is divided into two (test,train) or into three (test,train,valid) - # sub-samples. - # - The user needs to define the number of sub-samples (e.g., nSplit = 1,2 or 3) and the way to divide the + # - For training and testing/validation the input is divided into two (test,train) sub-samples. + # - The user needs to define the way to divide the samples # inputs in one of 4 ways (e.g., splitType = "serial", "blocks", "random" or "byInFiles" (default)): # - serial: -> test;train;valid;test;train;valid;test;train;valid;test;train;valid... # - blocks: -> test;test;test;test;train;train;train;train;valid;valid;valid;valid... # - random: -> valid;test;test;train;valid;test;valid;valid;test;train;valid;train... # - separate input files. Must supplay at least one file in splitTypeTrain and one in splitTypeTest. - # In this case, [nSplit = 2]. Optionally can set [nSplit = 3] and provide a list of files in "splitTypeValid" as well. # - example use: # set inFileOpt and choose one of the following options for input file configuration: # -------------------------------------------------------------------------------------------------- - inFileOpt = 2 - # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing and validation + inFileOpt = 0 + # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing if inFileOpt == 0: - glob.annz["nSplit"] = 2 glob.annz["splitTypeTrain"] = "boss_dr10_0_large.csv" glob.annz["splitTypeTest"] = "boss_dr10_1_large.csv" - # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing. splitTypeValid - list of files for validation - elif inFileOpt == 1: - glob.annz["nSplit"] = 3 - glob.annz["splitTypeTrain"] = "boss_dr10_0_large.csv" - glob.annz["splitTypeTest"] = "boss_dr10_1_large.csv" - glob.annz["splitTypeValid"] = "boss_dr10_2_large.csv" - # inAsciiFiles - one list of input files for training, testing and validation, where the the objects are assigned to a given + # inAsciiFiles - one list of input files for training and testing, where the the objects are assigned to a given # category based on the selection criteria defined by splitType - elif inFileOpt == 2: - glob.annz["nSplit"] = 3 + elif inFileOpt == 1: glob.annz["splitType"] = "serial" # "serial", "blocks" or "random" glob.annz["inAsciiFiles"] = "boss_dr10_0_large.csv;boss_dr10_1_large.csv;boss_dr10_2_large.csv" else: @@ -135,8 +125,8 @@ glob.annz["weightVarNames_wgtKNN"] = "MAG_U;MAG_G;MAG_R;MAG_I;MAG_Z" # optional parameters (may leave empty as default value): - glob.annz["sampleFracInp_wgtKNN"] = 0.1 # fraction of dataset to use (positive number, smaller or equal to 1) - glob.annz["sampleFracRef_wgtKNN"] = 0.2 # fraction of dataset to use (positive number, smaller or equal to 1) + glob.annz["sampleFracInp_wgtKNN"] = 0.9 # fraction of dataset to use (positive number, smaller or equal to 1) + glob.annz["sampleFracRef_wgtKNN"] = 0.8 # fraction of dataset to use (positive number, smaller or equal to 1) glob.annz["outAsciiVars_wgtKNN"] = "MAG_U;MAG_G;MAGERR_U" # write out two additional variables to the output file glob.annz["weightRef_wgtKNN"] = "(MAGERR_R<0.7)*1 + (MAGERR_R>=0.7)/MAGERR_R" # down-weight objects with high MAGERR_R glob.annz["cutRef_wgtKNN"] = "MAGERR_U<200" # only use objects which have small MAGERR_U diff --git a/examples/scripts/annz_binCls_quick.py b/examples/scripts/annz_binCls_quick.py index 0c2bc93..f0d326a 100644 --- a/examples/scripts/annz_binCls_quick.py +++ b/examples/scripts/annz_binCls_quick.py @@ -43,9 +43,8 @@ # file, e.g., [TYPE:NAME] may be [F:MAG_U], with 'F' standing for float. (see advanced example for detailed explanation) glob.annz["inAsciiVars"] = "F:MAG_U;F:MAGERR_U;F:MAG_G;F:MAGERR_G;F:MAG_R;F:MAGERR_R;F:MAG_I;F:MAGERR_I;F:MAG_Z;F:MAGERR_Z;D:Z" - # splitTypeTrain - list of files for training, testing and validation. the entire dataset is split into [nSplit=3] parts (one for + # splitTypeTrain - list of files for training and testing. the entire dataset is split into two parts (one for # each subsample), where splitting is determined by the [splitType="serial"] criteria. (see advanced example for more options/detials.) - glob.annz["nSplit"] = 3 glob.annz["splitType"] = "serial" # "serial", "blocks" or "random" glob.annz["inAsciiFiles"] = "boss_dr10_0_large.csv;boss_dr10_1_large.csv;boss_dr10_2_large.csv" # run ANNZ with the current settings diff --git a/examples/scripts/annz_rndCls_advanced.py b/examples/scripts/annz_rndCls_advanced.py index f529b8c..ef450b3 100644 --- a/examples/scripts/annz_rndCls_advanced.py +++ b/examples/scripts/annz_rndCls_advanced.py @@ -85,35 +85,24 @@ + " F:petroR90_r; F:lnLStar_r; F:lnLExp_r; F:lnLDeV_r; F:mE1_r; F:mE2_r; F:mRrCc_r; I:type_r; I:type" # -------------------------------------------------------------------------------------------------- - # - For training and testing/validation the input is divided into two (test,train) or into three (test,train,valid) - # sub-samples. - # - The user needs to define the number of sub-samples (e.g., nSplit = 1,2 or 3) and the way to divide the + # - For training and testing/validation the input is divided into two (test,train) sub-samples. + # - The user needs to define the way to divide the samples # inputs in one of 4 ways (e.g., splitType = "serial", "blocks", "random" or "byInFiles" (default)): # - serial: -> test;train;valid;test;train;valid;test;train;valid;test;train;valid... # - blocks: -> test;test;test;test;train;train;train;train;valid;valid;valid;valid... # - random: -> valid;test;test;train;valid;test;valid;valid;test;train;valid;train... # - separate input files. Must supplay at least one file in splitTypeTrain and one in splitTypeTest. - # In this case, [nSplit = 2]. Optionally can set [nSplit = 3] and provide a list of files in "splitTypeValid" as well. # - example use: # set inFileOpt and choose one of the following options for input file configuration: # -------------------------------------------------------------------------------------------------- - inFileOpt = 1 - - # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing and validation + inFileOpt = 0 + # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing if inFileOpt == 0: - glob.annz["nSplit"] = 2 glob.annz["splitTypeTrain"] = "sgCatalogue_galaxy_0.txt;sgCatalogue_star_0.txt" glob.annz["splitTypeTest"] = "sgCatalogue_galaxy_1.txt;sgCatalogue_star_1.txt" - # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing. splitTypeValid - list of files for validation - elif inFileOpt == 1: - glob.annz["nSplit"] = 3 - glob.annz["splitTypeTrain"] = "sgCatalogue_galaxy_0.txt;sgCatalogue_star_0.txt" - glob.annz["splitTypeTest"] = "sgCatalogue_galaxy_1.txt;sgCatalogue_star_1.txt" - glob.annz["splitTypeValid"] = "sgCatalogue_galaxy_2.txt;sgCatalogue_star_2.txt;sgCatalogue_star_3.txt" - # inAsciiFiles - one list of input files for training, testing and validation, where the the objects are assigned to a given + # inAsciiFiles - one list of input files for training and testing, where the the objects are assigned to a given # category based on the selection criteria defined by splitType - elif inFileOpt == 2: - glob.annz["nSplit"] = 3 + elif inFileOpt == 1: glob.annz["splitType"] = "random" # "serial", "blocks" or "random" glob.annz["inAsciiFiles"] = "sgCatalogue_galaxy_0.txt;sgCatalogue_galaxy_1.txt;sgCatalogue_star_0.txt;sgCatalogue_star_1.txt" else: diff --git a/examples/scripts/annz_rndReg_advanced.py b/examples/scripts/annz_rndReg_advanced.py index 530f12d..6455794 100644 --- a/examples/scripts/annz_rndReg_advanced.py +++ b/examples/scripts/annz_rndReg_advanced.py @@ -65,42 +65,28 @@ glob.annz["inAsciiVars"] = "F:MAG_U;F:MAGERR_U;F:MAG_G;F:MAGERR_G;F:MAG_R;F:MAGERR_R;F:MAG_I;F:MAGERR_I;F:MAG_Z;F:MAGERR_Z;D:Z" # -------------------------------------------------------------------------------------------------- - # - For training and testing/validation the input is divided into two (test,train) or into three (test,train,valid) - # sub-samples. - # - The user needs to define the number of sub-samples (e.g., nSplit = 1,2 or 3) and the way to divide the + # - For training and testing/validation the input is divided into two (test,train) sub-samples. + # - The user needs to define the way to divide the samples # inputs in one of 4 ways (e.g., splitType = "serial", "blocks", "random" or "byInFiles" (default)): # - serial: -> test;train;valid;test;train;valid;test;train;valid;test;train;valid... # - blocks: -> test;test;test;test;train;train;train;train;valid;valid;valid;valid... # - random: -> valid;test;test;train;valid;test;valid;valid;test;train;valid;train... # - separate input files. Must supplay at least one file in splitTypeTrain and one in splitTypeTest. - # In this case, [nSplit = 2]. Optionally can set [nSplit = 3] and provide a list of files in "splitTypeValid" as well. - # - It is possible to use root input files instead of ascii inputs. In this case, use the "splitTypeTrain", "splitTypeTest", - # "splitTypeValid" and "inAsciiFiles" variables in the same way as for ascii inputs, but in addition, specify the name of the - # tree inside the root files, as the variable "inTreeName". Make sure not to mix ascii and root input files! # - example use: # set inFileOpt and choose one of the following options for input file configuration: # -------------------------------------------------------------------------------------------------- inFileOpt = 1 - # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing and validation + # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing if inFileOpt == 0: - glob.annz["nSplit"] = 2 glob.annz["splitTypeTrain"] = "boss_dr10_0.csv" glob.annz["splitTypeTest"] = "boss_dr10_1.csv;boss_dr10_2.csv" - # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing. splitTypeValid - list of files for validation - elif inFileOpt == 1: - glob.annz["nSplit"] = 3 - glob.annz["splitTypeTrain"] = "boss_dr10_0.csv" - glob.annz["splitTypeTest"] = "boss_dr10_1.csv;boss_dr10_2.csv" - glob.annz["splitTypeValid"] = "boss_dr10_3.csv" - # inAsciiFiles - one list of input files for training, testing and validation, where the the objects are assigned to a given + # inAsciiFiles - one list of input files for training and testing, where the the objects are assigned to a given # category based on the selection criteria defined by splitType - elif inFileOpt == 2: - glob.annz["nSplit"] = 3 + elif inFileOpt == 1: glob.annz["splitType"] = "serial" # "serial", "blocks" or "random" glob.annz["inAsciiFiles"] = "boss_dr10_0.csv;boss_dr10_1.csv;boss_dr10_2.csv;boss_dr10_3.csv" # example ofr using a root tree input file, instead of an ascii input - elif inFileOpt == 3: - glob.annz["nSplit"] = 2 + elif inFileOpt == 2: glob.annz["splitType"] = "serial" # "serial", "blocks" or "random" glob.annz["inTreeName"] = "ANNZ_tree_full" glob.annz["inAsciiFiles"] = "ANNZ_tree_full_00000.root" @@ -156,12 +142,21 @@ glob.annz["weightVarNames_wgtKNN"] = "MAG_U;MAG_G;MAG_R;MAG_I;MAG_Z" # optional parameters (may leave empty as default value): - glob.annz["sampleFracInp_wgtKNN"] = 0.15 # fraction of dataset to use (positive number, smaller or equal to 1) + glob.annz["sampleFracInp_wgtKNN"] = 0.99 # fraction of dataset to use (positive number, smaller or equal to 1) glob.annz["sampleFracRef_wgtKNN"] = 0.95 # fraction of dataset to use (positive number, smaller or equal to 1) glob.annz["outAsciiVars_wgtKNN"] = "MAG_U;MAG_G;MAGERR_U" # write out two additional variables to the output file glob.annz["weightRef_wgtKNN"] = "(MAGERR_R<0.7)*1 + (MAGERR_R>=0.7)/MAGERR_R" # down-weight objects with high MAGERR_R glob.annz["cutRef_wgtKNN"] = "MAGERR_U<200" # only use objects which have small MAGERR_U + # - trainTestTogether_wgtKNN + # by default, the weights are computed for the entire sample [trainTestTogether_wgtKNN = True]. + # That is, the training and the testing samples are used together - we calculate the difference between the + # distribution of input-variables between [train+test samples] and [ref sample]. However, it is possible to + # decide to comput the weights for each separately. That is, to calculate wegiths for [train sample] + # with regards to [ref sample], and to separately get [test sample] with regards to [ref sample]. The latter + # is only recommended if the training and testing samples have different inpput-variable distributions. + glob.annz["trainTestTogether_wgtKNN"] = False + # example for using a root file as input, instead of an ascii input: useRootInputFile = False if useRootInputFile: @@ -333,7 +328,7 @@ glob.annz["nPDFbins"] = 90 elif pdfBinsType == 2: # pdfBinWidth - width of each PDF bin (equal width bins between minValZ and maxValZ - automatically derive nPDFbins) - glob.annz["pdfBinWidth"] = 0.1 + glob.annz["pdfBinWidth"] = 0.01 # -------------------------------------------------------------------------------------------------- # modify_userCuts_valid,modify_userWeights_valid - @@ -421,9 +416,11 @@ # The calculation is performed using a KNN approach, similar to the algorithm used for # the [glob.annz["useWgtKNN"] = True] calculation. # - minNobjInVol_inTrain - The number of reference objects in the reference dataset which are used in the calculation. - # - maxRelRatioInRef_inTrain - Nominally, a number in the range, [0,1] - The minimal threshold of the relative difference between - # distances in the inTrainFlag calculation for accepting an object - Should be a (<0.5) positive number. - # If [maxRelRatioInRef_inTrain < 0] then this number is ignored, and the "inTrainFlag" flag becomes + # - maxRelRatioInRef_inTrain - Nominally [maxRelRatioInRef_inTrain = -1], but can also be + # a number in the range, [0,1] - This is the minimal threshold of the relative + # difference between distances in the inTrainFlag calculation for accepting an object. + # If positive, it should probably be a (<0.5) positive number. If [maxRelRatioInRef_inTrain < 0], + # then this number is ignored, and the "inTrainFlag" flag becomes # a floating-point number in the range [0,1], instead of a binary flag. # - ...._inTrain - The rest of the parameters ending with "_inTrain" have a similar role as # their "_wgtKNN" counterparts, which are used with [glob.annz["useWgtKNN"] = True]. These are: @@ -432,9 +429,10 @@ # -------------------------------------------------------------------------------------------------- addInTrainFlag = False if addInTrainFlag: + glob.annz["inAsciiFiles"] = "boss_dr10_eval0_noZ.csv" # in this case, choose a larger input file glob.annz["addInTrainFlag"] = True glob.annz["minNobjInVol_inTrain"] = 100 - glob.annz["maxRelRatioInRef_inTrain"] = 0.1 + glob.annz["maxRelRatioInRef_inTrain"] = -1 glob.annz["weightVarNames_inTrain"] = "MAG_U;MAG_G;MAG_R;MAG_I;MAG_Z" # glob.annz["weightRef_inTrain"] = "(MAG_Z<20.5 && MAG_R<22 && MAG_U<24)" # cut the reference sample, just to have some difference... diff --git a/examples/scripts/annz_rndReg_quick.py b/examples/scripts/annz_rndReg_quick.py index 1097d0d..41d4f86 100644 --- a/examples/scripts/annz_rndReg_quick.py +++ b/examples/scripts/annz_rndReg_quick.py @@ -50,12 +50,9 @@ # file, e.g., [TYPE:NAME] may be [F:MAG_U], with 'F' standing for float. (see advanced example for detailed explanation) glob.annz["inAsciiVars"] = "F:MAG_U;F:MAGERR_U;F:MAG_G;F:MAGERR_G;F:MAG_R;F:MAGERR_R;F:MAG_I;F:MAGERR_I;F:MAG_Z;F:MAGERR_Z;D:Z" - # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing. splitTypeValid - list of files for validation if - # there is no dedicated validation sample, set [nSplit=2] and ignore splitTypeValid (see advanced example for more options). - glob.annz["nSplit"] = 3 + # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing. glob.annz["splitTypeTrain"] = "boss_dr10_0.csv" - glob.annz["splitTypeTest"] = "boss_dr10_1.csv;boss_dr10_2.csv" - glob.annz["splitTypeValid"] = "boss_dr10_3.csv" + glob.annz["splitTypeTest"] = "boss_dr10_1.csv;boss_dr10_2.csv;boss_dr10_3.csv" # run ANNZ with the current settings runANNZ() diff --git a/examples/scripts/annz_rndReg_weights.py b/examples/scripts/annz_rndReg_weights.py index 1fbd82e..339108c 100644 --- a/examples/scripts/annz_rndReg_weights.py +++ b/examples/scripts/annz_rndReg_weights.py @@ -50,7 +50,7 @@ glob.annz["outDirName"] = "test_randReg_weights" # no splitting of the dataset into training/validation is needed here -glob.annz["nSplit"] = 1 +glob.annz["nSplit"] = 1 # -------------------------------------------------------------------------------------------------- # pre-processing of the input dataset @@ -109,19 +109,31 @@ if useWgtKNN: glob.annz["useWgtKNN"] = True glob.annz["minNobjInVol_wgtKNN"] = 100 - glob.annz["inAsciiFiles_wgtKNN"] = "boss_dr10_colorCuts.csv" - glob.annz["inAsciiVars_wgtKNN"] = "F:MAG_U;F:MAGERR_U;F:MAG_G;F:MAGERR_G;F:MAG_R;F:MAGERR_R;F:MAG_I;F:MAGERR_I;F:MAG_Z;F:MAGERR_Z;D:Z" + # glob.annz["inAsciiFiles_wgtKNN"] = "boss_dr10_colorCuts.csv" + # glob.annz["inAsciiVars_wgtKNN"] = "F:MAG_U;F:MAGERR_U;F:MAG_G;F:MAGERR_G;F:MAG_R;F:MAGERR_R;F:MAG_I;F:MAGERR_I;F:MAG_Z;F:MAGERR_Z;D:Z" + glob.annz["inAsciiFiles_wgtKNN"] = glob.annz["inAsciiFiles"] + glob.annz["inAsciiVars_wgtKNN"] = glob.annz["inAsciiVars"] + glob.annz["weightInp_wgtKNN"] = "1/pow(MAG_G*MAG_U*MAG_R*MAG_I, 5)" + glob.annz["weightRef_wgtKNN"] = "1/MAGERR_G" + glob.annz["weightVarNames_wgtKNN"] = "MAG_U;MAG_G;MAG_R;MAG_I;MAG_Z" # optional parameters (may leave empty as default value): - glob.annz["sampleFracInp_wgtKNN"] = 0.15 # fraction of dataset to use (positive number, smaller or equal to 1) - glob.annz["sampleFracRef_wgtKNN"] = 0.95 # fraction of dataset to use (positive number, smaller or equal to 1) + glob.annz["sampleFracInp_wgtKNN"] = 0.85 # fraction of dataset to use (positive number, smaller or equal to 1) + glob.annz["sampleFracRef_wgtKNN"] = 0.90 # fraction of dataset to use (positive number, smaller or equal to 1) glob.annz["outAsciiVars_wgtKNN"] = "MAG_U;MAG_G;MAGERR_U" # write out two additional variables to the output file glob.annz["weightRef_wgtKNN"] = "(MAGERR_R<0.7)*1 + (MAGERR_R>=0.7)/MAGERR_R" # down-weight objects with high MAGERR_R glob.annz["cutRef_wgtKNN"] = "MAGERR_U<200" # only use objects which have small MAGERR_U glob.annz["doWidthRescale_wgtKNN"] = True - + # - trainTestTogether_wgtKNN + # by default, the weights are computed for the entire sample [trainTestTogether_wgtKNN = True]. + # That is, the training and the testing samples are used together - we calculate the difference between the + # distribution of input-variables between [train+test samples] and [ref sample]. However, it is possible to + # decide to comput the weights for each separately. That is, to calculate wegiths for [train sample] + # with regards to [ref sample], and to separately get [test sample] with regards to [ref sample]. The latter + # is only recommended if the training and testing samples have different inpput-variable distributions. + glob.annz["trainTestTogether_wgtKNN"] = False # run ANNZ with the current settings runANNZ() @@ -134,7 +146,7 @@ # inDirName,inAsciiFiles - directory with files to make the calculations from, and list of input files glob.annz["inDirName"] = "examples/data/photoZ/eval/" - glob.annz["inAsciiFiles"] = "boss_dr10_eval1_noZ.csv" + glob.annz["inAsciiFiles"] = "boss_dr10_eval0_noZ.csv" # inAsciiVars - list of parameters in the input files (doesnt need to be exactly the same as in doGenInputTrees, but must contain all # of the parameers which were used for training) glob.annz["inAsciiVars"] = "F:MAG_U;F:MAGERR_U;F:MAG_G;F:MAGERR_G;F:MAG_R;F:MAGERR_R;F:MAG_I;F:MAGERR_I;F:MAG_Z;F:MAGERR_Z" @@ -149,9 +161,11 @@ # The calculation is performed using a KNN approach, similar to the algorithm used for # the [glob.annz["useWgtKNN"] = True] calculation. # - minNobjInVol_inTrain - The number of reference objects in the reference dataset which are used in the calculation. - # - maxRelRatioInRef_inTrain - Nominally, a number in the range, [0,1] - The minimal threshold of the relative difference between - # distances in the inTrainFlag calculation for accepting an object - Should be a (<0.5) positive number. - # If [maxRelRatioInRef_inTrain < 0] then this number is ignored, and the "inTrainFlag" flag becomes + # - maxRelRatioInRef_inTrain - Nominally [maxRelRatioInRef_inTrain = -1], but can also be + # a number in the range, [0,1] - This is the minimal threshold of the relative + # difference between distances in the inTrainFlag calculation for accepting an object. + # If positive, it should probably be a (<0.5) positive number. If [maxRelRatioInRef_inTrain < 0], + # then this number is ignored, and the "inTrainFlag" flag becomes # a floating-point number in the range [0,1], instead of a binary flag. # - ...._inTrain - The rest of the parameters ending with "_inTrain" have a similar role as # their "_wgtKNN" counterparts, which are used with [glob.annz["useWgtKNN"] = True]. These are: diff --git a/examples/scripts/annz_singleReg_quick.py b/examples/scripts/annz_singleReg_quick.py index 896ffa5..e7db75f 100644 --- a/examples/scripts/annz_singleReg_quick.py +++ b/examples/scripts/annz_singleReg_quick.py @@ -49,12 +49,9 @@ # file, e.g., [TYPE:NAME] may be [F:MAG_U], with 'F' standing for float. (see advanced example for detailed explanation) glob.annz["inAsciiVars"] = "F:MAG_U;F:MAGERR_U;F:MAG_G;F:MAGERR_G;F:MAG_R;F:MAGERR_R;F:MAG_I;F:MAGERR_I;F:MAG_Z;F:MAGERR_Z;D:Z" - # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing. splitTypeValid - list of files for validation if - # there is no dedicated validation sample, set [nSplit=2] and ignore splitTypeValid (see advanced example for more options). - glob.annz["nSplit"] = 3 + # splitTypeTrain - list of files for training. splitTypeTest - list of files for testing. glob.annz["splitTypeTrain"] = "boss_dr10_0.csv" glob.annz["splitTypeTest"] = "boss_dr10_1.csv;boss_dr10_2.csv" - glob.annz["splitTypeValid"] = "boss_dr10_3.csv" # run ANNZ with the current settings runANNZ() diff --git a/examples/scripts/generalSettings.py b/examples/scripts/generalSettings.py index 2eb314d..42c60ed 100644 --- a/examples/scripts/generalSettings.py +++ b/examples/scripts/generalSettings.py @@ -47,7 +47,7 @@ def generalSettings(): # glob.annz["overwriteExistingTrain"] = True # optimCondReg - - # ["bias", "sig68" or "fracSig68"] - used for deciding how to rank MLM performance. the named criteria represents + # ["sig68" or "bias"] - used for deciding how to rank MLM performance. the named criteria represents # the metric which is more significant in deciding which MLM performs "best" (correspondingly, the bias, the # 68th percentile scatter of bias distribution, or the outlier fraction of the latter). The default value is "sig68". # glob.annz["optimCondReg"] = "bias" @@ -60,8 +60,7 @@ def generalSettings(): # correspondingly the value of the scatter of `deltaScaled` instead of that of `delta`. # The selection criteria for prioritizing the bias or the scatter remains the parameter `glob.annz["optimCondReg"]`. # This means that optimCondReg can take the value `bias` (for `delta` or `deltaScaled`), - # or `sig68` (for the 68th percentile scatter of `delta` or of `deltaScaled`), or - # `fracSig68` (for the outlier fraction of `delta` or of `deltaScaled`). The default value is False. + # or `sig68` (for the 68th percentile scatter of `delta` or of `deltaScaled`). The default value is False. # glob.annz["optimWithScaledBias"] = True # use the scaled bias `(zReg-zTrg)/(1+zTrg)` instead of the bias for the figures generated with the plotting @@ -104,7 +103,7 @@ def generalSettings(): # if propagating input-errors - nErrINP is the number of randomly generated MLM values used to propagate # the uncertainty on the input parameters to the MLM-estimator. See getRegClsErrINP() - # glob.annz["nErrINP"] = -1 + # glob.annz["nErrINP"] = -1 # if set to [-1], the default value is used # maximal number of objects in a tree/output ascii file # glob.annz["nObjectsToWrite"] = 1e6 diff --git a/examples/scripts/helperFuncs.py b/examples/scripts/helperFuncs.py index 1cbf38c..4f2dc5d 100644 --- a/examples/scripts/helperFuncs.py +++ b/examples/scripts/helperFuncs.py @@ -22,32 +22,31 @@ def init(): # -------------------------------------------------------------------------------------------------- def initParse(): parser = argparse.ArgumentParser(description="Command line parser:") - parser.add_argument("--make", action='store_true') - parser.add_argument("--clean", action='store_true') - parser.add_argument("--train", action='store_true') - parser.add_argument("--optimize", action='store_true') - parser.add_argument("--verify", action='store_true') - parser.add_argument("--evaluate", action='store_true') - parser.add_argument("--qsub", action='store_true') - parser.add_argument("--genInputTrees", action='store_true') - parser.add_argument("--singleRegression", action='store_true') - parser.add_argument("--randomRegression", action='store_true') - parser.add_argument("--binnedClassification",action='store_true') - parser.add_argument("--singleClassification",action='store_true') - parser.add_argument("--randomClassification",action='store_true') - parser.add_argument("--inTrainFlag" ,action='store_true') - - parser.add_argument("--truncateLog", action='store_true') - parser.add_argument("--isBatch", action='store_true') - - parser.add_argument("--logFileName",type=str, default="") - parser.add_argument("--logLevel", type=str, default="INFO") - - parser.add_argument("--maxNobj", type=float, default=0) - parser.add_argument("--trainIndex", type=int, default=-1) - - parser.add_argument("--fitsToAscii", action='store_true') - parser.add_argument("--asciiToFits", action='store_true') + parser.add_argument("--make", action='store_true') + parser.add_argument("--clean", action='store_true') + parser.add_argument("--train", action='store_true') + parser.add_argument("--optimize", action='store_true') + parser.add_argument("--verify", action='store_true') + parser.add_argument("--evaluate", action='store_true') + parser.add_argument("--qsub", action='store_true') + parser.add_argument("--genInputTrees", action='store_true') + parser.add_argument("--singleRegression", action='store_true') + parser.add_argument("--randomRegression", action='store_true') + parser.add_argument("--binnedClassification", action='store_true') + parser.add_argument("--singleClassification", action='store_true') + parser.add_argument("--randomClassification", action='store_true') + parser.add_argument("--inTrainFlag" , action='store_true') + parser.add_argument("--truncateLog", action='store_true') + parser.add_argument("--isBatch", action='store_true') + parser.add_argument("--fitsToAscii", action='store_true') + parser.add_argument("--asciiToFits", action='store_true') + parser.add_argument("--logFileName", type=str, default="") + parser.add_argument("--logLevel", type=str, default="INFO") + parser.add_argument("--generalOptS", type=str, default="NULL") + parser.add_argument("--makeOpt", type=str, default="NULL") + parser.add_argument("--maxNobj", type=float, default=0) + parser.add_argument("--trainIndex", type=int, default=-1) + parser.add_argument("--generalOptI", type=int, default=-1) glob.pars = vars(parser.parse_args()) @@ -68,7 +67,7 @@ def initParse(): if glob.pars["fitsToAscii"]: nSetups += 1 if glob.pars["asciiToFits"]: nSetups += 1 - if not ((nSetups == 1) or (nSetups == 0 and (hasMake or glob.pars["qsub"]))): + if not ((nSetups == 1) or (nSetups == 0 and (glob.pars["genInputTrees"] or hasMake or glob.pars["qsub"]))): log.warning("Should define exactly one of --singleClassification --randomClassification , --singleRegression " \ +"--randomRegression, --binnedClassification, --fitsToAscii, --asciiToFits !") @@ -85,7 +84,10 @@ def initParse(): if not (nModes == 1 or hasMake): log.warning("Should define exactly one of --genInputTrees --train , --optimize --verify, --evaluate, --inTrainFlag, --fitsToAscii, --asciiToFits !") - glob.pars["onlyMake"] = (((nSetups == 0) or (nModes == 0)) and hasMake) + glob.pars["onlyMake"] = (((nSetups == 0) or (nModes == 0)) and hasMake and (not glob.pars["genInputTrees"])) + + # add make option, e.g., "-j4" + glob.makeOpt = glob.pars["makeOpt"] if glob.pars["makeOpt"] is not "NULL" else "" # set basic values for operational flags # -------------------------------------------------------------------------------------------------- @@ -113,6 +115,10 @@ def initParse(): glob.annz["doFitsToAscii"] = glob.pars["fitsToAscii"] glob.annz["doAsciiToFits"] = glob.pars["asciiToFits"] + # general-use options for developers + glob.annz["generalOptS"] = glob.pars["generalOptS"] + glob.annz["generalOptI"] = glob.pars["generalOptI"] + # default values for options which should be overridden in generalSettings() # -------------------------------------------------------------------------------------------------- glob.annz["isBatch"] = (glob.pars["logFileName"] != "" or glob.pars["isBatch"]) @@ -221,7 +227,7 @@ def doMake(): resetDir(glob.libDirName,isClean) if isMake: log.info(blue(" - Moving to ")+red(glob.libDirName)+blue(" and compiling ANNZ... ")) - cmnd = "cd "+glob.libDirName+" ; make -f "+glob.annzDir+"Makefile" + cmnd = "cd "+glob.libDirName+" ; make "+glob.makeOpt+" -f "+glob.annzDir+"Makefile" cmkdStatus = os.system(cmnd) ; Assert("compilation failed",(cmkdStatus == 0)) if os.path.isfile(glob.exeName): log.info(blue(" - Found ")+red(glob.exeName)+blue(" - compilation seems to have succeded... ")) diff --git a/include/ANNZ.hpp b/include/ANNZ.hpp index 8af7a30..947be98 100644 --- a/include/ANNZ.hpp +++ b/include/ANNZ.hpp @@ -82,10 +82,11 @@ class ANNZ : public BaseClass { TString getErrKNNname(int nMLMnow = -1); int getErrKNNtagNow(TString errKNNname); TString getKeyWord(TString MLMname, TString sequence, TString key); + TString getRegularStrForm(TString strIn = "", VarMaps * var = NULL, TChain * aChain = NULL); void loadOptsMLM(); void setNominalParams(int nMLMnow, TString inputVariables, TString inputVarErrors); void setMethodCuts(VarMaps * var, int nMLMnow, bool verbose = true); - TCut getTrainTestCuts(TString cutType = "", int nMLMnow = -1, int split0 = 0, int split1 = 0); + TCut getTrainTestCuts(TString cutType = "", int nMLMnow = 0, int split0 = -1, int split1 = -1, VarMaps * var = NULL, TChain * aChain = NULL); void selectUserMLMlist(vector & optimMLMv, map & mlmSkipNow); void setInfoBinsZ(); int getBinZ(double valZ, vector & binEdgesV, bool forceCheck = false); @@ -153,7 +154,7 @@ class ANNZ : public BaseClass { void setBinClsPdfBinWeights(vector < vector < pair > > & pdfBinWgt, vector & nClsBinsIn); void getBinClsBiasCorPDF(TChain * aChain, vector & hisPdfBiasCorV); void doEvalReg(TChain * inChain = NULL, TString outDirName = "", vector * selctVarV = NULL); - void doMetricPlots(TChain * inChain = NULL, vector * selctMLMv = NULL); + void doMetricPlots(TChain * inChain = NULL, vector * addPlotVarV = NULL); // ----------------------------------------------------------------------------------------------------------- // ANNZ_loopRegCls.cpp : diff --git a/include/ANNZLinkDef.hpp b/include/ANNZLinkDef.hpp deleted file mode 100644 index 338b7ec..0000000 --- a/include/ANNZLinkDef.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#ifdef __CINT__ - -#pragma link off all globals; -#pragma link off all classes; -#pragma link off all functions; - -//#pragma link C++ class OptMaps+; -//#pragma link C++ class VarMaps+; - -#endif diff --git a/include/BaseClassLinkDef.hpp b/include/BaseClassLinkDef.hpp deleted file mode 100644 index 338b7ec..0000000 --- a/include/BaseClassLinkDef.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#ifdef __CINT__ - -#pragma link off all globals; -#pragma link off all classes; -#pragma link off all functions; - -//#pragma link C++ class OptMaps+; -//#pragma link C++ class VarMaps+; - -#endif diff --git a/include/CatFormat.hpp b/include/CatFormat.hpp index 13211fb..d95fd88 100644 --- a/include/CatFormat.hpp +++ b/include/CatFormat.hpp @@ -59,7 +59,7 @@ class CatFormat : public BaseClass { void parseInputVars(VarMaps * var, TString inAsciiVars, vector & inVarNames, vector & inVarTypes); bool inputLineToVars(TString line, VarMaps * var, vector & inVarNames, vector & inVarTypes); void setSplitVars(VarMaps * var, TRandom * rnd, map & intMap); - void addWgtKNNtoTree(TChain * aChainInp = NULL, TChain * aChainRef = NULL, TString outTreeName = ""); + void addWgtKNNtoTree(TChain * aChainInp = NULL, TChain * aChainRef = NULL, TChain * aChainEvl = NULL, TString outTreeName = ""); }; #endif // #ifndef CatFormat_h diff --git a/include/CatFormatLinkDef.hpp b/include/CatFormatLinkDef.hpp deleted file mode 100644 index 338b7ec..0000000 --- a/include/CatFormatLinkDef.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#ifdef __CINT__ - -#pragma link off all globals; -#pragma link off all classes; -#pragma link off all functions; - -//#pragma link C++ class OptMaps+; -//#pragma link C++ class VarMaps+; - -#endif diff --git a/include/OptMaps.hpp b/include/OptMaps.hpp index d134cc1..fe4e638 100644 --- a/include/OptMaps.hpp +++ b/include/OptMaps.hpp @@ -67,6 +67,8 @@ public : // ----------------------------------------------------------------------------------------------------------- inline TString baseOutDirName() { return (TString)"./output/"; }; // this values must never be changed !!! inline TString baseInDirName() { return (TString)"rootIn/"; }; // this values must never be changed !!! + // base-prefix (may be defined by user) for general naming + inline TString basePrefix() { return (TString)(HasOptC("basePrefix")?GetOptC("basePrefix"):"basePrefix_"); }; void setColors(); diff --git a/include/OptMapsLinkDef.hpp b/include/OptMapsLinkDef.hpp deleted file mode 100644 index 338b7ec..0000000 --- a/include/OptMapsLinkDef.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#ifdef __CINT__ - -#pragma link off all globals; -#pragma link off all classes; -#pragma link off all functions; - -//#pragma link C++ class OptMaps+; -//#pragma link C++ class VarMaps+; - -#endif diff --git a/include/OutMngrLinkDef.hpp b/include/OutMngrLinkDef.hpp deleted file mode 100644 index 338b7ec..0000000 --- a/include/OutMngrLinkDef.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#ifdef __CINT__ - -#pragma link off all globals; -#pragma link off all classes; -#pragma link off all functions; - -//#pragma link C++ class OptMaps+; -//#pragma link C++ class VarMaps+; - -#endif diff --git a/include/Utils.hpp b/include/Utils.hpp index 8b57a6b..2c73f55 100644 --- a/include/Utils.hpp +++ b/include/Utils.hpp @@ -81,13 +81,13 @@ class Utils { return TString::Format((TString)tree->GetName()+"_friend_%d",nTreeFriends); }; // see: http://www.cplusplus.com/reference/cstdio/printf/ - inline TString boolToStr (bool input) { return (TString)(input ? "1" : "0"); }; - inline TString intToStr (int input, TString format = "%d" ) { return TString::Format(format,input); }; - inline TString lIntToStr (long int input, TString format = "%ld" ) { return TString::Format(format,input); }; - inline TString uIntToStr (int input, TString format = "%u" ) { return TString::Format(format,input); }; - inline TString ULIntToStr (long int input, TString format = "%lu" ) { return TString::Format(format,input); }; - inline TString floatToStr (double input, TString format = "%f" ) { return TString::Format(format,input); }; - inline TString doubleToStr(double input, TString format = "%.10g") { return TString::Format(format,input); }; + inline TString boolToStr (bool input) { return (TString)(input ? "1" : "0"); }; + inline TString intToStr (int input, TString format = "%d" ) { return TString::Format(format,input); }; + inline TString lIntToStr (long int input, TString format = "%ld" ) { return TString::Format(format,input); }; + inline TString uIntToStr (unsigned int input, TString format = "%u" ) { return TString::Format(format,input); }; + inline TString ULIntToStr (unsigned long int input, TString format = "%lu" ) { return TString::Format(format,input); }; + inline TString floatToStr (double input, TString format = "%f" ) { return TString::Format(format,input); }; + inline TString doubleToStr(double input, TString format = "%.10g") { return TString::Format(format,input); }; Int_t strToInt (TString input); Long64_t strToLong (TString input); @@ -121,9 +121,9 @@ class Utils { bool validFileExists(TString fileName = "", bool verif = true); void resetDirectory(TString OutDirName = "", bool verbose = false, bool copyCode = false); void checkCmndSafety(TString cmnd = "", bool verbose = false); - void safeRM(TString cmnd = "", bool verbose = false); + void safeRM(TString cmnd = "", bool verbose = false, bool checkExitStatus = true); TString getShellCmndOutput(TString cmnd = "", vector * outV = NULL, bool verbose = false, bool checkExitStatus = true, int * getSysReturn = NULL); - void exeShellCmndOutput(TString cmnd = "", bool verbose = false, bool checkExitStatus = true); + int exeShellCmndOutput(TString cmnd = "", bool verbose = false, bool checkExitStatus = true); TString cleanWeightExpr(TString wgtIn); bool isSameWeightExpr(TString wgt0, TString wgt1); @@ -162,6 +162,7 @@ class Utils { // variables // =========================================================================================================== + TRandom3 * rnd; vector colours, markers, greens, blues, reds, fillStyles; OptMaps * glob, * param; diff --git a/include/UtilsLinkDef.hpp b/include/UtilsLinkDef.hpp deleted file mode 100644 index 338b7ec..0000000 --- a/include/UtilsLinkDef.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#ifdef __CINT__ - -#pragma link off all globals; -#pragma link off all classes; -#pragma link off all functions; - -//#pragma link C++ class OptMaps+; -//#pragma link C++ class VarMaps+; - -#endif diff --git a/include/VarMaps.hpp b/include/VarMaps.hpp index a94c15d..c5e370f 100644 --- a/include/VarMaps.hpp +++ b/include/VarMaps.hpp @@ -107,6 +107,7 @@ public : inline TTree * getTreeRead () { if(dynamic_cast(treeRead )) return treeRead; else return NULL; } void setTreeWrite(TTree * tree = NULL); void setTreeRead (TTree * tree = NULL); + void fillTree(); void createTreeBranches(TTree * tree = NULL, TString prefix = "", TString postfix = "", vector * excludedBranchNames = NULL); void connectTreeBranches(TTree * tree = NULL, vector * excludedBranchNames = NULL); @@ -276,6 +277,20 @@ public : inline void OrVarB (TString aName, Bool_t input) { SetVarB_ (aName, GetVarB_(aName) || input); } // ----------------------------------------------------------------------------------------------------------- + // simplify the interface to string formulae + // ----------------------------------------------------------------------------------------------------------- + inline TString regularizeStringForm(TString strIn) { + TString strOut(strIn), strGetBy(".fString"); + + for(Map ::iterator itr=varC.begin(); itr!=varC.end(); ++itr) { + strOut.ReplaceAll(itr->first,(TString)(itr->first)+strGetBy); + while(strOut.Contains(strGetBy+strGetBy)) { strOut.ReplaceAll(strGetBy+strGetBy,strGetBy); } + } + + return strOut; + } + inline TCut regularizeStringForm(TCut cutIn) { return ((TCut)regularizeStringForm((TString)cutIn)); } + // counter access // ----------------------------------------------------------------------------------------------------------- inline void resetCntr() { cntrMap->resetCntr(); } @@ -296,8 +311,9 @@ public : // internal functions for variable manipulatios // ----------------------------------------------------------------------------------------------------------- inline void NewVarB_ (TString aName, Bool_t input) { verifyType(aName,"B"); DelVarB_ (aName); varB [aName] = input; hasB [aName] = true; } - inline void NewVarC_ (TString aName, TString input) { verifyType(aName,"C"); DelVarC_ (aName); varC [aName] = new TObjString(aName); - varC [aName]->SetString(input); hasC [aName] = true; } + inline void NewVarC_ (TString aName, TString input) { + verifyType(aName,"C"); DelVarC_(aName); varC[aName] = new TObjString(aName); varC[aName]->SetString(input); hasC[aName] = true; + } inline void NewVarS_ (TString aName, Short_t input) { verifyType(aName,"S"); DelVarS_ (aName); varS [aName] = input; hasS [aName] = true; } inline void NewVarI_ (TString aName, Int_t input) { verifyType(aName,"I"); DelVarI_ (aName); varI [aName] = input; hasI [aName] = true; } inline void NewVarL_ (TString aName, Long64_t input) { verifyType(aName,"L"); DelVarL_ (aName); varL [aName] = input; hasL [aName] = true; } diff --git a/include/VarMapsLinkDef.hpp b/include/VarMapsLinkDef.hpp deleted file mode 100644 index 338b7ec..0000000 --- a/include/VarMapsLinkDef.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#ifdef __CINT__ - -#pragma link off all globals; -#pragma link off all classes; -#pragma link off all functions; - -//#pragma link C++ class OptMaps+; -//#pragma link C++ class VarMaps+; - -#endif diff --git a/src/ANNZ_err.cpp b/src/ANNZ_err.cpp index 0670be4..053693b 100644 --- a/src/ANNZ_err.cpp +++ b/src/ANNZ_err.cpp @@ -36,14 +36,7 @@ void ANNZ::createTreeErrKNN(int nMLMnow) { TString MLMname = getTagName(nMLMnow); TString errKNNname = getErrKNNname(nMLMnow); TString MLMname_i = getTagIndex(nMLMnow); - - // ----------------------------------------------------------------------------------------------------------- - // the _train trees are used in all cases: - // - if (separateTestValid == false) then the errors of the _train will be derived from the same source, but - // there is no choice in the matter anyway... - // - if (separateTestValid == true) then the errors for both the testing (ANNZ_tvType<0.5 in _valid) and for the - // validation (ANNZ_tvType>0.5 in _valid) will be derived from the independent source of the _train tree - // ----------------------------------------------------------------------------------------------------------- + VarMaps * var_0 = new VarMaps(glob,utils,"treeErrKNN_0"); VarMaps * var_1 = new VarMaps(glob,utils,"treeErrKNN_1"); @@ -121,7 +114,7 @@ void ANNZ::createTreeErrKNN(int nMLMnow) { var_1->SetVarF(zTrgName,zTrg); } - outTree->Fill(); + var_1->fillTree(); var_0->IncCntr("nObj"); mayWriteObjects = true; } diff --git a/src/ANNZ_loopCls.cpp b/src/ANNZ_loopCls.cpp index 8e9cc4a..b32ad7b 100644 --- a/src/ANNZ_loopCls.cpp +++ b/src/ANNZ_loopCls.cpp @@ -31,7 +31,7 @@ void ANNZ::optimCls() { int nMLMs = glob->GetOptI("nMLMs"); TString sigBckTypeName = glob->GetOptC("sigBckTypeName"); - bool separateTestValid = glob->GetOptB("separateTestValid"); + // bool separateTestValid = glob->GetOptB("separateTestValid"); // deprecated int nANNZtypes = (int)allANNZtypes.size(); // adjust maxNobj to stop loop after maxNobj/2 sig and maxNobj/2 bck objects have been accepted @@ -44,7 +44,8 @@ void ANNZ::optimCls() { // create the chain for the loop // ----------------------------------------------------------------------------------------------------------- - TString treeNamePostfix = (TString)(separateTestValid ? "_valid" : "_train"); + // TString treeNamePostfix = (TString)(separateTestValid ? "_valid" : "_train"); // deprecated + TString treeNamePostfix = (TString)"_train"; TString inTreeName = (TString)glob->GetOptC("treeName")+treeNamePostfix; TString inFileName = (TString)glob->GetOptC("postTrainDirNameFull")+inTreeName+"*.root"; @@ -58,7 +59,7 @@ void ANNZ::optimCls() { VarMaps * var = new VarMaps(glob,utils,"loopRegClsVar"); var->connectTreeBranches(aChain); - if(separateTestValid) var->setTreeCuts("_train",getTrainTestCuts("_valid",0)); + // if(separateTestValid) var->setTreeCuts("_train",getTrainTestCuts("_valid",0,0,0,var)); // deprecated // number of initial bins and rebin factor for classification response histograms hisBinsN = glob->GetOptI("clsResponseHisN"); rebinFactor = glob->GetOptI("clsResponseHisR"); @@ -135,7 +136,7 @@ void ANNZ::optimCls() { } if(breakLoop) break; - if(separateTestValid) { if(var->hasFailedTreeCuts("_train")) continue; } + // if(separateTestValid) { if(var->hasFailedTreeCuts("_train")) continue; } // deprecated int sigBckType = var->GetVarI(sigBckTypeName); if(sigBckType == -1) { @@ -592,8 +593,8 @@ void ANNZ::doEvalCls() { setMethodCuts(varKNN,nMLMnow,false); - TCut cutsNow(varKNN->getTreeCuts("_comn") + varKNN->getTreeCuts(MLMname+"_valid")); - TString wgtCls(userWgtsM[MLMname+"_valid"]); + TCut cutsNow = varKNN->getTreeCuts("_comn") + varKNN->getTreeCuts(MLMname+"_valid"); + TString wgtCls = getRegularStrForm(userWgtsM[MLMname+"_valid"],varKNN); TString inputComboNow = (TString)"[__ANNZ_VAR__]"+inputVariableV[nMLMnow]+"[__ANNZ_WGT__]"+wgtCls+"[__ANNZ_CUT__]"+(TString)cutsNow; inputComboNow.ReplaceAll(" ","").ReplaceAll("[__"," [__").ReplaceAll("__]","__] "); @@ -655,7 +656,8 @@ void ANNZ::doEvalCls() { TString MLMname_w = getTagWeight(nMLMnow); TString MLMname_v = getTagClsVal(nMLMnow); // create MLM-weight formulae for the input variables - var_0->NewForm(MLMname_w,userWgtsM[MLMname+"_valid"]); + TString wgtStr = getRegularStrForm(userWgtsM[MLMname+"_valid"],var_0); + var_0->NewForm(MLMname_w,wgtStr); // formulae for inpput-variable errors, to be used by getRegClsErrINP() if(isErrINPv[nMLMnow]) { @@ -749,7 +751,7 @@ void ANNZ::doEvalCls() { if(hasErrs) { var_1->SetVarF(MLMname_e,clsErr); } } - treeOut->Fill(); + var_1->fillTree(); // to increment the loop-counter, at least one method should have passed the cuts mayWriteObjects = true; var_0->IncCntr("nObj"); if(var_0->GetCntr("nObj") == maxNobj) breakLoop = true; diff --git a/src/ANNZ_loopReg.cpp b/src/ANNZ_loopReg.cpp index 8d6c993..70c1745 100644 --- a/src/ANNZ_loopReg.cpp +++ b/src/ANNZ_loopReg.cpp @@ -40,7 +40,7 @@ void ANNZ::optimReg() { int nMLMs = glob->GetOptI("nMLMs"); int nPDFs = glob->GetOptI("nPDFs"); - bool separateTestValid = glob->GetOptB("separateTestValid"); + // bool separateTestValid = glob->GetOptB("separateTestValid"); // deprecated TString addOutputVars = glob->GetOptC("addOutputVars"); TString indexName = glob->GetOptC("indexName"); TString zTrg = glob->GetOptC("zTrg"); @@ -50,15 +50,17 @@ void ANNZ::optimReg() { TString outDirNameOrig(outputs->GetOutDirName()), outDirName(""), inTreeName(""), inFileName(""); for(int nTrainValidNow=0; nTrainValidNow<2; nTrainValidNow++) { - TCut aCut(""); - TString treeNamePostfix(""); - if(separateTestValid) { - aCut = (TCut) ((nTrainValidNow == 0) ? getTrainTestCuts("_train",0) : getTrainTestCuts("_valid",0)); - treeNamePostfix = (TString)((nTrainValidNow == 0) ? "_valid" : "_valid"); - } - else { - treeNamePostfix = (TString)((nTrainValidNow == 0) ? "_train" : "_valid"); - } + // deprecated + // TCut aCut(""); + // TString treeNamePostfix(""); + // if(separateTestValid) { + // aCut = (TCut) ((nTrainValidNow == 0) ? getTrainTestCuts("_train") : getTrainTestCuts("_valid")); + // treeNamePostfix = (TString)((nTrainValidNow == 0) ? "_valid" : "_valid"); + // } + // else { + // treeNamePostfix = (TString)((nTrainValidNow == 0) ? "_train" : "_valid"); + // } + TString treeNamePostfix = (TString)((nTrainValidNow == 0) ? "_train" : "_valid"); if(isBinCls && nTrainValidNow == 0 && !doBiasCorPDF) continue; @@ -79,6 +81,11 @@ void ANNZ::optimReg() { VERIFY(LOCATION,(TString)"Main and friend chains have different numbers of entries ... Something is horribly wrong !!!" ,(nEntriesChain_0 == nEntriesChain_1)); + // deprecated + // // define a temporary VarMaps so as to regularize the weight expression + // for(int nChainNow=0; nChainNow<2; nChainNow++) { + // aCut = (TCut)getRegularStrForm((TString)aCut,NULL,((nChainNow == 0) ? aChain_0 : aChain_1)); + // } // get the list of branch names from the friend-chain (all ANNZ_* branches), and add the target name // ----------------------------------------------------------------------------------------------------------- @@ -96,7 +103,7 @@ void ANNZ::optimReg() { } // check if any variables used for cuts but not requested by the user need to be included in the tree - TString cutStr = (TString)(getTrainTestCuts("_comn",0)+getTrainTestCuts(getTagName(0)+"_train",0)+getTrainTestCuts(getTagName(0)+"_valid",0)); + TString cutStr = (TString)getTrainTestCuts((TString)"_comn"+";"+getTagName(0)+"_train"+";"+getTagName(0)+"_valid"); vector chain0_nameV, addPlotVarV; utils->getTreeBranchNames(aChain_0,chain0_nameV); @@ -115,7 +122,8 @@ void ANNZ::optimReg() { outputs->InitializeDir(outDirName,glob->GetOptC("baseName")); // create the merged chain - TChain * aChainMerged = mergeTreeFriends(aChain_0,aChain_1,NULL,&acceptV,NULL,aCut); + // TChain * aChainMerged = mergeTreeFriends(aChain_0,aChain_1,NULL,&acceptV,NULL,aCut); // deprecated + TChain * aChainMerged = mergeTreeFriends(aChain_0,aChain_1,NULL,&acceptV); // deprecated verifyIndicesMLM(aChainMerged); @@ -352,14 +360,14 @@ void ANNZ::fillColosureV( map < int,vector > & zRegQnt_nANNZ, map < i if((var->GetCntr("nObj")+1 % nObjectsToWrite == 0) || breakLoop) var->printCntr(aChainName,Log::DEBUG); if(breakLoop) break; + double zTrg = var->GetVarF(zTrgName); for(int nMLMnow=0; nMLMnowGetVarF(MLMname_w); if(weightNow < EPS) continue; - double regValNow = var->GetVarF(MLMname); - int zRegBinN = getBinZ(regValNow,zClos_binE); if(zRegBinN < 0) continue; - double zTrg = var->GetVarF(zTrgName); + double weightNow = var->GetVarF(MLMname_w); if(weightNow < EPS) continue; + double regValNow = var->GetVarF(MLMname); + int zRegBinN = getBinZ(regValNow,zClos_binE); if(zRegBinN < 0) continue; double sclBias(regValNow-zTrg); if(optimWithSclBias) { @@ -370,6 +378,8 @@ void ANNZ::fillColosureV( map < int,vector > & zRegQnt_nANNZ, map < i sumWeightsBin[zRegBinN] += weightNow; closH[nMLMnow][zRegBinN]->Fill(sclBias,weightNow); + + if(nMLMnow == 0) var->IncCntr("nObj with [weight > 0]"); } var->IncCntr("nObj"); if(var->GetCntr("nObj") == maxNobj) breakLoop = true; @@ -436,7 +446,8 @@ void ANNZ::fillColosureV( map < int,vector > & zRegQnt_nANNZ, map < i zRegQnt_sigma68 [-1].push_back(mean_sigma68); zRegQnt_fracSig68[-1].push_back(avgFracSig68); - aLOG(Log::INFO) <,<"<,: "<,<"<,: " + < > & zRegQnt_nANNZ, map < i * @brief - Find the "best" MLM solution in randomized regression. * * @details - Find the "best" MLMs, given the three metrics: bias, sig68 and fracSig68, - * where for optimCondReg (one of the three) the best several methods within the top fracLimNow + * where for optimCondReg (one of the three) the best several methods within the top fracLimNow0 * percentile are chosen, so long as for the other two metrics, the MLMs are within - * the top (fracLimNow + 1_sigma) of the distribution of MLMs. This selects the solution + * the top fracLimNow1 (fracLimNow1 ~> fracLimNow1+0.2) of the distribution of MLMs. This selects the solution * which gives the "best" of the selected metric, which is also not the "worst" of the other two. * * @param zRegQnt_nANNZ - Map of vector, which is filled with the index of MLMs, corresponding to the order of the @@ -478,7 +489,8 @@ void ANNZ::getBestANNZ( map < int,vector > & zRegQnt_nANNZ, map < int map < int,vector > & zRegQnt_sigma68, map < int,vector > & zRegQnt_fracSig68, vector < int > & bestMLMsV, bool onlyInclusiveBin) { // =================================================================================================================== - aLOG(Log::INFO) <GetOptC("optimCondReg")< > & zRegQnt_nANNZ, map < int // ----------------------------------------------------------------------------------------------------------- // find the "best" MLM, which has an "optimal" combination of good metrics // ----------------------------------------------------------------------------------------------------------- - int nAcptBestMLMs(0), nfracLims(37), minNoptimMLMs(1), nBinsZ(-1); + int nAcptBestMLMs(0), nfracLims0(30), nfracLims1(6), minNoptimMLMs(1), nBinsZ(-1); for(map < int,vector >::iterator Itr=zRegQnt_nANNZ.begin(); Itr!=zRegQnt_nANNZ.end(); ++Itr) { nBinsZ++; } map < int,vector > nANNZv; @@ -529,142 +541,150 @@ void ANNZ::getBestANNZ( map < int,vector > & zRegQnt_nANNZ, map < int map < int,map< TString,vector > > bestMLMs; map < int,map > > > bestMLMsPairs; - for(int nfracLimNow=0; nfracLimNowGetOptI("minEleForQuantile")) continue; + for(int nBinNow=-1; nBinNow > metricLowQuants; - for(int nMetricNow=0; nMetricNow<3; nMetricNow++) { - TString metricName(""); - vector fracV, quantV(3,-1); - fracV.push_back(fracLimNow); fracV.push_back(0.16); fracV.push_back(0.84); + nANNZv [nBinNow].push_back(zRegQnt_nANNZ [nBinNow][nAcptANNZNow]); + biasV [nBinNow].push_back(zRegQnt_bias [nBinNow][nAcptANNZNow]); + sigma68V [nBinNow].push_back(zRegQnt_sigma68 [nBinNow][nAcptANNZNow]); + fracSig68V[nBinNow].push_back(zRegQnt_fracSig68[nBinNow][nAcptANNZNow]); + } + } - int hasQuants(0); - utils->param->clearAll(); - if (nMetricNow == 0) { hasQuants = utils->getQuantileV(fracV,quantV,biasV [nBinNow]); metricName = "bias"; } - else if(nMetricNow == 1) { hasQuants = utils->getQuantileV(fracV,quantV,sigma68V [nBinNow]); metricName = "sig68"; } - else if(nMetricNow == 2) { hasQuants = utils->getQuantileV(fracV,quantV,fracSig68V[nBinNow]); metricName = "fracSig68"; } - else assert(false); - - double quantLow = hasQuants ? quantV[0] : -1; - double quantDif = hasQuants ? quantLow + (quantV[2] - quantV[1])/2. : -1; - - metricLowQuants[metricName].push_back(quantLow); - metricLowQuants[metricName].push_back(quantDif); + bestMLMsPairs.clear(); + nAcptBestMLMs = 0; + for(int nBinNow=-1; nBinNowGetOptI("minEleForQuantile")) continue; + + map < TString, vector > metricLowQuants; + for(int nMetricNow=0; nMetricNow<3; nMetricNow++) { + TString metricName(""); + vector fracV, quantV(2,-1); + fracV.push_back(fracLimNow0); fracV.push_back(fracLimNow1); //fracV.push_back(0.16); fracV.push_back(0.84); + + int hasQuants(0); + utils->param->clearAll(); + if (nMetricNow == 0) { hasQuants = utils->getQuantileV(fracV,quantV,biasV [nBinNow]); metricName = "bias"; } + else if(nMetricNow == 1) { hasQuants = utils->getQuantileV(fracV,quantV,sigma68V [nBinNow]); metricName = "sig68"; } + else if(nMetricNow == 2) { hasQuants = utils->getQuantileV(fracV,quantV,fracSig68V[nBinNow]); metricName = "fracSig68"; } + else assert(false); + + double quantLow = hasQuants ? quantV[0] : -1; + double quantDif = hasQuants ? quantV[1] : -1; + // double quantDif = hasQuants ? quantLow + (quantV[2] - quantV[1])/2. : -1; + + metricLowQuants[metricName].push_back(quantLow); + metricLowQuants[metricName].push_back(quantDif); - aLOG(Log::DEBUG_2)< hasMetric; - double frac_bias(0),frac_sig68(0),frac_fracSig68(0); - - int nMLMnow = nANNZv [nBinNow][nAcptANNZNow]; - double mean_bias = biasV [nBinNow][nAcptANNZNow]; - double mean_sigma68 = sigma68V [nBinNow][nAcptANNZNow]; - double mean_fracSig68 = fracSig68V[nBinNow][nAcptANNZNow]; - - // make sure all the quantile calculations for this bin which are OK - bool skip(false); - for(int nEle=0; nEle<2; nEle++) { - for(int nMetricNow=0; nMetricNow<3; nMetricNow++) { - TString metricName(""); - if(nMetricNow == 0) metricName = "bias"; else if(nMetricNow == 1) metricName = "sig68"; else if(nMetricNow == 2) metricName = "fracSig68"; - if(metricLowQuants[metricName][nEle] < 0) skip = true; + for(int nAcptANNZNow=0; nAcptANNZNow hasMetric; + double frac_bias(0),frac_sig68(0),frac_fracSig68(0); + + int nMLMnow = nANNZv [nBinNow][nAcptANNZNow]; + double mean_bias = biasV [nBinNow][nAcptANNZNow]; + double mean_sigma68 = sigma68V [nBinNow][nAcptANNZNow]; + double mean_fracSig68 = fracSig68V[nBinNow][nAcptANNZNow]; + + // make sure all the quantile calculations for this bin which are OK + bool skip(false); + for(int nEle=0; nEle<2; nEle++) { + for(int nMetricNow=0; nMetricNow<3; nMetricNow++) { + TString metricName(""); + if(nMetricNow == 0) metricName = "bias"; else if(nMetricNow == 1) metricName = "sig68"; else if(nMetricNow == 2) metricName = "fracSig68"; + if(metricLowQuants[metricName][nEle] < 0) skip = true; + } } - } - if(skip) continue; + if(skip) continue; - // ----------------------------------------------------------------------------------------------------------- - // check if "bias" is within the top fracLimNow percentile, as well as that "sig68","fracSig68" are within one sigma of fracLimNow - // ----------------------------------------------------------------------------------------------------------- - frac_bias = metricLowQuants["bias"][0]; frac_sig68 = metricLowQuants["sig68"][1]; frac_fracSig68 = metricLowQuants["fracSig68"][1]; - if(mean_bias < frac_bias && mean_sigma68 < frac_sig68 && mean_fracSig68 < frac_fracSig68) { - hasMetric["bias"] = true; - bestMLMsPairs[nBinNow]["bias"].push_back(pair(nMLMnow,mean_bias)); - } - // ----------------------------------------------------------------------------------------------------------- - // check if "sig68" is within the top fracLimNow percentile, as well as that "bias","fracSig68" are within one sigma of fracLimNow - // ----------------------------------------------------------------------------------------------------------- - frac_bias = metricLowQuants["bias"][1]; frac_sig68 = metricLowQuants["sig68"][0]; frac_fracSig68 = metricLowQuants["fracSig68"][1]; - if(mean_bias < frac_bias && mean_sigma68 < frac_sig68 && mean_fracSig68 < frac_fracSig68) { - hasMetric["sig68"] = true; - bestMLMsPairs[nBinNow]["sig68"].push_back(pair(nMLMnow,mean_sigma68)); - } - // ----------------------------------------------------------------------------------------------------------- - // check if "fracSig68" is within the top fracLimNow percentile, as well as that "bias","sig68" are within one sigma of fracLimNow - // ----------------------------------------------------------------------------------------------------------- - frac_bias = metricLowQuants["bias"][1]; frac_sig68 = metricLowQuants["sig68"][1]; frac_fracSig68 = metricLowQuants["fracSig68"][0]; - if(mean_bias < frac_bias && mean_sigma68 < frac_sig68 && mean_fracSig68 < frac_fracSig68) { - hasMetric["fracSig68"] = true; - bestMLMsPairs[nBinNow]["fracSig68"].push_back(pair(nMLMnow,mean_fracSig68)); - } + // ----------------------------------------------------------------------------------------------------------- + // check if "bias" is within the top fracLimNow0 percentile, as well as that "sig68","fracSig68" are within fracLimNow1 + // ----------------------------------------------------------------------------------------------------------- + frac_bias = metricLowQuants["bias"][0]; frac_sig68 = metricLowQuants["sig68"][1]; frac_fracSig68 = metricLowQuants["fracSig68"][1]; + if(mean_bias < frac_bias && mean_sigma68 < frac_sig68 && mean_fracSig68 < frac_fracSig68) { + hasMetric["bias"] = true; + bestMLMsPairs[nBinNow]["bias"].push_back(pair(nMLMnow,mean_bias)); + } + // ----------------------------------------------------------------------------------------------------------- + // check if "sig68" is within the top fracLimNow0 percentile, as well as that "bias","fracSig68" are within fracLimNow1 + // ----------------------------------------------------------------------------------------------------------- + frac_bias = metricLowQuants["bias"][1]; frac_sig68 = metricLowQuants["sig68"][0]; frac_fracSig68 = metricLowQuants["fracSig68"][1]; + if(mean_bias < frac_bias && mean_sigma68 < frac_sig68 && mean_fracSig68 < frac_fracSig68) { + hasMetric["sig68"] = true; + bestMLMsPairs[nBinNow]["sig68"].push_back(pair(nMLMnow,mean_sigma68)); + } + // ----------------------------------------------------------------------------------------------------------- + // check if "fracSig68" is within the top fracLimNow0 percentile, as well as that "bias","sig68" are within fracLimNow1 + // ----------------------------------------------------------------------------------------------------------- + frac_bias = metricLowQuants["bias"][1]; frac_sig68 = metricLowQuants["sig68"][1]; frac_fracSig68 = metricLowQuants["fracSig68"][0]; + if(mean_bias < frac_bias && mean_sigma68 < frac_sig68 && mean_fracSig68 < frac_fracSig68) { + hasMetric["fracSig68"] = true; + bestMLMsPairs[nBinNow]["fracSig68"].push_back(pair(nMLMnow,mean_fracSig68)); + } - if(nBinNow == -1) { if(hasMetric[optimCondReg]) nAcptBestMLMs++; } + if(nBinNow == -1) { if(hasMetric[optimCondReg]) nAcptBestMLMs++; } - TString hasAll = TString(hasMetric["bias"]?(TString)biasTitle+" ":"")+TString(hasMetric["sig68"]?(TString)optimCondTitle+" ":"") - +TString(hasMetric["fracSig68"]?"fracSig68 ":""); - if(hasAll != "") aLOG(Log::DEBUG) < "< "< "< "<= minNoptimMLMs || nfracLimNow == nfracLims-1) { - aLOG(Log::INFO) <= minNoptimMLMs || (nfracLimNow0 == nfracLims0-1 && nfracLimNow1 == nfracLims1-1)) { + aLOG(Log::INFO) < 0)); @@ -1908,6 +1928,7 @@ void ANNZ::doEvalReg(TChain * inChain, TString outDirName, vector * s double minValZ = glob->GetOptF("minValZ"); double maxValZ = glob->GetOptF("maxValZ"); int nSmearsRnd = glob->GetOptI("nSmearsRnd"); + double nSmearUnf = glob->GetOptI("nSmearUnf"); // and cast to double, since we divide by this later TString _typeANNZ = glob->GetOptC("_typeANNZ"); UInt_t seed = glob->GetOptI("initSeedRnd"); if(seed > 0) seed += 11825; TString baseTag_v = glob->GetOptC("baseTag_v"); @@ -1918,8 +1939,9 @@ void ANNZ::doEvalReg(TChain * inChain, TString outDirName, vector * s bool needBinClsErr = glob->GetOptB("needBinClsErr"); bool writePosNegErrs = glob->GetOptB("writePosNegErrs"); bool doBiasCorPDF = glob->GetOptB("doBiasCorPDF"); + bool addMaxPDF = glob->GetOptB("addMaxPDF"); double minWeight = 0.001; - + TRandom * rnd = new TRandom(seed); TString regBestNameVal = getTagBestMLMname(baseTag_v); TString regBestNameErr = getTagBestMLMname(baseTag_e); @@ -1950,9 +1972,10 @@ void ANNZ::doEvalReg(TChain * inChain, TString outDirName, vector * s } } - int nPdfTypes(3); + int nPdfTypes(addMaxPDF ? 3 : 2); vector tagNameV(nPdfTypes); - tagNameV[0] = glob->GetOptC("baseTag_MLM_avg"); tagNameV[1] = glob->GetOptC("baseTag_PDF_avg"); tagNameV[2] = glob->GetOptC("baseTag_PDF_max"); + tagNameV[0] = glob->GetOptC("baseTag_MLM_avg"); tagNameV[1] = glob->GetOptC("baseTag_PDF_avg"); + if(nPdfTypes > 2) tagNameV[2] = glob->GetOptC("baseTag_PDF_max"); // figure out which MLMs to generate an error for, using which method (KNN errors or propagation of user-defined parameter-errors) // ----------------------------------------------------------------------------------------------------------- @@ -2311,8 +2334,8 @@ void ANNZ::doEvalReg(TChain * inChain, TString outDirName, vector * s setMethodCuts(varKNN,nMLMnow,false); - TCut cutsNow(varKNN->getTreeCuts("_comn") + varKNN->getTreeCuts(MLMname+"_valid")); - TString wgtReg(userWgtsM[MLMname+"_valid"]); + TCut cutsNow = varKNN->getTreeCuts("_comn") + varKNN->getTreeCuts(MLMname+"_valid"); + TString wgtReg = getRegularStrForm(userWgtsM[MLMname+"_valid"],varKNN); TString inputComboNow = (TString)"[__ANNZ_VAR__]"+inputVariableV[nMLMnow]+"[__ANNZ_WGT__]"+wgtReg+"[__ANNZ_CUT__]"+(TString)cutsNow; inputComboNow.ReplaceAll(" ","").ReplaceAll("[__"," [__").ReplaceAll("__]","__] "); @@ -2369,7 +2392,8 @@ void ANNZ::doEvalReg(TChain * inChain, TString outDirName, vector * s if(nLoopTypeNow == 1) continue; // create MLM-weight formulae for the input variables - var_0->NewForm(MLMname_w,userWgtsM[MLMname+"_valid"]); + TString wgtStr = getRegularStrForm(userWgtsM[MLMname+"_valid"],var_0); + var_0->NewForm(MLMname_w,wgtStr); // formulae for inpput-variable errors, to be used by getRegClsErrINP() if(isErrINPv[nMLMnow]) { @@ -2600,10 +2624,10 @@ void ANNZ::doEvalReg(TChain * inChain, TString outDirName, vector * s regErrN = var_0->GetVarF(MLMname_eN); regErr = var_0->GetVarF(MLMname_e); regErrP = var_0->GetVarF(MLMname_eP); } - bool hasNoErrNow = (regErrN < 0 || regErr < 0 || regErrP < 0); + bool hasNoErrNow = ((regErrN < 0) || (regErr < 0) || (regErrP < 0)); // in the (hopefully unlikely) event that the error calculation failed for a valid object - if(hasNoErrNow && regWgt > EPS) { + if(hasNoErrNow && (regWgt > EPS)) { nHasNoErr++; if(inLOG(Log::DEBUG_2)) { aLOG(Log::DEBUG_2)< * s // apply the bias-correction to the pdf // ----------------------------------------------------------------------------------------------------------- if(doBiasCorPDF) { - double nSmearUnf = nSmearsRnd * 2; TH1 * hisPDF_w_TMP = (TH1*)hisPDF_w[nPDFnow]->Clone((TString)hisPDF_w[nPDFnow]->GetName()+"_TMP"); for(int nBinXnow=1; nBinXnow * s } } - treeOut->Fill(); + var_1->fillTree(); mayWriteObjects = true; var_0->IncCntr("nObj"); if(var_0->GetCntr("nObj") == maxNobj) breakLoop = true; } @@ -2818,7 +2841,8 @@ void ANNZ::doEvalReg(TChain * inChain, TString outDirName, vector * s if(doStoreToAscii) { TChain * aChainReg = new TChain(outTreeNameV[1][0],outTreeNameV[1][0]); aChainReg->SetDirectory(0); aChainReg->Add(outFileNameV[1][0]); int nEntriesChain = aChainReg->GetEntries(); - aLOG(Log::DEBUG) < * s // =========================================================================================================== /** - * @brief - Create performance plots for regression. + * @brief - Create performance plots for regression. * - * @param aChain - Input chain, the result of doEvalReg(). - * @param selctMLMv - Possible vector of MLM names, which will be added to the list of solutions for which - * plots are generated. + * @param aChain - Input chain, the result of doEvalReg(). + * @param addPlotVarV - Possible vector of MLM or variable names, which will be added to the list of solutions + * for which plots are generated. */ // =========================================================================================================== -void ANNZ::doMetricPlots(TChain * aChain, vector * selctMLMv) { -// ======================================================================= +void ANNZ::doMetricPlots(TChain * aChain, vector * addPlotVarV) { +// ========================================================================= + if(!glob->GetOptB("doPlots")) { + aLOG(Log::DEBUG) <(aChain))); aLOG(Log::INFO) < * selctMLMv) { TString outDirNameFull = glob->GetOptC("outDirNameFull"); TString addOutputVars = glob->GetOptC("addOutputVars"); TString plotExt = glob->GetOptC("printPlotExtension"); - TString baseName_ANNZ = glob->GetOptC("baseName_ANNZ"); + TString basePrefix = glob->GetOptC("basePrefix"); TString baseName_regMLM_avg = glob->GetOptC("baseName_regMLM_avg"); TString baseName_regPDF_max = glob->GetOptC("baseName_regPDF_max"); TString baseName_regPDF_avg = glob->GetOptC("baseName_regPDF_avg"); @@ -2948,6 +2977,7 @@ void ANNZ::doMetricPlots(TChain * aChain, vector * selctMLMv) { double minValZ = glob->GetOptF("minValZ"); double maxValZ = glob->GetOptF("maxValZ"); bool isBinCls = glob->GetOptB("doBinnedCls"); + bool addMaxPDF = glob->GetOptB("addMaxPDF"); TString aChainName = (TString)aChain->GetName(); bool plotWithSclBias = glob->GetOptB("plotWithScaledBias"); @@ -2961,9 +2991,10 @@ void ANNZ::doMetricPlots(TChain * aChain, vector * selctMLMv) { TString regBestNameErr = getTagBestMLMname(baseTag_e); TString regBestNameWgt = getTagBestMLMname(baseTag_w); - int nPdfTypes(3); + int nPdfTypes(addMaxPDF ? 3 : 2); vector tagNameV(nPdfTypes); - tagNameV[0] = glob->GetOptC("baseTag_MLM_avg"); tagNameV[1] = glob->GetOptC("baseTag_PDF_avg"); tagNameV[2] = glob->GetOptC("baseTag_PDF_max"); + tagNameV[0] = glob->GetOptC("baseTag_MLM_avg"); tagNameV[1] = glob->GetOptC("baseTag_PDF_avg"); + if(nPdfTypes > 2) tagNameV[2] = glob->GetOptC("baseTag_PDF_max"); vector < TString > pdfTagWgtV(nPDFs), pdfTagErrV(nPDFs); for(int nPDFnow=0; nPDFnow * selctMLMv) { TString branchName = branchNameV[nBranchNow]; // search for format like "ANNZ_0": - if(branchName.BeginsWith(baseName_ANNZ)) { - TString nMLMstr(branchName); nMLMstr.ReplaceAll(baseName_ANNZ,""); + if(branchName.BeginsWith(basePrefix)) { + TString nMLMstr(branchName); nMLMstr.ReplaceAll(basePrefix,""); if(nMLMstr.IsDigit()) { // if a selection vector is inputed, only accept MLMs which are listed in it - if(selctMLMv) { - if(find(selctMLMv->begin(),selctMLMv->end(), branchName) == selctMLMv->end()) continue; + if(addPlotVarV) { + if(find(addPlotVarV->begin(),addPlotVarV->end(), branchName) == addPlotVarV->end()) continue; } int nMLMnow = utils->strToInt(nMLMstr); @@ -3155,7 +3186,7 @@ void ANNZ::doMetricPlots(TChain * aChain, vector * selctMLMv) { hisName = (TString)"his1_TMP"; // get the cuts (assume here that this function is used for "_valid" only, otherwise, would need to add a flag...) - TString treeCuts = (TString)(getTrainTestCuts("_comn",0)+getTrainTestCuts(getTagName(0)+"_valid",0)); + TString treeCuts = (TString)getTrainTestCuts((TString)"_comn"+";"+getTagName(0)+"_valid",0,0,0,var); for(int nTypeBinNow=0; nTypeBinNow>"+hisName; @@ -3183,7 +3214,7 @@ void ANNZ::doMetricPlots(TChain * aChain, vector * selctMLMv) { // double maxVal = aChain->GetMaximum(plotVars[nTypeBinNow]); // VERIFY(LOCATION,(TString)"Something is horribly wrong ?!?! ",(maxVal > minVal)); - double binW = (maxVal - minVal)/double(nBinsZ); + double binW = (maxVal - minVal)/double(nBinsZ); aLOG(Log::DEBUG) <GetOptB("doRegression")) doEvalReg(); - else if(glob->GetOptB("doClassification")) doEvalCls(); + // ----------------------------------------------------------------------------------------------------------- + // evaluation for regression + // ----------------------------------------------------------------------------------------------------------- + if(glob->GetOptB("doRegression")) { + // ----------------------------------------------------------------------------------------------------------- + // perform the evaluation + // ----------------------------------------------------------------------------------------------------------- + doEvalReg(); + + // ----------------------------------------------------------------------------------------------------------- + // run the plotting function if the regression target is included in the evaluated sample + // ----------------------------------------------------------------------------------------------------------- + if(glob->GetOptB("doPlots")) { + aLOG(Log::DEBUG) <GetOptC("basePrefix"); + TString treeName = glob->GetOptC("treeName"); + TString typeANNZ = glob->GetOptC("_typeANNZ"); + TString outDirNameFull = glob->GetOptC("outDirNameFull"); + TString zTrg = glob->GetOptC("zTrg"); + TString inTreeName = (TString)treeName+typeANNZ; + TString inFileName = (TString)outDirNameFull+inTreeName+"*.root"; + + // create a chain from the output of the evaluation + TChain * aChain = new TChain(inTreeName,inTreeName); aChain->SetDirectory(0); aChain->Add(inFileName); + aLOG(Log::DEBUG) <GetEntries()<<")" + <<" from "< branchNameV, addPlotVarV; + utils->getTreeBranchNames(aChain,branchNameV); + + // check if the chain has the regression target, and collect any other added variables + bool hasZtrg(false); + for(int nBranchNow=0; nBranchNow<(int)branchNameV.size(); nBranchNow++) { + TString addVarName = branchNameV[nBranchNow]; + + if (addVarName == zTrg) { hasZtrg = true; } + else if(!addVarName.BeginsWith(basePrefix)) { addPlotVarV.push_back(addVarName); } + } + + // generate the plots if the regression target was detected + if(hasZtrg) { + aLOG(Log::INFO) <GetOptB("doClassification")) { + doEvalCls(); + } return; } @@ -82,7 +139,7 @@ void ANNZ::makeTreeRegClsAllMLM() { int nMLMs = glob->GetOptI("nMLMs"); TString postTrainDirName = glob->GetOptC("postTrainDirNameFull"); - bool separateTestValid = glob->GetOptB("separateTestValid"); + // bool separateTestValid = glob->GetOptB("separateTestValid"); // deprecated int maxTreesMerge = glob->GetOptI("maxTreesMerge"); bool needBinClsErr = glob->GetOptB("needBinClsErr"); @@ -90,7 +147,7 @@ void ANNZ::makeTreeRegClsAllMLM() { // get the number of entries in the input trees to compare to the generated result-trees // ----------------------------------------------------------------------------------------------------------- for(int nTrainValidNow=0; nTrainValidNow<2; nTrainValidNow++) { - if(separateTestValid && (nTrainValidNow==0)) continue; + // if(separateTestValid && (nTrainValidNow==0)) continue; // deprecated treeNamePostfix = (TString)( (nTrainValidNow == 0) ? "_train" : "_valid" ); inTreeName = (TString)glob->GetOptC("treeName")+treeNamePostfix; @@ -123,7 +180,7 @@ void ANNZ::makeTreeRegClsAllMLM() { // ----------------------------------------------------------------------------------------------------------- bool foundGoodTrees = true; for(int nTrainValidNow=0; nTrainValidNow<2; nTrainValidNow++) { - if(separateTestValid && (nTrainValidNow==0)) continue; + // if(separateTestValid && (nTrainValidNow==0)) continue; // deprecated treeNamePostfix = (TString)( (nTrainValidNow == 0) ? "_train" : "_valid" ); inTreeName = (TString)glob->GetOptC("treeName")+treeNamePostfix; @@ -182,15 +239,15 @@ void ANNZ::makeTreeRegClsAllMLM() { TString weights_train = optMap->GetOptC("userWeights_train"); TString weights_valid = optMap->GetOptC("userWeights_valid"); TString misMatchs(""); - if(cuts_train != (TString)(getTrainTestCuts("_comn",nMLMnow)+getTrainTestCuts(MLMname+"_train",nMLMnow))) misMatchs += "cuts(train) "; - if(cuts_valid != (TString)(getTrainTestCuts("_comn",nMLMnow)+getTrainTestCuts(MLMname+"_valid",nMLMnow))) misMatchs += "cuts(valid) "; - if(weights_train != userWgtsM[MLMname+"_train"] ) misMatchs += "weights(train) "; - if(weights_valid != userWgtsM[MLMname+"_valid"] ) misMatchs += "weights(valid) "; + if(cuts_train != (TString)(getTrainTestCuts((TString)"_comn"+";"+MLMname+"_train",nMLMnow))) misMatchs += "cuts(train) "; + if(cuts_valid != (TString)(getTrainTestCuts((TString)"_comn"+";"+MLMname+"_valid",nMLMnow))) misMatchs += "cuts(valid) "; + if(weights_train != userWgtsM[MLMname+"_train"] ) misMatchs += "weights(train) "; + if(weights_valid != userWgtsM[MLMname+"_valid"] ) misMatchs += "weights(valid) "; - aLOG(Log::DEBUG_2) <<"cuts_train "<GetOptC("treeName")+treeNamePostfix; @@ -298,7 +355,7 @@ void ANNZ::makeTreeRegClsAllMLM() { saveFileName = getKeyWord("","postTrain","configSaveFileName"); //saveFileName = (TString)glob->GetOptC("postTrainDirNameFull")+"saveTime.txt"; for(int nTreeInNow=0; nTreeInNow<3; nTreeInNow++) { - if(separateTestValid && (nTreeInNow==0)) continue; + // if(separateTestValid && (nTreeInNow==0)) continue; // deprecated inTreeName = ""; if (nTreeInNow == 0) inTreeName = (TString)glob->GetOptC("treeName")+"_train"; @@ -438,9 +495,9 @@ void ANNZ::makeTreeRegClsOneMLM(int nMLMnow) { int maxNobj = 0; // maxNobj = glob->GetOptI("maxNobj"); // only allow limits in case of debugging !! TString indexName = glob->GetOptC("indexName"); TString sigBckTypeName = glob->GetOptC("sigBckTypeName"); - TString testValidType = glob->GetOptC("testValidType"); + // TString testValidType = glob->GetOptC("testValidType"); // deprecated UInt_t seed = glob->GetOptI("initSeedRnd"); if(seed > 0) seed += 58606; - bool separateTestValid = glob->GetOptB("separateTestValid"); + // bool separateTestValid = glob->GetOptB("separateTestValid"); // deprecated bool isBinCls = glob->GetOptB("doBinnedCls"); bool isCls = glob->GetOptB("doClassification") || isBinCls; bool needBinClsErr = glob->GetOptB("needBinClsErr"); @@ -491,7 +548,7 @@ void ANNZ::makeTreeRegClsOneMLM(int nMLMnow) { double separation(-1); for(int nTrainValidNow=0; nTrainValidNow<2; nTrainValidNow++) { - if(separateTestValid && (nTrainValidNow==0)) continue; + // if(separateTestValid && (nTrainValidNow==0)) continue; // deprecated TString treeNamePostfix = (TString)( (nTrainValidNow == 0) ? "_train" : "_valid" ); TString baseCutsName = (TString)"_comn"+";"+MLMname+treeNamePostfix; @@ -525,8 +582,8 @@ void ANNZ::makeTreeRegClsOneMLM(int nMLMnow) { setMethodCuts(varKNN,nMLMnow); - TCut cutsNow(varKNN->getTreeCuts("_comn") + varKNN->getTreeCuts(MLMname+treeNamePostfix)); - TString wgtReg(userWgtsM[MLMname+treeNamePostfix]); + TCut cutsNow = varKNN->getTreeCuts("_comn") + varKNN->getTreeCuts(MLMname+treeNamePostfix); + TString wgtReg = getRegularStrForm(userWgtsM[MLMname+treeNamePostfix],varKNN); setupKdTreeKNN(aChainKnn[0],knnErrOutFile,knnErrFactory,knnErrModule,trgIndexV,nMLMnow,cutsNow,wgtReg); } @@ -548,10 +605,11 @@ void ANNZ::makeTreeRegClsOneMLM(int nMLMnow) { int nEntriesChain = aChain->GetEntries(); aLOG(Log::INFO) <NewVarI(testValidType); + // var_1->NewVarI(testValidType); // deprecated // create MLM-weight formulae for the input variables - var_0->NewForm(MLMname_w,userWgtsM[MLMname+treeNamePostfix]); + TString wgtStr = getRegularStrForm(userWgtsM[MLMname+treeNamePostfix],var_0); + var_0->NewForm(MLMname_w,wgtStr); // formulae for inpput-variable errors, to be used by getRegClsErrINP() if(isErrINP) { @@ -623,8 +681,8 @@ void ANNZ::makeTreeRegClsOneMLM(int nMLMnow) { } if(skipObj) continue; // only relevant for classification - var_1->SetVarI(MLMname_i, var_0->GetVarI(indexName) ); - var_1->SetVarI(testValidType,var_0->GetVarI(testValidType)); + var_1->SetVarI(MLMname_i,var_0->GetVarI(indexName)); + // var_1->SetVarI(testValidType,var_0->GetVarI(testValidType)); // deprecated // fill the output tree if(isCls) { @@ -654,7 +712,7 @@ void ANNZ::makeTreeRegClsOneMLM(int nMLMnow) { var_1->SetVarF(MLMname_eN,regErrV[nMLMnow][0]); var_1->SetVarF(MLMname_e,regErrV[nMLMnow][1]); var_1->SetVarF(MLMname_eP,regErrV[nMLMnow][2]); } - treeOut->Fill(); + var_1->fillTree(); // to increment the loop-counter, at least one method should have passed the cuts mayWriteObjects = true; @@ -695,9 +753,9 @@ void ANNZ::makeTreeRegClsOneMLM(int nMLMnow) { TString hisName = (TString)"sepHis"+"_all"; TString drawExprs = (TString)MLMname+">>"+hisName; - TString trainCut = (TString)var_0->getTreeCuts("_train"); TString cutExprs = (TString)"("+MLMname_w+" > 0)"; - if(trainCut != "") cutExprs += (TString)" && ("+trainCut+")"; + // TString trainCut = (TString)var_0->getTreeCuts("_train"); // deprecated + // if(trainCut != "") cutExprs += (TString)" && ("+trainCut+")"; // deprecated int nEvtPass = aChainOut->Draw(drawExprs,cutExprs); if(nEvtPass > 0) his_all = (TH1F*)gDirectory->Get(hisName); his_all->BufferEmpty(); @@ -712,9 +770,9 @@ void ANNZ::makeTreeRegClsOneMLM(int nMLMnow) { TH1 * his1_sb = new TH1F(hisName,hisName,nBins,binL,binH); TString drawExprs = (TString)MLMname+">>+"+hisName; - TString trainCut = (TString)var_0->getTreeCuts("_train"); TString cutExprs = (TString)"("+MLMname_w+" > 0) && ("+sigBckTypeName+sigBckCut+")"; - if(trainCut != "") cutExprs += (TString)" && ("+trainCut+")"; + // TString trainCut = (TString)var_0->getTreeCuts("_train"); // deprecated + // if(trainCut != "") cutExprs += (TString)" && ("+trainCut+")"; // deprecated int nEvtPass = aChainOut->Draw(drawExprs,cutExprs); @@ -760,8 +818,8 @@ void ANNZ::makeTreeRegClsOneMLM(int nMLMnow) { OptMaps * optMap = new OptMaps("localOptMap"); TString saveName = ""; - TString cut_train = (TString)(getTrainTestCuts("_comn",nMLMnow)+getTrainTestCuts(MLMname+"_train",nMLMnow)); - TString cut_valid = (TString)(getTrainTestCuts("_comn",nMLMnow)+getTrainTestCuts(MLMname+"_valid",nMLMnow)); + TString cut_train = (TString)(getTrainTestCuts((TString)"_comn"+";"+MLMname+"_train",nMLMnow)); + TString cut_valid = (TString)(getTrainTestCuts((TString)"_comn"+";"+MLMname+"_valid",nMLMnow)); vector optNames; saveName = "userCuts_train"; optNames.push_back(saveName); optMap->NewOptC(saveName, cut_train); @@ -906,7 +964,8 @@ void ANNZ::deriveHisClsPrb(int nMLMnow) { var_0->setTreeCuts("_bck",bckCuts); // create MLM-weight formulae for the input variables - var_0->NewForm(MLMname_w,userWgtsM[MLMname+"_train"]); + TString wgtStr = getRegularStrForm(userWgtsM[MLMname+"_train"],var_0); + var_0->NewForm(MLMname_w,wgtStr); var_0->connectTreeBranchesForm(aChain,&readerInptV); @@ -1127,8 +1186,7 @@ TChain * ANNZ::mergeTreeFriends(TChain * aChain, TChain * aChainFriend, vectorhasFailedTreeCuts("aCut")) continue; } var_1->copyVarData(var_0,&varTypeNameV); - - mergedTree->Fill(); + var_1->fillTree(); var_0->IncCntr("nObj"); mayWriteObjects = true; } diff --git a/src/ANNZ_train.cpp b/src/ANNZ_train.cpp index d1ce1cd..4d023ac 100644 --- a/src/ANNZ_train.cpp +++ b/src/ANNZ_train.cpp @@ -88,11 +88,7 @@ void ANNZ::Train_singleCls() { aLOG(Log::DEBUG) <GetOptC("testValidType_train")) if has separate sub-sample for training - // convergence and for testing, as this condition is true for all training sample and for half of the testing sample // ----------------------------------------------------------------------------------------------------------- VarMaps * var = new VarMaps(glob,utils,"mainTrainVar"); @@ -100,20 +96,22 @@ void ANNZ::Train_singleCls() { setMethodCuts(var,nMLMnow); - // replace the training/validation cut definition in var, os that cutM["_valid"] will get the testing objects - if(glob->GetOptB("separateTestValid")) { - int nFoundCuts = var->replaceTreeCut(glob->GetOptC("testValidType_valid"),glob->GetOptC("testValidType_train")); - VERIFY(LOCATION,(TString)"Did not find cut \""+glob->GetOptC("testValidType_valid")+"\". Something is horribly wrong... ?!?",(nFoundCuts != 0)); - } + // deprecated + // // replace the training/validation cut definition in var, so that cutM["_valid"] will get the testing objects + // if(glob->GetOptB("separateTestValid")) { + // int nFoundCuts = var->replaceTreeCut(glob->GetOptC("testValidType_valid"),glob->GetOptC("testValidType_train")); + // VERIFY(LOCATION,(TString)"Did not find cut \""+glob->GetOptC("testValidType_valid")+"\". Something is horribly wrong... ?!?",(nFoundCuts != 0)); + // } - TCut cutTrain = var->getTreeCuts(MLMname+"_train"); - TCut cutValid = var->getTreeCuts(MLMname+"_valid"); + TString wgtTrain = getRegularStrForm(userWgtsM[MLMname+"_train"],var); - cutM["_comn"] = var->getTreeCuts("_comn"); - cutM["_sig"] = var->getTreeCuts("_sig"); - cutM["_bck"] = var->getTreeCuts("_bck"); - cutM["_train"] = var->getTreeCuts("_train") + cutTrain; - cutM["_valid"] = var->getTreeCuts("_valid") + cutValid; + cutM["_comn"] = var->getTreeCuts("_comn"); + cutM["_sig"] = var->getTreeCuts("_sig"); + cutM["_bck"] = var->getTreeCuts("_bck"); + cutM["_train"] = var->getTreeCuts(MLMname+"_train"); + cutM["_valid"] = var->getTreeCuts(MLMname+"_valid"); + // cutM["_train"] = var->getTreeCuts("_train") + var->getTreeCuts(MLMname+"_train"); // deprecated + // cutM["_valid"] = var->getTreeCuts("_valid") + var->getTreeCuts(MLMname+"_valid"); // deprecated DELNULL(var); @@ -139,8 +137,8 @@ void ANNZ::Train_singleCls() { factory->AddBackgroundTree(chainM["_valid_bck"],clsWeight,TMVA::Types::kTesting ); // set the sample-weights - factory->SetWeightExpression(userWgtsM[MLMname+"_train"],"Signal"); - factory->SetWeightExpression(userWgtsM[MLMname+"_train"],"Background"); + factory->SetWeightExpression(wgtTrain,"Signal"); + factory->SetWeightExpression(wgtTrain,"Background"); TString trainValidStr = (TString)sigBckStr+":SplitMode=Random:"+factoryNorm; @@ -148,12 +146,12 @@ void ANNZ::Train_singleCls() { aLOG(Log::INFO) <GetOptC("testValidType_train")) if has separate sub-sample for training - // convergence and for testing, as this condition is true for all training sample and for half of the testing sample // ----------------------------------------------------------------------------------------------------------- VarMaps * var = new VarMaps(glob,utils,"mainTrainVar"); @@ -294,23 +288,27 @@ void ANNZ::Train_singleReg() { setMethodCuts(var,nMLMnow); - // replace the training/validation cut definition in var, os that cutM["_valid"] will get the testing objects - if(glob->GetOptB("separateTestValid")) { - int nFoundCuts = var->replaceTreeCut(glob->GetOptC("testValidType_valid"),glob->GetOptC("testValidType_train")); - VERIFY(LOCATION,(TString)"Did not find cut \""+glob->GetOptC("testValidType_valid")+"\". Something is horribly wrong... ?!?",(nFoundCuts != 0)); - } + // deprecated + // // replace the training/validation cut definition in var, so that cutM["_valid"] will get the testing objects + // if(glob->GetOptB("separateTestValid")) { + // int nFoundCuts = var->replaceTreeCut(glob->GetOptC("testValidType_valid"),glob->GetOptC("testValidType_train")); + // VERIFY(LOCATION,(TString)"Did not find cut \""+glob->GetOptC("testValidType_valid")+"\". Something is horribly wrong... ?!?",(nFoundCuts != 0)); + // } cutM["_comn"] = var->getTreeCuts("_comn"); - cutM["_train"] = var->getTreeCuts("_train") + var->getTreeCuts(MLMname+"_train"); - cutM["_valid"] = var->getTreeCuts("_valid") + var->getTreeCuts(MLMname+"_valid"); + cutM["_train"] = var->getTreeCuts(MLMname+"_train"); + cutM["_valid"] = var->getTreeCuts(MLMname+"_valid"); + // cutM["_train"] = var->getTreeCuts("_train") + var->getTreeCuts(MLMname+"_train"); // deprecated + // cutM["_valid"] = var->getTreeCuts("_valid") + var->getTreeCuts(MLMname+"_valid"); // deprecated - TString cutTrain = ((TString)var->getTreeCuts(MLMname+"_train")).ReplaceAll(" ",""); - TString cutValid = ((TString)var->getTreeCuts(MLMname+"_valid")).ReplaceAll(" ",""); + TString wgtTrain = getRegularStrForm(userWgtsM[MLMname+"_train"],var); DELNULL(var); // if the cuts for training and validation are different, create new trees // for each of these with the corresponding cuts. + TString cutTrain((TString)cutM["_train"]); cutTrain.ReplaceAll(" ",""); + TString cutValid((TString)cutM["_valid"]); cutValid.ReplaceAll(" ",""); if(cutTrain != cutValid) { createCutTrainTrees(chainM,cutM,optMap); cutM["_combined"] = ""; @@ -326,7 +324,7 @@ void ANNZ::Train_singleReg() { factory->AddRegressionTree(chainM["_valid_cut"], regWeight, TMVA::Types::kTesting ); // set the sample-weights - factory->SetWeightExpression(userWgtsM[MLMname+"_train"],"Regression"); + factory->SetWeightExpression(wgtTrain,"Regression"); TCanvas * tmpCnvs = new TCanvas("tmpCnvs","tmpCnvs"); int nTrain = chainM["_train_cut"]->Draw(zTrgName,cutM["_combined"]); if(maxNobj > 0 && maxNobj < nTrain) nTrain = maxNobj; @@ -345,8 +343,8 @@ void ANNZ::Train_singleReg() { aLOG(Log::INFO) <GetOptC("testValidType_train")) if has separate sub-sample for training - // convergence and for testing, as this condition is true for all training sample and for half of the testing sample // ----------------------------------------------------------------------------------------------------------- VarMaps * var = new VarMaps(glob,utils,"mainTrainVar"); @@ -561,18 +555,20 @@ void ANNZ::Train_binnedCls() { setMethodCuts(var,nMLMnow); - // replace the training/validation cut definition in var, os that cutM["_valid"] will get the testing objects - if(glob->GetOptB("separateTestValid")) { - int nFoundCuts = var->replaceTreeCut(glob->GetOptC("testValidType_valid"),glob->GetOptC("testValidType_train")); - VERIFY(LOCATION,(TString)"Did not find cut \""+glob->GetOptC("testValidType_valid")+"\". Something is horribly wrong... ?!?",(nFoundCuts != 0)); - } + // deprecated + // // replace the training/validation cut definition in var, so that cutM["_valid"] will get the testing objects + // if(glob->GetOptB("separateTestValid")) { + // int nFoundCuts = var->replaceTreeCut(glob->GetOptC("testValidType_valid"),glob->GetOptC("testValidType_train")); + // VERIFY(LOCATION,(TString)"Did not find cut \""+glob->GetOptC("testValidType_valid")+"\". Something is horribly wrong... ?!?",(nFoundCuts != 0)); + // } - cutTrain = var->getTreeCuts(MLMname+"_train"); - cutValid = var->getTreeCuts(MLMname+"_valid"); + wgtTrain = getRegularStrForm(userWgtsM[MLMname+"_train"],var); cutM["_comn"] = var->getTreeCuts("_comn"); - cutM["_train"] = var->getTreeCuts("_train") + cutTrain; - cutM["_valid"] = var->getTreeCuts("_valid") + cutValid; + cutM["_train"] = var->getTreeCuts(MLMname+"_train"); + cutM["_valid"] = var->getTreeCuts(MLMname+"_valid"); + // cutM["_train"] = var->getTreeCuts("_train") + var->getTreeCuts(MLMname+"_train"); // deprecated + // cutM["_valid"] = var->getTreeCuts("_valid") + var->getTreeCuts(MLMname+"_valid"); // deprecated DELNULL(var); @@ -603,9 +599,9 @@ void ANNZ::Train_binnedCls() { // log-in the signal/background samples in the factory. include user0defined cuts if needed. // for each sample, set the sum of weights to 1 after all cuts // ----------------------------------------------------------------------------------------------------------- - double clsWeight(1); // weight for the entire sample - TCut bckShiftCut(""), bckSubsetCut(""), fullCut(""); // optional user-defined cuts - TString clsWgtExp(userWgtsM[MLMname+"_train"]), fullWgtCut(""); // object-weight expressions + double clsWeight(1); // weight for the entire sample + TCut bckShiftCut(""), bckSubsetCut(""), fullCut(""); // optional user-defined cuts + TString fullWgtCut(""); // object-weight expressions // ----------------------------------------------------------------------------------------------------------- // all background objects treated as one sample @@ -660,12 +656,12 @@ void ANNZ::Train_binnedCls() { } fullCut = bckShiftCut + bckSubsetCut; - fullWgtCut = utils->cleanWeightExpr((TString)"("+(TString)fullCut+")*("+clsWgtExp+")"); + fullWgtCut = utils->cleanWeightExpr((TString)"("+(TString)fullCut+")*("+wgtTrain+")"); - clsWeight = chainM["_train_sig"]->Draw(zTrgName,clsWgtExp); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; + clsWeight = chainM["_train_sig"]->Draw(zTrgName,wgtTrain); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; factory->AddTree(chainM["_train_sig"],"Signal",clsWeight,"",TMVA::Types::kTraining); - clsWeight = chainM["_valid_sig"]->Draw(zTrgName,clsWgtExp); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; + clsWeight = chainM["_valid_sig"]->Draw(zTrgName,wgtTrain); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; factory->AddTree(chainM["_valid_sig"],"Signal",clsWeight,"",TMVA::Types::kTesting ); clsWeight = chainM["_train_bck"]->Draw(zTrgName,fullWgtCut); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; @@ -674,8 +670,8 @@ void ANNZ::Train_binnedCls() { clsWeight = chainM["_valid_bck"]->Draw(zTrgName,fullWgtCut); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; factory->AddTree(chainM["_valid_bck"],"Background",clsWeight,fullCut,TMVA::Types::kTesting ); - factory->SetWeightExpression(clsWgtExp,"Signal"); - factory->SetWeightExpression(clsWgtExp,"Background"); + factory->SetWeightExpression(wgtTrain,"Signal"); + factory->SetWeightExpression(wgtTrain,"Background"); } // ----------------------------------------------------------------------------------------------------------- @@ -687,13 +683,13 @@ void ANNZ::Train_binnedCls() { TCanvas * tmpCnvs = new TCanvas("tmpCnvs","tmpCnvs"); - clsWeight = chainM["_train_sig"]->Draw(zTrgName,clsWgtExp); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; + clsWeight = chainM["_train_sig"]->Draw(zTrgName,wgtTrain); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; factory->AddTree(chainM["_train_sig"],"Signal",clsWeight,"",TMVA::Types::kTraining); - clsWeight = chainM["_valid_sig"]->Draw(zTrgName,clsWgtExp); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; + clsWeight = chainM["_valid_sig"]->Draw(zTrgName,wgtTrain); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; factory->AddTree(chainM["_valid_sig"],"Signal",clsWeight,"",TMVA::Types::kTesting ); - factory->SetWeightExpression(clsWgtExp,"Signal"); + factory->SetWeightExpression(wgtTrain,"Signal"); for(int nClsBinNow=0; nClsBinNow<(int)zBinCls_binE.size()-1; nClsBinNow++) { if(nClsBinNow == nMLMnow) continue; @@ -701,7 +697,7 @@ void ANNZ::Train_binnedCls() { TString bckName = TString::Format("Background_%d",nClsBinNow); TCut bckCut = (TCut)(TString::Format((TString)"("+zTrgName+" > %f && "+zTrgName+" <= %f)",zBinCls_binE[nClsBinNow],zBinCls_binE[nClsBinNow+1])); - fullWgtCut = utils->cleanWeightExpr((TString)"("+(TString)bckCut+")*("+clsWgtExp+")"); + fullWgtCut = utils->cleanWeightExpr((TString)"("+(TString)bckCut+")*("+wgtTrain+")"); // check that the unweighted number of objects is sufficient clsWeight = chainM["_train_bck"]->Draw(zTrgName,bckCut); @@ -713,7 +709,7 @@ void ANNZ::Train_binnedCls() { clsWeight = chainM["_valid_bck"]->Draw(zTrgName,fullWgtCut); clsWeight = (clsWeight > 0) ? 1/clsWeight : 0; factory->AddTree(chainM["_valid_bck"],bckName,clsWeight,bckCut,TMVA::Types::kTesting ); - factory->SetWeightExpression(clsWgtExp,bckName); + factory->SetWeightExpression(wgtTrain,bckName); } DELNULL(tmpCnvs); @@ -748,8 +744,8 @@ void ANNZ::Train_binnedCls() { aLOG(Log::INFO) <NewOptC(saveName, inputVarErrors); saveName = "userWeights_train"; optNames.push_back(saveName); optMap->NewOptC(saveName, userWgtsM[MLMname+"_train"]); saveName = "userWeights_valid"; optNames.push_back(saveName); optMap->NewOptC(saveName, userWgtsM[MLMname+"_valid"]); - saveName = "userCuts_train"; optNames.push_back(saveName); optMap->NewOptC(saveName, (TString)cutTrain); - saveName = "userCuts_valid"; optNames.push_back(saveName); optMap->NewOptC(saveName, (TString)cutValid); + saveName = "userCuts_train"; optNames.push_back(saveName); optMap->NewOptC(saveName, (TString)cutM["_train"]); + saveName = "userCuts_valid"; optNames.push_back(saveName); optMap->NewOptC(saveName, (TString)cutM["_valid"]); saveName = "zTrg"; optNames.push_back(saveName); optMap->NewOptC(saveName, glob->GetOptC(saveName)); saveName = "minValZ"; optNames.push_back(saveName); optMap->NewOptF(saveName, glob->GetOptF(saveName)); saveName = "maxValZ"; optNames.push_back(saveName); optMap->NewOptF(saveName, glob->GetOptF(saveName)); diff --git a/src/ANNZ_utils.cpp b/src/ANNZ_utils.cpp index eafff79..4041c47 100644 --- a/src/ANNZ_utils.cpp +++ b/src/ANNZ_utils.cpp @@ -49,9 +49,10 @@ void ANNZ::Init() { glob->NewOptC("aTimeName","ANNZ_aTime"); // internal flag name for time operations glob->NewOptB("hasTruth",!glob->GetOptB("doEval")); // internal flag for adding a cut on the range of values of zTrg - // internal parameters - the condition for defining the test/valid sub-smaples from the _valid tree in case of nSplit==3 - glob->NewOptC("testValidType_train", glob->GetOptC("testValidType")+"<0.5"); // where [ testValidType == 0 ] for train - glob->NewOptC("testValidType_valid", glob->GetOptC("testValidType")+">0.5"); // where [ testValidType == 1 ] for valid + // deprecated + // // internal parameters - the condition for defining the test/valid sub-smaples from the _valid tree in case of nSplit==3 + // glob->NewOptC("testValidType_train", glob->GetOptC("testValidType")+"<0.5"); // where [ testValidType == 0 ] for train + // glob->NewOptC("testValidType_valid", glob->GetOptC("testValidType")+">0.5"); // where [ testValidType == 1 ] for valid if(glob->GetOptC("userWeights_train") == "") glob->SetOptC("userWeights_train","1"); // set proper init value for weights if(glob->GetOptC("userWeights_valid") == "") glob->SetOptC("userWeights_valid","1"); // set proper init value for weights @@ -87,7 +88,7 @@ void ANNZ::Init() { if(glob->GetOptB("doTrain") && !glob->GetOptB("isBatch")) { aLOG(Log::WARNING) <isDirFile(glob->GetOptC("inputTreeDirName")))); vector optNames; - optNames.push_back("nSplit"); optNames.push_back("treeName"); optNames.push_back("indexName"); - optNames.push_back("splitName"); optNames.push_back("testValidType"); optNames.push_back("useWgtKNN"); + optNames.push_back("nSplit"); optNames.push_back("treeName"); optNames.push_back("indexName"); optNames.push_back("useWgtKNN"); + // optNames.push_back("splitName"); optNames.push_back("testValidType"); // deprecated if(glob->GetOptB("storeOrigFileName")) optNames.push_back("origFileName"); @@ -108,12 +109,15 @@ void ANNZ::Init() { optNames.clear(); int nSplit(glob->GetOptI("nSplit")); - VERIFY(LOCATION,(TString)"Cant run with [\"nSplit\" = "+utils->intToStr(nSplit)+"]. Must set to either [2] or [3] when " - +"generating the input trees !",(nSplit == 2 || nSplit == 3)); + if(!glob->GetOptB("doEval")) { + VERIFY(LOCATION,(TString)"Cant run with [\"nSplit\" = "+utils->intToStr(nSplit)+"] ..." + +" Must set to 2 when generating the input trees !",(nSplit == 2)); + } - // flag for splitting the sample into two (train,valid) or three (train,test,valid) - // (must come after reading-in userOptsFile_genInputTrees, which may override nSplit) - glob->NewOptB("separateTestValid",(nSplit == 3)); // internal flag + // deprecated + // // flag for splitting the sample into two (train,valid) or three (train,test,valid) + // // (must come after reading-in userOptsFile_genInputTrees, which may override nSplit) + // glob->NewOptB("separateTestValid",(nSplit == 3)); // internal flag // ----------------------------------------------------------------------------------------------------------- // single regression/classification overrides @@ -295,6 +299,12 @@ void ANNZ::Init() { setInfoBinsZ(); // validate correct setting for the primary condition for optimization, and set the corresponding title + if(glob->GetOptC("optimCondReg") == "fracSig68") { + aLOG(Log::WARNING) <SetOptC("optimCondReg","sig68"); + } if (glob->GetOptC("optimCondReg") == "sig68") glob->NewOptC("optimCondRegtitle", "#sigma_{68}"); else if(glob->GetOptC("optimCondReg") == "bias") glob->NewOptC("optimCondRegtitle", "Bias"); else if(glob->GetOptC("optimCondReg") == "fracSig68") glob->NewOptC("optimCondRegtitle", "f(2,3#sigma_{68})"); @@ -312,6 +322,8 @@ void ANNZ::Init() { glob->SetOptI("nSmearsRnd",max(glob->GetOptI("nSmearsRnd"),50)); // set minimal value if(glob->GetOptI("nSmearsRnd") % 2 == 1) glob->SetOptI("nSmearsRnd",glob->GetOptI("nSmearsRnd")+1); // needs to be an even number + glob->SetOptI("nSmearUnf",max(glob->GetOptI("nSmearUnf"),glob->GetOptI("nSmearsRnd") * 2)); // set minimal value + // a lower acceptance bound to check if too few MLMs are trained or if something went wrong with the optimization procedure // (e.g., not enough trained MLMs have 'good' combinations of scatter, bias and outlier-fraction metrics). if(glob->GetOptI("minAcptMLMsForPDFs") < 5) glob->NewOptI("minAcptMLMsForPDFs",5); @@ -511,7 +523,7 @@ void ANNZ::setTags() { inErrTag.resize(nMLMs,vector(0,"")); // the internal length is set in setNominalParams() for(int nMLMnow=0; nMLMnowGetOptC("baseName_ANNZ")+utils->intToStr(nMLMnow); // format of "ANNZ_0" + TString MLMname = (TString)glob->GetOptC("basePrefix")+utils->intToStr(nMLMnow); // format of "ANNZ_0" mlmTagName [nMLMnow] = (TString)MLMname; mlmTagWeight[nMLMnow] = (TString)MLMname+baseTag_w; @@ -620,7 +632,7 @@ TString ANNZ::getTagBestMLMname(TString type) { // =========================================================================================================== int ANNZ::getTagNow(TString MLMname) { // =================================== - TString MLMnamePost(MLMname); MLMnamePost.ReplaceAll(glob->GetOptC("baseName_ANNZ"),""); + TString MLMnamePost(MLMname); MLMnamePost.ReplaceAll(glob->GetOptC("basePrefix"),""); VERIFY(LOCATION,(TString)"(MLMname = \""+MLMname+"\") has unsupported format",(MLMnamePost.IsDigit())); int nMLMnow = static_cast(MLMnamePost.Atoi()); @@ -733,6 +745,38 @@ TString ANNZ::getKeyWord(TString MLMname, TString sequence, TString key) { return ""; } +// =========================================================================================================== +/** + * @brief - get the weight expression from userWgtsM using either a VarMaps or a TChain + * + * @param var - The associated chain (needed for string variable identification) + * @param aChain - The associated chain (needed for string variable identification) + * @param tagName - The access string from userWgtsM + * + * @return - The requested string + */ +// =========================================================================================================== +TString ANNZ::getRegularStrForm(TString strIn, VarMaps * var, TChain * aChain) { + if(strIn == "") return strIn; + + bool hasVar = (dynamic_cast(var) != NULL); + bool hasChain = (dynamic_cast(aChain) != NULL); + + VERIFY(LOCATION,(TString)"Must provide either a VarMaps or a TChain in order to get a safe weight expression",(hasVar || hasChain)); + + // define a temporary VarMaps so as to regularize the weight expression + if(!hasVar) { + var = new VarMaps(glob,utils,"varRegularStrForm"); + var->connectTreeBranches(aChain); + } + + TString strOut = var->regularizeStringForm(strIn); + + if(!hasVar) DELNULL(var); + + return strOut; +} + // =========================================================================================================== /** * @brief - Load options for different setups. @@ -1051,7 +1095,7 @@ void ANNZ::loadOptsMLM() { void ANNZ::setNominalParams(int nMLMnow, TString inputVariables, TString inputVarErrors) { // ======================================================================================= TString MLMname = getTagName(nMLMnow); - TString baseName_ANNZ = glob->GetOptC("baseName_ANNZ"); + TString basePrefix = glob->GetOptC("basePrefix"); TString baseName_inVarErr = glob->GetOptC("baseName_inVarErr"); aLOG(Log::DEBUG_2) < 0)); for(int nVarNow=0; nVarNowsetTreeCuts(cutName,getTrainTestCuts(cutName,nMLMnow)); @@ -1131,36 +1175,52 @@ void ANNZ::setMethodCuts(VarMaps * var, int nMLMnow, bool verbose) { * e.g., for [split0=2 , split1=3], accept two out of every three objects. */ // =========================================================================================================== -TCut ANNZ::getTrainTestCuts(TString cutType, int nMLMnow, int split0, int split1) { -// ================================================================================ - TString MLMname = getTagName(nMLMnow); - TCut treeCuts = ""; - - if(cutType == "_comn") { - if(glob->GetOptB("hasTruth") && !glob->GetOptB("doClassification")) { - if(glob->GetOptB("useCutsMinMaxZ")) { - treeCuts += glob->GetOptC("zTrg")+TString::Format(" > %f",glob->GetOptF("minValZ")); - treeCuts += glob->GetOptC("zTrg")+TString::Format(" < %f",glob->GetOptF("maxValZ")); +TCut ANNZ::getTrainTestCuts(TString cutType, int nMLMnow, int split0, int split1, VarMaps * var, TChain * aChain) { +// ================================================================================================================ + vector cutTypeV = utils->splitStringByChar(cutType,';'); + int nCuts = (int)cutTypeV.size(); + TString MLMname = getTagName(nMLMnow); + TCut treeCuts = ""; + + VERIFY(LOCATION,(TString)"Trying to use getTrainTestCuts() with no defined cut-type: \""+cutType+"\"",(nCuts > 0 && cutType != "")); + + for(int nCutNow=0; nCutNowGetOptB("hasTruth") && !glob->GetOptB("doClassification")) { + if(glob->GetOptB("useCutsMinMaxZ")) { + treeCuts += (TCut)(glob->GetOptC("zTrg")+TString::Format(" > %f",glob->GetOptF("minValZ"))); + treeCuts += (TCut)(glob->GetOptC("zTrg")+TString::Format(" < %f",glob->GetOptF("maxValZ"))); + } } } - } - else if(cutType == "_train" || cutType == "_valid") { - // if has separate sub-sample for training convergence and for testing -> "testValidType_train" or "testValidType_valid" - if(glob->GetOptB("separateTestValid")) treeCuts += glob->GetOptC((TString)"testValidType"+cutType); - } - else if(cutType == MLMname+"_train" || cutType == MLMname+"_valid" || cutType == "_sig" || cutType == "_bck") { - VERIFY(LOCATION,(TString)"Trying to get from userCutsM element [\""+cutType+"\"] which doesnt exist ..." - ,(userCutsM.find(cutType) != userCutsM.end())); + // deprecated + // else if(cutTypeNow == "_train" || cutTypeNow == "_valid") { + // // if has separate sub-sample for training convergence and for testing -> "testValidType_train" or "testValidType_valid" + // if(glob->GetOptB("separateTestValid")) treeCuts += (TCut)glob->GetOptC((TString)"testValidType"+cutTypeNow); + // } + else if(cutTypeNow == "_train" || cutTypeNow == "_valid") { + aLOG(Log::WARNING) < split0 && split0 > 0)); + treeCuts += (TCut)userCutsM[cutTypeNow]; + } + else if(cutTypeNow == "split") { + VERIFY(LOCATION,(TString)"Must have (0 < split0 < split1)",(split1 > split0 && split0 > 0)); - treeCuts += (TString)glob->GetOptC("splitName")+" % "+TString::Format("%d < %d",split1,split0); + treeCuts += (TCut)((TString)glob->GetOptC("indexName")+" % "+TString::Format("%d < %d",split1,split0)); + } + else VERIFY(LOCATION,(TString)"Trying to use getTrainTestCuts() with unsupported cut-type (\""+cutTypeNow+"\")",false); } - else VERIFY(LOCATION,(TString)"Trying to use getTrainTestCuts() with unsupported cut-type (\""+cutType+"\")",false); + if(var || aChain) treeCuts = (TCut)getRegularStrForm((TString)treeCuts,var,aChain); + + cutTypeV.clear(); return treeCuts; } @@ -1484,9 +1544,11 @@ TString ANNZ::deriveBinClsBins(map < TString,TChain* > & chainM, map < TString,T double maxBinW = glob->GetOptF("binCls_maxBinW"); int nBinDivs = glob->GetOptI("binCls_nBins"); TString MLMname = getTagName(nMLMnow); + TString wgtTrain = getRegularStrForm(userWgtsM[MLMname+"_train"],NULL,chainM["_train"]); Log::LOGtypes binLog(Log::DEBUG_2); + // fill a histogram with the distribution of zTrgName (from the training chain, after cuts) // ----------------------------------------------------------------------------------------------------------- vector < TH1 *> hisQ_orig(2,NULL); @@ -1496,8 +1558,8 @@ TString ANNZ::deriveBinClsBins(map < TString,TChain* > & chainM, map < TString,T TString drawExprs = (TString)zTrgName+">>+"+hisQuantName; TString cutExprs = (TString)+"("+(TString)((TCut)cutM["_comn"]+(TCut)cutM["_train"])+")"; cutExprs.ReplaceAll("()",""); - if(nHisNow == 1 && userWgtsM[MLMname+"_train"] != "") { - cutExprs += (TString)" * ("+userWgtsM[MLMname+"_train"]+")"; + if(nHisNow == 1 && wgtTrain != "") { + cutExprs += (TString)" * ("+wgtTrain+")"; } TCanvas * tmpCnvs = new TCanvas("tmpCnvs","tmpCnvs"); int nEvtPass = chainM["_train"]->Draw(drawExprs,cutExprs); DELNULL(tmpCnvs); @@ -1790,8 +1852,7 @@ void ANNZ::createCutTrainTrees(map < TString,TChain* > & chainM, map < TString,T if(var_0->hasFailedTreeCuts(varCutNameCmn)) continue; var_1->copyVarData(var_0,&varTypeNameV); - - cutTree->Fill(); + var_1->fillTree(); var_0->IncCntr("nObj"); mayWriteObjects = true; } @@ -1837,7 +1898,7 @@ void ANNZ::splitToSigBckTrees(map < TString,TChain* > & chainM, map < TString,TC aLOG(Log::DEBUG) <GetOptI("maxNobj"); - TString splitName = glob->GetOptC("splitName"); + // TString splitName = glob->GetOptC("splitName"); // deprecated TString outDirNameFull = glob->GetOptC("outDirNameFull"); int nObjectsToWrite = glob->GetOptI("nObjectsToWrite"); @@ -1863,8 +1924,10 @@ void ANNZ::splitToSigBckTrees(map < TString,TChain* > & chainM, map < TString,TC // create splitIndex variables for signal and background (countinous counters for each sub-sample that can later be used for cuts) vector < pair > varTypeNameV; - VarMaps * varSig = new VarMaps(glob,utils,varName+"_sig"); varSig->varStruct(var,NULL,NULL,&varTypeNameV); varSig->NewVarI(splitName+"_sigBck"); - VarMaps * varBck = new VarMaps(glob,utils,varName+"_bck"); varBck->varStruct(var); varBck->NewVarI(splitName+"_sigBck"); + VarMaps * varSig = new VarMaps(glob,utils,varName+"_sig"); varSig->varStruct(var,NULL,NULL,&varTypeNameV); + VarMaps * varBck = new VarMaps(glob,utils,varName+"_bck"); varBck->varStruct(var); + + // varSig->NewVarI(splitName+"_sigBck"); varBck->NewVarI(splitName+"_sigBck"); // deprecated // create an output tree with branches according to var TTree * mergedTreeSig = new TTree(inTreeNameSig,inTreeNameSig); mergedTreeSig->SetDirectory(0); outputs->TreeMap[inTreeNameSig] = mergedTreeSig; @@ -1895,19 +1958,19 @@ void ANNZ::splitToSigBckTrees(map < TString,TChain* > & chainM, map < TString,TC if(!var->hasFailedTreeCuts(varCutNameSig)) { varSig->copyVarData(var,&varTypeNameV); - varSig->SetVarI(splitName+"_sigBck",var->GetCntr(nObjNameSig)); + // varSig->SetVarI(splitName+"_sigBck",var->GetCntr(nObjNameSig)); // deprecated var->IncCntr(nObjNameSig); - mergedTreeSig->Fill(); + varSig->fillTree(); } else var->IncCntr((TString)"failedCut: "+var->getFailedCutType()); if(!var->hasFailedTreeCuts(varCutNameBck)) { varBck->copyVarData(var,&varTypeNameV); - varBck->SetVarI(splitName+"_sigBck",var->GetCntr(nObjNameBck)); + // varBck->SetVarI(splitName+"_sigBck",var->GetCntr(nObjNameBck)); // deprecated var->IncCntr(nObjNameBck); - mergedTreeBck->Fill(); + varBck->fillTree(); } else var->IncCntr((TString)"failedCut: "+var->getFailedCutType()); diff --git a/src/CatFormat_asciiToTree.cpp b/src/CatFormat_asciiToTree.cpp index 00a249d..00833a2 100644 --- a/src/CatFormat_asciiToTree.cpp +++ b/src/CatFormat_asciiToTree.cpp @@ -209,7 +209,7 @@ void CatFormat::inputToFullTree(TString inAsciiFiles, TString inAsciiVars, TStri var->SetVarF(weightName,1); // fill the tree with the current variables - treeOut->Fill(); + var->fillTree(); if(inLOG(Log::DEBUG_3)) { int nPrintRow(4), width(14); @@ -254,7 +254,7 @@ void CatFormat::inputToFullTree(TString inAsciiFiles, TString inAsciiVars, TStri var->SetVarF(weightName,1); // fill the tree with the current variables - treeOut->Fill(); + var->fillTree(); if(inLOG(Log::DEBUG_3)) { int nPrintRow(4), width(14); @@ -284,13 +284,13 @@ void CatFormat::inputToFullTree(TString inAsciiFiles, TString inAsciiVars, TStri * * @details - For training and testing/validation the input is divided into two (test,train) or into three (test,train,valid) * sub-samples. - * - The user needs to define the number of sub-samples (e.g., nSplit = 1,2 or 3) and the way to divide the + * - The user needs to define the number of sub-samples (e.g., nSplit = 1 or 2) and the way to divide the * inputs in one of 4 ways (e.g., splitType = "serial", "blocks", "random" or "byInFiles" (default)): * - serial: -> test;train;valid;test;train;valid;test;train;valid;test;train;valid... * - blocks: -> test;test;test;test;train;train;train;train;valid;valid;valid;valid... * - random: -> valid;test;test;train;valid;test;valid;valid;test;train;valid;train... * - separate input files. Must supplay at least one file in splitTypeTrain and one in splitTypeTest. - * In this case, [nSplit = 2]. Optionally can set [nSplit = 3] and provide a list of files in "splitTypeValid" as well. + * In this case, [nSplit = 2]. * * @param inAsciiFiles - semicolon-separated list of input ascii files * @param inAsciiVars - semicolon-separated list of input parameter names, corresponding to columns in the input files @@ -311,8 +311,8 @@ void CatFormat::inputToSplitTree(TString inAsciiFiles, TString inAsciiVars) { int nSplit = glob->GetOptI("nSplit"); TString splitType = glob->GetOptC("splitType"); TString indexName = glob->GetOptC("indexName"); - TString splitName = glob->GetOptC("splitName"); - TString testValidType = glob->GetOptC("testValidType"); + // TString splitName = glob->GetOptC("splitName"); // deprecated + // TString testValidType = glob->GetOptC("testValidType"); // deprecated TString weightName = glob->GetOptC("baseName_wgtKNN"); bool doPlots = glob->GetOptB("doPlots"); TString plotExt = glob->GetOptC("printPlotExtension"); @@ -324,8 +324,8 @@ void CatFormat::inputToSplitTree(TString inAsciiFiles, TString inAsciiVars) { TString inpFiles_bck = glob->GetOptC("inpFiles_bck"); bool addSigBckInp = (inpFiles_sig != "" || inpFiles_bck != ""); - VERIFY(LOCATION,(TString)"found unsupported number of splittings ("+utils->intToStr(nSplit)+"). Allowed values are: 1,2,3" - ,(nSplit >= 1 && nSplit <= 3)); + VERIFY(LOCATION,(TString)"found unsupported number of splittings ("+utils->intToStr(nSplit)+"). Allowed values are: 1 or 2" + ,(nSplit == 1 || nSplit == 2)); // random number generator for the (splitType == "random") option TRandom * rnd = new TRandom3(glob->GetOptI("splitSeed")); @@ -345,18 +345,20 @@ void CatFormat::inputToSplitTree(TString inAsciiFiles, TString inAsciiVars) { // ----------------------------------------------------------------------------------------------------------- // input files - if diffrent for each type // ----------------------------------------------------------------------------------------------------------- - if(splitType == "byInFiles" && nSplit > 1) { + if(splitType == "byInFiles" && nSplit == 2) { vector inFileNameV_now; for(int nSplitNow=0; nSplitNowsplitStringByChar((glob->GetOptC(splitTypeNow)).ReplaceAll(" ",""),';'); nInFiles = (int)inFileNameV_now.size(); - VERIFY(LOCATION,(TString)"found no input files defined in "+splitTypeNow+" (nSplitNow = "+utils->intToStr(nSplitNow)+") ?!?",(nInFiles > 0)); + + VERIFY(LOCATION,(TString)"found no input files defined in \""+splitTypeNow+"\" ... either set \"splitTypeTrain\"" + +" and \"splitTypeTest\" together with [\"splitType\"=byInFiles], or set \"inAsciiFiles\"" + +" together with [\"splitType\"=\"serial\", \"blocks\" or \"random\"]",(nInFiles > 0)); for(int nInFileNow=0; nInFileNowGetOptC("inDirName")+inFileNameV_now[nInFileNow]; @@ -425,8 +427,8 @@ void CatFormat::inputToSplitTree(TString inAsciiFiles, TString inAsciiVars) { // clear tree objects and reserve variables which are not part of the input file VarMaps * var = new VarMaps(glob,utils,"treeVars"); //var->printMapOpt(nPrintRow,width); cout<NewVarI(indexName); var->NewVarI(splitName); - var->NewVarI(testValidType); var->NewVarF(weightName); + // var->NewVarI(testValidType); var->NewVarI(splitName); // deprecated + var->NewVarI(indexName); var->NewVarF(weightName); if(storeOrigFileName) var->NewVarC(origFileName); if(addSigBckInp) var->NewVarI(sigBckInpName); @@ -436,14 +438,13 @@ void CatFormat::inputToSplitTree(TString inAsciiFiles, TString inAsciiVars) { // create the output tree(s) now thah all the variables are defined // ----------------------------------------------------------------------------------------------------------- - int nTrees = min(nSplit,2); - vector treeOut (nTrees); - vector treeNames(nTrees); + vector treeOut (nSplit); + vector treeNames(nSplit); - if(nTrees == 1) { treeNames[0] = (TString)treeName+"_full"; } + if(nSplit == 1) { treeNames[0] = (TString)treeName+"_full"; } else { treeNames[0] = (TString)treeName+"_train"; treeNames[1] = (TString)treeName+"_valid"; } - for(int nTreeNow=0; nTreeNowSetDirectory(0); outputs->TreeMap[treeNameNow] = treeOut[nTreeNow]; @@ -453,13 +454,13 @@ void CatFormat::inputToSplitTree(TString inAsciiFiles, TString inAsciiVars) { // loop control-variables (outside the for() of file-names) // ----------------------------------------------------------------------------------------------------------- - var->NewCntr("nObj",0); var->NewCntr("nLine",0); - var->NewCntr("nTrain",0); var->NewCntr("nTest",0); var->NewCntr("nValid",0); + var->NewCntr("nObj",0); var->NewCntr("nLine",0); var->NewCntr("index",0); + var->NewCntr("nTrain",0); var->NewCntr("nTest",0); // var->NewCntr("nValid",0); // deprecated int nSplitType(0); bool breakLoop(false), mayWriteObjects(false); - if(nSplit > 1) { + if(nSplit == 2) { if (splitType == "serial") nSplitType = 0; else if(splitType == "blocks") nSplitType = 1; else if(splitType == "random") nSplitType = 2; @@ -634,7 +635,7 @@ void CatFormat::inputToSplitTree(TString inAsciiFiles, TString inAsciiVars) { // some histograms of the input branches // ----------------------------------------------------------------------------------------------------------- if(doPlots) { - for(int nTreeNow=0; nTreeNowGetVarType(branchName); - if(branchType == "C" || branchType == "FM") continue; - if(branchName.BeginsWith(glob->GetOptC("baseName_ANNZ"))) continue; + if(branchType == "C" || branchType == "FM") continue; + if(branchName.BeginsWith(glob->GetOptC("basePrefix"))) continue; branchNameV.push_back(branchName); } @@ -689,7 +690,7 @@ void CatFormat::inputToSplitTree(TString inAsciiFiles, TString inAsciiVars) { // cleanup DELNULL(var); DELNULL(rnd); - for(int nTreeNow=0; nTreeNowTreeMap.erase(treeNameNow); @@ -708,8 +709,8 @@ void CatFormat::inputToSplitTree(TString inAsciiFiles, TString inAsciiVars) { saveName = "nSplit"; optNames.push_back(saveName); optMap->NewOptI(saveName, glob->GetOptI(saveName)); saveName = "treeName"; optNames.push_back(saveName); optMap->NewOptC(saveName, glob->GetOptC(saveName)); saveName = "indexName"; optNames.push_back(saveName); optMap->NewOptC(saveName, glob->GetOptC(saveName)); - saveName = "splitName"; optNames.push_back(saveName); optMap->NewOptC(saveName, glob->GetOptC(saveName)); - saveName = "testValidType"; optNames.push_back(saveName); optMap->NewOptC(saveName, glob->GetOptC(saveName)); + // saveName = "splitName"; optNames.push_back(saveName); optMap->NewOptC(saveName, glob->GetOptC(saveName)); // deprecated + // saveName = "testValidType"; optNames.push_back(saveName); optMap->NewOptC(saveName, glob->GetOptC(saveName)); // deprecated saveName = "useWgtKNN"; optNames.push_back(saveName); optMap->NewOptB(saveName, glob->GetOptB(saveName)); if(storeOrigFileName) { saveName = "origFileName"; optNames.push_back(saveName); optMap->NewOptC(saveName, glob->GetOptC(saveName)); } @@ -864,16 +865,17 @@ void CatFormat::setSplitVars(VarMaps * var, TRandom * rnd, map & i int nLine_splitBlocks = (nSplitType == 1) ? intMap["nLine_splitBlocks"] : 0; int inFileSplitIndex = (nSplitType == 3) ? intMap["inFileSplitIndex"] : 0; TString indexName = glob->GetOptC("indexName"); - TString splitName = glob->GetOptC("splitName"); - TString testValidType = glob->GetOptC("testValidType"); + // TString splitName = glob->GetOptC("splitName"); // deprecated + // TString testValidType = glob->GetOptC("testValidType"); // deprecated // ----------------------------------------------------------------------------------------------------------- // no splitting // ----------------------------------------------------------------------------------------------------------- if(nSplit == 1) { intMap["nSplitTree"] = 0; - var->SetVarI(splitName, 0); - var->SetVarI(testValidType, 0); + var->SetVarI(indexName, var->GetCntr("nLine")); + // var->SetVarI(splitName, 0); // deprecated + // var->SetVarI(testValidType, 0); // deprecated } // ----------------------------------------------------------------------------------------------------------- // split into two or three sub-sets within two trees @@ -885,71 +887,71 @@ void CatFormat::setSplitVars(VarMaps * var, TRandom * rnd, map & i if(nSplit == 2) { int resid2(0); - // "serial" - - if (nSplitType == 0) { resid2 = var->GetCntr("nLine") % 2; } - // - "blocks" - - else if(nSplitType == 1) { resid2 = (var->GetCntr("nLine") < nLine_splitBlocks) ? 0 : 1; } - // - "random" - - else if(nSplitType == 2) { resid2 = static_cast(floor(rnd->Rndm() * 2)); } - // - "byInFiles" - - else if(nSplitType == 3) { resid2 = inFileSplitIndex; } + if (nSplitType == 0) { resid2 = var->GetCntr("nLine") % 2; }// "serial" - + else if(nSplitType == 1) { resid2 = (var->GetCntr("nLine") < nLine_splitBlocks) ? 0 : 1; } // "blocks" + else if(nSplitType == 2) { resid2 = static_cast(floor(rnd->Rndm() * 2)); } // "random" + else if(nSplitType == 3) { resid2 = inFileSplitIndex; } // "byInFiles" // ----------------------------------------------------------------------------------------------------------- // now set the variables // ----------------------------------------------------------------------------------------------------------- - if (resid2 == 0) { + if (resid2 == 0) { intMap["nSplitTree"] = 0; - var->SetVarI(splitName, var->GetCntr("nTrain")); - var->IncCntr("nTrain"); + // var->SetVarI(splitName, var->GetCntr("nTrain")); var->IncCntr("nTrain"); // deprecated + var->SetVarI(indexName, var->GetCntr("nTrain")); var->IncCntr("nTrain"); } - else if(resid2 == 1) { + else { intMap["nSplitTree"] = 1; - var->SetVarI(splitName, var->GetCntr("nTest")); - var->IncCntr("nTest"); + // var->SetVarI(splitName, var->GetCntr("nTest")); var->IncCntr("nTest"); // deprecated + var->SetVarI(indexName, var->GetCntr("nTest")); var->IncCntr("nTest"); } - var->SetVarI(testValidType, 0); // this will not be used if there is no three-way splitting - } - else { - int resid3(0); - - // - "serial" - - if (nSplitType == 0) { resid3 = var->GetCntr("nLine") % 3; } - // - "blocks" - - else if(nSplitType == 1) { - if (var->GetCntr("nLine") < nLine_splitBlocks) resid3 = 0; - else if(var->GetCntr("nLine") < 2*nLine_splitBlocks) resid3 = 1; - else resid3 = 2; - } - // - "random" - - else if(nSplitType == 2) { resid3 = static_cast(floor(rnd->Rndm() * 3)); } - // - "byInFiles" - - else if(nSplitType == 3) { resid3 = inFileSplitIndex; } - // ----------------------------------------------------------------------------------------------------------- - // now set the variables - // ----------------------------------------------------------------------------------------------------------- - if (resid3 == 0) { - intMap["nSplitTree"] = 0; - var->SetVarI(splitName,var->GetCntr("nTrain")); - var->SetVarI(testValidType, 0); - var->IncCntr("nTrain"); - } - else if(resid3 == 1) { - intMap["nSplitTree"] = 1; - var->SetVarI(splitName,var->GetCntr("nTest")); - var->SetVarI(testValidType, 0); - var->IncCntr("nTest"); - } - else if(resid3 == 2) { - intMap["nSplitTree"] = 1; - var->SetVarI(splitName,var->GetCntr("nValid")); - var->SetVarI(testValidType, 1); - var->IncCntr("nValid"); - } + // var->SetVarI(testValidType, 0); // deprecated } + // ----------------------------------------------------------------------------------------------------------- + // deprecated ... + // ----------------------------------------------------------------------------------------------------------- + // else { + // int resid3(0); + + // // - "serial" - + // if (nSplitType == 0) { resid3 = var->GetCntr("nLine") % 3; } + // // - "blocks" - + // else if(nSplitType == 1) { + // if (var->GetCntr("nLine") < nLine_splitBlocks) resid3 = 0; + // else if(var->GetCntr("nLine") < 2*nLine_splitBlocks) resid3 = 1; + // else resid3 = 2; + // } + // // - "random" - + // else if(nSplitType == 2) { resid3 = static_cast(floor(rnd->Rndm() * 3)); } + // // - "byInFiles" - + // else if(nSplitType == 3) { resid3 = inFileSplitIndex; } + + // // ----------------------------------------------------------------------------------------------------------- + // // now set the variables + // // ----------------------------------------------------------------------------------------------------------- + // if (resid3 == 0) { + // intMap["nSplitTree"] = 0; + // var->SetVarI(splitName,var->GetCntr("nTrain")); + // var->SetVarI(testValidType, 0); + // var->IncCntr("nTrain"); + // } + // else if(resid3 == 1) { + // intMap["nSplitTree"] = 1; + // var->SetVarI(splitName,var->GetCntr("nTest")); + // var->SetVarI(testValidType, 0); + // var->IncCntr("nTest"); + // } + // else if(resid3 == 2) { + // intMap["nSplitTree"] = 1; + // var->SetVarI(splitName,var->GetCntr("nValid")); + // var->SetVarI(testValidType, 1); + // var->IncCntr("nValid"); + // } + // } } // the main counter - var->SetVarI(indexName, var->GetCntr("nLine")); + // var->SetVarI(indexName, var->GetCntr("nLine")); // deprecated var->IncCntr("nLine"); return; diff --git a/src/CatFormat_wgtKNN.cpp b/src/CatFormat_wgtKNN.cpp index 3498625..210f241 100644 --- a/src/CatFormat_wgtKNN.cpp +++ b/src/CatFormat_wgtKNN.cpp @@ -31,6 +31,7 @@ void CatFormat::inputToSplitTree_wgtKNN(TString inAsciiFiles, TString inAsciiVar aLOG(Log::INFO) <GetOptI("nSplit"); + bool trainTestTogether = glob->GetOptB("trainTestTogether_wgtKNN"); TString treeName = glob->GetOptC("treeName"); TString outDirNameFull = glob->GetOptC("outDirNameFull"); TString inTreeName_wgtKNN = glob->GetOptC("inTreeName_wgtKNN"); @@ -73,6 +74,15 @@ void CatFormat::inputToSplitTree_wgtKNN(TString inAsciiFiles, TString inAsciiVar TString mkdirCmnd = (TString)"mkdir -p "+outDirNameTMP; utils->exeShellCmndOutput(mkdirCmnd,inLOG(Log::DEBUG),true); + TChain * aChainMerge(NULL); + vector < TChain * > aChainV(nChains,NULL); + + TString mergeChainName(""); + for(int nChainNow=0; nChainNowGetEntries()<<")" <<" from "< 1 && trainTestTogether) { + if(!aChainMerge) aChainMerge = new TChain(treeNameNow,treeNameNow); + else aChainMerge->SetName(treeNameNow); + + aChainMerge->SetDirectory(0); aChainMerge->Add(fileNameNow); + aChainMerge->SetName(mergeChainName); + + aLOG(Log::INFO) <GetEntries()<<")"<<" from "<safeRM(outDirNameTMP,inLOG(Log::DEBUG)); - treeNames.clear(); + treeNames.clear(); aChainV.clear(); return; } // =========================================================================================================== /** - * @brief - create root trees from the input ascii files and add a weight branch, which estimates - * if each objects is "near enough" to enough objects in the training dataset. + * @brief - create root trees from the input ascii files and add a weight branch, which estimates + * if each objects is "near enough" to enough objects in the training dataset. * - * @param inAsciiFiles - semicolon-separated list of input ascii files (main dataset) - * @param inAsciiVars - semicolon-separated list of input parameter names, corresponding to columns in the input files (main dataset) - * @param treeNamePostfix - postfix for the final output trees + * @param inAsciiFiles - semicolon-separated list of input ascii files (main dataset) + * @param inAsciiVars - semicolon-separated list of input parameter names, corresponding to columns in the input files (main dataset) + * @param treeNamePostfix - postfix for the final output trees */ // =========================================================================================================== void CatFormat::inputToFullTree_wgtKNN(TString inAsciiFiles, TString inAsciiVars, TString treeNamePostfix) { @@ -146,7 +182,7 @@ void CatFormat::inputToFullTree_wgtKNN(TString inAsciiFiles, TString inAsciiVars // the name of the final tree TString treeFinalName = (TString)treeName+treeNamePostfix; - addWgtKNNtoTree(aChainInp,aChainRef,treeFinalName); + addWgtKNNtoTree(aChainInp,aChainRef,NULL,treeFinalName); // remove the intermidiate trees created by inputToFullTree utils->safeRM(fileNameInp,inLOG(Log::DEBUG)); @@ -194,16 +230,17 @@ void CatFormat::inputToFullTree_wgtKNN(TString inAsciiFiles, TString inAsciiVars * * @param aChainInp - a chain corresponding to the main dataset * @param aChainRef - a chain corresponding to the reference dataset + * @param aChainRef - a chain corresponding to the evaluated dataset * @param outTreeName - optional tree name (or else the name is extracted from aChainInp) */ // =========================================================================================================== -void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString outTreeName) { -// =========================================================================================== +void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TChain * aChainEvl, TString outTreeName) { +// =============================================================================================================== VERIFY(LOCATION,(TString)"Memory leak ?!?",(dynamic_cast(aChainInp) && dynamic_cast(aChainRef))); - if(outTreeName == "") outTreeName = aChainInp->GetName(); - aLOG(Log::INFO) <GetOptC("basePrefix"); TString outDirNameFull = glob->GetOptC("outDirNameFull"); TString indexName = glob->GetOptC("indexName"); TString plotExt = glob->GetOptC("printPlotExtension"); @@ -214,7 +251,10 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString int minObjTrainTest = glob->GetOptI("minObjTrainTest"); double maxRelRatioInRef = glob->GetOptF("maxRelRatioInRef_inTrain"); TString weightName = glob->GetOptC("baseName_wgtKNN"); - TString baseName_ANNZ = glob->GetOptC("baseName_ANNZ"); + // number of KNN modules (with decreasing object fractions) for hierarchical searches, and + // fraction-cut factor for each module-level + int nKnnFracs = glob->GetOptI("nKnnFracs_wgtKNN"); + int knnFracFact = glob->GetOptI("knnFracFact_wgtKNN"); TString typePostfix = (TString)(doRelWgts ? "_wgtKNN" : "_inTrain"); TString wgtKNNname = glob->GetOptC((TString)"baseName" +typePostfix); // e.g., "baseName_wgtKNN" @@ -225,6 +265,12 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString double sampleFracRef = glob->GetOptF((TString)"sampleFracRef" +typePostfix); // e.g., "sampleFracRef_wgtKNN" bool doWidthRescale = glob->GetOptB((TString)"doWidthRescale"+typePostfix); bool debug = inLOG(Log::DEBUG_2); + + bool hasChainEvl = dynamic_cast(aChainEvl); + TChain * aChainInpEvl = hasChainEvl ? aChainEvl : aChainInp; + + if(outTreeName == "") outTreeName = aChainInpEvl->GetName(); + aLOG(Log::INFO) < chainWgtV(2), chainCutV(2); @@ -232,8 +278,16 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString chainCutV[0] = glob->GetOptC((TString)"cutInp" +typePostfix); chainCutV[1] = glob->GetOptC((TString)"cutRef" +typePostfix); // force reasonable min/max values - minNobjInVol = max(minNobjInVol,20); - maxRelRatioInRef = (maxRelRatioInRef > 0) ? max(min(maxRelRatioInRef,0.999),0.001) : -1; + minNobjInVol = max(minNobjInVol,20); + + double maxRelRatioInRef_0(0.001), maxRelRatioInRef_1(0.999); + if(maxRelRatioInRef > 0 && (maxRelRatioInRef < maxRelRatioInRef_0 || maxRelRatioInRef > maxRelRatioInRef_1)) { + aLOG(Log::WARNING) < 0) ? max(min(maxRelRatioInRef,maxRelRatioInRef_1),maxRelRatioInRef_0) : -1; int maxNobj = 0; // maxNobj = glob->GetOptI("maxNobj"); // only allow maxNobj limits for debugging !! TString outBaseName = (TString)outDirNameFull+glob->GetOptC("treeName")+wgtKNNname; @@ -241,8 +295,9 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString // decompose the variable names for the KNN distance calculation vector varNames = utils->splitStringByChar(weightVarNames,';'); - - int nVars = (int)varNames.size(); + int nVars = (int)varNames.size(); + + // sanity checks VERIFY(LOCATION,(TString)"Did not find input variables for KNN weight computation [\"weightVarNames_wgtKNN\"/\"weightVarNames_inTrain\" = " +weightVarNames+"] ... Something is horribly wrong !?!?",(nVars > 0)); @@ -251,22 +306,53 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString VERIFY(LOCATION,(TString)"sampleFracRef_wgtKNN must be a positive number, smaller or equal to 1. Currently is set to "+utils->floatToStr(sampleFracRef) ,(sampleFracRef > 0 && sampleFracRef <= 1)); - aLOG(Log::INFO) <= 2 (recommended > 5), while now set [\"nKnnFracs_wgtKNN\" = " + +utils->intToStr(nKnnFracs)+"] ... Something is horribly wrong !?!?",(nKnnFracs > 1)); + + VERIFY(LOCATION,(TString)"Must set knnFracFact_wgtKNN >= 2 (recommended 3,4 or 5), while now set [\"knnFracFact_wgtKNN\" = " + +utils->intToStr(knnFracFact)+"] ... Something is horribly wrong !?!?",(knnFracFact > 1)); + + for(int nChainNow=0; nChainNow<3; nChainNow++) { + TChain * aChain(NULL); + if (nChainNow == 0) aChain = aChainInp; + else if(nChainNow == 1) aChain = aChainRef; + else if(nChainNow == 2) { + if(hasChainEvl) aChain = aChainEvl; + else continue; + } - vector < vector > minMaxVarVals(2,vector(nVars)); - vector < TString > varNamesScaled(nVars,""); - vector < vector > hisVarV(2,vector(nVars,NULL)); + vector branchNameV; + utils->getTreeBranchNames(aChain,branchNameV); + for(int nVarNow=0; nVarNowGetName(), + (find(branchNameV.begin(),branchNameV.end(), varNames[nVarNow]) != branchNameV.end())); + } + + branchNameV.clear(); + } + + aLOG(Log::INFO) < > minMaxVarVals (2,vector(nVars,0) ); + vector < vector > hisVarV (2,vector (nVars,NULL)); + vector < TString > varNamesScaled(nVars,""); // ----------------------------------------------------------------------------------------------------------- // setup the kd-trees for the two chains // ----------------------------------------------------------------------------------------------------------- - vector aChainV(2); - vector knnErrOutFile(2); vector knnErrFactory(2); - vector knnErrMethod(2); vector knnErrModule(2); + double effEntRatio(1); + vector chainEntV(3,0); - vector outFileDirKnnErrV(2), outFileNameKnnErr(2); - vector chainWgtNormV(2,0), chainEntV(2,0); + vector < TChain * > aChainV (2, NULL); + vector < vector > knnErrOutFile (2, vector (nKnnFracs,NULL)); + vector < vector > knnErrFactory (2, vector (nKnnFracs,NULL)); + vector < vector > knnErrMethod (2, vector (nKnnFracs,NULL)); + vector < vector > knnErrModule (2, vector(nKnnFracs,NULL)); + vector < vector > outFileNameKnnErr(2, vector (nKnnFracs,"") ); + + TString outFileDirKnnErrV = (TString)outBaseName+"_weights"+"/"; + (TMVA::gConfig().GetIONames()).fWeightFileDir = outFileDirKnnErrV; // ----------------------------------------------------------------------------------------------------------- // setup some objects and get histograms for the variable range limits @@ -281,29 +367,41 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString TChain * aChain = (nChainNow == 0) ? aChainInp : aChainRef; aChainV[nChainNow] = (TChain*)aChain->Clone((TString)aChain->GetName()+nChainKNNname); aChainV[nChainNow]->SetDirectory(0); + // make sure that string variables are in the correct access format for cuts and weights + VarMaps * varTMP = new VarMaps(glob,utils,"varTMP"); + varTMP->connectTreeBranches(aChainV[nChainNow]); + + chainWgtV[nChainNow] = varTMP->regularizeStringForm(chainWgtV[nChainNow]); + chainCutV[nChainNow] = varTMP->regularizeStringForm(chainCutV[nChainNow]); + + DELNULL(varTMP); + // add the "baseName_wgtKNN" variable to the weight expression - this is just unity if generating the // initial trees, but may hold non-trivial values for the [doRelWgts==false] mode chainWgtV[nChainNow] = utils->cleanWeightExpr((TString)"("+chainWgtV[nChainNow]+")*"+weightName); - outFileDirKnnErrV[nChainNow] = outBaseName+"_weights"+nChainKNNname+"/"; - outFileNameKnnErr[nChainNow] = outBaseName+nChainKNNname+".root"; - - (TMVA::gConfig().GetIONames()).fWeightFileDir = outFileDirKnnErrV[nChainNow]; TString verbLvlF = (TString)(debug ? ":V:!Silent" : ":!V:Silent"); TString drawProgBarStr = (TString)(debug ? ":Color:DrawProgressBar" : ":Color:!DrawProgressBar"); TString transStr = (TString)":Transformations=I,N"; TString analysType = (TString)":AnalysisType=Regression"; + TString allOpts = (TString)verbLvlF+drawProgBarStr+transStr+analysType; - // setup the factory + // setup the factories // ----------------------------------------------------------------------------------------------------------- - knnErrOutFile[nChainNow] = new TFile(outFileNameKnnErr[nChainNow],"RECREATE"); - knnErrFactory[nChainNow] = new TMVA::Factory(wgtKNNname, knnErrOutFile[nChainNow], (TString)verbLvlF+drawProgBarStr+transStr+analysType); + for(int nFracNow=0; nFracNowregularizeName(varNames[nVarNow])+"_hisVar"; TString drawExprs = (TString)varNames[nVarNow]+">>"+hisName; @@ -319,13 +417,29 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString hisVarV[nChainNow][nVarNow]->SetDirectory(0); hisVarV[nChainNow][nVarNow]->BufferEmpty(); outputs->BaseDir->cd(); if(nVarNow == 0) { - sumWgt = hisVarV[nChainNow][nVarNow]->Integral(); sumEnt = hisVarV[nChainNow][nVarNow]->GetEntries(); - + double sumEnt = hisVarV[nChainNow][nVarNow]->GetEntries(); VERIFY(LOCATION,(TString)"Got sum of entries = "+utils->floatToStr(sumEnt)+" ... Something is horribly wrong ?!?! ",(sumEnt > 0)); - VERIFY(LOCATION,(TString)"Got sum of weights = "+utils->floatToStr(sumWgt)+" ... Something is horribly wrong ?!?! ",(sumWgt > 0)); - chainEntV [nChainNow] = sumEnt; - chainWgtNormV[nChainNow] = 1/sumWgt; + chainEntV[nChainNow] = sumEnt; + + // if there is a separate evaluation chain, get the weighted/cut sum of entries for later normalization + if(nChainNow == 0 && hasChainEvl) { + hisName += (TString)"_TMP"; + drawExprs = (TString)varNames[nVarNow]+">>"+hisName; + + TCanvas * tmpCnvs = new TCanvas("tmpCnvs","tmpCnvs"); + aChainEvl->Draw(drawExprs,wgtCut); DELNULL(tmpCnvs); + + TH1 * hisTMP = (TH1F*)gDirectory->Get(hisName); + VERIFY(LOCATION,(TString)"Could not derive histogram ("+hisName+") from chain ... Something is horribly wrong ?!?!",(dynamic_cast(hisTMP))); + + sumEnt = hisTMP->GetEntries(); + + VERIFY(LOCATION,(TString)"Got sum of entries = "+utils->floatToStr(sumEnt)+" ... Something is horribly wrong ?!?! ",(sumEnt > 0)); + + chainEntV[2] = sumEnt; + DELNULL(hisTMP); + } } } } @@ -398,10 +512,9 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString // ----------------------------------------------------------------------------------------------------------- - // cabook the variables and chains in the factory and initialize the kd-tree + // book the variables and chains in the factory and initialize the kd-tree // ----------------------------------------------------------------------------------------------------------- for(int nChainNow=0; nChainNow<2; nChainNow++) { - double objFracNow = (nChainNow == 0) ? sampleFracInp : sampleFracRef; int nTrainObj = static_cast(floor(aChainV[nChainNow]->GetEntries() * objFracNow)); @@ -409,49 +522,63 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString +utils->lIntToStr(aChainV[nChainNow]->GetEntries())+" objects, now has "+utils->lIntToStr(nTrainObj) +" objects (minimum is \"minObjTrainTest\" = "+utils->lIntToStr(minObjTrainTest)+") ...",(nTrainObj >= minObjTrainTest)); - int split0 = 20; - int split1 = static_cast(floor(split0/objFracNow)); - TString fracCut = (TString)indexName+" % "+TString::Format("%d < %d",split1,split0); - TCut finalCut = ((TCut)fracCut) + ((TCut)chainCutV[nChainNow]); - TString verbLvlM = (TString)(debug ? ":V:H" : ":!V:!H"); TString optKNN = (TString)":nkNN=10:ScaleFrac=0.0"; TString trainValidStr = (TString)"nTrain_Regression=0:nTest_Regression=0:SplitMode=Random:NormMode=NumEvents:!V"; - // define all (scaled) input variables as floats in the factory - for(int nVarNow=0; nVarNowAddVariable(varNamesScaled[nVarNow],varNamesScaled[nVarNow],"",'F'); - } + // setup multiple modules with decreasing object fractions (for search of far neighbours) + // ----------------------------------------------------------------------------------------------------------- + int nKnnFracsIn(0); + for(int nFracNow=0; nFracNow(floor(EPS + split0 * pow(knnFracFact,nFracNow) / objFracNow)); + TString fracCut = (TString)indexName+" % "+TString::Format("%d < %d",split1,split0); + TCut finalCut = ((TCut)fracCut) + ((TCut)chainCutV[nChainNow]); + TString wgtNameNow = (TString)wgtKNNname+TString::Format("_%d",nFracNow); + + // define all (scaled) input variables as floats in the factory + for(int nVarNow=0; nVarNowAddVariable(varNamesScaled[nVarNow],varNamesScaled[nVarNow],"",'F'); + } - knnErrFactory[nChainNow]->AddRegressionTree(aChainV[nChainNow], 1, TMVA::Types::kTesting); - knnErrFactory[nChainNow]->AddRegressionTree(aChainV[nChainNow], 1, TMVA::Types::kTraining); - knnErrFactory[nChainNow]->SetWeightExpression(chainWgtV[nChainNow],"Regression"); + knnErrFactory[nChainNow][nFracNow]->AddRegressionTree(aChainV[nChainNow], 1, TMVA::Types::kTesting); + knnErrFactory[nChainNow][nFracNow]->AddRegressionTree(aChainV[nChainNow], 1, TMVA::Types::kTraining); + knnErrFactory[nChainNow][nFracNow]->SetWeightExpression(chainWgtV[nChainNow],"Regression"); - knnErrFactory[nChainNow]->PrepareTrainingAndTestTree(finalCut,trainValidStr); + knnErrFactory[nChainNow][nFracNow]->PrepareTrainingAndTestTree(finalCut,trainValidStr); - knnErrMethod[nChainNow] = dynamic_cast(knnErrFactory[nChainNow]->BookMethod(TMVA::Types::kKNN, wgtKNNname,(TString)optKNN+verbLvlM)); - knnErrModule[nChainNow] = knnErrMethod[nChainNow]->fModule; + knnErrMethod[nChainNow][nFracNow] = dynamic_cast + (knnErrFactory[nChainNow][nFracNow]->BookMethod(TMVA::Types::kKNN, wgtNameNow,(TString)optKNN+verbLvlM)); + + // fill the module with events made from the tree entries and create the binary tree + knnErrMethod[nChainNow][nFracNow]->Train(); - // fill the module with events made from the tree entries and create the binary tree - // ----------------------------------------------------------------------------------------------------------- - knnErrMethod[nChainNow]->Train(); + // make sure we have enough objects after the fraction-cut, and that we aren't repeating ourselves + int nEffObj = knnErrMethod[nChainNow][nFracNow]->fEvent.size(); + if(nEffObj < minNobjInVol*5) break; + + if(nKnnFracsIn == 0) { + if(nChainNow == 0) effEntRatio *= nEffObj; + else effEntRatio /= nEffObj; + } + + knnErrModule[nChainNow][nFracNow] = knnErrMethod[nChainNow][nFracNow]->fModule; - aLOG(Log::INFO) <GetName()<GetName()<fScaleFrac < EPS)); + // sanity check - if this is not true, the distance calculations will be off + VERIFY(LOCATION,(TString)"Somehow the fScaleFrac for the kd-tree is not zero ... Something is horribly wrong ?!?!?" + ,(knnErrMethod[nChainNow][nFracNow]->fScaleFrac < EPS)); + nKnnFracsIn++; + } + // final check on nKnnFracs - at least two knnErrModule need to be accepted + VERIFY(LOCATION,(TString)"Could not find enough objects for the KNN search." + +" Try to decrease the value of "+"minNobjInVol"+typePostfix+" ...",(nKnnFracsIn > 1)); } - // chainWgtNormV now holds normalization, such that the weighted sum of entries in the main and in the - // reference chains is the same (multiplying by the max(chainEntV[0],chainEntV[1]) is not strictly - // needed, but may reduce numerical errors when dealing with large input samples) - // ----------------------------------------------------------------------------------------------------------- - for(int nChainNow=0; nChainNow<2; nChainNow++) { chainWgtNormV[nChainNow] *= max(chainEntV[0],chainEntV[1]); } // ----------------------------------------------------------------------------------------------------------- // create the vars to read/write trees @@ -467,7 +594,7 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString var_0->NewForm(varFormNames[nVarNow],varNamesScaled[nVarNow]); } - var_0->connectTreeBranches(aChainInp); + var_0->connectTreeBranches(aChainInpEvl); vector < pair > varTypeNameV; var_1->varStruct(var_0,NULL,NULL,&varTypeNameV); @@ -485,12 +612,12 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString // loop on the tree // ----------------------------------------------------------------------------------------------------------- double weightSum(0); - vector distIndexV(2); - vector distV(2), weightSumV(2); + vector distIndexV(2,0); + vector distV(2,0), weightSumV(2,0); TMVA::kNN::VarVec objNowV(nVars,0); - int nObjectsToPrint = min(static_cast(aChainInp->GetEntries()/10.) , glob->GetOptI("nObjectsToPrint")); - bool breakLoop(false), mayWriteObjects(false); + int nObjectsToPrint = min(static_cast(aChainInpEvl->GetEntries()/10.) , glob->GetOptI("nObjectsToPrint")); + bool breakLoop(false), mayWriteObjects(false); var_0->clearCntr(); for(Long64_t loopEntry=0; true; loopEntry++) { if(!var_0->getTreeEntry(loopEntry)) breakLoop = true; @@ -515,87 +642,96 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString if(objNowV[nVarNow] < minMaxVarVals[0][nVarNow] || objNowV[nVarNow] > minMaxVarVals[1][nVarNow]) { isInsideRef = false; - var_0->IncCntr("Found good weight"); var_0->IncCntr(wgtKNNname+" = 0 (outside input parameter range)"); + var_0->IncCntr(wgtKNNname+" = 0 (outside input parameter range)"); break; } } double weightKNN(0); + if(isInsideRef) { + const TMVA::kNN::Event evtNow(objNowV,1,0); + // ----------------------------------------------------------------------------------------------------------- // derive the weights as the ratio between input and reference samples, if needed // ----------------------------------------------------------------------------------------------------------- if(doRelWgts) { - const TMVA::kNN::Event evtNow(objNowV,1,0); - // find the same number of near neighbours for each chain, and derive the distance this requires int nObjKNN(minNobjInVol); for(int nChainNow=0; nChainNow<2; nChainNow++) { - double wgtNorm = chainWgtNormV[nChainNow]; - knnErrModule[nChainNow]->Find(evtNow,nObjKNN); - const TMVA::kNN::List & knnList = knnErrModule[nChainNow]->GetkNNList(); + knnErrModule[nChainNow][0]->Find(evtNow,nObjKNN); + const TMVA::kNN::List & knnList = knnErrModule[nChainNow][0]->GetkNNList(); weightSumV[nChainNow] = 0; for(TMVA::kNN::List::const_iterator lit=knnList.begin(); lit!=knnList.end(); ++lit) { - double wgtNow = (lit->first->GetEvent()).GetWeight() * wgtNorm; + double wgtNow = (lit->first->GetEvent()).GetWeight(); VERIFY(LOCATION,(TString)"Found negative weight in reference sample ... Something is horribly wrong ?!?",(wgtNow > 0)); weightSumV[nChainNow] += wgtNow; } + distV[nChainNow] = evtNow.GetDist(knnList.back().first->GetEvent()); } + // the index of the chain with the shorter distance if(distV[0] < distV[1]) { distIndexV[0] = 0; distIndexV[1] = 1; } else { distIndexV[0] = 1; distIndexV[1] = 0; } bool foundDist(false); - double weightSumNow(0), wgtNorm(chainWgtNormV[distIndexV[0]]); - int preNobj(0); - for(int nSearchNow=0; nSearchNow<10; nSearchNow++) { - preNobj = nObjKNN; nObjKNN *= 5; + for(int nFracNow=1; nFracNowFind(evtNow,nObjKNN); - const TMVA::kNN::List & knnList = knnErrModule[distIndexV[0]]->GetkNNList(); + knnErrModule[distIndexV[0]][nFracNow]->Find(evtNow,nObjKNN); + const TMVA::kNN::List & knnList = knnErrModule[distIndexV[0]][nFracNow]->GetkNNList(); - for(TMVA::kNN::List::const_iterator lit=std::next(knnList.begin(),preNobj); lit!=knnList.end(); ++lit) { - const TMVA::kNN::Event & eventNow = lit->first->GetEvent(); + weightSumV[distIndexV[0]] = 0; + for(TMVA::kNN::List::const_iterator lit=knnList.begin(); lit!=knnList.end(); ++lit) { + const TMVA::kNN::Event & evtLst = lit->first->GetEvent(); - double knnDistNow = evtNow.GetDist(eventNow); - double weightObj = eventNow.GetWeight() * wgtNorm; + double knnDistNow = evtNow.GetDist(evtLst); + double weightObj = evtLst.GetWeight(); VERIFY(LOCATION,(TString)"Found negative weight in reference sample ... Something is horribly wrong ?!?",(weightObj > 0)); - if(knnDistNow < distV[distIndexV[1]]) { weightSumNow += weightObj; } - else { foundDist = true; break; } + if(knnDistNow < distV[distIndexV[1]]) { weightSumV[distIndexV[0]] += weightObj; } + else { foundDist = true; break; } } if(foundDist) { - weightSumV[distIndexV[0]] += weightSumNow; - if(weightSumV[1] * weightSumV[0] > EPS) weightKNN = weightSumV[1]/weightSumV[0]; - break; + if(weightSumV[0] > EPS && weightSumV[1] > EPS) { + // use effEntRatio, a constant normalization - this does not change the result, but may help + // to prevent numerical errors if the sizes of the input/reference samples is very different + weightSumV[1] *= effEntRatio; + + // correct for the hierarchical search-level + weightSumV[distIndexV[0]] *= pow(knnFracFact,nFracNow); + + // finally, calculate the weight for this object + weightKNN = weightSumV[1]/weightSumV[0]; + + break; + } } } - if(foundDist) { var_0->IncCntr("Found good weight"); weightSum += weightKNN; } - else { var_0->IncCntr("Did not fined good weight"); } + if(foundDist) { var_0->IncCntr("Found good weight"); weightSum += weightKNN; } + else { var_0->IncCntr("Did not find good weight"); } } // ----------------------------------------------------------------------------------------------------------- // derive the weight from the approximated density estimation of near objects from the reference sample, if needed // ----------------------------------------------------------------------------------------------------------- else { - const TMVA::kNN::Event evtInp(objNowV,1,0); - // find the closest object in the reference chain - knnErrModule[1]->Find(evtInp,1); - const TMVA::kNN::List & knnListInp = knnErrModule[1]->GetkNNList(); + knnErrModule[1][0]->Find(evtNow,1); + const TMVA::kNN::List & knnListInp = knnErrModule[1][0]->GetkNNList(); // must make a sanith check before using the pointer to GetEvent() VERIFY(LOCATION,(TString)"could not find any near neighbours for objects ... Something is horribly wrong ?!?",(knnListInp.size() > 0)); const TMVA::kNN::Event evtRef(knnListInp.back().first->GetEvent()); - // find the diatnace to the reference object we just found - double dist_Ref_Inp = evtInp.GetDist(evtRef); + // find the distnace to the reference object we just found + double dist_Ref_Inp = evtNow.GetDist(evtRef); double wgt_Ref_Inp = evtRef.GetWeight(); VERIFY(LOCATION,(TString)"Found negative weight in reference sample ... Something is horribly wrong ?!?",(wgt_Ref_Inp > 0)); @@ -605,30 +741,29 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString // are not defined, then the sum of weights will be exactly minNobjInVol. Otherwise, several searches with increasing // NN numbers may be needed... // ----------------------------------------------------------------------------------------------------------- - int minNobjInVolNow(minNobjInVol), minNobjInVolPrev(0); bool foundDist(false); - double wgtSum_Ref0_RefNear(0), dist_Ref0_RefNear(0); + double dist_Ref0_RefNear(0); + for(int nFracNow=0; nFracNowFind(evtRef,minNobjInVolNow+1); - const TMVA::kNN::List & knnListRef = knnErrModule[1]->GetkNNList(); + knnErrModule[1][nFracNow]->Find(evtRef,minNobjInVol); + const TMVA::kNN::List & knnListRef = knnErrModule[1][nFracNow]->GetkNNList(); - for(TMVA::kNN::List::const_iterator lit=std::next(knnListRef.begin(),minNobjInVolPrev); lit!=knnListRef.end(); ++lit) { - const TMVA::kNN::Event & eventRefNow = lit->first->GetEvent(); + for(TMVA::kNN::List::const_iterator lit=knnListRef.begin(); lit!=knnListRef.end(); ++lit) { + const TMVA::kNN::Event & evtLst = lit->first->GetEvent(); - double distNow = evtRef.GetDist(eventRefNow); if(distNow < EPS) continue; // the first element is the initial object (-> zero distance) - double wgtNow = eventRefNow.GetWeight(); + double distNow = evtRef.GetDist(evtLst); if(distNow < EPS) continue; // the first element is the initial object (-> zero distance) + double wgtNow = evtLst.GetWeight(); VERIFY(LOCATION,(TString)"Found negative weight in reference sample ... Something is horribly wrong ?!?",(wgtNow > 0)); dist_Ref0_RefNear = distNow; wgtSum_Ref0_RefNear += wgtNow; - if(wgtSum_Ref0_RefNear >= minNobjInVol * wgt_Ref_Inp) { foundDist = true; break; } + if(wgtSum_Ref0_RefNear >= minNobjInVolWgt) { foundDist = true; break; } } if(foundDist) break; - - minNobjInVolPrev = minNobjInVolNow; minNobjInVolNow *= 5; } if(foundDist) { @@ -656,9 +791,9 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString var_1->SetVarF(wgtKNNname,weightKNN); - outTree->Fill(); + var_1->fillTree(); - mayWriteObjects = true; var_0->IncCntr("nObj"); if(var_0->GetCntr("nObj") == maxNobj) breakLoop = true; + mayWriteObjects = true; var_0->IncCntr("nObj"); /// Cant use this here !!! if(var_0->GetCntr("nObj") == maxNobj) breakLoop = true; } if(!breakLoop) { var_0->printCntr(outTreeName); outputs->WriteOutObjects(false,true); outputs->ResetObjects(); } @@ -675,7 +810,7 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString VERIFY(LOCATION,(TString)"Got sum of weights from the entire main sample = "+utils->floatToStr(weightSum) +" ... Something is horribly wrong ?!?! ",(weightSum > 0)); - double weightNorm = chainEntV[0] / weightSum; + double weightNorm = (hasChainEvl ? chainEntV[2] : chainEntV[0]) / weightSum; // create a temporary sub-dir for the output trees of the above // ----------------------------------------------------------------------------------------------------------- @@ -726,7 +861,7 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString double weightKNN = var_0->GetVarF(wgtKNNname) * weightNorm; var_1->SetVarF(wgtKNNname,weightKNN); - outTree->Fill(); + var_1->fillTree(); mayWriteObjects = true; var_0->IncCntr("nObj"); if(var_0->GetCntr("nObj") == maxNobj) breakLoop = true; } @@ -763,7 +898,7 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString // ----------------------------------------------------------------------------------------------------------- // write out the weights to an ascii file // ----------------------------------------------------------------------------------------------------------- - if(doStoreToAscii ) { + if(doStoreToAscii) { // extract the names of the variables which will be written out to the ascii output // including the derived KNN weights. Add the actuall weight variable if not already included vector outVarNames = utils->splitStringByChar(outAsciiVars,';'); @@ -772,7 +907,7 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString // create a VarMaps, connect it to the tree, and write out the requested variables VarMaps * var_2 = new VarMaps(glob,utils,"treeRegClsVar_2"); var_2->connectTreeBranches(aChainOut); - var_2->storeTreeToAscii((TString)wgtKNNname+aChainInp->GetName(),"",0,nObjectsToWrite,"",&outVarNames,NULL); + var_2->storeTreeToAscii((TString)wgtKNNname+aChainInpEvl->GetName(),"",0,nObjectsToWrite,"",&outVarNames,NULL); DELNULL(var_2); outVarNames.clear(); } @@ -792,16 +927,16 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString TString branchType = var_0->GetVarType(branchName); if(branchType != "F" && branchType != "D" && branchType != "I") continue; - if(branchName.BeginsWith(baseName_ANNZ)) continue; + if(branchName.BeginsWith(basePrefix)) continue; if(find(varNames.begin(),varNames.end(), branchName) != varNames.end()) continue; // only accept branches common to all chains int hasBranch(0); for(int nChainNow=0; nChainNow<3; nChainNow++) { TChain * aChain(NULL); - if (nChainNow == 0) { aChain = aChainRef; } - else if(nChainNow == 1) { aChain = aChainInp; } - else if(nChainNow == 2) { aChain = aChainOut; } + if (nChainNow == 0) { aChain = aChainRef; } + else if(nChainNow == 1) { aChain = aChainInpEvl; } + else if(nChainNow == 2) { aChain = aChainOut; } if(aChain->FindBranch(branchName)) hasBranch++; } @@ -830,9 +965,9 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString TString weightNow("1"), hisTitle(""); TChain * aChain(NULL); - if (nChainNow == 0) { aChain = aChainRef; hisTitle = "Reference"; } - else if(nChainNow == 1) { aChain = aChainInp; hisTitle = "Original"; } - else if(nChainNow == 2) { aChain = aChainOut; hisTitle = "Weighted"; } + if (nChainNow == 0) { aChain = aChainRef; hisTitle = "Reference"; } + else if(nChainNow == 1) { aChain = aChainInpEvl; hisTitle = "Original"; } + else if(nChainNow == 2) { aChain = aChainOut; hisTitle = "Weighted"; } if(nChainNow == 2 && wgtKNNname != weightName) weightNow = wgtKNNname; @@ -853,7 +988,7 @@ void CatFormat::addWgtKNNtoTree(TChain * aChainInp, TChain * aChainRef, TString } weightNow = utils->cleanWeightExpr(weightNow); - TString hisName = (TString)baseName+aChainInp->GetName()+nVarName+nChainKNNname; + TString hisName = (TString)baseName+aChainInpEvl->GetName()+nVarName+nChainKNNname; TString drawExprs = (TString)varNameNow+">>"+hisName; if(nDrawNow == 1) drawExprs += TString::Format("(%d,%f,%f)",nDrawBins,drawLim0,drawLim1); // cout < "<BaseDir->cd(); + for(int nFracNow=0; nFracNowsafeRM(outFileDirKnnErrV[nChainNow],inLOG(Log::DEBUG)); - utils->safeRM(outFileNameKnnErr[nChainNow],inLOG(Log::DEBUG)); + // after closing a TFile, need to return to the correct directory, or else histogram pointers will be affected + outputs->BaseDir->cd(); + + utils->safeRM(outFileNameKnnErr[nChainNow][nFracNow],inLOG(Log::DEBUG)); + } } + utils->safeRM(outFileDirKnnErrV,inLOG(Log::DEBUG)); knnErrOutFile.clear(); knnErrFactory.clear(); knnErrMethod.clear(); knnErrModule.clear(); aChainV.clear(); - outFileDirKnnErrV.clear(); minMaxVarVals.clear(); outFileNameKnnErr.clear(); varNames.clear(); chainWgtV.clear(); chainCutV.clear(); + minMaxVarVals.clear(); outFileNameKnnErr.clear(); varNames.clear(); chainWgtV.clear(); chainCutV.clear(); varFormNames.clear(); objNowV.clear(); distV.clear(); weightSumV.clear(); distIndexV.clear(); - chainWgtNormV.clear(); chainEntV.clear(); varNamesScaled.clear(); + chainEntV.clear(); varNamesScaled.clear(); for(int nChainNow=0; nChainNow<2; nChainNow++) { for(int nVarNow=0; nVarNowOptOrNullB("moreLogX")) { grph->GetXaxis()->SetMoreLogLabels(); } if(draw->OptOrNullB("moreLogY")) { grph->GetYaxis()->SetMoreLogLabels(); } + if(draw->HasOptC("yRangeL") && draw->HasOptC("yRangeH")) { + grph->GetYaxis()->SetRangeUser(draw->GetOptC("yRangeL").Atof(),draw->GetOptC("yRangeH").Atof()); + } + if(dynamic_cast(mGrph)) { mGrph->GetXaxis()->SetTitle(grph->GetXaxis()->GetTitle()); mGrph->GetYaxis()->SetTitle(grph->GetYaxis()->GetTitle()); @@ -99,6 +103,10 @@ void OutMngr::SetHisStyle(TGraph * grph, TMultiGraph * mGrph) { if(draw->OptOrNullB("moreLogX")) { mGrph->GetXaxis()->SetMoreLogLabels(); } if(draw->OptOrNullB("moreLogY")) { mGrph->GetYaxis()->SetMoreLogLabels(); } // grph->GetXaxis()->Copy(*mGrph->GetXaxis()); grph->GetYaxis()->Copy(*mGrph->GetYaxis()); // must come after: mGrph->Draw("a") + + if(draw->HasOptC("yRangeL") && draw->HasOptC("yRangeH")) { + mGrph->GetYaxis()->SetRangeUser(draw->GetOptC("yRangeL").Atof(),draw->GetOptC("yRangeH").Atof()); + } } return; diff --git a/src/Utils.cpp b/src/Utils.cpp index 73c461a..3901f44 100644 --- a/src/Utils.cpp +++ b/src/Utils.cpp @@ -222,11 +222,21 @@ vector Utils::splitStringByChar(TString s, char delim) { } // =========================================================================================================== -void Utils::safeRM(TString cmnd, bool verbose) { -// ============================================= +void Utils::safeRM(TString cmnd, bool verbose, bool checkExitStatus) { +// =================================================================== checkCmndSafety(cmnd); - exeShellCmndOutput((TString)"rm -rf "+cmnd,verbose); + int sysReturn = exeShellCmndOutput((TString)"rm -rf "+cmnd,verbose,false); + + if(checkExitStatus && sysReturn != 0) { + TString junkDirName = (TString)"/tmp/"+regularizeName(glob->basePrefix(),"")+"_junk/" + +((doubleToStr(rnd->Rndm(),"%.20f")).ReplaceAll("0.",""))+"/"; + + aLOG(Log::WARNING) < * outV, bool ve } // =========================================================================================================== -void Utils::exeShellCmndOutput(TString cmnd, bool verbose, bool checkExitStatus) { -// =============================================================================== +int Utils::exeShellCmndOutput(TString cmnd, bool verbose, bool checkExitStatus) { +// ============================================================================== if(glob->OptOrNullB("debugSysCmnd")) verbose = true; int sysReturn = system(cmnd); if(verbose) aCustomLOG(" ")<setLock(true); + rnd = new TRandom3(0); + return; } // =========================================================================================================== Utils::~Utils() { // ============== - DELNULL(param); + DELNULL(param); DELNULL(rnd); colours.clear(); markers.clear(); greens.clear(); blues.clear(); reds.clear(); fillStyles.clear(); return; } diff --git a/src/VarMaps.cpp b/src/VarMaps.cpp index 5f1e6c9..de166b9 100644 --- a/src/VarMaps.cpp +++ b/src/VarMaps.cpp @@ -165,6 +165,8 @@ void VarMaps::NewForm(TString aName, TString input) { <(treeWrite))); + treeWrite->Fill(); + return; +} // =========================================================================================================== @@ -629,7 +639,7 @@ void VarMaps::eraseTreeCutsPattern(TString cutPattern, bool ignorCase) { if(hasEle || cutPattern == "") eraseEle.push_back(itr->first); } - for(int nEraseEleNow=0; nEraseEleNow<(int)eraseEle.size(); nEraseEleNow++) treeCutsM.erase(eraseEle[nEraseEleNow]); + for(int nEraseEleNow=0; nEraseEleNow<(int)eraseEle.size(); nEraseEleNow++) { treeCutsM.erase(eraseEle[nEraseEleNow]); } eraseEle.clear(); return; @@ -639,10 +649,14 @@ int VarMaps::replaceTreeCut(TString oldCut, TString newCut) { // ========================================================== TString cutExpr(""); int foundOldCut(0); + + oldCut = regularizeStringForm(oldCut); + newCut = regularizeStringForm(newCut); + for(Map ::iterator itr=treeCutsM.begin(); itr!=treeCutsM.end(); ++itr) { if(!(((TString)itr->second).Contains(oldCut))) continue; - cutExpr = itr->second; cutExpr.ReplaceAll(oldCut,newCut); + cutExpr = itr->second; cutExpr.ReplaceAll(oldCut,newCut); itr->second = (TCut)cutExpr; foundOldCut++; } @@ -652,10 +666,12 @@ int VarMaps::replaceTreeCut(TString oldCut, TString newCut) { // =========================================================================================================== void VarMaps::addTreeCuts(TString cutType, TCut aCut) { // ==================================================== - TString cutStr0 = treeCutsM[cutType]; cutStr0.ReplaceAll(" ",""); - TString cutStr1 = (TString)aCut; cutStr1.ReplaceAll(" ",""); + TString cutStr0 = treeCutsM[cutType]; cutStr0 = (regularizeStringForm(cutStr0)).ReplaceAll(" ",""); + TString cutStr1 = (TString)aCut; cutStr1 = (regularizeStringForm(cutStr1)).ReplaceAll(" ",""); - if(cutStr1 != "" && !cutStr0.Contains(cutStr1)) treeCutsM[cutType] = (TString)( (TCut)(treeCutsM[cutType]) + aCut ); + if(cutStr1 != "" && !cutStr0.Contains(cutStr1)) { + treeCutsM[cutType] = (TString)( (TCut)(treeCutsM[cutType]) + (TCut)(aCut) ); + } printCut(cutType); return; @@ -663,7 +679,7 @@ void VarMaps::addTreeCuts(TString cutType, TCut aCut) { // =========================================================================================================== void VarMaps::setTreeCuts(TString cutType, TCut aCut) { // ==================================================== - treeCutsM[cutType] = (TString)aCut; + treeCutsM[cutType] = regularizeStringForm((TString)aCut); printCut(cutType); return; } @@ -671,14 +687,18 @@ void VarMaps::setTreeCuts(TString cutType, TCut aCut) { void VarMaps::getTreeCutsM(map & aTreeCutsM) { // ========================================================== aTreeCutsM.clear(); - for(Map ::iterator itr=treeCutsM.begin(); itr!=treeCutsM.end(); ++itr) { aTreeCutsM[itr->first] = (TCut)itr->second; } + for(Map ::iterator itr=treeCutsM.begin(); itr!=treeCutsM.end(); ++itr) { + aTreeCutsM[itr->first] = (TCut)itr->second; + } return; } // =========================================================================================================== void VarMaps::setTreeCutsM(map & aTreeCutsM) { // ========================================================== treeCutsM.clear(); - for(map ::iterator itr=aTreeCutsM.begin(); itr!=aTreeCutsM.end(); ++itr) { treeCutsM[itr->first] = (TString)itr->second; } + for(map ::iterator itr=aTreeCutsM.begin(); itr!=aTreeCutsM.end(); ++itr) { + treeCutsM[itr->first] = regularizeStringForm((TString)itr->second); + } return; } // =========================================================================================================== @@ -719,7 +739,7 @@ bool VarMaps::hasFailedTreeCuts(TString cutType) { if(!areCutsEnabled) return false; vector cutsV = utils->splitStringByChar(cutType,';'); - + for(int nCutTypeNow=0; nCutTypeNow<(int)cutsV.size(); nCutTypeNow++) { TString cutTypeNow = cutsV[nCutTypeNow]; @@ -1279,13 +1299,14 @@ void VarMaps::setTreeForms(bool isFirstEntry) { if(isFirstEntry) { for(Map ::iterator itr=nameMap->begin(); itr!=nameMap->end(); ++itr) { - TString treeForm = itr->second; treeForm.ReplaceAll(" ",""); + TString treeForm = (regularizeStringForm(itr->second)).ReplaceAll(" ",""); TString treeFormName = utils->regularizeName( (TString)treeRead->GetName()+"_"+itr->first ); - TCut aCut = (TCut)((treeForm == "") ? "1" : treeForm); - //cout <<"setTreeForms "<find(itr->first) != formMap->end()) DELNULL((*formMap)[itr->first]); - (*formMap)[itr->first] = new TTreeFormula(treeFormName,aCut,treeRead); + (*formMap)[itr->first] = new TTreeFormula(treeFormName,(TCut)aCut,treeRead); VERIFY(LOCATION,(TString)"TTreeFormula is not valid (\""+(TString)aCut+"\") ...",((*formMap)[itr->first]->GetNdim() != 0)); } @@ -1411,7 +1432,7 @@ void VarMaps::storeTreeToAscii(TString outFilePrefix, TString outFileDir, int ma IncCntr("nObj"); if(GetCntr("nObj") == maxNobj) break; - if(!dynamic_cast(fout) || (nLinesFile > 0 && (GetCntr("nObj") % nLinesFile == 0))) { + if(!dynamic_cast(fout) || (dynamic_cast(fout) && (nLinesFile > 0) && ((GetCntr("nObj")-1) % nLinesFile == 0))) { nOutFileNow += 1; outFileName = (TString)outFileDir+outFilePrefix+"_"+TString::Format("%4.4d",nOutFileNow-1)+csvPostfix; diff --git a/src/myANNZ.cpp b/src/myANNZ.cpp index 93a1da3..ba4f5aa 100644 --- a/src/myANNZ.cpp +++ b/src/myANNZ.cpp @@ -167,14 +167,14 @@ myANNZ::myANNZ() { glob->NewOptC("inpFiles_bck",""); // nSplit - how to split into training/testing(/validation) - - // nSplit = [2,3] -> split into 2 (training,testing) or 3 (training,testing,validting) sub-sets - for trainig/optimization - // nSplit = 1 -> no splitting - for evaluation + // nSplit = 2 -> split into 2 (training,testing) sub-sets - for trainig/optimization + // nSplit = 1 -> no splitting - for evaluation // ----------------------------------------------------------------------------------------------------------- - glob->NewOptI("nSplit" ,3); + glob->NewOptI("nSplit" ,2); glob->NewOptC("splitType" ,"byInFiles"); // [serial,blocks,random,byInFiles] - methods for splitting the input dataset glob->NewOptC("splitTypeTrain",""); // in case of seperate input files - this is the list of training input files glob->NewOptC("splitTypeTest" ,""); // in case of seperate input files - this is the list of testing input files - glob->NewOptC("splitTypeValid",""); // in case of seperate input files - this is the list of validating input files + glob->NewOptC("splitTypeValid",""); // deprecated (kept for backward compatibility only) glob->NewOptI("splitSeed" ,19888687); // seed for random number generator for one of the splitting methods glob->NewOptC("inputVariables",""); // list of input variables as they appear in the input ascii files glob->NewOptC("inputVarErrors",""); // (optional) list of input variable errors @@ -194,8 +194,19 @@ myANNZ::myANNZ() { glob->NewOptF("sampleFracInp_wgtKNN" ,1); // fraction of the input sample to use for the kd-tree (positive number, smaller or equal to 1) glob->NewOptF("sampleFracRef_wgtKNN" ,1); // fraction of the input sample to use for the kd-tree (positive number, smaller or equal to 1) glob->NewOptB("doWidthRescale_wgtKNN",true); // transform the input parameters used for the kd-tree to the range [-1,1] - - // input files (given by splitTypeTrain, splitTypeTest, splitTypeValid and inAsciiFiles) may also be root files, containing + // number of KNN modules to use for hierarchical searches (may limit if consumes too much memory, but must be >= 2 + glob->NewOptI("nKnnFracs_wgtKNN" ,10); + // factor to decrease fraction of accepted objects for each KNN module - e.g., for module 1 all objects are + // in, for module 2, 1/knnFracFact_wgtKNN are in, for module 3 1/(knnFracFact_wgtKNN*knnFracFact_wgtKNN) are in ... + glob->NewOptI("knnFracFact_wgtKNN" ,3); + // by default, the weights are computed for the entire sample. That is, the training and the testing samples + // are used together - we calculate the difference between the distribution of input-variables between [train+test samples] + // and [ref sample]. However, it is possible to decide to comput the weights for each separately. That is, to calculate + // wegiths for [train sample] with regards to [ref sample], and to separately get [test sample] with regards to [ref sample]. The + // latter is only recommended if the training and testing samples have different inpput-variable distributions. + glob->NewOptB("trainTestTogether_wgtKNN",true); + + // input files (given by splitTypeTrain, splitTypeTest and inAsciiFiles) may also be root files, containing // root trees, instead of ascii files. In this case, the name of the tree in the input files is defined in inTreeName glob->NewOptC("inTreeName" ,""); // if root input is given in inAsciiFiles_wgtKNN, the corresponding tree name is defined in treeName_wgtKNN @@ -210,12 +221,13 @@ myANNZ::myANNZ() { // The calculation is performed using a KNN approach, similar to the algorithm used for the "useWgtKNN" calculation. // minNobjInVol_inTrain - The number of reference objects in the reference dataset which are used in the calculation. // maxRelRatioInRef_inTrain - A number in the range, [0,1] - The minimal threshold of the relative difference between distances - // in the inTrainFlag calculation for accepting an object. + // in the inTrainFlag calculation for accepting an object. If set to a negative value, then the value of the + // output parameter will be distributed within the range [0,1]. // ...._inTrain - The rest of the parameters ending with "_inTrain" have a similar role as their "_wgtKNN" counterparts // ----------------------------------------------------------------------------------------------------------- glob->NewOptB("addInTrainFlag" ,false); glob->NewOptI("minNobjInVol_inTrain" ,100); - glob->NewOptF("maxRelRatioInRef_inTrain",0.1); + glob->NewOptF("maxRelRatioInRef_inTrain",-1); glob->NewOptC("weightVarNames_inTrain" ,""); // list of input variables for KNN in/out computation glob->NewOptC("outAsciiVars_inTrain" ,""); // list of output variables to be written to the ascii output of the KNN in/out computation glob->NewOptC("weightInp_inTrain" ,""); // weight expression for input kd-tree (function of the variables used in weightVarNames_inTrain) @@ -365,7 +377,7 @@ myANNZ::myANNZ() { glob->NewOptF("excludeRangePdfModelFit",0.1); // exclude margin for fitting cumulative dist as part of PDF optimization // optimCondReg - - // ["bias", "sig68" or "fracSig68"] - used for deciding how to rank MLM performance. the named criteria represents + // ["sig68" or "bias"] - used for deciding how to rank MLM performance. the named criteria represents // the metric which is more significant in deciding which MLM performs "best". // ----------------------------------------------------------------------------------------------------------- glob->NewOptC("optimCondReg","sig68"); @@ -419,6 +431,10 @@ myANNZ::myANNZ() { glob->NewOptI("minAcptMLMsForPDFs",5); // wether or not to perform a bias-correction on PDFs glob->NewOptB("doBiasCorPDF" ,true); + // number of random smearing to perform for the PDF bias-correction + glob->NewOptI("nSmearUnf" ,100); + // add calculation of maximum of PDF to output + glob->NewOptB("addMaxPDF" ,false); // if max_sigma68_PDF,max_bias_PDF are positive, they put thresholds on the maximal value of the // scatter/bias/outlier-fraction of an MLM which may be included in the PDF created in randomized regression @@ -550,25 +566,26 @@ void myANNZ::Init() { // ----------------------------------------------------------------------------------------------------------- // names of index parameters in trees // ----------------------------------------------------------------------------------------------------------- - glob->NewOptC("baseName_ANNZ" ,"ANNZ_"); // base tag for all MLM names - glob->NewOptC("baseName_inVarErr" ,"ANNZ_inVarErr_"); // base tag for all PDF names - glob->NewOptC("baseName_nPDF" ,"ANNZ_PDF_"); // base tag for all PDF names - glob->NewOptC("baseName_wgtKNN" ,"ANNZ_KNN_w"); // KNN weight variable - glob->NewOptC("treeName" ,"ANNZ_tree"); // internal name prefix for input trees - glob->NewOptC("hisName" ,"ANNZ_his"); // internal name prefix for histograms - glob->NewOptC("indexName" ,"ANNZ_index"); // original index from input file - glob->NewOptC("splitName" ,"ANNZ_split"); // continous index for a given sub-sample (training,testing,validting) - glob->NewOptC("origFileName" ,"ANNZ_inFile"); // name of original source file - glob->NewOptC("testValidType" ,"ANNZ_tvType"); // index to keep track of testing/validation sub-sample withn the _valid trees - glob->NewOptC("baseName_regBest" ,"ANNZ_best"); // the "best"-performing MLM in randomized regression - glob->NewOptC("baseName_regMLM_avg" ,"ANNZ_MLM_avg_"); // base-name for the average MLM solution (and its error) in randomized regression - glob->NewOptC("baseName_regPDF_max" ,"ANNZ_PDF_max_"); // base-name for the peak of the pdf solution (and its error) in randomized regression - glob->NewOptC("baseName_regPDF_avg" ,"ANNZ_PDF_avg_"); // base-name for the average pdf solution (and its error) in randomized regression + TString basePrefix("ANNZ_"); + glob->NewOptC("basePrefix" ,basePrefix); // base tag for all MLM names + glob->NewOptC("baseName_inVarErr" ,basePrefix+"inVarErr_"); // base tag for all PDF names + glob->NewOptC("baseName_nPDF" ,basePrefix+"PDF_"); // base tag for all PDF names + glob->NewOptC("baseName_wgtKNN" ,basePrefix+"KNN_w"); // KNN weight variable + glob->NewOptC("treeName" ,basePrefix+"tree"); // internal name prefix for input trees + glob->NewOptC("hisName" ,basePrefix+"his"); // internal name prefix for histograms + glob->NewOptC("indexName" ,basePrefix+"index"); // original index from input file + glob->NewOptC("origFileName" ,basePrefix+"inFile"); // name of original source file + glob->NewOptC("baseName_regBest" ,basePrefix+"best"); // the "best"-performing MLM in randomized regression + glob->NewOptC("baseName_regMLM_avg" ,basePrefix+"MLM_avg_"); // base-name for the average MLM solution (and its error) in randomized regression + glob->NewOptC("baseName_regPDF_max" ,basePrefix+"PDF_max_"); // base-name for the peak of the pdf solution (and its error) in randomized regression + glob->NewOptC("baseName_regPDF_avg" ,basePrefix+"PDF_avg_"); // base-name for the average pdf solution (and its error) in randomized regression glob->NewOptC("baseName_knnErr" ,"_knnErr"); // name-postfix of error variable derived with the KNN estimator // name of an optional output parameter in evaluation (is it "safe" to use the result for an evaluated object) glob->NewOptC("baseName_inTrain" ,"inTrainFlag"); // optional parameter to mark if an object is of type signal (1), background (0) or undefined (-1), based on the name of the original input file glob->NewOptC("sigBckInpName","sigBckInp"); + // glob->NewOptC("splitName" ,basePrefix+"split"); // deprecated + // glob->NewOptC("testValidType" ,basePrefix+"tvType"); // deprecated // ----------------------------------------------------------------------------------------------------------- // working directory names @@ -628,7 +645,7 @@ void myANNZ::Init() { if(glob->GetOptB("doTrain")) { int nMLMnow = glob->GetOptB("doBinnedCls") ? glob->GetOptI("nBinNow") : glob->GetOptI("nMLMnow"); - glob->SetOptC("trainDirName",(TString)glob->GetOptC("trainDirName")+glob->GetOptC("baseName_ANNZ")+TString::Format("%d",nMLMnow)+"/"); + glob->SetOptC("trainDirName",(TString)glob->GetOptC("trainDirName")+glob->GetOptC("basePrefix")+TString::Format("%d",nMLMnow)+"/"); } glob->NewOptC("outDirNamePath", (TString)glob->baseOutDirName() +glob->GetOptC("outDirName")); @@ -701,8 +718,38 @@ void myANNZ::Init() { +" without setting corresponding input files in \"inAsciiFiles_wgtKNN\"",(glob->GetOptC("inAsciiFiles_wgtKNN") != "")); } + // check the tadaset division (full, or split into training/testing). make sure that the + // deprecated split to 3 samples (training/testing/validation) is not requested by mistake + // ----------------------------------------------------------------------------------------------------------- int nSplit = glob->GetOptI("nSplit"); - VERIFY(LOCATION,(TString)"Currently, only [\"nSplit\" = 2 or 3] is supported ...",(nSplit >= 1 && nSplit <= 3)); + VERIFY(LOCATION,(TString)"Currently, only [\"nSplit\" = 1 or 2] is supported ...",(nSplit == 1 || nSplit == 2)); + + // for backward compatibility, make sure that splitTypeValid and splitTypeTest are not noth set. + // If splitTypeValid is set instead of splitTypeTest, give a warning set splitTypeTest and reset splitTypeValid. + TString sptTrn(glob->GetOptC("splitTypeTrain")), sptTst(glob->GetOptC("splitTypeTest")), sptVld(glob->GetOptC("splitTypeValid")); + if(sptTrn != "" && sptVld != "") { + VERIFY(LOCATION,(TString)"Got [\"splitTypeTrain\"="+sptTrn+"], [\"splitTypeTest\"="+sptTst+"], [\"splitTypeValid\"="+sptVld+"] ..." + +" please set only \"splitTypeTest\", as \"splitTypeValid\" is deprecated.",(sptTst == "")); + + aLOG(Log::WARNING) <SetOptC("splitTypeTest",sptVld); glob->SetOptC("splitTypeValid",""); + } + + TString sptTyp(glob->GetOptC("splitType")), inFiles(glob->GetOptC("inAsciiFiles")); + if(glob->GetOptB("doGenInputTrees") && sptTyp != "byInFiles") { + if(inFiles == "") { + VERIFY(LOCATION,(TString)"Got [\"splitType\"="+sptTyp+"] but \"inAsciiFiles\" is not set ...",false); + } + else if(sptTrn != "" || sptTst != "" || sptVld != "") { + aLOG(Log::WARNING) <GetOptI("initSeedRnd") < 0) glob->SetOptI("initSeedRnd",0); if(glob->GetOptI("maxNobj") < 0) glob->SetOptI("maxNobj" ,0);