Skip to content

Commit

Permalink
feat: add decision trees example
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar authored Dec 25, 2023
1 parent 86c5834 commit 82bce6e
Showing 1 changed file with 228 additions and 0 deletions.
228 changes: 228 additions & 0 deletions colabs/scikit/wandb_decision_tree.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Decision Trees",
"provenance": [],
"collapsed_sections": [
"bE1OZxeIwPqg",
"_N6qEF3cvIpN",
"_vdIezY_zb_9",
"OGV1c4VYJMGa",
"H8IIKXd0PFL8",
"kycxKC7ER7lW",
"4wWuB4NZl3h_",
"BX31c0s8MVu5",
"c6PoNscLNuog",
"3Cf152r2NK1M",
"a_HO0jHMyyyP",
"ahE27LgSzUVx",
"FaENg6O44dso",
"t3K1L5UWSZCL"
],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "e6qvOYJV8BQw"
},
"source": [
"## Author: [@SauravMaheshkar](https://twitter.com/MaheshkarSaurav)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BL5XadIr0MbM"
},
"source": [
"# Packages 📦 and Basic Setup\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fZH0PUfw0OBV"
},
"source": [
"## Install Packages"
]
},
{
"cell_type": "code",
"metadata": {
"id": "_uk-bdosSzwk"
},
"source": [
"%%capture\n",
"## Install Sklearn\n",
"!pip install -U scikit-learn\n",
"## Install the latest version of wandb client 🔥🔥\n",
"!pip install -q --upgrade wandb"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "UgQLiLtz0STo"
},
"source": [
"## Project Configuration using **`wandb.config`**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "OvABqBlB0VhB"
},
"source": [
"import os\n",
"import wandb\n",
"\n",
"## Importing Libraries\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor\n",
"\n",
"# Paste your api key here\n",
"os.environ[\"WANDB_API_KEY\"] = '...'\n",
"\n",
"# Feel free to change these and experiment !!\n",
"config = wandb.config\n",
"config.max_depth = 5\n",
"config.min_samples_split = 2\n",
"config.clf_criterion = \"gini\"\n",
"config.reg_criterion = \"mse\"\n",
"config.splitter = \"best\"\n",
"config.dataset = \"iris\"\n",
"config.test_size = 0.2\n",
"config.random_state = 42\n",
"config.labels =['setosa', 'versicolor', 'virginica']"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "OGV1c4VYJMGa"
},
"source": [
"# 💿 Dataset\n",
"---"
]
},
{
"cell_type": "code",
"metadata": {
"id": "lWnhwnPI2NWD"
},
"source": [
"## Loading the Dataset\n",
"iris = load_iris(return_X_y = True, as_frame= True)\n",
"dataset = iris[0]\n",
"target = iris[1]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "kycxKC7ER7lW"
},
"source": [
"# ✍️ Model Architecture\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4wWuB4NZl3h_"
},
"source": [
"## Classification"
]
},
{
"cell_type": "code",
"metadata": {
"id": "vElZpyEQtXzX"
},
"source": [
"run = wandb.init(project='...', entity='...', config = config)\n",
"\n",
"X, y = load_iris(return_X_y=True)\n",
"x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)\n",
"\n",
"clf = DecisionTreeClassifier(\n",
" max_depth=config.max_depth,\n",
" min_samples_split=config.min_samples_split,\n",
" criterion=config.clf_criterion,\n",
" splitter=config.splitter\n",
")\n",
"clf = clf.fit(x_train,y_train)\n",
"\n",
"y_pred = clf.predict(x_test)\n",
"\n",
"# Visualize Confustion Matrix\n",
"wandb.sklearn.plot_confusion_matrix(y_test, y_pred, config.labels)\n",
"\n",
"run.finish()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "BX31c0s8MVu5"
},
"source": [
"## Regression"
]
},
{
"cell_type": "code",
"metadata": {
"id": "1kBLTtH4MbsY"
},
"source": [
"run = wandb.init(project='...', entity='...', config = config)\n",
"\n",
"X, y = load_iris(return_X_y=True)\n",
"\n",
"x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)\n",
"\n",
"reg = DecisionTreeRegressor(\n",
" max_depth=config.max_depth,\n",
" min_samples_split=config.min_samples_split,\n",
" criterion=config.reg_criterion,\n",
" splitter=config.splitter\n",
")\n",
"\n",
"reg = reg.fit(x_train,y_train)\n",
"\n",
"# All regression plots\n",
"wandb.sklearn.plot_regressor(reg, x_train, x_test, y_train, y_test, model_name='DecisionTreeRegressor')\n",
"\n",
"run.finish()"
],
"execution_count": null,
"outputs": []
}
]
}

0 comments on commit 82bce6e

Please sign in to comment.