diff --git a/RootInteractive/MLpipeline/MIForestErrPDF.py b/RootInteractive/MLpipeline/MIForestErrPDF.py index 57edcb8b..615a586e 100644 --- a/RootInteractive/MLpipeline/MIForestErrPDF.py +++ b/RootInteractive/MLpipeline/MIForestErrPDF.py @@ -37,6 +37,9 @@ def _accumulate_predictionNL(predict, X, out,col): prediction = predict(X, check_input=False) out[col] += prediction +def partitionBlock(allRF, k, begin, end): + allRF[begin:end].partition(k) + def predictRFStat(rf, X, statDictionary,n_jobs): """ inspired by https://github.com/scikit-learn/scikit-learn/blob/37ac6788c/sklearn/ensemble/_forest.py#L1410 @@ -47,25 +50,36 @@ def predictRFStat(rf, X, statDictionary,n_jobs): :param n_jobs: number of parallel jobs for prediction :return: dictionary with requested output statistics """ - allRF = np.zeros((len(rf.estimators_), X.shape[0])) + nEstimators = len(rf.estimators_) + allRF = np.zeros((nEstimators, X.shape[0])) lock = threading.Lock() statOut={} - Parallel(n_jobs=n_jobs, verbose=rf.verbose,**_joblib_parallel_args(require="sharedmem"),)( + Parallel(n_jobs=n_jobs, verbose=rf.verbose,require="sharedmem")( delayed(_accumulate_prediction)(e.predict, X, allRF, col,lock) for col,e in enumerate(rf.estimators_) ) # - if "median" in statDictionary: statOut["median"]=np.median(allRF, 0) - if "mean" in statDictionary: statOut["mean"]=np.mean(allRF, 0) - if "std" in statDictionary: statOut["std"]=np.std(allRF, 0) + allRFTranspose = allRF.T.copy(order='C') + if "median" in statDictionary: + blockSize = X.shape[0] // n_jobs + 1 + block_begin = list(range(0, X.shape[0], blockSize)) + block_end = block_begin[1:] + block_end.append(X.shape[0]) + Parallel(n_jobs=n_jobs, verbose=rf.verbose, require="sharedmem")( + delayed(partitionBlock)(allRFTranspose, nEstimators // 2, first, last) + for first, last in zip(block_begin, block_end) + ) + statOut["median"]= allRFTranspose[:,nEstimators//2] + if "mean" in statDictionary: statOut["mean"]=np.mean(allRFTranspose, -1) + if "std" in statDictionary: statOut["std"]=np.std(allRFTranspose, -1) if "quantile" in statDictionary: - statOut["quantiles"]={} + statOut["quantile"]={} for quant in statDictionary["quantile"]: - statOut["quantiles"][quant]=np.quantile(allRF,quant,axis=0) + statOut["quantile"][quant]=np.quantile(allRF,quant,axis=1) if "trim_mean" in statDictionary: statOut["trim_mean"]={} for quant in statDictionary["trim_mean"]: - statOut["trim_mean"][quant]=stats.trim_mean(allRF,quant,axis=0) + statOut["trim_mean"][quant]=stats.trim_mean(allRF,quant,axis=1) return statOut def predictRFStatNew(rf, X, statDictionary,n_jobs): """ @@ -371,4 +385,4 @@ def getImportance(self): impTree = np.zeros((len(self.trees[0]), len(self.trees[0][0]))) for row,tree in enumerate(self.trees[0]): impTree[row]=tree.feature_importances_ - return impTree.mean(axis=0) \ No newline at end of file + return impTree.mean(axis=0)