-
Notifications
You must be signed in to change notification settings - Fork 0
sklearn PR
The sklearn_splitter branch of the repository is an under-construction implementation of SPORF and MORF building off sklearn's existing code base. As it stands, the current implementation of SPORF and MORF in the main branch are slow, likely owing to the fact that most of the code is in python, not cython. Furthermore, we hope to eventually PR SPORF into scikit-learn to allow others to build oblique forests. Therefore, the next step with this repository is to leverage the existing code in sklearn to build SPORF and MORF in a highly optimized manner. The changes required to modify the sklearn codebase to implement SPORF are listed below.
- Cython's documentation is a great resource for beginners to learn cython.
- Sklearn's implementation of decision trees is done almost entirely in cython, and more specifically, a no-gil context. To learn more about python's gil, visit this link. This implementation leverages tons of parallelism and C-level speed to build forests quickly, but is difficult to work with. No python objects may be used in no-gil blocks, which means no numpy or python libraries! Arrays are all C-style, and need to be malloc'd and freed. --- If you hit any "no-gil" errors in compile time, generally it means you are trying to use a python object in a no-gil block where it is not allowed.
Sklearn's splitter class must be modified in order to perform oblique splits. These modifications live in _oblique_splitter.pyx.
- Modify splitRecord struct in _oblique_splitter.pxd to hold a projection vector (DTYPE_t*). The projection vector is the best linear combination of features to split on, and needs to be returned to the tree.
- Add a sample_proj_mat function to the splitter (and modify the API in the pxd file). This function will sample the projection matrix for any tree that applies a transformation to the data at each node, such as SPORF, MORF, etc. This function needs to be written in a no-gil context, so no external python libraries can be used.
- Modify the init and node_reset functions to initialize and reset the projection matrix for each split. This will need to be done using C type arrays (DTYPE_t**) that are malloc'd and freed, as the rest of the splitter functions are done from a no-gil context.
- Modify the node_split function to perform an oblique split. This is done by sampling the projection matrix by calling sample_proj_mat, and then searching for the projection vector to split on. This can reuse much of the code in sklearn's BestSplitter node_split function.
Currently, an implementation of SPORF's splitter exists that builds successfully. Unit tests need to be added to make sure it works correctly.
Sklearn's builder class must be modified to build an oblique tree. Fortunately, most of the existing code can be leveraged to do this.
- Inherit from the TreeBuilder class and follow DepthFirstTreeBuilder's code structure. Only differences are using an ObliqueSplitter, an ObliqueTree, and an ObliqueSplitRecord. The add_node function needs an extra parameter for the projection vector.
An implementation that builds exists, but it needs to be tested.
Sklearn's base tree must be modified to make it oblique. Most of the code can again be reused, but changes need to happen to the add_node and predict function. Memory management is also a major issue to deal with.
- The oblique tree will need to store a projection vector for each node. Therefore, memory must be dynamically allocated for a 2D DTYPE_t array (indexed by node id) to store the projection vectors. This means that the resize function needs to be modified to allocate memory, and the dealloc function needs to be modified to free memory.
- The add_node function needs to be modified to copy over the contents of the projection vector from the ObliqueSplitRecord to the new array.
- The predict function needs to be modified to apply the appropriate projection vector at each node to the data.
The above changes have been implemented, but it needs to be debugged. Memory issues such as double-free do exist, so we need unit tests to ensure we are doing things correctly.