diff --git a/colabs/scikit/wandb_decision_tree.ipynb b/colabs/scikit/wandb_decision_tree.ipynb new file mode 100644 index 00000000..d2d34873 --- /dev/null +++ b/colabs/scikit/wandb_decision_tree.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "e6qvOYJV8BQw" + }, + "source": [ + "## Author: [@SauravMaheshkar](https://twitter.com/MaheshkarSaurav)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BL5XadIr0MbM" + }, + "source": [ + "# Packages 📦 and Basic Setup\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fZH0PUfw0OBV" + }, + "source": [ + "## Install Packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_uk-bdosSzwk" + }, + "outputs": [], + "source": [ + "%%capture\n", + "## Install Sklearn\n", + "!pip install -U scikit-learn\n", + "## Install the latest version of wandb client 🔥🔥\n", + "!pip install -q --upgrade wandb" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UgQLiLtz0STo" + }, + "source": [ + "## Project Configuration using **`wandb.config`**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OvABqBlB0VhB" + }, + "outputs": [], + "source": [ + "import wandb\n", + "\n", + "## Importing Libraries\n", + "from sklearn.datasets import load_iris\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wandb.login()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the run\n", + "run = wandb.init(project='simple-scikit')\n", + "\n", + "# Feel free to change these and experiment !!\n", + "config = wandb.config\n", + "config.max_depth = 5\n", + "config.min_samples_split = 2\n", + "config.clf_criterion = \"gini\"\n", + "config.reg_criterion = \"mse\"\n", + "config.splitter = \"best\"\n", + "config.dataset = \"iris\"\n", + "config.test_size = 0.2\n", + "config.random_state = 42\n", + "config.labels =['setosa', 'versicolor', 'virginica']\n", + "\n", + "# Update the config\n", + "wandb.config.update(config)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OGV1c4VYJMGa" + }, + "source": [ + "# 💿 Dataset\n", + "---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lWnhwnPI2NWD" + }, + "outputs": [], + "source": [ + "## Loading the Dataset\n", + "iris = load_iris(return_X_y = True, as_frame= True)\n", + "dataset = iris[0]\n", + "target = iris[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kycxKC7ER7lW" + }, + "source": [ + "# ✍️ Model Architecture\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4wWuB4NZl3h_" + }, + "source": [ + "## Classification" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vElZpyEQtXzX" + }, + "outputs": [], + "source": [ + "X, y = load_iris(return_X_y=True)\n", + "x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)\n", + "\n", + "clf = DecisionTreeClassifier(\n", + " max_depth=config.max_depth,\n", + " min_samples_split=config.min_samples_split,\n", + " criterion=config.clf_criterion,\n", + " splitter=config.splitter\n", + ")\n", + "clf = clf.fit(x_train,y_train)\n", + "\n", + "y_pred = clf.predict(x_test)\n", + "\n", + "# Visualize Confustion Matrix\n", + "wandb.sklearn.plot_confusion_matrix(y_test, y_pred, config.labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BX31c0s8MVu5" + }, + "source": [ + "## Regression" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1kBLTtH4MbsY" + }, + "outputs": [], + "source": [ + "X, y = load_iris(return_X_y=True)\n", + "\n", + "x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)\n", + "\n", + "reg = DecisionTreeRegressor(\n", + " max_depth=config.max_depth,\n", + " min_samples_split=config.min_samples_split,\n", + " criterion=config.reg_criterion,\n", + " splitter=config.splitter\n", + ")\n", + "\n", + "reg = reg.fit(x_train,y_train)\n", + "\n", + "# All regression plots\n", + "wandb.sklearn.plot_regressor(reg, x_train, x_test, y_train, y_test, model_name='DecisionTreeRegressor')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Finish the W&B Process\n", + "wandb.finish()" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "bE1OZxeIwPqg", + "_N6qEF3cvIpN", + "_vdIezY_zb_9", + "OGV1c4VYJMGa", + "H8IIKXd0PFL8", + "kycxKC7ER7lW", + "4wWuB4NZl3h_", + "BX31c0s8MVu5", + "c6PoNscLNuog", + "3Cf152r2NK1M", + "a_HO0jHMyyyP", + "ahE27LgSzUVx", + "FaENg6O44dso", + "t3K1L5UWSZCL" + ], + "name": "Decision Trees", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}