From 931a7540bfcf19d9eb655ab2ee5ea9292447caf9 Mon Sep 17 00:00:00 2001 From: Qiu Xiaoqi Date: Wed, 20 Sep 2023 20:13:59 +0800 Subject: [PATCH 1/3] fix: fix the 'AttributeError: 'str' object has no attribute 'columns'' --- .../DiCE_with_advanced_options.ipynb | 178 ++++++++++++++++-- 1 file changed, 158 insertions(+), 20 deletions(-) diff --git a/docs/source/notebooks/DiCE_with_advanced_options.ipynb b/docs/source/notebooks/DiCE_with_advanced_options.ipynb index a2db0632..2e6ccd97 100644 --- a/docs/source/notebooks/DiCE_with_advanced_options.ipynb +++ b/docs/source/notebooks/DiCE_with_advanced_options.ipynb @@ -16,9 +16,18 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "D:\\xiaoqi\\anaconda3\\envs\\pytorch-py38\\lib\\site-packages\\dice_ml\\utils\\exception.py:12: UserWarning: UserConfigValidationException will be deprecated from dice_ml.utils. Please import UserConfigValidationException from raiutils.exceptions.\n", + " warnings.warn(\"UserConfigValidationException will be deprecated from dice_ml.utils. \"\n" + ] + } + ], "source": [ "from numpy.random import seed\n", "\n", @@ -35,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -59,9 +68,127 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclasseducationmarital_statusoccupationracegenderhours_per_weekincome
028PrivateBachelorsSingleWhite-CollarWhiteFemale600
130Self-EmployedAssocMarriedProfessionalWhiteMale651
232PrivateSome-collegeMarriedWhite-CollarWhiteMale500
320PrivateSome-collegeSingleServiceWhiteFemale350
441Self-EmployedSome-collegeMarriedWhite-CollarWhiteMale500
\n", + "
" + ], + "text/plain": [ + " age workclass education marital_status occupation race \\\n", + "0 28 Private Bachelors Single White-Collar White \n", + "1 30 Self-Employed Assoc Married Professional White \n", + "2 32 Private Some-college Married White-Collar White \n", + "3 20 Private Some-college Single Service White \n", + "4 41 Self-Employed Some-college Married White-Collar White \n", + "\n", + " gender hours_per_week income \n", + "0 Female 60 0 \n", + "1 Male 65 1 \n", + "2 Male 50 0 \n", + "3 Female 35 0 \n", + "4 Male 50 0 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "dataset = helpers.load_adult_income_dataset()\n", "dataset.head()" @@ -69,7 +196,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -92,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -104,15 +231,24 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid character in identifier (2337364001.py, line 5)", + "output_type": "error", + "traceback": [ + "\u001b[1;36m Cell \u001b[1;32mIn[6], line 5\u001b[1;36m\u001b[0m\n\u001b[1;33m m = dice_ml.Model(model_path=ML_modelpath, backend=backend,func=\"ohe-min-max\" )\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid character in identifier\n" + ] + } + ], "source": [ "backend = 'TF'+tf.__version__[0] # TF1\n", "# provide the trained ML model to DiCE's model object\n", "ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)\n", "# Step 2: dice_ml.Model\n", - "m = dice_ml.Model(model_path=ML_modelpath, backend=backend)" + "m = dice_ml.Model(model_path=ML_modelpath, backend=backend,func=\"ohe-min-max\" )" ] }, { @@ -139,14 +275,16 @@ "outputs": [], "source": [ "# query instance in the form of a dictionary; keys: feature name, values: feature value\n", - "query_instance = {'age': 22,\n", + "query_instance_dict = {'age': 22,\n", " 'workclass': 'Private',\n", " 'education': 'HS-grad',\n", " 'marital_status': 'Single',\n", " 'occupation': 'Service',\n", " 'race': 'White',\n", " 'gender': 'Female',\n", - " 'hours_per_week': 45}" + " 'hours_per_week': 45}\n", + "import pandas as pd\n", + "query_instance = pd.DataFrame([query_instance_dict])" ] }, { @@ -314,11 +452,8 @@ } ], "metadata": { - "nbsphinx": { - "execute": "never" - }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -332,7 +467,10 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.12" + "version": "3.8.18" + }, + "nbsphinx": { + "execute": "never" }, "toc": { "base_numbering": 1, From bb849747ee4e4ef6837e4a4ab9c6b1a3a75beda2 Mon Sep 17 00:00:00 2001 From: Qiu Xiaoqi Date: Wed, 20 Sep 2023 20:40:29 +0800 Subject: [PATCH 2/3] Stashing local changes before merge --- .../DiCE_with_advanced_options.ipynb | 213 ++++++++++++++++-- 1 file changed, 195 insertions(+), 18 deletions(-) diff --git a/docs/source/notebooks/DiCE_with_advanced_options.ipynb b/docs/source/notebooks/DiCE_with_advanced_options.ipynb index 2e6ccd97..7000e1cc 100644 --- a/docs/source/notebooks/DiCE_with_advanced_options.ipynb +++ b/docs/source/notebooks/DiCE_with_advanced_options.ipynb @@ -231,24 +231,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, - "outputs": [ - { - "ename": "SyntaxError", - "evalue": "invalid character in identifier (2337364001.py, line 5)", - "output_type": "error", - "traceback": [ - "\u001b[1;36m Cell \u001b[1;32mIn[6], line 5\u001b[1;36m\u001b[0m\n\u001b[1;33m m = dice_ml.Model(model_path=ML_modelpath, backend=backend,func=\"ohe-min-max\" )\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid character in identifier\n" - ] - } - ], + "outputs": [], "source": [ "backend = 'TF'+tf.__version__[0] # TF1\n", "# provide the trained ML model to DiCE's model object\n", "ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)\n", "# Step 2: dice_ml.Model\n", - "m = dice_ml.Model(model_path=ML_modelpath, backend=backend,func=\"ohe-min-max\" )" + "m = dice_ml.Model(model_path=ML_modelpath, backend=backend,func=\"ohe-min-max\" )" ] }, { @@ -260,7 +251,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -270,7 +261,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -296,9 +287,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00, 2.61s/it]\n" + ] + } + ], "source": [ "# generate counterfactuals\n", "dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=4, desired_class=\"opposite\")" @@ -306,9 +305,187 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Query instance (original outcome : 0)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclasseducationmarital_statusoccupationracegenderhours_per_weekincome
022PrivateHS-gradSingleServiceWhiteFemale450
\n", + "
" + ], + "text/plain": [ + " age workclass education marital_status occupation race gender \\\n", + "0 22 Private HS-grad Single Service White Female \n", + "\n", + " hours_per_week income \n", + "0 45 0 " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Diverse Counterfactual set (new outcome: 1.0)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclasseducationmarital_statusoccupationracegenderhours_per_weekincome
0--AssocMarried----1
177.0--Married----1
264.0--Married----1
350.0--Married----1
\n", + "
" + ], + "text/plain": [ + " age workclass education marital_status occupation race gender \\\n", + "0 - - Assoc Married - - - \n", + "1 77.0 - - Married - - - \n", + "2 64.0 - - Married - - - \n", + "3 50.0 - - Married - - - \n", + "\n", + " hours_per_week income \n", + "0 - 1 \n", + "1 - 1 \n", + "2 - 1 \n", + "3 - 1 " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "# visualize the resutls\n", "dice_exp.visualize_as_dataframe(show_only_changes=True)" From d894f839c9e7d43c6351eaddcbc5944e0c63c8c7 Mon Sep 17 00:00:00 2001 From: Qiu Xiaoqi Date: Thu, 21 Sep 2023 11:05:09 +0800 Subject: [PATCH 3/3] clear output and fix the ERR of unexpected feature_weight --- .../DiCE_with_advanced_options.ipynb | 361 ++---------------- 1 file changed, 30 insertions(+), 331 deletions(-) diff --git a/docs/source/notebooks/DiCE_with_advanced_options.ipynb b/docs/source/notebooks/DiCE_with_advanced_options.ipynb index 7000e1cc..f27d3b68 100644 --- a/docs/source/notebooks/DiCE_with_advanced_options.ipynb +++ b/docs/source/notebooks/DiCE_with_advanced_options.ipynb @@ -16,18 +16,9 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "D:\\xiaoqi\\anaconda3\\envs\\pytorch-py38\\lib\\site-packages\\dice_ml\\utils\\exception.py:12: UserWarning: UserConfigValidationException will be deprecated from dice_ml.utils. Please import UserConfigValidationException from raiutils.exceptions.\n", - " warnings.warn(\"UserConfigValidationException will be deprecated from dice_ml.utils. \"\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from numpy.random import seed\n", "\n", @@ -44,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -68,127 +59,11 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclasseducationmarital_statusoccupationracegenderhours_per_weekincome
028PrivateBachelorsSingleWhite-CollarWhiteFemale600
130Self-EmployedAssocMarriedProfessionalWhiteMale651
232PrivateSome-collegeMarriedWhite-CollarWhiteMale500
320PrivateSome-collegeSingleServiceWhiteFemale350
441Self-EmployedSome-collegeMarriedWhite-CollarWhiteMale500
\n", - "
" - ], - "text/plain": [ - " age workclass education marital_status occupation race \\\n", - "0 28 Private Bachelors Single White-Collar White \n", - "1 30 Self-Employed Assoc Married Professional White \n", - "2 32 Private Some-college Married White-Collar White \n", - "3 20 Private Some-college Single Service White \n", - "4 41 Self-Employed Some-college Married White-Collar White \n", - "\n", - " gender hours_per_week income \n", - "0 Female 60 0 \n", - "1 Male 65 1 \n", - "2 Male 50 0 \n", - "3 Female 35 0 \n", - "4 Male 50 0 " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], "source": [ "dataset = helpers.load_adult_income_dataset()\n", "dataset.head()" @@ -196,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -219,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -231,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -251,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -261,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -287,17 +162,9 @@ }, { "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00, 2.61s/it]\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# generate counterfactuals\n", "dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=4, desired_class=\"opposite\")" @@ -305,187 +172,9 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Query instance (original outcome : 0)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclasseducationmarital_statusoccupationracegenderhours_per_weekincome
022PrivateHS-gradSingleServiceWhiteFemale450
\n", - "
" - ], - "text/plain": [ - " age workclass education marital_status occupation race gender \\\n", - "0 22 Private HS-grad Single Service White Female \n", - "\n", - " hours_per_week income \n", - "0 45 0 " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Diverse Counterfactual set (new outcome: 1.0)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ageworkclasseducationmarital_statusoccupationracegenderhours_per_weekincome
0--AssocMarried----1
177.0--Married----1
264.0--Married----1
350.0--Married----1
\n", - "
" - ], - "text/plain": [ - " age workclass education marital_status occupation race gender \\\n", - "0 - - Assoc Married - - - \n", - "1 77.0 - - Married - - - \n", - "2 64.0 - - Married - - - \n", - "3 50.0 - - Married - - - \n", - "\n", - " hours_per_week income \n", - "0 - 1 \n", - "1 - 1 \n", - "2 - 1 \n", - "3 - 1 " - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# visualize the resutls\n", "dice_exp.visualize_as_dataframe(show_only_changes=True)" @@ -547,6 +236,16 @@ "feature_weights = {'age': 1, 'hours_per_week': 1}" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# initiate gradient-based DiCE for user-defined feature_weight\n", + "exp = dice_ml.Dice(d, m, method=\"gradient\")" + ] + }, { "cell_type": "code", "execution_count": null,