diff --git a/classifier_performance_comparison.ipynb b/classifier_performance_comparison.ipynb index a571085..c2e0de3 100644 --- a/classifier_performance_comparison.ipynb +++ b/classifier_performance_comparison.ipynb @@ -2,18 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "worth-attendance", - "metadata": {}, - "source": [ - "# TODO:\n", - "- produce ROC for main publication (svc, lr, rf on test set).\n", - "- produce roc for SI (validation performance for RF).\n", - "- do word freq comparison with chemprot (and test v train)" - ] - }, - { - "cell_type": "markdown", - "id": "resident-moment", + "id": "close-vatican", "metadata": {}, "source": [ "## Evaulating classifier performance. \n", @@ -24,7 +13,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "prescription-jacob", + "id": "impressed-tobacco", "metadata": {}, "outputs": [], "source": [ @@ -41,7 +30,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "young-norfolk", + "id": "victorian-emission", "metadata": {}, "outputs": [], "source": [ @@ -80,7 +69,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "focused-tracker", + "id": "mighty-paintball", "metadata": {}, "outputs": [ { @@ -110,7 +99,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "personalized-objective", + "id": "brown-sound", "metadata": {}, "outputs": [], "source": [ @@ -128,7 +117,7 @@ }, { "cell_type": "markdown", - "id": "vulnerable-remedy", + "id": "broadband-dependence", "metadata": {}, "source": [ "We extract the same folds that were used in the GridsearchCV optimisation the classifiers (see fitting_classifiers.ipynb), to compute the validation performance for each fold:" @@ -136,19 +125,10 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "valued-conducting", + "execution_count": null, + "id": "weird-antenna", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/rustybilges/anaconda3/envs/CAP/lib/python3.8/site-packages/sklearn/model_selection/_split.py:292: FutureWarning: Setting a random_state has no effect since shuffle is False. This will raise an error in 0.24. You should leave random_state to its default (None), or set shuffle=True.\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "X_dict = {'test': X_test}\n", "y_dict = {'test': y_test}\n", @@ -171,7 +151,7 @@ { "cell_type": "code", "execution_count": 6, - "id": "entertaining-means", + "id": "resistant-correspondence", "metadata": {}, "outputs": [], "source": [ @@ -191,7 +171,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "afraid-gross", + "id": "suitable-management", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +182,7 @@ { "cell_type": "code", "execution_count": 8, - "id": "statutory-lebanon", + "id": "supported-huntington", "metadata": {}, "outputs": [], "source": [ @@ -213,7 +193,7 @@ { "cell_type": "code", "execution_count": 9, - "id": "aggregate-boxing", + "id": "indoor-gardening", "metadata": {}, "outputs": [], "source": [ @@ -224,7 +204,7 @@ { "cell_type": "code", "execution_count": 10, - "id": "micro-bristol", + "id": "authentic-tennessee", "metadata": {}, "outputs": [ { @@ -323,7 +303,7 @@ }, { "cell_type": "markdown", - "id": "absent-speaker", + "id": "similar-accommodation", "metadata": {}, "source": [ "#### We now produce the ROC curve for the random forest for each training fold and the test set:" @@ -332,7 +312,7 @@ { "cell_type": "code", "execution_count": 11, - "id": "contemporary-death", + "id": "suspended-quest", "metadata": {}, "outputs": [], "source": [ @@ -347,7 +327,7 @@ { "cell_type": "code", "execution_count": 12, - "id": "periodic-powder", + "id": "celtic-simon", "metadata": {}, "outputs": [ { @@ -369,7 +349,7 @@ }, { "cell_type": "markdown", - "id": "linear-particle", + "id": "directed-uniform", "metadata": {}, "source": [ "#### We now produce and ROC curve for the three classifiers to compare performance on the training set:" @@ -378,7 +358,7 @@ { "cell_type": "code", "execution_count": 13, - "id": "compatible-quantity", + "id": "fancy-collins", "metadata": {}, "outputs": [], "source": [ @@ -400,7 +380,7 @@ { "cell_type": "code", "execution_count": 14, - "id": "elder-bundle", + "id": "complimentary-retrieval", "metadata": {}, "outputs": [ { @@ -422,16 +402,27 @@ }, { "cell_type": "markdown", - "id": "affiliated-pharmaceutical", + "id": "dominant-guess", "metadata": {}, "source": [ "#### We now compare the word frequencies in our test and training set. And compare these with a benchmark datset from this domain (CHEMPROT https://pubmed.ncbi.nlm.nih.gov/20935044/)." ] }, + { + "cell_type": "code", + "execution_count": 39, + "id": "verified-beatles", + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib_venn import venn2, venn2_circles\n", + "import matplotlib.pyplot as plt" + ] + }, { "cell_type": "code", "execution_count": 15, - "id": "adjustable-cream", + "id": "beautiful-motivation", "metadata": {}, "outputs": [ { @@ -458,7 +449,7 @@ { "cell_type": "code", "execution_count": 16, - "id": "conservative-practice", + "id": "perfect-conflict", "metadata": {}, "outputs": [], "source": [ @@ -475,7 +466,7 @@ { "cell_type": "code", "execution_count": 17, - "id": "satisfied-filling", + "id": "pacific-identification", "metadata": {}, "outputs": [ { @@ -555,7 +546,7 @@ { "cell_type": "code", "execution_count": 18, - "id": "complimentary-browse", + "id": "intelligent-adjustment", "metadata": {}, "outputs": [], "source": [ @@ -565,8 +556,8 @@ }, { "cell_type": "code", - "execution_count": 19, - "id": "signal-wright", + "execution_count": 30, + "id": "further-instrumentation", "metadata": {}, "outputs": [ { @@ -585,10 +576,38 @@ "print(len(x_train_counts.index))" ] }, + { + "cell_type": "code", + "execution_count": 42, + "id": "finnish-scanner", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAATIAAAEWCAYAAADl+xvlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXycV33v8c9vZjQa7ZIt74r32M5inAQvWSAbBEIIISUXKKVAeknDDpfbUrbeO+hVAiXtvaVtoNDe0FKSF1Cg5JUmEBJDErLZcRLHSxw7dhxblldZkiVLs2iWc/94pMRyZFuy53nOs/zer5decWRrnt9Imu+cc56ziDEGpZQKspjtApRS6kxpkCmlAk+DTCkVeBpkSqnA0yBTSgWeBplSKvA0yJRSgadBppQKPA0ypVTgaZAppQJPg0wpFXgaZEqpwNMgU0oFngaZUirwNMiUUoGnQaaUCjwNMqVU4GmQKaUCT4NMKRV4GmRKqcDTIFNKBZ4GmVIq8DTIlFKBp0GmlAo8DTKlVOAlbBegfEokDlQNfwCUx/zQo+qVD2iQRYlICmg47qMeqMYJrCSvhdf4WusiZSAHZIc/MsMfI/8/CPRjzGAFn4lSo4i+oYaQSA0wBWgd/mjCCS2bb1wFoA84AvQMf3RrwKlK0CALOpFqYBpOYI2EV53VmiYmBxwE9g9/HNbuqpooDbKgEYkBU4GzgDac4BKrNVVWgdHBdghjynZLUn6nQRYEIo04odUGzMQZy4qKArAHeAXowJiC5XqUD2mQ+ZVILbBw+KPVcjV+UQL24oTabozJWa5H+YQGmZ+IJIF5wNnADMLVZaw0g9P1fAnYiTFFy/UoizTIbBMRYDawaPi/cbsFBdIQTqC9iDG9totR3tMgs8W527gEOA9nLpeqjAPAizittJLtYpQ3NMi8JtIMLMXpPuqEZPfkgW3ARozJ2C5GuUuDzCsi04FlwBzbpURMCaeFtkEn34aXBpnbRFqBlThTJ5Q9ZZwW2vMYc9R2MaqyNMjc4nQhlwPzbZeiRikD24H1GNNvuxhVGRpklSZSD7wR5y6kTp/wrzKwGXgOY4ZsF6POjAZZpThzwC7CuQupUyiCIwusA7bpGs/g0iCrBJH5wKVAre1S1Gk7DDyBMQdtF6ImToPsTIg0AJfhTGRV4bADWKt3OINFg+x0ODtQLMUZC9O5YOEzBDyJMS/ZLkSNjwbZRIlMBd4MTLZdinLdLuAxjMnaLkSdnAbZeDlrIi/EaYXp3cjoyOGE2Su2C1EnpkE2Hs6WOlfj7AWmomkHzs2AvO1C1OtpkJ2KyGzgSiBluRJlXwb4Lcbst12IGk2D7EScAf2LgfNtl6J8pQyswZjNtgtRr9EgG4szreIadGdWdWI7gN/rho7+oEF2PJFpwNvRrqQ6tcPAQ7oI3T4NsmOJLASuQJcYqfHL44ybddouJMo0yEaILMdZK6nURBmccbNNtguJKg0ykTjOXckFlitRLhlKUsrUUxyspzzYQDnTgORTYATKgiDOn40Ax3wOA8khyqkM1GQwqQxSkyGWyhBLZUkkh4gdd6mNGLPGwlOMvGgHmUgKuBbnwFsVYAZMfwuF7qkU+1ogU0csV0ssnyJeTrgzVCAlyskhSnVHKTV3U27pJlbXz+7GPh7TQ4W9Fd0gE6kBrgdabJeiJq4Yp9wzlaFDMyn1tBIfaCJZjr+uhWTFniRH/nw23Zk4XTiHoXSZdFRfaN6IZpA5M/WvB5ptl6LGr6+Zob1zKRyeTuJoE1Um5o/gGsvOanJfnE0y59SYAzqA3cAek9YpG5UWvSATqcMJsSbbpahTy9ZQ7FhIvnMuVdl6krbrmYg9SfJ/MZvEQHxU13bktPTdwG6T1hOeKiFaQeZsQ3090Gi7FHVihSpKnXPJ75lPrL+F6uGh90A6mGDoz+cQO5IYc7snA+wBtuC01CL0Yqys6ASZM1v/XehhuL61v43s7oXQPZVq45Pxrko4mGDos3OJZ+InvekwgHPK01aT1k0dJyoaQeZ0J9+NhpjvGDB75pN76XziubpgdR0nYns1uS/Mobp06talwRlP22TSZp8HpYVC+IPMORTkBmCS7VLUawyYjgVkt59PIlcb3gA71tN1ZP6qbULnOuwHnjFp3W3jVMIdZM4OFteh+4j5RlkwHQvJbT+PRL6GKtv1eO3+ZjLfmzbhQ2r24QTaATdqCoOwB9nVwELbZSgnwHafTXb7eVQNpaIXYMf6wRQyv5x0Widu7QOeNmlzqNI1BV14g0xkJXCB7TIU9E4m/9ylSNCmT7ilDOb2meSeaKDmNB9iK7DWpHW32hHhDDKRc4E32S4j6opxypuXk+2cR22Qp1C4oQDlr55F4cVaqk/zIXI4YbatknUFVfiCTKQNeAd6QIhVB2aR3biSRNS7kSeTEUr/Yy7l/ckz+h4dAB4zadNbqbqCKFxB5kyzuAndFNGafDWl5y8m3zVTT10fj944hU/OIzZw8jlmp1IGngeeM+loLlYPT5A5dyjfBUyzXUpU7V5AZsuFVJeqdGPKiVhfS+Z/n1WR4D8A/M6kzUAFHitQwhRkl+Cc/q08VoxTfu4ycodmaSvsdH1vKpn7Wyry/csDj5q02VWBxwqMcASZyDycw0KUxwbrKay5CqN3JM9MAcqfmUtxb3XFvo8vAGtM2pQq9Hi+FvwgE2kE3gP6QvLa/jay6y8h6dbGhVGzt4r8p+aRHMcypvHqBh406fAfjhLshbnONtXXoCHmuW1LGXz2TaQ0xCpnVoHqWw9RyW19JgM3SrtMqeBj+lKwgwzeiPPDUh4pC+aZN5HZfj51Ojes8q49Qu2Fg+Qq+JA1wLukXeZU8DF9J7hdS5GpODta6IvJI0NJSmuuptDfotNb3HQ0RvHW+cgZTsk4ngGeMGmzpYKP6RvBbJG9dvKRhphHClWUHn8bRQ0x9zWUSXxpH0MVflgB3iTtsrLCj+sLwQwy5/xJ3W/fIyMhlmk47eU0aoKWZah5V29Fx8tGXCDtcqULj2tV8IJMZBKwzHYZUVGoovTENRQHGzXEvPbhLqobi7gxfWKRtEuo1iIHK8hEBLicoNUdUMUE5SeuoTDQpCFmQ8oQv6ULt3a4OFfa5RKXHttzQQuEc9DDdD0xHGJDA006JmbT5f3UzMlXfLxsxFJplxUuPbanghNkzpbVy22XEQXFuBNiR5s1xGyLg3z6gCvdyxEXSrsEft++4ASZs0mivrBcZsCseQt5DTH/WJKjZtVRsi5eYqW0y7kuPr7rghFkznmUuiDcA5tWkD0y+bR3LlUu+fgh4nGDm5M+L5V2meXi47sqGEEGK0CXwritcy6ZjoW6g4UftRZJ3tjjaqssBlwj7RLIaU3+DzKRVuBs22WEXX8TQxtXanfSz97XTXV9ydXxsiTwNmmXwK1d9n+QwcW2Cwi7QhWltVdCOUSne4dRrSH+0UOuTccY0YyzaiZQ/P2L6+y/r2dSumzd5QzlI3JIbtBd1U/NLPemY4yYG7Q7mf4OMrjQdgFht+UCBnum6uB+UMRBbj5M0YNLrQjS9j/+DTJnd4sZtssIs4Mzye5cooP7QbNigJrWguthJsBV0i6BuMnm3yDT9ZSuKsYpb1hFQvcUC544yAe6XR8rA2e8LBAz//0ZZCJNwFzbZYTZlovI6pmTwXVFP6lad+9gjlgq7TLdg+ucEX8GmdMa05aCS/qaGeqYr13KIKs2xN/T40mrTIArpV0SHlzrtPkvyERq0Xljrlp/KYaYvlEE3XVHSIq7s/1HNAK+3pDRf0EG56Gz+F3z8hIyui1PODSUSVx+tKL7+5/MedIukzy61oT5K8ic08IX2y4jrHI1FLct1RALkz/o8axlLfh4crq/ggzaQMdu3LJhFUN6fFu4LMiTcnG/suO1Sbuc5dG1JsRvQaatMZd0TSfXNUPfJMLoph5PJsiOuFjaxXfjq/4JMpEUEOqz92zaomskQuvSo1RXlSl7dLkWYIlH1xo3/wQZLMRf9YTGwZnkdKPE8Ko2xFcMeta9BFgu7eKrOYh+Cg7tVrpkq66RCL0r+z2ZhjGiBvDVjrL+CDKRycBk22WE0aEZ2hqLgmWDJF3eQfZ450u7+CM/8EuQwQLbBYTVNt0gPBJqDfFlg57M9B9RhzMc5At+CTId5HfBkRbyfZO1NRYVVx71bMB/xBs8vt4J2Q8ykQacOyGqwl5a6vkvtrJo+YDnm2NOknZp8/iaY7IfZNoac0W2huKhGdoai5KGMolzMp52L8EnrTI/BNls2wWE0a5F5HVhePRc3e/p5FhwZvtbP3nJbpCJVKF78rti32zdayyKVnnfvQQf7FZju0XW5oMaQqe/iaFsvR4mEkUtJarm5TydHAs+uHtpO0S0W+mCjoUUbNeg7HnTUc9//g3SLtM8vuYotoNMDxdxwf6ztDUWZUuyVl7XVltl9oLMWSTeaO36IdXTylC+RsfHomxe3srPf4HNmf42W2RWm6Jh1bFAu5VR11AmMano+d3LFDDL42u+SoMsRAyYA226A6yCczOeBxlYPPnMZpBNtXjtUOqeSr6YxNen3ShvnJf15Ki440WsRSYiQGCOYw+KQzOt/PIqH1qUs/LabpR2qbdwXWstskmgA9KV1jPF+l1o5ROz89buXFtpldn6xddupQv6m3XahXKkDPFZ3h1KcqxIBZn1tVlh09+kJySp0SyNk1lZcmgryJosXTe0uqdZuUulfOzcrKc7xo6olXbxfFsuDbKQODxNd7pQoy2wM+APFm7kef9EnTuWDZ5fN+T6Jum0CzVac8naUMMkry9oI7EbLF03tPLVlHK1ehdYjVZvL8gi0bXUbmWF9UyxcndK+VwCYo1FKwP+kWiRaZBV2ECjlUFdFQCT7QRZnbSLp1OBrMz+tXDNUMvWaZCpsU0tWFvt4WmrzEaQ6YEYFZat1TFHNbapRWtvchpkamJytTr1Qo2ttWAtyDxdc2kjyHSbmQrLp3RGvxrblKK1NzlPGywaZAFnwAzp1j3qBCYXrAVZjZcX065lwOVrKOn5lepEmkvWxk9DHGQiMdAdGiopU6d7kKkTa7I3KTbUXUsNsQrL1lG2XYPyryoTjRaZ12Mr1sbH7oFpt8KtI/9/BFr/AO79Gjx1Hdx6BCY3Q/cD8M+LITPy7x6BSdfA126A//oFPHTsYy6CT/VA62Fo9+6ZjFaoisYcsn1Hqfn67/lwb87ZJuaPl/LDN0zjwNce5dbBISbXJeluv5J/bmsk871nWPl4B28f+dq+PLM+fzFfv3oenf/yLMsf2c11xhCb28ymb7yFX9h7Vu4Tc9yww/f5CF0spYqjfHH49/ZH3MB+LgAMSfq5iX/jLPoA2MgsHuKPKVKDUObTfINaijzHbFZzMyWSTGUTf8JPj4vMKmmXhEkbT3Zl8TrIrM13uhEO3gh/BZADmQy3fwzWfxLesQy23gMP3AjXfgKu/R3858jXfQLetwg2H/94X4ALU5Dz8jmMxURkBtm3Huf9i1t54XOr+P7gEPG+PMk7nua6ec1s/cvLeeDrv+faf1rHtbe9hf/8+HKe/vhyngZ4dBezvvsMn7x6Hp0dfdQ9uJObbn8rt81rYeBzv+bmn21hyXvPZavt5+eW1/16vIEnSfEwD/Inr37ueh6khXsB+DlX8wDX86fcTYEYD/BR3s4PWEYnh6gjOTyU8Ts+yJXcxXJ28o98lkc5j6t44birJcCb7aW8fhn4YlD6b+GcFui6Gno2wLIvw1MAX4annocLRv7dl+CC6dA1G/Yf+/V7oPrH8Nb/Bb/yuvbjlSMQZAcHSO0fYNFnVvI4QF2S0swGsjt7Wfbec52f3XvP5amXe1/72Y14aCcrFk1iHcCLXbQ2VnNoXgsDAItbeXFtJxd5+Vy8Fj/+NXcJ22licNTnWo55Qy6QhOFW/uOcSwOdLKMTgKkMksDQSRNFUqxkJzHgbJ5iBxeOcXnPfjsj0yI71i9hxRU479iD0LgKpxm9Cvoyw1sM7YPkj+Dta+Dbt8Lbjv36D8K7/wgeasH+Yu0oBNkLXbSmEhz9/G+4uTtD25Q6dv/lm/lpvkTj4lbnZ7e4lb586fXbQ23vYfmnVvBdgPOm0tW/nunPH2DyOa30bjrEhWUT/jl4cQOlUzUh/p0b6eRi4mT5CP8HgMPDRzZ+m89RoJ6zWMcf8iCHaCbFkVe/toVedoy567NnDZcIvAxG64P4FnjD/4RnT/bv/ghueB+sPgvyx37+36HtAEy5HZ53t9LxMRL+MbJimfiRHLOvXcCjd72Hr1fFGPr7tVx7qq/79XbmJYShy+ewD6CtkcwNi7j7/z7Fn956H19oquawEP6bJTEzjuf4Ye7hK3yJ2axlNVcBUCbGERbyAe7kE/wNe7mQx1hygt+4sT7rWZBFbiLlX8P5M6DjjXAUoA7610LTKuhbC021w59/GeY9BxfdCTfloVbAfBAKMTB7YU4jfKMM8Sw0zIY/62D4XcxjrxvMDaF5zfSmEvS+42xeAbjsLJ69fzvXVsfp33aYpsWt9G07TFN13PnZjXhkNyuWTHG6lSM+tIyNH1rGRoBvr+HNsQi8EZRlAg2Wi3ma/+AzwH/RSC8tvMQ0pyvOdDazj9lcxlpyx7TAemmhZvjmwGiefW+9bpFZf/e7D1a+hdd+uZfBhm/CJQDfhEuWwQaAPfA3/fCVfvjK1fDbd8Kv7oZHfgSPDsJf9MNX7oHbW+CgrRADiFn/jrrv7Mn011bRu6bT6eqsP8A5k2vZP6+FDT/b4vzsfraFS+a3OD87gGIZ2dnLG9959ugge7nX6X529lP7zD6ufPdiHvPyudhwym7ltmNONXuWZdRxAIAVbOEosxggSYEYXSxiCvtpo48EedYxjzKwnUtYOGYPxbMg87pFZvXd7wAkt8M598BdI5+7Ax54J9zaApc1Qc+v4fs2a5yoKAQZwIeX8ePvrOOj//g0iYYkXX95OT8sG6T9UW79w59zWV2Snq9d8drP7t5tnF2ToPeiGRw+9nH+7ine352lDeCqudy3qo1DHj8VT5Wc19xrUfZdbqGXRRSp5za+xfncyx6Wcg/TEAwpurmRuwFoJcO5rOY7fAUwTGMzV7MJgKu5m4e4mdVUMYUXuOL1d/bxsOEixniYLSJNwPu9u2D47VpIZvMKam3XofypAOY9i60NP/zAq3lkXnct86f+J2oiqobCP0amTp/Fm0EFr0IMNMgCr3Ywenee1fgNibVx6ayXF/P2ReD0Y63PvQqTmkz450Gp09cft7apgKerXmy8m1tf1hMmqSwJyuGfQqBOT2/C2u9GiFtkDu1eVlhyyJv1bCp4ehLatXSLBlmFVed0TzI1ti5tkblGu5YVlsrYn2is/KmrytrNoAEvL6ZBFgI1GR0jU2PrsrcIsdfLi9kIsn4L1wy12gGdS6bGdqjK2l3tHi8vZiPIxlpcqs5Afb8GmRpbd8LKxhADJm08nWalQRYCLYf1LAT1ekUo9yesvMY97VaCnSA7ig92wQiT6jzxVEYnGqvRjtqbDOtptxJsBJkzu1/HySqsuVvnkqnRehPWgiwSLTLQIKu4yYd0nEyNtrPaWs/H862RbAXZkVP/EzURrQd0zaUabYunJ0u+KmPSxvPXtwZZSDT0k4wXdYa/es3mWqosXHavhWtaC7JQ78ppS2OvDvgrR04o7U9qkLmtFyhYunZoTerSu8HK0VFt7U1tn42L2gky585ll5Vrh9jUfTpOphxbU1be1PpM2ni6xnKEzd1FD1q8dihN7iKVzGlLV8ELtVbe1Kx0K0GDLHSm7dVxMgVbaqwsTdpl4ZqA3SDTAX8XzNkRvUOX1Wh9cQpHvF9jmSWSLTJjcui6y4pr7qG6WpcrRdor1VZWebxs0l6eLTma7RN49lu+fijN2KPjZFG2tcbKQP8OC9d8le0g67B8/VCas8PK/CHlE0/We74bSr9JG6tDRbaDrBN0NnqlNfSTrB3QsxGiqDvO0Cspz9/IrLbGwHaQGVNEu5eumLlbd8OIorUNVoYVtlu45ii2W2QAu20XEEZzt1Ot511GzyONnrfG9pi0sX7TToMspFJZEtP2eXskl7KrL07hxRrPx8c2eHy9MdkPMmMGsLCjZBQs2qRzyqLkmTrPu5WHTdpYWVt5PPtB5tC7ly5oOkKy+bC2yqLikUbPlyVt9Ph6J+SXILN+1yOsFm/SnWOjICOUNtR62q0cAHZ6eL2T8keQGdMDHLZdRhhNOUCqsVcPRQ679XXkjXj6prXZpI1vto3yR5A5ttkuIKzOed52BcptjzR6+lrOAFs8vN4p+SnIdqCTY12hrbJwywuldfVUe3jJdSZtfDVP0T9BZkwenYrhmnPW265AueXxBvIl77qV3cBLHl1r3PwTZA7tXrpkykFSU/eSsV2HqrxfTPJ0ms0am7tcnIjfgqwT9MXmlmVrqdaTlsLlpRS5PdWe3a3sMGljbc+xk/FXkDl7+W+1XUZYVeeJL9mgi8nD5J4Wz5ahGWCtR9eaMH8FmeMFdNDfNfNeorbhiA78h0FfnMLjDaQ8utwmkza9Hl1rwvwXZMZk0bEyV134JDFdUB589zcz5NHcsSPAOg+uc9r8F2SOjaAvNLc09pGcu0PHIoMsL5R+OYkaDy5lgEdM2vi6l+TPIDOmH3jFdhlhtuR5aqqzuiV2UP22iVwu5snrd4Pt3V/Hw59B5vDF9iBhlSgRu2ANJYy2fIOmCOWfTvZkAmwP8IwH1zlj/g0yY7qwdPx6VEw5QGrhFu1iBs3aenI97h/3VsbpUvpmPeXJ+DfIHDof3WVLNlI3+YBu9RMUJTA/nOLJLrBPm7QJzEYO/g4yY/biTJJVLlrxGNUpPQszEH7bRGZ/0vUg22nSxjd7jY2Hv4PMsRa9g+mqRJHYqochprP+fS0jlO6c4vq8sV7gUZevUXH+DzJjuvHBKS1h19BPctlabZX52U9ayWfiru4CmwN+Y9ImcHez/R9kjnWgx5u5bVYHNXNe0sF/PzqYYOieFlfnjZWB1SZt+l28hmuCEWTGDOKj/cHD7PxnqWnp0sF/v/mnaZRdnsX/uF8OEjkdwQgyxwbQF5jbBGTVw1Q39uh6TL/YkiL7bL2rY2NrTNoEerOG4ASZMQV8vt4rLBIlYpeuJqlhZl8JzB3TXR0XeyZodyjHEpwgAzBmK3DAdhlR8GqY6RbZVj3cSNbF/caeN2nznEuP7algBZnj9zgDk8plI2Gm2/7YkRFKd051bSnSZpM2T7v02J4LXpAZcwSd8e+ZRJHYZQ+RrO/TMPPav00hP+DOdIsXTdo86cLjWhO8IHOsx1nQqjwwHGZV9X26u6xXnqsl8+sWal146GdN2jzmwuNaFcwgM6YMPIJ2MT1TVSB+2UMkmrq1Zea2vjiFb82seJeyDDxq0ubZCj+uLwQzyACMOQzo0bMeGg6z6pm7dNKsW8pgvjWTUoVn8BeAB0zahHbn5eAGmeM5oMt2EVESM8hFT1G7eAMZ3cus8u5rJrOptqJzxjLAvSZtQr35ghj/HVE3MSINwHvA05OWFXBwJtnnLiVZqnJ1nlNk7EmS/8xckhU8bPcQzrKjgQo9nm8FP8gARGYD19ouI4oytRTWXkV5sFHfSM7EkFD+9FxKFdyiZyPOnmKRGEcOR5ABiKwELrBdRhSVYpTXX0LuwGxX7rJFwh3TyPymuSLfvxzwsEmbPRV4rMAIU5AJcD0ww3YpUdU5l8zmN5IsJl3fhjlUnqkj095WkRDbB/zOpE3kbsaEJ8gARGqBm8CTY7LUGApVlDasJK+ts/E5nGDoU3OJn+FdyhLO3Mr1Jh2mF/T4hSvIAERmAtcR/DuygdY1ndzzFxPP13iyv3wgDcYofm4O5uCZjYvtxdmCp69SdQVR+IIMQGQJcLntMqKuFKO85SKyuxdQS8yTE7EDY0gof+ksCttrTvsmSQZn+50dlawrqMIZZAAiy4GLbJeh4EgL+fWXgt7ZdJTAfHMmubUNpzUEYoAtwDqTNro1+bDwBhmAyBXAYttlKDBgOhaQ3X4+iVyta9vSBML3ppK5//TWUXbg7B8WmGPavBL2IIvhzC9rs12Kchgwe+aTe+l84rm66AXaPS0M3jmVugl+WQfOYm9dxXIC4Q4yAJEq4F1Aq+1S1Gidc8lsW0oiWx+NQHuqnsw3Zk2oJbYHJ8AOuVVTWIQ/yGBkWsYNQKPtUtTr7Z1NdtsbiGUawjuGtjVF9kuzSY1j+VEZ2AVsMmlz0P3KwiEaQQYgUo8zYVbDzKcOzSC3eyHlrhmkyvHwTJ/ZV0X+c3OpysVO+pz6ga3ANpM2esjOBEUnyGCkZXY90Gy7FHVixTjlfXPI7ZlPrHcy1UGeurGvivwXZpPoT4w54XWk9fWiSZu93lYWLtEKMhgJs3cCLbZLUaeWT1HsmM9Q5zziQZu+sStJ7ouzqTpu1n4B6AR2Ax0mbXSjygqIXpABiNTghNkk26Wo8TvayNC+ORS6phPvayFpfNz93F5N7suzSead7mQGJ7h2A3tN2pTsVhc+0QwyAJEUTphNtl2KmriyYHqmkD80k1JvK/H+FqpKY3ffPNdZRffn59KTi3EQOGjSptt2TWEX3SADEKkG3obumBEKRxspdE+l0DcJk61DsrXE8ykSxaQ7ARcrUkrmKdX3U2ruwbQcJtbQxyu1gzwxfK6E8ki0gwxGJs1eDiyyXYpyRzFOOVtHcbCB0mADJlOPydUgRsAImBgYnP/nmM9hoDqHSWUxqQxSk4FUhlhNhnh1lnii9Lqu7XqMWWfhKUaeBtkIkQuAlbbLUIFUBp7EmC22C4kqDbJjicwHrgTdGFCNWxZYjTH7bRcSZRpkxxOZArwddGNAdUpdwIMYM2i7kKjTIBuLswrgrcBU26Uo39oKPIHRqRR+oEF2Is5NgOXogSZqtDJOgL1ouxD1Gg2yUxFpA65CzwFQMAD8FqOLuf1Gg2w8nJUAV6H7mkXZNuApjO7K6kcaZBMhsgxYgR5sEiUZ4DGM2W27EHViGmQTJTIZZwLtFNulKNftBB7H6MJuv9MgOx3OYcDn4bTO9Liz8MnhDOi/bLsQNT4aZGdCpA64FJhnuxRVMduApzG6uWGQ6FjPmTBmEMsi/1cAAASrSURBVGMeAn6Dc0dLBddB4JcC9wt0iciAiJRFJDv85wER+eBEH1REHhGRW1yoVx1Dl+JUgjG7EdkLXAgsRb+vQTIIrMU4B90aqB/5CxHZBdxijFltqTY1TtoiqxRjisM7H/wYeAFn4qTyrxKwHvjpSIidjIjERORLIvKyiHSLyH+IyKThv0uJyF3Dnz8iIutEZJqI3Aa8GbhjuEV3h7tPKbq05VBpztjKE4hsxFkZsBCCu+d8CJVwxsGex5iJDAd8FrgRuAJnjeU/AN8BPgB8BGgCzgLyOKtBssaYr4rIZcBdxpj/V7mnoI6nQeYWY44CDyOyAefu5hzLFUVdEWd95IbTXOT9MeDTxphOABH5GtAhIh/C2Yd/MrDQGLMReLYyJavx0iBzmzE9wG+Gd9VYhnOHU1to3ikAW4CNZ3gncg7wSxE5dsigBEwDfoTTGvuJiDQDdwFfNcYUzuB6agI0yLxiTBewGpEGnBsCi9E5aG7KAi8Cmys0oXUP8N+NMU+c4O/bgXYRmQv8Cqf7eieg85s8oIP9XjPmKMY8CdwNPAn0Wa4obPYCq4G7MeaZCs7K/x5wm4jMARCRKSLy7uE/XyUiS0UkjnPQbgGntQbOtI75FapBnYAGmS3GDGHMZoz5Kc47+A6ccRw1cTlgA/ATjLkfY3a6cPjH3wP3Ag+KyFFgDbBq+O+mAz/HCbEXgUdxupcjX/ffRKRXRP6hwjWpYTqz309EEsBc4GxgFvpGczJlnNbXS8ArempRtGmQ+ZVz7uYCnOkb0yxX4xdFoAPYBXToljpqhAZZEDhrOtuGP2YBKbsFeSqPc0L3LmCPbi2txqJBFjTOzhutOLf723DOFQhTFzQPHAD2D390a7dRnYoGWdCJVOGEWSvOHmmtQKPVmiZmEOfOnhNczrw7pSZEgyyMRJKMDrYmoAGotlhVDjgy/NELdAM9ummhqgQNsihxWm8Nx3zUD3+kcCbnVgHJY/58KmWc+VI5nAmomTH+mwH6NbCUmzTI1Ik5wTcSaOVRHzpupXxEg0wpFXhhutullIooDTKlVOBpkCmlAk+DTCkVeBpkSqnA0yBTSgWeBplSKvA0yJRSgadBppQKPA2yiBORX4vIR2zXodSZ0CVKASQixx4sW4uzh9fIhoMfM8bc7X1VStmjQRZwIrILuMUYs3qMv0sYY/RAExV62rUMERG5UkQ6ReSLInIA+FcRaRGR+0Ska/gkn/tEpO2Yr3lERG4Z/vPNIvK4iPzt8L99RUTeYe0JKTVOGmThMx2YhHMy9q04P+N/Hf7/2Tj7hN1xkq9fhXO4bCtwO3CnONtrK+VbGmThUwbSxpi8MSZrjOk2xvzCGJMxxhwFbgOuOMnX7zbG/ItxDvn4ITADPcVJ+VzCdgGq4rrMMbuxikgt8HfAtUDL8KcbRCRuxj6R6MDIH4wxmeHGWL2L9Sp1xrRFFj7H3735M2AxsMoY0whcPvx57S6q0NAgC78GnHGxIyIyCUhbrkepitMgC79vAzXAYWAN8IDdcpSqPJ1HppQKPG2RKaUCT4NMKRV4GmRKqcDTIFNKBZ4GmVIq8DTIlFKBp0GmlAo8DTKlVOD9f+XwT87zF3j/AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "venn2(subsets=[\n", + " len(set(x_train_counts.index).difference(x_test_counts.index)),\n", + " len(set(x_test_counts.index).difference(x_train_counts.index)),\n", + " len(set(x_test_counts.index).intersection(x_train_counts.index))],\n", + " set_labels=('Train', 'Test') \n", + ");\n", + "plt.tight_layout()\n", + "plt.savefig('venn_train_test.jpg', dpi=300)" + ] + }, { "cell_type": "code", "execution_count": 20, - "id": "waiting-currency", + "id": "sporting-information", "metadata": {}, "outputs": [ { @@ -609,29 +628,36 @@ }, { "cell_type": "code", - "execution_count": 21, - "id": "divine-affiliation", + "execution_count": 45, + "id": "noted-attack", "metadata": {}, "outputs": [ { "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de3zcdZ3v8dd3ksmtSZqm6b20BUoptKUU5WZB7mrxeFtU1GV1Vz2uZ8+q56wr+kDdcVjlILsre3H1oa4uXtBdWEAUwQsuCAKlXFouLS3Q+71JkzT3STLzPX98JxJK2ibpZL6/33zfz8djHpTc5jPJzO8937ux1iIiIuFK+C5ARET8UhCIiAROQSAiEjgFgYhI4BQEIiKBUxCIiAROQSAiEjgFgYhI4BQEIiKBUxCIiAROQSAiEjgFgYhI4BQEIiKBUxCIiAROQSAiEjgFgYhI4BQEIiKBUxCIiAROQSAiEjgFgYhI4BQEIiKBUxCIiAROQSAiEjgFgYhI4BQEIiKBUxCIiAROQSAiEjgFgYhI4BQEIiKBK/ddgIgcm0mbGmBS/lY97JbEvY7LDruVAwbIAoNHuPUC3cNvNmUHivagJDKMtdZ3DSICmLQpBxqAxvxtSv7/J1G81ns/LhQ6gdZht3absrki1SBFpiAQ8cCkTRkwHZgFNOEu+vW4d/FRlAPagYP5236gWeFQGhQEIkWQf7c/HZiNu/hPx3XhxNkALhD25G8tCoZ4UhCITBCTNnXAicAC3IW/1CdnDAD7gB3AVpuyPZ7rkVFSEIgUkEmbKbiL/4nAVM/l+LYf2ApssSnb5bsYOTIFgchxMmkzGViEu/g3eC4nqpqBLcBmhUL0KAhExiE/2HsisBjX7y+jY4GdwAvADpvSBSgKFAQiY2DSpgF38V8EVHkuJ+66gI3ARo0n+KUgKFXGlAG1QB1QA1QCFfnb0L+TuAHM4VMWc/nbIG5OeR+Qyd/68rcuoAsbzgwRkzbzgDPQu/+JkAO2A8/blN3ru5gQKQjizpha3Dz0qcBk3MW/Hnfxn0g5XCB0DLu1A81Y2zvB9100Jm1OBFbgfscy8fYD62zKbvddSEgUBHFiTA0wE5iGu/A3Ec3uiS7c4GAz0AIcwNp+vyWNnkkbAywEzsQt9JLiawWesim71XchIVAQRJkxFbjFR3Pyt7helCwuEHblb/uj2K2UD4BFuBZAvedyxDmIC4RtvgspZQqCqDFmMnASMA/3zr8UFyEN4Fai7gS2Yf0PFJq0mQ2cj+b+R9U+4FGbsi2+CylFCoIoMKYeOBkXAKFdiCywF9gMbMHaTDHv3KRNPXAebvWvRJsFNgFP2FTpjENFgYLAF2OqcN0QC9FA5JAcrutoM7AVawcn6o5M2lQAZwFLKc1WVynrB57GzTKKXBdjHCkIis2YGcDpuHf/cd90bCJlcO/+NmBtRyF/sEmbhbhuoOpC/lwpukO47qKdvguJOwVBMRiTBE4BTiO8rp9C2Amsx9odx/ND8oe7XAjML0hVEhUv4gIhNjPTokZBMJHcrJ+lwDLcIi45Ph3AM8Cmsc46yrcCVqK/Q6nqAn5nU3a370LiSEEwEYypxF38l+JW8EphdQFrGUUgmLSpxrUCFhShLvFvA7DapiZufKkUKQgKyQ0AnwEswW3fIBPrqIFg0uYkXAioFRCWDuBBm7L7fBcSFwqCQnD7+izFLURSC6D4OoE1WLsZ/rAz6Pm4QXkJkwWetCm71nchcaAgOF7GnASci9vcTfza98OzWPvBt3M2mpIrznbgAQ0kH52CYLyMmYIbfNRulBGxfza9T64k+fAUBr45nYqOck3PFcBNM/21Tdk234VElYJgrIxJAK8DlqOFSJGxaRndLy2hBuO21O4x5H4wjb5fTJnwXVglHgZxs4o2+y4kihQEY2FME3Ax0Oi5EskbLCP31IX0Nc8a+YK/oYrer84h2VpOebFrk0h6DjerSBe+YRQEo+EGg1+HmxGkVkBE9NQwsPoycj21R58V1GfI/tt0Mr9qUOtAANgK/LdN2azvQqJCQXAsrhVwCfHdArokddcy8MgV0F81+mm6z1XTe9Nsku1qHYjb6PBXGkR2FARHY8zpuGmIGnSMkK46+h+5gsRA5dgv6D2G7E2zGXiqNpIH+khxtQL36rxkBcHIjCkH3ojbGVQipLPehcBgxfjf1efA3tlIz/enMamQtUksdeLC4JDvQnxSEBzOmAbgTUCD71Lk1Q410P/YZccXAsOtr6b3+jlU9JSpxRe4PuA+m7LNvgvxRUEwnDEnAxeB+pCjpr2RzGOXUp5NFvai3VbGQHouuc1V2oYicP3APaGegKYgGGLMCuBs32XIa7VNJbP6UsqzE7RAbAByN88i83C9zicIXB/w8xAXnikI3AKxC4DFvkuR1zo4jb7HLyaZm+BVwjmwt0yj965GTTENXA8uDIIaMwg7CNyBMZcDJ/guRV6rZTp9ay6mIldWvLUb9zbQ/c0ZGkQOXBcuDDp9F1Is4QaBMTXAW9DmZJHU2kTmsUtJ2iKGwJA1k+i5YQ7V2fx2FRKkTuBnNmW7fRdSDGGukjVmEvB2FAKR1FfN4JqLKPMRAgDndFNz0w76kjl0MHq46oArTdoEsa18eEFgTC3wNqDedynyWtkEudWXkC3UFNHxWtRH9Y07ySgMgjYFuMykTcm3DMMKAhcC/wOFQGQ9dQF9XZOjMZVTYSC48cPzfRcx0cIJArUEIm/TUroPzInWrJ1FfVTfoDAI3VKTNiV92l0YQeAGht+GThGLrP2z6X1pabRCYMjiPqq/vJNMmSXQmRUCvMGkzRzfRUyU0g8CN0V0FQqByOqqo/+plVQQ4Vk6p/dR/YXd9PquQ7xJAJebtJnsu5CJUNpB4BaLXQFM9V2KjGwgSXb1pTDRC8YK4fXd1Pyv/QQxnVBGVAlcYdIm8s/VsSrtIHA7iM71XYSMzIJdcxH9fTXEZorele1MelsbwW9bHLBG4DzfRRRa6QaBMa8HFvkuQ47s+dfT0zYtfvv7fOQA1ed2qpsoYEtM2sz3XUQhlWYQGLMIOMt3GXJkzTPo235KPLdyKANz7V4qT+4j47sW8eYikzaRnNwwHqUXBMZMxW0iJxGVTZBbd168n3sVlsSXdpGoyaJzb8NUBVxaKovNYv1ifA1jKnCDwzpPIMI2rKAvE6NxgSNpyJL8m93ozNtwzQbO9F1EIZRWEMDFaMFYpB1qoH/7wviNCxzJkl6q/7hFM4kC9jqTNo2+izhepRMExiwHFvguQ47Mgn16JZZEdNcLjMd7DlKzrIc+33WIFwnceEGsn9OlEQTGzESni0Xei0vp6a6Pxj5ChVQG5rrdlDcMMui7FvFiGrDUdxHHI/5BYEw5rkso/o+lhHVPYuDl00unS+hwtTnKP7eHAd91iDdnm7Sp9V3EeJXCxfNcNC4QeWtXMujrfIFiWdJL9RXtWmwWqHLgDb6LGK94vzCNmQ0s8V2GHN22hfS0Ty3d1sBwHz1ApbqIgrXApE0sj72NbxC4zeQu8l2GHF2misEXVpTeuMCR1FjKrlUXUchWmrSJ3XU1dgUPcy7aUTTynj2b/mwMNpQrpGW9VF92SF1EgaoHTvNdxFjFMwiMmQ6U9EERpaCznv79c8LoEjrcx/ZTWatVx6E6y6RNrBa1xi8IjDFoC4lYWH8W2SifMTCRaixlH9+vvYgCVQ0s813EWMQvCGAx0OS7CDm6Qw30t8wKszUw5MJOqrUxXbCWm7SJzdhYvILA7SWkhWMxsP4sdYskwHxyn463DFQFsMJ3EaMVryBwW0tX+S5Cjq5tKpnWGWG3BoaclKHqwg6dXRCoJSZtYrHVenyCwJg6Yr6MOxQvLCfnu4Yo+cgBynTwfZDKgOW+ixiN+ASBa2bFqd4gddbTr9bAq03NUnFlu1oFgVps0ibyW67H48LqWgM6djIGNi7X2MBI3nOQpFoFQSonBusK4hEEag3EQm81g/tnawxnJFOyJN+kVkGolkZ9tXGkiwPAmFrUGoiFTWeQKbWzBgrp6oOUG7UKQjQJOMl3EUcT/SBwR8HFoc6gDZaT2zNfYwNHMzVLxeWH1CoIVKQXmEX7AmtMDW4BmUTcrgX05Up8m+lCeN9BnacdqGkmbWb6LuJIov7CPY3o1yjArpPUJTQa0wepOLtLrYJARXZ/tOheZI1JEIPRdoG+agbbGzVIPFrvavVdgXiywKRN0ncRI4luELiD6Gt8FyHHtuMkMqFuLjceS3qpmjagMwsCVA6c6LuIkUQ5CCLbjJJX271A/d5jkQDzzlb6fdchXpziu4CRRDMIjGkAZvsuQ46ts57+7vpwTiArlEs7qNQCsyDNjuL+Q9EMArUGYmP7QnVxjEdtjvI3ajO6EBlgoe8iDhe9IHAHz0R68YW8Yu88Ir+PSlRdcUjjKoGKXPdQ9IIAZqFB4lhobSKTqSaSsyDi4LReKmt0nGWIGk3aNPguYrgoBoFaAzGx/RRdxI5HOSQu7NQJZoE6wXcBw0UrCNQtFBs5g903R4PEx+uSjoi9BqVYFARHMQedQBYLrdPIZJOU+a4j7hareyhUs0zaRGbaddSCQK2BmGiZqYtXIZSBuahD3UMBKsO98Y2EqAVBpJpLcmQHp6k1UCjnd/muQDyJzPUuOkHgFpFFbqGFjKyjUbOFCuW0Xip1TkGQFAQjmOu7ABmdznr6s+VqERRKlaVsUZ8W5gWozqTNZN9FQLSCIDL9ZXJ0LTMZ9F1DqTmnS0EQqOm+C4CoBIHbclp7C8VEywzfFZSeFd1qYQVKQTDMdFCfc1y0T9XfqtBOzFCRzJHzXYcUnYJgmMge4SavlqliUNtKFF45JE7v1dbUAZpq0sb7ddh7AXlNvguQ0WmZob7sibK4V2szApQgAtc/BYGMScsMTXOcKIv6tBtpoLx3D/kPAmMqgXrfZcjotDVpUHOinJhRl1ugpvkuwH8QqDUQK911ulhNlKmDlFdqwDhEU3wXEIUg8J6GMjqZSrI2EYnnTElKgFnUpwHjAHnvEYnCi7rRdwEyOr2TtJBsoi3qVYsgQBUmbbzuuhyFIPCehjI63bW6SE20Of0ajA+U1+uggkBGradWF6mJNnMgEq9JKT6vew75fdIZk0QH0cRGr/aGnXDTBzQrK1BBtwjUGogRBcHEm5IlMqdWSVEF3CKAOs/3L2PQW+P9+VLyKiyJ+kGtMA5Qrc879/3CVhDESKZK71aLYU6/ZmcFqNLnnfsOgmrP9y+jZMEOVKj/uhgas5qdFaCgg8Drg5fR661hEKO9cIqhTh1DIVIQSPT1aA1B0dRlNU03QGUmbbx1vfoOggrP9y+jlKlWEBRLvYIgVN7eGPsOArUIYiJbpotTsdRm1QUXKG/XQ9+zQGITBO+Fy34LFxqws2D3A3BLE252x1VwxZ3w7mfh08ugC+BWmHMtXNMH1QZyG+GGJhj8JJz9Y1hlgHpo/yl8b+h7oizn+y3DYfqzmA/fzedrkrR/+218HeDvHuWSp/ZwScKQW9DAczdcxh3r9jE19SDpugr2A8yoZcs/vIlbAf7sbj7dO8Dk8oQ7bCd9Cf948hQ6/T0qp7YU2l7f4kM0s4wknXyWNADN1HArH6OPqVRxkGv4Nk300EcZ3+caDjEfsKzkP1nJiwD0UcYPeD+tnArkOIuf8ibWso1G7uJPGaQGS4KzuZNLeN7fAy6IYIMgFlsaPwINv4LLXoLUdBhYAh+7Ds7+Njz2EExZC6fXQevQ1/dA4v/AR26G710Du9bDpFrI9kDi3+Dqx+FLy6DrErjqWrjkPvi5z8c3GjZiQfC1x7isoYq9/Vk38+yOFzh1Ywtn/tvbub62gsHNba9MTZ6UpPlHf8TfjvRzPryC777pZLYXq+7RmFQKLYIzeJQqHuDX/NkfPvYLVjGDjbyfX/IT3sI9vIU/5U5+wYUAXMv17KWOH/FJzuMGyrD8F1dSRSef44tkMbTgljX+miuZz1P8Eb9jPbO4m09wCdd5eayF421Wnu+Xd2ye8DlIHIRkDyT6oWIeHAL4OLz3ergDXuk6+SqcPht2XQO7AJZAdxXYLBgLphkqskAPVM2Adj+PaGyi1CJY30zDS60su3gBvx/62H9v4aI3n8x9tRWulRaFd/bjFaFf9fidz0tMpvtVH9vHci7kMQAu5DH2cSYAbcxiLhsBmEUn5fSwlvkA7GAl7+I+AMqwzMi3ng3Qn9+epptqKt3rMea8/el9twhi8ZxfCe1vhV+fATeWw8ApsOELsOE6OKMR2q+BXX8x7Os3wgwDnASf6obalfDEnfDrOsj+Bdz6FkglIdMIB34HP/b1uMYiSkHwr2u4+n1LuKOz/5V9qg5lmLG+mVM+eBfvLEsw+L4l3P7mhe6dfs8ATdfcyReSZfS981R++o7FvDz0fd9fx4d+8Az2lKk8/cU38otEBN6amFIdjRmgnrn5C/ZcDjGQb7U1sYttLGeAJ9jBFLqYTytTOOi68/gv3kELi6ihmXfxE2bTySp+zo/5FF/hUnJU8FZu9va4Csfbsy9CL+/o2gQ1q+HMJ+G6A3BtP1R8HM67Ba78Pvzs8K8fhMRWWHg3fPcZ+LvHYcX/g8WdUHYHXHQ3fLkDrj0Bdr8bVnl4SGMWla6hHz7LspoknVeczI7hH89ZEr0D1NzyTm68egn/9d21/HnOwsJGDv3zKj73oz/iyx9Yym0/eo6P7u9yAXLtG/jurVdx/dfezE07D7Hwm09ynp9H9WoR+VUXz1t5hGrauZnPcy9XU89mEuQYpIx+pjCHl/kMX2EaW7ib9wDwKGezgMf4PJ/lzfwLv+HDxL9LTUEQZd+C05qgZTl01UH2Ylh7H7yhHZpWwBfr4YYumPIG+PwTUD8X2k6GF5dB10zoXwHPPwHz/gPmAqyC5jLg3fDkRjjZ76MbHRORAcxNLSzcfojl772dG378HP9zfzenfuI+Pjypgrbz5rI2YeAtC9lmDLnt7dTWVjA4L99FccXJ7KitoPmZ/cwAOG2a65abPonMmTNZs62dE30+tiER+VUXXpIOduU3V9vFZJL57rskOT7MbVzL3/IJvsEgNcziANPoIkE/l7EOgHN4ig7mAbCVCziHJ/Mf30KOJAf87tdTAN7agr6DIBaN4EXQuhVO2pfv238cFq+EtT3w1x1wXQdcVwttj8JXzoaOv4ANu2HOPqjogcR6WLQU9i6F9maY9Ux+g6lfwWlzYK/nhzcqiYhcnb58KXfd/h4+e9t7uO4Dy/jOjEls+pdVfO+0Jtat28digMd3MT1nKZ/fQNfWNmr78+8Un95LU1eG6ac10ZwZJLGt3f0degco29DMGbNq2e3zsQ2xcX9feyQzeYaHOR+AhzmfmTwDQBcVdObXFD3EaRiyLGEvCaCJZ3mMRQA8x2Jq2QNAFa085/7ebGAmOZLMiO+4UJ6366HvMYJYBMHHYetP4anT4fMJyM2Fnf8KDx/p60+Fnqvg/iVwnQG7DJ6/Hp4DuAruuRQ+k4DsFDh4F9xStAdyHKISBEfy56/jkc/9lg994A5SCcPg1Uv494SBB7Zxyv1beIcxZA3k3n4qt54wmZ6DPVR88QE+lbOUWUtibj0v/OU5R/6bFlPEf9Wj8w0+ShuLGKSWr/BVlvIzruSX3MrHuJGVVNLKNXwLgAPUcRufwmCppJ138b0//JxV3MFdfJhHuZoknbyL7wNwObfzS/6Em7gcgAu4xfvb2uPn7U9vrPV4LTbm/WgH0ljYtpCe58+mxncdIVgziZ6/navfdYDusSm7x8cd+87QjOf7l1GKeouglHRqj9dQebseKghkVBQExdOdiEeXqRScgkCirbLX+3MlGB1qEYQq2CDo93z/Mko1XTqUplg6ymI/H17GLmdTdsDXnfsOArUIYqK6h3KsuiyKoUtBECKv10LfQdDn+f5llBIWUz6gQ9WLobVcQRCgoIMg7gtAglLVqyAohr0V3tf3SPEFHQQdnu9fxqCqtzTWOkXZAORayxUEAfJ6JonvIFCLIEaqehQEE629XK2uQHl9U+w3CKzNoJlDsVET+XPU4q+53J2nIMHxep6C7xYBqHsoNmq6I/F8KWn7KjQzK1ABtwgcBUFM1HRF4vlS0vYkFQSBCj4IWo/9JRIFNd1aVDbRXqzW7zhAAzZle3wWEIUgaPFdgIxOVS/l5PSOdSJtqiLpuwYpOu+9IlEIgmbfBcjo1fRocH+itJUx0FOmFkGA2n0X4D8IrO0Fd5SgRF9Di6Y3TpRtlXjba0a88v5m2H8QON5/ETI6Tft9V1C6XqpSt1ugDvguICpBoHGCmGjapz7sibJBA8UhskTg+heVIND7zJio6SGZzGjRU6Flwa6vyR/gLiFpsynr/fUUlSDYB+p7jouGVg0YF9rOCvr7EpF5PUrxeO8WgqgEgbVZXBhIDExV+63gnpmkN0KBUhAcZpfvAmR0mvZrd8xCe7w2Uq9FKR4FwWF2+y5ARmdyK8lEVjuRFkq/IbehmkrfdUjR9dqUjcTOCtEJAmtb0IllsWDA1LfrmNFCebmSTNboVLIA7fRdwJDoBIGj7qGYaDygOe+FsrpOv8tA7fBdwJCoBcEW3wXI6DTtj9xzJ5ZyYB+o17TRAFki9MY3ai/mneigmliYup+KxKBmuhyvrZVk2nU0ZYj225SNzLUuWkHgppFu912GHFtZjsT0vRonOF4P1mvQPVCRGR+AqAWBs9l3ATI68zZH8vkTGzmw/12v2UKBUhAcwy7QO804mLaXyvJ+bTcxXpsryXSUa3+hAHXYlPW+v9Bw0QsCa3PAVt9lyLEZMLN2KrTH67eT1S0UqJd8F3C46AWBs9F3ATI681/WQOd4ZAzZ+ydT5bsO8UJBMCrWHkBnFMRCQyuVVTq1bMweriOT0SZzITpgU9b70ZSHi/ITcYPvAmR0Zm/XyVpjdWejWlKBetF3ASOJchBsRoPGsTD/ZS2IGovNlfTtrNTvLEA5IrpoNrpBYO0gsMl3GXJsk7pI1h7SPlGjdfcUDRIHaodN2Ui+TqIbBM4G0D4scTB3qy5uo3GojIGH6qn2XYd4EdlJMNEOAms70AKzWDhhC5XkFNrHckcjA9ppNEjtNmUjs8nc4aIdBM5a1CqIvMoMZTN30+u7jijrSjB4T4OmjAbqOd8FHE30g8DaNrTALBZOfZZyrEL7SO5spH9AU0ZD1EcE1w4MF5cn5dO+C5Bjq+ugYuoBDRqPpDvB4M+mqDUQqBdsykZ6K5Z4BIG1rahVEAunPhuT51SR3T1FC8gClQPW+y7iWOL0xHwKjRVEXmMLlfWtahUM15Vg8M5GzRQK1Ms2ZXt8F3Es8QkC1yqI5Ko8ebXFz/quIFp+1ES/WgNBssA630WMRtyenE+AtjOIuul7qapvU6sAYF+SzL0Nag0EapNN2XbfRYxGvILA2h5ikrChO13D+wB8fQbWat1AiLK47uxYiFcQOM8Ckdu9T16t6QBVDQfDXlfwdA29z0zSTKFArbcp2+27iNGKXxC4c40f9V2GHNvpT8fw+VUgA5D7+kztMBqofmLWcxHPF6q1O4joLn7yisYWKhsPhNkq+EkTfc1Jkr7rEC+ejermckcSzyBwfg8akIy65Y9TbrJhbUi3pZK+26dS47sO8aKXiG8nMZL4BoG1fbgwkAib1EVy4YZwWgUDkLtxtg6kD9hqm7Kxm9kY3yAAsHYL2p008hY9T82kjjAOGfqPJvr2VqhLKFB7bcpGek+hI4l3EDiPQDjvOOPIgFnxKJT6NtXbKui7XSuIQ5Ujxj0U8Q8C10X0sO8y5Oga2qicv5nIL7Ufr4whe8McyrRmIFjP2pRt813EeMU/CACs3UYMB2hCc/rTVFf20O+7jonwTzPpV5dQsA4Ro8VjIymNIHAeB/b5LkKOrCxH4szHS28G0a8n0/Owjp8M2UM2ZbO+izgepRME1uaA+9F4QaRN20fVzB2l00W0o4K+b8xQCATseZuye30XcbxKJwhgaC+i36LtqiNt+Roqy/uJ9EEdo9FjyKbmUq4ziIN1ENcTEXulFQQA1u7B7VIqEZUcoGzZk/HeRTYL9qbZDLQktY1EoAaB38a9S2hI6QUBgLXrgE2+y5Ajm7Od6hm74ttF9K3p9D5Vqw3lAvZoXLaYHo3SDALnYWCX7yLkyM56hKraQ/FbaPazBrrvm6ItJAK2xabsRt9FFFLpBoEbPP4N0OK7FBlZWY7EeQ9QFqfxgtW19HxnBpN81yHedAEP+S6i0Eo3CACsHQB+ifvjSQRV9VJ+zu/Imlz0p5VurqTvxtmaIRSwLHC/TdmSWwtT2kEAQzOJ7kU7lUZWYwuVS5+M9t9nd5LMdSeQ1AyhoP3OpuwB30VMhNIPAgBr24FfoDCIrPmbqTlhM5E80WlPksxfzae8p0y7igbsaZuyL/suYqKEEQQA1h5EYRBpZ6yhJmrHW+5Jkvm0QiB0W2zKPum7iIkUThDAUBjcg8IgkgyYcx+gIir7Ee3Lh0CXQiBkLcCDvouYaGEFAYC1rSgMIis5QNl5D0BiEK8LdfYlyXx6nkIgcD3Ar2zKxmZW23iFFwQwFAY/R7OJIqmug4oVj9GP9bNVyOZK+j41n/KOcoVAwPqAe23KRnLcqtCMtQFvy2NMDbAKmOq7FHmtXQvoWXcu1SSKN1PnyUn0fHkO1ZodFLR+4B6bssGsQQo7CACMSQKXAyf4LkVea/c8eteeT1UxwuDeBrq/qcVioRvAtQT2+y6kmBQEAMYkgAuAxb5LkdfaO5fep1dSaRMT05WZA/uDJnrvmKptIwI3CNxXCttKj5WCYDhjlgPngLoFomb/bHqfvIBKW1bYMOgxZP9+Nv1P1GrFcOCyuIHhIPcnUxAczpjZwGWgC0PUNM+kb80bqShUGOxOkvniCSSakzpiMnCDuK0jdvguxBcFwUjcIPJlwCzfpcirtUynb83FVOSOMwwerqPn5plUDUxQd5PERgbXEgj6mFsFwZEYY4CzgTN9lyKvdnAafY9fTDI3jumdg5D79nT6tI20AN24geE234X4piA4FmPmARehrqJIaZtKZvWllGfHEAa7k2S+Mgezs5KKiaxNYqENNzCstUQoCEbHmCrgDcBC36XIK1DfqmoAAAiPSURBVNqnkFlzMYn+qqP38WfB/nQKvT+cpvUBAsB+4Jc2ZWN3KNJEURCMhTELcNNM1a0QEZkqBldfwmBnw8jHRu5L0n/jbOzmKiqLXZtE0lbggRC2jRgLBcFYGVOJax2c4rsUcXIGu+48evcseCWgs2B/0UDP96ZTo1aAABZYY1P2Gd+FRJGCYLyMmYsLhAbfpYizdRE961dQ/WINfV+bSdlujQWI0wf81qbsbt+FRJWC4Hi4FclLgbNAF50I6Lz9XNa9dxUrgFrfxUgkNAO/0aDw0SkICsENJr8et0WF5qUX3wCwFngOa7MmbaqAi4F5XqsS3zYCj9iU9bqleRwoCArJmAbgdcBJaJuKYhgA1gPPYu1rzpcwabMctxZE4RyWDC4ASvZoyUJTEEwEFwgrcNNNFQiFd9QAGM6kTSOuddBUhLrEvx3AQzZle3wXEicKgolkTD1uZfIi9K60EPqBDYwiAIYzaZMAluNaa/o7lKZ+YLVN2Y2+C4kjBUExGFMLnA6cilYoj0cbrgXwEtYOjPeHqHVQsvYAD2pAePwUBMXkZhmdiAsFbWh3dDlgO7Aea/cU6ocOax2cBTqKMub6gCdsyr7gu5C4UxD4YswU4DTcOMKIq2ID1QZsBjZhJ+68WJM2dcC5uIF9iZccrovwKW0TURgKAt/cLqezgZOBBYQZCh24i/9mrG0t5h2btJmJWxio7qJ42A08qh1DC0tBECWu62g27l3qPEp7T6ODwC5gK9Ye8F2MSZtFuNPpSvl3HmcduMHgbb4LKUUKgihz3UdzcOEwm3ivXu7FXfh3Abux0ZveZ9ImiRs/WEq8f9elpAdYB7yghWETR0EQF64LaRpukLkpf5vstaYjywGtuOX9zcCBYnf5HA+TNhXAEmAZYXbVRUEv8AywQTuFTjwFQZwZkwSm8kow1AN1uO6NYixkywGduGZ7B26gtxloxcb/3ZtJm3LcDK8zUJdRsXThAmCTAqB4FASlyI011A271QCVuO6Ow29DC6wMbqtei7vAZ3HT8zL52/B/D138uwjgCWTSpgy3BuQMXNhK4bUAzwMv25TN+S4mNAoCkTEwaTMXN+13PlqlfLwGgZdx/f/NvosJmYJAZBxM2tTgtg5ZjFoJY9WGWwfwkk3Zft/FiIJA5LjlWwmLcK2Eo56fHLAeYBuu62ef51rkMAoCkQLJjyXMwW0jMh/NOOrGnRG8BdhvU7rYRJWCQGQCmLQxuKm+Q6EQyolpbcBOYKtN2f2+i5HRURCIFIFJm3pcMAzd6vxWVDDtuN0/9wB7bcr2eq5HxkFBIOKBSZtaXgmFJqABKPda1LH14xYKHgT2A3t0AExpUBCIREC+K6keaASm5P/bmP9YsaepDuLWirQOux3Ufv+lS0EgEmH5gKgBJuHGGWryt+r8LYk7V2HoVj7s3wncRf1Itz7cgO7wW5emdIZHQSAiEjitjBQRCZyCoMQZY+4zxnzIdx0iEl3qGoogY8zwQbka3EZvQ7t5/rm19tbiVyUipUpBEHHGmG3AR62194/wuXJrtVWviBwfdQ3FiDHmYmPMLmPMZ40x+4B/N8ZMMcbcY4xpNsa05f89d9j3PGiM+Wj+339qjPm9Mebv81+71RizytsDEpFIUBDEz0zc/PL5wMdwf8N/z///PNzJTl8/yvefC2zCLWK6CfiucaefiUigFATxkwNS1tqMtbbXWnvQWnuHtbbHWtsJfAW46Cjfv91a+x3rThD7Pm5l64wi1C0iERX1Je3yWs3W2r6h/zHG1AA3A2/BrUgFqDPGlNmRj4v8wxbA1tqefGMglA3RRGQEahHEz+Gj+5/GHaN4rrW2Hnhj/uPq7hGRUVEQxF8dblyg3RjTCKQ81yMiMaMgiL9/xO050wKsBn7ptxwRiRutIxAJgDHmA8Bf4c5Y7gTW4SYWXA4stNZec9jXW+AUa+3LxpgvAZ/HLWwcMmitbRj2tQeAOUPrWowx5bgzCqZZa03+Yw8C5/HKhncPAf/bWrvXGHML8AHcVtf9wFPAJ6y1G4fVNBf4Km48rApYD1xvrb3HGDMPdw7ykEm44zGHLnCrrLUPj/03Fwa1CERKnDHmr3AtxxtwM8TmAd8A3jGGH/Of1traYbeGwz7fDgxfk3Il7rSyw/2ltbYWd8ZzA26iw5Cb8p+bA+wGvjvsMTQCv8eFxBLc9OebgR8bY95trd0xvL78ty0f9jGFwFEoCERKmDFmMnA97p33ndbabmvtgLX259bazxTwrn4IfHDY/38Q+MGRvtha2wrcASwd4XO9wG3AmcM+/H+BLuAj1tp9+anTP8G1av5Ba2GOj4JApLSdj+tGuWuC7+enwBuNMQ3GmAbgQuDuI32xMaYJuApYO8LnJgHvB14e9uErgDustbnDvvw2XAtn0fGVHzYFgUhpmwq0HGNPqvcaY9qH30bxNQ8c9vk+4OfA1cD7gJ/lP3a4f87//GeAvbhxiyF/nf9cJ3AB8CfDPteU//rD7R32eRknBYFIaTsINOUHb4/kNmttw/DbKL7mkhG+5ge4LqGjdQt9Mv/9c6y1f2ytbR72ub/P3/cC3JToU4d9rgW3Cv5ws4Z9XsZJQSBS2h7DvTN/ZxHu62Fe2bLk9+P9IdbaHcCngH8yxlTnP3w/cJUx5vBr1nuBncCL470/URCIlDRr7SHgb4B/Nca80xhTY4xJGmNWGWNuKvB9WeBtwNvtcc5Lt9b+Bjf99GP5D90M1OM2SZxpjKkyxrwfN631M8d7f6FTEIiUOGvt13B98V8AmnHvoP8SN8A7WlcbY7oOu00f4b7WW2vXF6Rw+DvgWmNMpbX2IG7coAq3XuAg7jH9ibX2Pwt0f8HSgjIRkcCpRSAiEjgFgYhI4BQEIiKBUxCIiAROQSAiEjgFgYhI4BQEIiKBUxCIiAROQSAiErj/D2S4TzcHF2sYAAAAAElFTkSuQmCC\n", "text/plain": [ - "1360" + "
" ] }, - "execution_count": 21, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "len(set(x_test_counts.index).difference(x_train_counts.index))" + "venn2(subsets=[\n", + " len(set(x_train_counts.index).difference(chemprot_counts.index)),\n", + " len(set(chemprot_counts.index).difference(x_train_counts.index)),\n", + " len(set(chemprot_counts.index).intersection(x_train_counts.index))],\n", + " set_labels=('Train', 'CHEMPROT') \n", + ");\n", + "plt.tight_layout()\n", + "plt.savefig('venn_train_chemprot.jpg', dpi=300)" ] }, { "cell_type": "code", "execution_count": 22, - "id": "colored-casino", + "id": "difficult-traffic", "metadata": {}, "outputs": [ { @@ -661,27 +687,27 @@ " \n", " \n", " \n", - " cytarabine\n", + " metastic\n", " 4\n", " 0.000023\n", " \n", " \n", - " downstairs\n", + " myasthenia\n", " 4\n", " 0.000023\n", " \n", " \n", - " myasthenia\n", + " downstairs\n", " 4\n", " 0.000023\n", " \n", " \n", - " metastic\n", + " cytarabine\n", " 4\n", " 0.000023\n", " \n", " \n", - " ne\n", + " assault\n", " 3\n", " 0.000017\n", " \n", @@ -691,27 +717,27 @@ " ...\n", " \n", " \n", - " complied\n", + " msec\n", " 1\n", " 0.000006\n", " \n", " \n", - " t2bn1\n", + " hugely\n", " 1\n", " 0.000006\n", " \n", " \n", - " wheter\n", + " gencitabine\n", " 1\n", " 0.000006\n", " \n", " \n", - " pneumotised\n", + " upto\n", " 1\n", " 0.000006\n", " \n", " \n", - " nucleocytoplasmic\n", + " 182ng\n", " 1\n", " 0.000006\n", " \n", @@ -721,18 +747,18 @@ "" ], "text/plain": [ - " counts freq\n", - "cytarabine 4 0.000023\n", - "downstairs 4 0.000023\n", - "myasthenia 4 0.000023\n", - "metastic 4 0.000023\n", - "ne 3 0.000017\n", - "... ... ...\n", - "complied 1 0.000006\n", - "t2bn1 1 0.000006\n", - "wheter 1 0.000006\n", - "pneumotised 1 0.000006\n", - "nucleocytoplasmic 1 0.000006\n", + " counts freq\n", + "metastic 4 0.000023\n", + "myasthenia 4 0.000023\n", + "downstairs 4 0.000023\n", + "cytarabine 4 0.000023\n", + "assault 3 0.000017\n", + "... ... ...\n", + "msec 1 0.000006\n", + "hugely 1 0.000006\n", + "gencitabine 1 0.000006\n", + "upto 1 0.000006\n", + "182ng 1 0.000006\n", "\n", "[1360 rows x 2 columns]" ] @@ -749,7 +775,7 @@ { "cell_type": "code", "execution_count": null, - "id": "standard-pixel", + "id": "artistic-buying", "metadata": {}, "outputs": [], "source": [] diff --git a/interpretable_pdf.py b/interpretable_pdf.py new file mode 100644 index 0000000..2e062b4 --- /dev/null +++ b/interpretable_pdf.py @@ -0,0 +1,350 @@ +from fpdf import FPDF +import numpy as np +from nltk.stem import WordNetLemmatizer +import re + +from explainability import (get_ti_feature_contributions_for_instance_i, + run_tree_interpreter) + + +class InterpretablePDF: + """ + Class to produce formatted vignettes that indicate which text elements are contributing to + any given classification. + + Currently this class only works with the CAP Prostate Cancer data, because the formatting + corresponds to the unique structure of this text data. However, it could easily be adapted to + work with other textual data (such as LeDeR). + """ + + def __init__(self, + classifier, + x_data, + y_data, + feature_columns, + base_font_size=12, line_height=8, + header_col_width=100, legend_offset=47.5, + legend_offset_2=63, + top_n_features=None, + contributions=None): + + self.font_size = base_font_size + self.line_height = line_height + self.header_col_width = header_col_width + self.legend_offset = legend_offset + self.legend_offset_2 = legend_offset_2 + + self.top_n = top_n_features + + self.pdf = None + self.original_data = None + + self.feature_columns = feature_columns + self.X = x_data + self.y = y_data + self.clf = classifier + _, _, self.contributions = (run_tree_interpreter(self.clf, + self.X) + if contributions is None else ( + None, None, contributions)) + + self.stemmer = WordNetLemmatizer() + + self.section_headers = ['Clinical features at diagnosis', + 'Treatments received', + 'Prostate cancer progression', + 'Progression of co-morbidities', + 'End of life'] + + self.vignette_column_names = [ + 'Gleason Score at diagnosis (with dates)', + 'Clinical stage (TNM)', + 'Pathological stage (TNM)', + 'Co-morbidities with dates of diagnosis', + 'Other primary cancers with dates of diagnosis', + 'PSA level at diagnosis with dates', + 'Radiological evidence of local spread at diagnosis', + 'Radiological evidence of metastases at diagnosis', + 'Initial treatments (dates)', + 'Hormone therapy (start date)', + 'Maximum androgen blockade (start date)', + 'Orchidectomy (date)', + 'Chemotherapy (start date)', + 'Treatment for complications of treating prostate cancer with dates (if available)', + 'Serial PSA levels (dates)', + 'Serum testosterone', + 'Radiological evidence of metastases', + 'Other indications or complications of disease progression', + 'Date of recurrence following radical surgery or radiotherapy', + 'Palliative care referrals and treatments', + 'Treatment/ admission for co-morbidity with dates (if available)', + 'Symptoms in last 3-6 months (i.e. bone pain, weight loss, cachexia,\ + loss of appetite, obstructive uraemia)', + 'Last consultation: speciality & date', + 'Was a DS1500 report issued?', + 'Post mortem findings'] + + def create_pdf(self, case_id, original_data, filename): + + self.original_data = original_data + + self.pdf = FPDF() + self.pdf.add_page() + self.pdf.set_font('Arial', 'B', self.font_size) + + self.pdf.cell(w=0, + h=self.line_height, + txt='Interpretable Vignette Classification for Cause of Death Review', + border=0, + ln=0, + align='C', fill=False, link='') + + self.pdf.set_font('') + self.pdf.ln() + self.pdf.ln() + + y = self.pdf.get_y() + self.pdf.multi_cell(w=self.header_col_width, + h=self.line_height, + txt='Study ID number : %s\nDate of death : %s\nDate of diagnosis : %s' % ( + original_data['cp1random_id_5_char'], + original_data['cnr19datedeath'].date(), + original_data['cnr_date_pca_diag'].date()), + border=0, + align='L', + fill=False) + self.pdf.y = y + self.pdf.x = self.header_col_width + self.pdf.multi_cell(w=self.header_col_width, + h=self.line_height, + txt='Predicted death code: %d (%.2f)\nActual death code: %d\nCOD route: %d' % ( + self.clf.best_estimator_.predict(self.X)[case_id], + self.clf.best_estimator_.predict_proba(self.X)[case_id][1], + original_data.pca_death_code, + original_data.cp1do_cod_route), + border=0, + align='R', fill=False) + + self.pdf.set_line_width(0.5) + self.pdf.line(10, self.pdf.get_y(), 210 - 10, self.pdf.get_y()) + self.pdf.set_line_width(0.2) + + self.write_legend(case_id) + + for ci, col in enumerate(self.feature_columns): + + self.pdf.set_text_color(0, 0, 0) + self.pdf.set_font_size(self.font_size) + + if ci in [0, 8, 14, 20, 21]: + self.print_section_header(0) + + self.pdf.line(10, self.pdf.get_y(), 210 - 10, self.pdf.get_y()) + self.pdf.write(self.line_height, self.vignette_column_names[ci] + ': ') + + if 'palliative' not in col: + text = str(original_data[col]) + self.print_paragraph(text, case_id) + + self.pdf.output(filename) + + def get_font_size_and_color(self, + contribution, + min_contribution, + max_contribution, + shrink=0.5): + c = (255, 160, 0) if contribution < 0 else (0, 0, 255) + s = ( + (self.font_size * shrink) + 1.5 * self.font_size + * (np.absolute(contribution) - min_contribution) + / float(max_contribution - min_contribution) + ) + + return s, c + + def legend_entry(self, fimps, fimp, align): + + size, color = self.get_font_size_and_color(fimp.contribution, + fimps.magnitude.min(), + fimps.magnitude.max()) + self.pdf.set_text_color(*color) + self.pdf.set_font_size(size) + self.pdf.cell(w=self.legend_offset, + h=self.line_height, + txt=fimp.feature, + border=0, ln=0, + align=align, fill=False, link='') + + def legend_label(self, text, align): + + self.pdf.cell(w=self.legend_offset_2, + h=self.line_height, + txt=text, border=0, ln=0, + align=align, fill=False, link='') + + def write_legend(self, case_id): + + self.pdf.cell(w=0, + h=self.line_height, + txt='Feature contribution legend', + border=0, ln=0, + align='C', fill=False, link='') + self.pdf.ln() + + fimps = get_ti_feature_contributions_for_instance_i(case_id, + self.contributions, + self.clf).sort_values(by='magnitude', + ascending=False) + fimps = fimps.head(self.top_n) if self.top_n is not None else fimps + + fimp = fimps.loc[fimps.contribution.idxmin()] + self.legend_entry(fimps, fimp, align='L') + + fimp = fimps.loc[fimps.contribution < 0] + fimp = fimp.loc[fimp.contribution.idxmax()] + self.legend_entry(fimps, fimp, align='R') + + fimp = fimps.loc[fimps.contribution > 0] + fimp = fimp.loc[fimp.contribution.idxmin()] + self.legend_entry(fimps, fimp, align='L') + + fimp = fimps.loc[fimps.contribution.idxmax()] + self.legend_entry(fimps, fimp, align='R') + + self.pdf.ln() + self.pdf.set_text_color(0, 0, 0) + self.pdf.set_font('Arial', '', self.font_size * .6) + + self.legend_label('Largest negative contribution', 'L') + self.legend_label('Smallest contributions', 'C') + self.legend_label('Largest positive contribution', 'R') + + self.pdf.set_font('Arial', '', self.font_size) + + def print_section_header(self, section): + + self.pdf.ln() + self.pdf.set_line_width(0.5) + self.pdf.line(10, self.pdf.get_y(), 210 - 10, self.pdf.get_y()) + self.pdf.set_line_width(0.2) + self.pdf.set_font('Arial', 'B', self.font_size) + self.pdf.write(self.line_height, self.section_headers[section] + '\n') + self.pdf.set_font('') + + # REFACTOR! + def print_paragraph(self, text, i, + base_color=(128, 128, 128)): + + fimps = get_ti_feature_contributions_for_instance_i(i, + self.contributions, + self.clf) + fimps.sort_values(by='magnitude', inplace=True, ascending=False) + fimps = fimps.head(self.top_n) + + old_word = '' + old_tr_word = '' + old_bigram = '' + old_bigram_contribution = None + + old_color = base_color + old_size = self.font_size + + words = text.split(' ') + words.append(' .') + + for word in words: + tr_word = self.transform_text(word) + + feat_tr = old_tr_word + ' ' + tr_word + contribution_bi = (fimps.loc[fimps.feature == feat_tr] + .iloc[0].contribution + if feat_tr in list(fimps.feature) + else None) + magnitude_bi = np.absolute(contribution_bi) if contribution_bi is not None else 0 + + feat_tr = old_tr_word + contribution_uni = (fimps.loc[fimps.feature == feat_tr] + .iloc[0].contribution + if feat_tr in list(fimps.feature) + else None) + magnitude_uni = np.absolute(contribution_uni) if contribution_uni is not None else 0 + + if contribution_bi and magnitude_bi > magnitude_uni: + # print('bigram: ', old_tr_word) + feat_tr = old_tr_word + ' ' + tr_word + feat = old_word + ' ' + word + + contribution = fimps.loc[fimps.feature == feat_tr].iloc[0].contribution + size, color = self.get_font_size_and_color(contribution, + fimps.magnitude.min(), + fimps.magnitude.max()) + self.pdf.set_text_color(*color) + self.pdf.set_font_size(size) + feat = feat.encode('latin-1', 'replace').decode('latin-1') + self.pdf.write(self.line_height, feat + ' ') + + old_word = '' + old_tr_word = '' + old_color = base_color + old_size = self.font_size + + elif contribution_uni and magnitude_uni > magnitude_bi: + # print('unigram: ', old_tr_word) + feat_tr = old_tr_word + feat = old_word + + contribution = fimps.loc[fimps.feature == feat_tr].iloc[0].contribution + size, color = self.get_font_size_and_color(contribution, + fimps.magnitude.min(), + fimps.magnitude.max()) + self.pdf.set_text_color(*color) + self.pdf.set_font_size(size) + feat = feat.encode('latin-1', 'replace').decode('latin-1') + self.pdf.write(self.line_height, feat + ' ') + + old_word = word + old_tr_word = tr_word + old_color = base_color + old_size = self.font_size + + else: + self.pdf.set_text_color(*old_color) + self.pdf.set_font_size(old_size) + w = old_word.encode('latin-1', 'replace').decode('latin-1') + self.pdf.write(self.line_height, w + ' ') + + old_word = word + old_tr_word = tr_word + old_color = base_color + old_size = self.font_size + + self.pdf.ln() + + def transform_text(self, text): + + # Remove all the special characters + document = re.sub(r'\W', ' ', str(text)) + + # remove all single characters + document = re.sub(r'\s+[a-zA-Z]\s+', ' ', document) + # document = re.sub(r'\s+[a-zA-Z]\s+', ' ', str(X[sen])) + + # Remove single characters from the start + document = re.sub(r'\^[a-zA-Z]\s+', ' ', document) + + # Substituting multiple spaces with single space + document = re.sub(r'\s+', ' ', document, flags=re.I) + + # Removing prefixed 'b' + document = re.sub(r'^b\s+', '', document) + + # Converting to Lowercase + document = document.lower() + + # Lemmatization + document = document.split() + + document = [self.stemmer.lemmatize(word) for word in document] + document = ' '.join(document) + + return document diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f8b5417 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,102 @@ +async-generator==1.10 +attrs @ file:///tmp/build/80754af9/attrs_1604765588209/work +Babel @ file:///tmp/build/80754af9/babel_1607110387436/work +backcall==0.2.0 +bleach @ file:///tmp/build/80754af9/bleach_1600439572647/work +brotlipy==0.7.0 +certifi==2020.6.20 +cffi @ file:///tmp/build/80754af9/cffi_1606255081583/work +chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work +click @ file:///home/linux1/recipes/ci/click_1610990599742/work +cloudpickle @ file:///home/conda/feedstock_root/build_artifacts/cloudpickle_1598400192773/work +cryptography @ file:///tmp/build/80754af9/cryptography_1607635341180/work +cycler==0.10.0 +cytoolz==0.11.0 +dask @ file:///home/conda/feedstock_root/build_artifacts/dask-core_1611349541186/work +decorator==4.4.2 +defusedxml==0.6.0 +entrypoints==0.3 +et-xmlfile==1.0.1 +FAT-Forensics==0.1.0 +fpdf==1.7.2 +idna @ file:///home/linux1/recipes/ci/idna_1610986105248/work +imageio @ file:///home/conda/feedstock_root/build_artifacts/imageio_1594044661732/work +importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1602276842396/work +ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl +ipython @ file:///tmp/build/80754af9/ipython_1604101197014/work +ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work +jdcal==1.4.1 +jedi @ file:///tmp/build/80754af9/jedi_1608920709268/work +Jinja2 @ file:///home/linux1/recipes/ci/jinja2_1610990516718/work +joblib @ file:///tmp/build/80754af9/joblib_1607970656719/work +json5==0.9.5 +jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work +jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work +jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1606148996965/work +jupyterlab==2.2.6 +jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work +jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1594164409481/work +kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/kiwisolver_1604322295622/work +llvmlite==0.34.0 +MarkupSafe==1.1.1 +matplotlib @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-base_1594091694890/work +matplotlib-venn==0.11.6 +mistune==0.8.4 +mkl-fft==1.2.0 +mkl-random==1.1.1 +mkl-service==2.3.0 +nbclient @ file:///tmp/build/80754af9/nbclient_1602783176460/work +nbconvert==5.5.0 +nbformat @ file:///tmp/build/80754af9/nbformat_1610738111109/work +nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1606153767164/work +networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1598210780226/work +nltk @ file:///tmp/build/80754af9/nltk_1592496090529/work +notebook @ file:///tmp/build/80754af9/notebook_1595951624445/work +numba @ file:///home/conda/feedstock_root/build_artifacts/numba_1599084802945/work +numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1603570489231/work +olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1602866521163/work +openpyxl @ file:///tmp/build/80754af9/openpyxl_1610651698508/work +packaging @ file:///tmp/build/80754af9/packaging_1607971725249/work +pandas==1.2.1 +pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work +parso==0.7.0 +pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work +pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work +Pillow==6.2.1 +prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1606344362066/work +prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work +ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl +pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work +Pygments @ file:///tmp/build/80754af9/pygments_1610565767015/work +pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work +pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work +pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work +PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work +python-dateutil==2.8.1 +pytz @ file:///tmp/build/80754af9/pytz_1608922264688/work +PyWavelets @ file:///home/conda/feedstock_root/build_artifacts/pywavelets_1602504439440/work +PyYAML==5.3.1 +pyzmq==20.0.0 +regex @ file:///tmp/build/80754af9/regex_1606772724491/work +requests @ file:///tmp/build/80754af9/requests_1608241421344/work +scikit-image==0.16.2 +scikit-learn==0.22.1 +scipy @ file:///tmp/build/80754af9/scipy_1597686649129/work +Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work +shap @ file:///home/conda/feedstock_root/build_artifacts/shap_1608143397482/work +six @ file:///tmp/build/80754af9/six_1605205327372/work +slicer @ file:///home/conda/feedstock_root/build_artifacts/slicer_1608146800664/work +terminado==0.9.2 +testpath==0.4.4 +threadpoolctl @ file:///tmp/tmp79xdzxkt/threadpoolctl-2.1.0-py3-none-any.whl +toolz @ file:///home/conda/feedstock_root/build_artifacts/toolz_1600973991856/work +tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work +tqdm @ file:///tmp/build/80754af9/tqdm_1609788246169/work +traitlets @ file:///tmp/build/80754af9/traitlets_1602787416690/work +treeinterpreter==0.2.3 +urllib3 @ file:///tmp/build/80754af9/urllib3_1606938623459/work +wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work +webencodings==0.5.1 +wordcloud==1.8.1 +xlrd==1.2.0 +zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work diff --git a/rf_interpretability_outputs.ipynb b/rf_interpretability_outputs.ipynb index 0d42c35..0ddae33 100644 --- a/rf_interpretability_outputs.ipynb +++ b/rf_interpretability_outputs.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "adjustable-situation", + "id": "respiratory-fancy", "metadata": {}, "source": [ "## Interpretability outputs\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "arabic-grove", + "id": "dominant-fabric", "metadata": {}, "outputs": [], "source": [ @@ -32,20 +32,10 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "square-sending", + "execution_count": null, + "id": "electronic-bernard", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "21-Jan-30 09:20:31 fatf.utils.array.tools INFO Using numpy's numpy.lib.recfunctions.structured_to_unstructured as fatf.utils.array.tools.structured_to_unstructured and fatf.utils.array.tools.structured_to_unstructured_row.\n", - "21-Jan-30 09:20:31 fatf INFO Seeding RNGs using the input parameter.\n", - "21-Jan-30 09:20:31 fatf INFO Seeding RNGs with 42.\n" - ] - } - ], + "outputs": [], "source": [ "import shap\n", "import numpy as np\n", @@ -95,7 +85,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "pretty-conservation", + "id": "bottom-heart", "metadata": {}, "outputs": [ { @@ -125,7 +115,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "quick-management", + "id": "capable-immune", "metadata": {}, "outputs": [], "source": [ @@ -144,7 +134,7 @@ { "cell_type": "code", "execution_count": 5, - "id": "portable-chest", + "id": "disturbed-camcorder", "metadata": {}, "outputs": [], "source": [ @@ -154,7 +144,7 @@ { "cell_type": "code", "execution_count": 6, - "id": "liberal-organization", + "id": "korean-trouble", "metadata": {}, "outputs": [ { @@ -327,7 +317,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "upset-mechanics", + "id": "spare-pottery", "metadata": {}, "outputs": [ { @@ -348,7 +338,7 @@ { "cell_type": "code", "execution_count": 8, - "id": "current-childhood", + "id": "olive-discretion", "metadata": {}, "outputs": [], "source": [ @@ -359,7 +349,7 @@ { "cell_type": "code", "execution_count": 9, - "id": "synthetic-completion", + "id": "available-center", "metadata": {}, "outputs": [ { @@ -380,7 +370,7 @@ { "cell_type": "code", "execution_count": 10, - "id": "remarkable-tennessee", + "id": "stuffed-ancient", "metadata": {}, "outputs": [], "source": [ @@ -390,7 +380,7 @@ { "cell_type": "code", "execution_count": 11, - "id": "powered-seven", + "id": "infrared-baptist", "metadata": {}, "outputs": [ { @@ -410,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "imported-shock", + "id": "vanilla-broad", "metadata": {}, "source": [ "#### Using LIME:" @@ -419,7 +409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "advanced-discretion", + "id": "opponent-publication", "metadata": { "scrolled": true }, @@ -431,7 +421,7 @@ { "cell_type": "code", "execution_count": null, - "id": "rational-mailman", + "id": "synthetic-pierre", "metadata": {}, "outputs": [], "source": [ @@ -442,7 +432,7 @@ { "cell_type": "code", "execution_count": 14, - "id": "suburban-great", + "id": "exempt-monroe", "metadata": {}, "outputs": [], "source": [ @@ -452,7 +442,7 @@ }, { "cell_type": "markdown", - "id": "virtual-newspaper", + "id": "handled-intellectual", "metadata": {}, "source": [ "#### Now we use SHAP to evaluate feature importances:" @@ -461,7 +451,7 @@ { "cell_type": "code", "execution_count": 15, - "id": "golden-northwest", + "id": "nutritional-madness", "metadata": {}, "outputs": [], "source": [ @@ -471,7 +461,7 @@ { "cell_type": "code", "execution_count": 16, - "id": "underlying-qatar", + "id": "immune-birthday", "metadata": {}, "outputs": [ { @@ -491,7 +481,7 @@ }, { "cell_type": "markdown", - "id": "rough-nitrogen", + "id": "piano-bishop", "metadata": {}, "source": [ "#### We now compare the feature rankings obtained by the different interpretability metrics:" @@ -500,7 +490,7 @@ { "cell_type": "code", "execution_count": 35, - "id": "mounted-salmon", + "id": "confident-specification", "metadata": {}, "outputs": [ { @@ -522,7 +512,7 @@ { "cell_type": "code", "execution_count": 18, - "id": "dietary-feeling", + "id": "outer-barrel", "metadata": {}, "outputs": [ { @@ -544,7 +534,7 @@ { "cell_type": "code", "execution_count": 19, - "id": "technical-biography", + "id": "fantastic-express", "metadata": {}, "outputs": [ { @@ -566,7 +556,7 @@ { "cell_type": "code", "execution_count": 20, - "id": "grateful-support", + "id": "prescribed-amount", "metadata": {}, "outputs": [ { @@ -588,7 +578,7 @@ { "cell_type": "code", "execution_count": 21, - "id": "prostate-blocking", + "id": "metallic-chapter", "metadata": {}, "outputs": [ { @@ -610,7 +600,7 @@ { "cell_type": "code", "execution_count": 34, - "id": "hired-machinery", + "id": "subject-research", "metadata": {}, "outputs": [ { @@ -632,7 +622,7 @@ { "cell_type": "code", "execution_count": 45, - "id": "excited-raising", + "id": "aerial-navigation", "metadata": {}, "outputs": [ { @@ -823,7 +813,7 @@ }, { "cell_type": "markdown", - "id": "useful-carrier", + "id": "binary-times", "metadata": {}, "source": [ "#### Shap plots:" @@ -832,7 +822,7 @@ { "cell_type": "code", "execution_count": 26, - "id": "asian-strength", + "id": "noted-consortium", "metadata": {}, "outputs": [], "source": [ @@ -842,7 +832,7 @@ { "cell_type": "code", "execution_count": 27, - "id": "compound-japan", + "id": "private-coordinator", "metadata": {}, "outputs": [ { @@ -863,7 +853,7 @@ { "cell_type": "code", "execution_count": 28, - "id": "formed-shaft", + "id": "higher-cherry", "metadata": {}, "outputs": [ { @@ -883,7 +873,7 @@ }, { "cell_type": "markdown", - "id": "musical-darwin", + "id": "artificial-powell", "metadata": {}, "source": [ "#### Now we produce output for the main text of the publication:" @@ -892,7 +882,7 @@ { "cell_type": "code", "execution_count": 29, - "id": "encouraging-blanket", + "id": "ethical-prompt", "metadata": {}, "outputs": [ { @@ -934,7 +924,7 @@ { "cell_type": "code", "execution_count": 30, - "id": "simple-knowing", + "id": "registered-valuable", "metadata": {}, "outputs": [], "source": [