diff --git a/docs/source/notebooks/DiCE_with_advanced_options.ipynb b/docs/source/notebooks/DiCE_with_advanced_options.ipynb index a2db0632..f27d3b68 100644 --- a/docs/source/notebooks/DiCE_with_advanced_options.ipynb +++ b/docs/source/notebooks/DiCE_with_advanced_options.ipynb @@ -60,7 +60,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ "dataset = helpers.load_adult_income_dataset()\n", @@ -112,7 +114,7 @@ "# provide the trained ML model to DiCE's model object\n", "ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)\n", "# Step 2: dice_ml.Model\n", - "m = dice_ml.Model(model_path=ML_modelpath, backend=backend)" + "m = dice_ml.Model(model_path=ML_modelpath, backend=backend,func=\"ohe-min-max\" )" ] }, { @@ -139,14 +141,16 @@ "outputs": [], "source": [ "# query instance in the form of a dictionary; keys: feature name, values: feature value\n", - "query_instance = {'age': 22,\n", + "query_instance_dict = {'age': 22,\n", " 'workclass': 'Private',\n", " 'education': 'HS-grad',\n", " 'marital_status': 'Single',\n", " 'occupation': 'Service',\n", " 'race': 'White',\n", " 'gender': 'Female',\n", - " 'hours_per_week': 45}" + " 'hours_per_week': 45}\n", + "import pandas as pd\n", + "query_instance = pd.DataFrame([query_instance_dict])" ] }, { @@ -232,6 +236,16 @@ "feature_weights = {'age': 1, 'hours_per_week': 1}" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# initiate gradient-based DiCE for user-defined feature_weight\n", + "exp = dice_ml.Dice(d, m, method=\"gradient\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -314,11 +328,8 @@ } ], "metadata": { - "nbsphinx": { - "execute": "never" - }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -332,7 +343,10 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.12" + "version": "3.8.18" + }, + "nbsphinx": { + "execute": "never" }, "toc": { "base_numbering": 1,