forked from handspeaker/RandomForests
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Tree.h
68 lines (66 loc) · 2.53 KB
/
Tree.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
/************************************************
*Random Forest Program
*Function: implementation of CART--Classification and Regression Tree without pruning.
for classification, utilize Gini Index as the criterion
Gini(t)=1-sum(p(j|t),j=[1,J]),assuming there are J classes
for regression, utilize sum of square residue as the criterion
I(t)=sum(Ni(left)-E(Ni(left))^2+Ni(right)-E(Ni(right))^2)
*Author: [email protected]
*CreateTime: 2014.7.10
*Version: V0.1
*************************************************/
#ifndef CARTREE_H
#define CARTREE_H
#include<stdio.h>
#include<math.h>
#include"Sample.h"
#include"Node.h"
class Tree
{
public:
/*************************************************************
*MaxDepth:the max Depth of one single tree
*trainFeatureNumPerNode:the feature number used in every node while training
*minLeafSample:terminate criterion,the min samples in a leaf
*minInfoGain:terminate criterion,the min information gain in
*a node if it can be splitted
*isRegression:if the problem is regression(true) or classification(false)
**************************************************************/
Tree(int MaxDepth,int trainFeatureNumPerNode,int minLeafSample,float minInfoGain,bool isRegression);
virtual ~Tree();
virtual void train(Sample*Sample)=0;
Result predict(float*data);
inline Node**getTreeArray(){return _cartreeArray;};
virtual void createNode(int id,int featureIndex,float threshold)=0;
protected:
bool _isRegression; //the type of this tree
int _MaxDepth;
int _nodeNum; //the number of node,=2^_MaxDepth-1
int _minLeafSample;
int _trainFeatureNumPerNode;
float _minInfoGain;
// Sample*_samples;//all samples used while training the tree
Node** _cartreeArray; //utilize a node array to store the tree,
//every node is a split or leaf node
};
//Classification Tree
class ClasTree:public Tree
{
public:
ClasTree(int MaxDepth,int trainFeatureNumPerNode,int minLeafSample,float minInfoGain,bool isRegression);
~ClasTree();
void train(Sample*Sample);
void createNode(int id,int featureIndex,float threshold);
void createLeaf(int id,float clas,float prob);
};
//Regression Tree
class RegrTree:public Tree
{
public:
RegrTree(int MaxDepth,int trainFeatureNumPerNode,int minLeafSample,float minInfoGain,bool isRegression);
~RegrTree();
void train(Sample*Sample);
void createNode(int id,int featureIndex,float threshold);
void createLeaf(int id,float value);
};
#endif//CARTREE_H