Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctness of visualization #10829

Open
Marchlak opened this issue Sep 19, 2024 · 3 comments
Open

Correctness of visualization #10829

Marchlak opened this issue Sep 19, 2024 · 3 comments

Comments

@Marchlak
Copy link

Hello,
I'm developing a library for decision tree visualization https://github.com/mljar/supertree and would appreciate feedback on whether my visualization approach for XGBoost is correct. I've compared my library with dtreeviz, and in dtreeviz, the data in the histogram appears to be split at each node according to the feature from the root node (based on my observations). In contrast, my implementation splits the data according to the feature extracted from the respective node in booster.get_dump().
I would greatly appreciate it if you could provide guidance on the correct visualization approach for your library.
Code from my compare notebook:

import xgboost as xgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score



iris = load_iris()
X = iris.data  
y = iris.target  
features = iris.feature_names  
target = 'species'  
class_names = iris.target_names  

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


xgb_classifier = xgb.XGBClassifier(
    objective='multi:softmax',  
    num_class=3,                
    max_depth=5,                
    learning_rate=0.3,          
    n_estimators=50,            
    random_state=42,             
)

xgb_classifier.fit(X_train, y_train)

from xgboost import plot_tree
import matplotlib.pyplot as plt

plt.figure(figsize=(30, 20))  
plot_tree(xgb_classifier, num_trees=20) 
plt.show()

from dtreeviz import model

viz_model = model(
    xgb_classifier.get_booster(),
    X_train=X_train,
    y_train=y_train,
    feature_names=features,
    target_name=target,
    class_names=list(class_names),
    tree_index=5  
)

viz_model.view()

from supertree import SuperTree

st = SuperTree(
    xgb_classifier, 
    X_train, 
    y_train, 
    iris.feature_names, 
    iris.target_names
)
# Visualize the tree
st.show_tree(which_tree=2)
@trivialfis
Copy link
Member

Hi, I still need to look into the source code of either project.

the data in the histogram appears to be split at each node according to the feature from the root node

Please help elaborate on what this means. Is there a histogram when plotting a tree? And this histogram has some data in it, and this data can be split by a (node split) feature in the root node?

If you want to compare against xgboost's own plot tree function, you can dump the tree in the dot format, and plot the tree using graphviz yourself.

@Marchlak
Copy link
Author

Marchlak commented Sep 20, 2024

Please help elaborate on what this means. Is there a histogram when plotting a tree? And this histogram has some data in it, and this data can be split by a (node split) feature in the root node?
That's basically what I mean. Maybe I'll provide screenshots of the visualization comparison.
f0 - sepal length (cm) f1 -sepal width (cm) f2 - petal length (cm) f3-petal width (cm)
graphiz
image
dtreeviz
image
My library supertree
image

@trivialfis
Copy link
Member

Ah, that's much clearer! Thank you for sharing. I will look into it after sorting out some of the on going work here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants