diff --git a/docs/notebooks/GraphPipelineAdvanced.html b/docs/notebooks/GraphPipelineAdvanced.html new file mode 100644 index 0000000..d6f40c5 --- /dev/null +++ b/docs/notebooks/GraphPipelineAdvanced.html @@ -0,0 +1,15997 @@ + + + + +GraphPipelineAdvanced + + + + + + + + + + + + + + + + + + + + + + +
+
+ +
+
+
+

GraphPipeline advanced functionnalities

+
+
+
+
+
+
+

In this notebook we will explain some of the GraphPipeline more advanced functionnalities. +Especially how to play with the GraphPipeline to retrieve features at some nodes or create a new graphpipeline with a subset of nodes.

+ +
+
+
+
+
+
+

As always we will play with the titanic dataset

+ +
+
+
+
+
+
In [1]:
+
+
+
from aikit.datasets import load_titanic
+
+Xtrain, ytrain, *_ = load_titanic()
+
+non_text_cols = [c for c in Xtrain.columns if c not in ("ticket","name")]
+text_cols = ["ticket", "name"]
+
+Xtrain.head(10)
+
+ +
+
+
+ +
+
+ + +
+ +
Out[1]:
+ + + +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
pclassnamesexagesibspparchticketfarecabinembarkedboatbodyhome_dest
01McCarthy, Mr. Timothy Jmale54.0001746351.8625E46SNaN175.0Dorchester, MA
11Fortune, Mr. Markmale64.01419950263.0000C23 C25 C27SNaNNaNWinnipeg, MB
21Sagesser, Mlle. Emmafemale24.000PC 1747769.3000B35C9NaNNaN
33Panula, Master. Urho Abrahammale2.041310129539.6875NaNSNaNNaNNaN
41Maioni, Miss. Robertafemale16.00011015286.5000B79S8NaNNaN
53Waelens, Mr. Achillemale22.0003457679.0000NaNSNaNNaNAntwerp, Belgium / Stanton, OH
63Reed, Mr. James GeorgemaleNaN003623167.2500NaNSNaNNaNNaN
71Swift, Mrs. Frederick Joel (Margaret Welles Ba...female48.0001746625.9292D17S8NaNBrooklyn, NY
81Smith, Mrs. Lucien Philip (Mary Eloise Hughes)female18.0101369560.0000C31S6NaNHuntington, WV
91Rowe, Mr. Alfred Gmale33.00011379026.5500NaNSNaN109.0London
+
+
+ +
+ +
+
+ +
+
+
+
In [2]:
+
+
+
from aikit.pipeline import GraphPipeline
+from aikit.transformers import ColumnsSelector, NumImputer, NumericalEncoder, CountVectorizerWrapper
+from sklearn.ensemble import RandomForestClassifier
+
+
+gpipeline = GraphPipeline(models = {
+    "sel":ColumnsSelector(columns_to_use=non_text_cols),
+    "enc":NumericalEncoder(columns_to_use="object"),
+    "imp":NumImputer(),
+    "vect":CountVectorizerWrapper(analyzer="word",columns_to_use=text_cols),
+    "rf":RandomForestClassifier(n_estimators=100, random_state=123)
+                       },
+              edges = [("sel","enc","imp","rf"),("vect","rf")])
+
+gpipeline.graphviz
+
+ +
+
+
+ +
+
+ + +
+ +
+ + +
+
C:\Users\arwen\Anaconda3\lib\site-packages\statsmodels\tools\_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
+  import pandas.util.testing as tm
+C:\Users\arwen\Anaconda3\lib\site-packages\sklearn\utils\deprecation.py:143: FutureWarning: The sklearn.metrics.scorer module is  deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.metrics. Anything that cannot be imported from sklearn.metrics is now part of the private API.
+  warnings.warn(message, FutureWarning)
+
+
+
+ +
+ +
Out[2]:
+ + + +
+ + + + + + +%3 + + +enc + +enc + + +imp + +imp + + +enc->imp + + + + +rf + +rf + + +imp->rf + + + + +sel + +sel + + +sel->enc + + + + +vect + +vect + + +vect->rf + + + + + + +
+ +
+ +
+
+ +
+
+
+
In [3]:
+
+
+
gpipeline.fit(Xtrain, ytrain)
+
+ +
+
+
+ +
+
+ + +
+ +
Out[3]:
+ + + + +
+
GraphPipeline(edges=[('sel', 'enc', 'imp', 'rf'), ('vect', 'rf')],
+              models={'enc': NumericalEncoder(columns_to_use='object'),
+                      'imp': NumImputer(),
+                      'rf': RandomForestClassifier(random_state=123),
+                      'sel': ColumnsSelector(columns_to_use=['pclass', 'sex',
+                                                             'age', 'sibsp',
+                                                             'parch', 'fare',
+                                                             'cabin',
+                                                             'embarked', 'boat',
+                                                             'body',
+                                                             'home_dest']),
+                      'vect': CountVectorizerWrapper(columns_to_use=['ticket',
+                                                                     'name'])})
+
+ +
+ +
+
+ +
+
+
+
+

Features Importances

Let's see how to retrieve the features importance. +With graphpipeline you can easily retrieve the name of the features at each node

+ +
+
+
+
+
+
+

To retrieve features entering a given node, just use 'get_inut_features_at_node'

+ +
+
+
+
+
+
In [4]:
+
+
+
random_forest_features = gpipeline.get_input_features_at_node("rf")
+random_forest_features
+
+ +
+
+
+ +
+
+ + +
+ +
Out[4]:
+ + + + +
+
['pclass',
+ 'age',
+ 'sibsp',
+ 'parch',
+ 'fare',
+ 'body',
+ 'sex__male',
+ 'sex__female',
+ 'cabin____null__',
+ 'cabin____default__',
+ 'embarked__S',
+ 'embarked__C',
+ 'embarked__Q',
+ 'boat____null__',
+ 'boat__13',
+ 'boat__C',
+ 'boat__15',
+ 'boat__14',
+ 'boat__4',
+ 'boat__10',
+ 'boat__5',
+ 'boat__6',
+ 'boat__11',
+ 'boat__8',
+ 'boat__12',
+ 'boat__16',
+ 'boat__7',
+ 'boat__3',
+ 'boat__D',
+ 'boat__9',
+ 'boat____default__',
+ 'home_dest____null__',
+ 'home_dest__New York, NY',
+ 'home_dest__London',
+ 'home_dest____default__',
+ 'age_isnull',
+ 'fare_isnull',
+ 'body_isnull',
+ 'ticket__BAG__10482',
+ 'ticket__BAG__110152',
+ 'ticket__BAG__110413',
+ 'ticket__BAG__110465',
+ 'ticket__BAG__110469',
+ 'ticket__BAG__110489',
+ 'ticket__BAG__110564',
+ 'ticket__BAG__110813',
+ 'ticket__BAG__111163',
+ 'ticket__BAG__111240',
+ 'ticket__BAG__111320',
+ 'ticket__BAG__111361',
+ 'ticket__BAG__111369',
+ 'ticket__BAG__111426',
+ 'ticket__BAG__111427',
+ 'ticket__BAG__112050',
+ 'ticket__BAG__112052',
+ 'ticket__BAG__112053',
+ 'ticket__BAG__112058',
+ 'ticket__BAG__11206',
+ 'ticket__BAG__112277',
+ 'ticket__BAG__112377',
+ 'ticket__BAG__112378',
+ 'ticket__BAG__112379',
+ 'ticket__BAG__112901',
+ 'ticket__BAG__113028',
+ 'ticket__BAG__113038',
+ 'ticket__BAG__113043',
+ 'ticket__BAG__113044',
+ 'ticket__BAG__113050',
+ 'ticket__BAG__113051',
+ 'ticket__BAG__113054',
+ 'ticket__BAG__113059',
+ 'ticket__BAG__113501',
+ 'ticket__BAG__113503',
+ 'ticket__BAG__113505',
+ 'ticket__BAG__113509',
+ 'ticket__BAG__113510',
+ 'ticket__BAG__113572',
+ 'ticket__BAG__113760',
+ 'ticket__BAG__113767',
+ 'ticket__BAG__113773',
+ 'ticket__BAG__113776',
+ 'ticket__BAG__113778',
+ 'ticket__BAG__113780',
+ 'ticket__BAG__113781',
+ 'ticket__BAG__113783',
+ 'ticket__BAG__113784',
+ 'ticket__BAG__113786',
+ 'ticket__BAG__113788',
+ 'ticket__BAG__113789',
+ 'ticket__BAG__113790',
+ 'ticket__BAG__113791',
+ 'ticket__BAG__113794',
+ 'ticket__BAG__113795',
+ 'ticket__BAG__113796',
+ 'ticket__BAG__113798',
+ 'ticket__BAG__113800',
+ 'ticket__BAG__113801',
+ 'ticket__BAG__113803',
+ 'ticket__BAG__113804',
+ 'ticket__BAG__113806',
+ 'ticket__BAG__113807',
+ 'ticket__BAG__1166',
+ 'ticket__BAG__11668',
+ 'ticket__BAG__11752',
+ 'ticket__BAG__11753',
+ 'ticket__BAG__11755',
+ 'ticket__BAG__11767',
+ 'ticket__BAG__11769',
+ 'ticket__BAG__11778',
+ 'ticket__BAG__11813',
+ 'ticket__BAG__11967',
+ 'ticket__BAG__1222',
+ 'ticket__BAG__12460',
+ 'ticket__BAG__12749',
+ 'ticket__BAG__12750',
+ 'ticket__BAG__13032',
+ 'ticket__BAG__13049',
+ 'ticket__BAG__13050',
+ 'ticket__BAG__13213',
+ 'ticket__BAG__13236',
+ 'ticket__BAG__13502',
+ 'ticket__BAG__13507',
+ 'ticket__BAG__13508',
+ 'ticket__BAG__13509',
+ 'ticket__BAG__13528',
+ 'ticket__BAG__13529',
+ 'ticket__BAG__13531',
+ 'ticket__BAG__13534',
+ 'ticket__BAG__13540',
+ 'ticket__BAG__13567',
+ 'ticket__BAG__13568',
+ 'ticket__BAG__13695',
+ 'ticket__BAG__14208',
+ 'ticket__BAG__14258',
+ 'ticket__BAG__14260',
+ 'ticket__BAG__14263',
+ 'ticket__BAG__14266',
+ 'ticket__BAG__14312',
+ 'ticket__BAG__14313',
+ 'ticket__BAG__1478',
+ 'ticket__BAG__14879',
+ 'ticket__BAG__14888',
+ 'ticket__BAG__14973',
+ 'ticket__BAG__15185',
+ 'ticket__BAG__1585',
+ 'ticket__BAG__1588',
+ 'ticket__BAG__1601',
+ 'ticket__BAG__16966',
+ 'ticket__BAG__17248',
+ 'ticket__BAG__17368',
+ 'ticket__BAG__17369',
+ 'ticket__BAG__17421',
+ 'ticket__BAG__17453',
+ 'ticket__BAG__17463',
+ 'ticket__BAG__17464',
+ 'ticket__BAG__17465',
+ 'ticket__BAG__17466',
+ 'ticket__BAG__17473',
+ 'ticket__BAG__17474',
+ 'ticket__BAG__17475',
+ 'ticket__BAG__17476',
+ 'ticket__BAG__17477',
+ 'ticket__BAG__1748',
+ 'ticket__BAG__17483',
+ 'ticket__BAG__17485',
+ 'ticket__BAG__17531',
+ 'ticket__BAG__17558',
+ 'ticket__BAG__17569',
+ 'ticket__BAG__17572',
+ 'ticket__BAG__17580',
+ 'ticket__BAG__17582',
+ 'ticket__BAG__17585',
+ 'ticket__BAG__17590',
+ 'ticket__BAG__17591',
+ 'ticket__BAG__17592',
+ 'ticket__BAG__17593',
+ 'ticket__BAG__17594',
+ 'ticket__BAG__17595',
+ 'ticket__BAG__17596',
+ 'ticket__BAG__17598',
+ 'ticket__BAG__17599',
+ 'ticket__BAG__17601',
+ 'ticket__BAG__17603',
+ 'ticket__BAG__17604',
+ 'ticket__BAG__17605',
+ 'ticket__BAG__17606',
+ 'ticket__BAG__17608',
+ 'ticket__BAG__17609',
+ 'ticket__BAG__17610',
+ 'ticket__BAG__17611',
+ 'ticket__BAG__17612',
+ 'ticket__BAG__17613',
+ 'ticket__BAG__17754',
+ 'ticket__BAG__17755',
+ 'ticket__BAG__17756',
+ 'ticket__BAG__17757',
+ 'ticket__BAG__17758',
+ 'ticket__BAG__17759',
+ 'ticket__BAG__17760',
+ 'ticket__BAG__17761',
+ 'ticket__BAG__17764',
+ 'ticket__BAG__17765',
+ 'ticket__BAG__17770',
+ 'ticket__BAG__18509',
+ 'ticket__BAG__18723',
+ 'ticket__BAG__19877',
+ 'ticket__BAG__19924',
+ 'ticket__BAG__19928',
+ 'ticket__BAG__19943',
+ 'ticket__BAG__19947',
+ 'ticket__BAG__19950',
+ 'ticket__BAG__19952',
+ 'ticket__BAG__19972',
+ 'ticket__BAG__19988',
+ 'ticket__BAG__19996',
+ 'ticket__BAG__2003',
+ 'ticket__BAG__2079',
+ 'ticket__BAG__211535',
+ 'ticket__BAG__211536',
+ 'ticket__BAG__21172',
+ 'ticket__BAG__21173',
+ 'ticket__BAG__21174',
+ 'ticket__BAG__21175',
+ 'ticket__BAG__21228',
+ 'ticket__BAG__2123',
+ 'ticket__BAG__2131',
+ 'ticket__BAG__2133',
+ 'ticket__BAG__21332',
+ 'ticket__BAG__2144',
+ 'ticket__BAG__21440',
+ 'ticket__BAG__2148',
+ 'ticket__BAG__2149',
+ 'ticket__BAG__2152',
+ 'ticket__BAG__2166',
+ 'ticket__BAG__2167',
+ 'ticket__BAG__218629',
+ 'ticket__BAG__219533',
+ 'ticket__BAG__220367',
+ 'ticket__BAG__220844',
+ 'ticket__BAG__220845',
+ 'ticket__BAG__2223',
+ 'ticket__BAG__223596',
+ 'ticket__BAG__226875',
+ 'ticket__BAG__228414',
+ 'ticket__BAG__229236',
+ 'ticket__BAG__230080',
+ 'ticket__BAG__230136',
+ 'ticket__BAG__230433',
+ 'ticket__BAG__230434',
+ 'ticket__BAG__2314',
+ 'ticket__BAG__2315',
+ 'ticket__BAG__231919',
+ 'ticket__BAG__231945',
+ 'ticket__BAG__233478',
+ 'ticket__BAG__233639',
+ 'ticket__BAG__233866',
+ 'ticket__BAG__2343',
+ 'ticket__BAG__234604',
+ 'ticket__BAG__234686',
+ 'ticket__BAG__235509',
+ 'ticket__BAG__23567',
+ 'ticket__BAG__236852',
+ 'ticket__BAG__236853',
+ 'ticket__BAG__236854',
+ 'ticket__BAG__237216',
+ 'ticket__BAG__237249',
+ 'ticket__BAG__237442',
+ 'ticket__BAG__237668',
+ 'ticket__BAG__237670',
+ 'ticket__BAG__237671',
+ 'ticket__BAG__237735',
+ 'ticket__BAG__237736',
+ 'ticket__BAG__237789',
+ 'ticket__BAG__237798',
+ 'ticket__BAG__239853',
+ 'ticket__BAG__239855',
+ 'ticket__BAG__239856',
+ 'ticket__BAG__239865',
+ 'ticket__BAG__240276',
+ 'ticket__BAG__24065',
+ 'ticket__BAG__240929',
+ 'ticket__BAG__24160',
+ 'ticket__BAG__242963',
+ 'ticket__BAG__243847',
+ 'ticket__BAG__243880',
+ 'ticket__BAG__244252',
+ 'ticket__BAG__244270',
+ 'ticket__BAG__244278',
+ 'ticket__BAG__244346',
+ 'ticket__BAG__244358',
+ 'ticket__BAG__244360',
+ 'ticket__BAG__244361',
+ 'ticket__BAG__244367',
+ 'ticket__BAG__244368',
+ 'ticket__BAG__244373',
+ 'ticket__BAG__24579',
+ 'ticket__BAG__24580',
+ 'ticket__BAG__2466',
+ 'ticket__BAG__248659',
+ 'ticket__BAG__248698',
+ 'ticket__BAG__248706',
+ 'ticket__BAG__248723',
+ 'ticket__BAG__248726',
+ 'ticket__BAG__248727',
+ 'ticket__BAG__248731',
+ 'ticket__BAG__248733',
+ 'ticket__BAG__248734',
+ 'ticket__BAG__248738',
+ 'ticket__BAG__248740',
+ 'ticket__BAG__248744',
+ 'ticket__BAG__248746',
+ 'ticket__BAG__248747',
+ 'ticket__BAG__250644',
+ 'ticket__BAG__250646',
+ 'ticket__BAG__250647',
+ 'ticket__BAG__250648',
+ 'ticket__BAG__250649',
+ 'ticket__BAG__250650',
+ 'ticket__BAG__250651',
+ 'ticket__BAG__250652',
+ 'ticket__BAG__250653',
+ 'ticket__BAG__250655',
+ 'ticket__BAG__2543',
+ 'ticket__BAG__2620',
+ 'ticket__BAG__2621',
+ 'ticket__BAG__2622',
+ 'ticket__BAG__2623',
+ 'ticket__BAG__2624',
+ 'ticket__BAG__2625',
+ 'ticket__BAG__2626',
+ 'ticket__BAG__2627',
+ 'ticket__BAG__2628',
+ 'ticket__BAG__2631',
+ 'ticket__BAG__26360',
+ 'ticket__BAG__2641',
+ 'ticket__BAG__2647',
+ 'ticket__BAG__2648',
+ 'ticket__BAG__2649',
+ 'ticket__BAG__2650',
+ 'ticket__BAG__2651',
+ 'ticket__BAG__2652',
+ 'ticket__BAG__2653',
+ 'ticket__BAG__265302',
+ 'ticket__BAG__2654',
+ 'ticket__BAG__2655',
+ 'ticket__BAG__2656',
+ 'ticket__BAG__2657',
+ 'ticket__BAG__2658',
+ 'ticket__BAG__2659',
+ 'ticket__BAG__2660',
+ 'ticket__BAG__2661',
+ 'ticket__BAG__2662',
+ 'ticket__BAG__2663',
+ 'ticket__BAG__2664',
+ 'ticket__BAG__2665',
+ 'ticket__BAG__2666',
+ 'ticket__BAG__2667',
+ 'ticket__BAG__2668',
+ 'ticket__BAG__2669',
+ 'ticket__BAG__26707',
+ 'ticket__BAG__2672',
+ 'ticket__BAG__2673',
+ 'ticket__BAG__2674',
+ 'ticket__BAG__2675',
+ 'ticket__BAG__2677',
+ 'ticket__BAG__2678',
+ 'ticket__BAG__2679',
+ 'ticket__BAG__2680',
+ 'ticket__BAG__2681',
+ 'ticket__BAG__2682',
+ 'ticket__BAG__2685',
+ 'ticket__BAG__2686',
+ 'ticket__BAG__2687',
+ 'ticket__BAG__2688',
+ 'ticket__BAG__2689',
+ 'ticket__BAG__2690',
+ 'ticket__BAG__2691',
+ 'ticket__BAG__2694',
+ 'ticket__BAG__2696',
+ 'ticket__BAG__2697',
+ 'ticket__BAG__2698',
+ 'ticket__BAG__2699',
+ 'ticket__BAG__2700',
+ 'ticket__BAG__27042',
+ 'ticket__BAG__27267',
+ 'ticket__BAG__27849',
+ 'ticket__BAG__28004',
+ 'ticket__BAG__28133',
+ 'ticket__BAG__2816',
+ 'ticket__BAG__2817',
+ 'ticket__BAG__28206',
+ 'ticket__BAG__28213',
+ 'ticket__BAG__28220',
+ 'ticket__BAG__28221',
+ 'ticket__BAG__28228',
+ 'ticket__BAG__28403',
+ 'ticket__BAG__28404',
+ 'ticket__BAG__28424',
+ 'ticket__BAG__28425',
+ 'ticket__BAG__28551',
+ 'ticket__BAG__2861',
+ 'ticket__BAG__28664',
+ 'ticket__BAG__28665',
+ 'ticket__BAG__28666',
+ 'ticket__BAG__29037',
+ 'ticket__BAG__2908',
+ 'ticket__BAG__29103',
+ 'ticket__BAG__29104',
+ 'ticket__BAG__29105',
+ 'ticket__BAG__29106',
+ 'ticket__BAG__29107',
+ 'ticket__BAG__29108',
+ 'ticket__BAG__29178',
+ 'ticket__BAG__2926',
+ 'ticket__BAG__29395',
+ 'ticket__BAG__29566',
+ 'ticket__BAG__29750',
+ 'ticket__BAG__29751',
+ 'ticket__BAG__30631',
+ 'ticket__BAG__30769',
+ 'ticket__BAG__3085',
+ 'ticket__BAG__3101262',
+ 'ticket__BAG__3101263',
+ 'ticket__BAG__3101264',
+ 'ticket__BAG__3101265',
+ 'ticket__BAG__3101266',
+ 'ticket__BAG__3101268',
+ 'ticket__BAG__3101269',
+ 'ticket__BAG__3101270',
+ 'ticket__BAG__3101272',
+ 'ticket__BAG__3101273',
+ 'ticket__BAG__3101275',
+ 'ticket__BAG__3101277',
+ 'ticket__BAG__3101278',
+ 'ticket__BAG__3101279',
+ 'ticket__BAG__3101281',
+ 'ticket__BAG__3101282',
+ 'ticket__BAG__3101283',
+ 'ticket__BAG__3101285',
+ 'ticket__BAG__3101286',
+ 'ticket__BAG__3101288',
+ 'ticket__BAG__3101289',
+ 'ticket__BAG__3101290',
+ 'ticket__BAG__3101291',
+ 'ticket__BAG__3101292',
+ 'ticket__BAG__3101293',
+ 'ticket__BAG__3101294',
+ 'ticket__BAG__3101295',
+ 'ticket__BAG__3101296',
+ 'ticket__BAG__3101297',
+ 'ticket__BAG__3101298',
+ 'ticket__BAG__3101306',
+ 'ticket__BAG__3101307',
+ 'ticket__BAG__3101308',
+ 'ticket__BAG__3101309',
+ 'ticket__BAG__3101310',
+ 'ticket__BAG__3101311',
+ 'ticket__BAG__3101314',
+ 'ticket__BAG__3101315',
+ 'ticket__BAG__3101316',
+ 'ticket__BAG__3101317',
+ 'ticket__BAG__31026',
+ 'ticket__BAG__31027',
+ 'ticket__BAG__31029',
+ 'ticket__BAG__31030',
+ 'ticket__BAG__312991',
+ 'ticket__BAG__312992',
+ 'ticket__BAG__3130',
+ 'ticket__BAG__31352',
+ 'ticket__BAG__31416',
+ 'ticket__BAG__315084',
+ 'ticket__BAG__315085',
+ 'ticket__BAG__315086',
+ 'ticket__BAG__315088',
+ 'ticket__BAG__315089',
+ 'ticket__BAG__315091',
+ 'ticket__BAG__315092',
+ 'ticket__BAG__315093',
+ 'ticket__BAG__315094',
+ 'ticket__BAG__315095',
+ 'ticket__BAG__315096',
+ 'ticket__BAG__315097',
+ 'ticket__BAG__315098',
+ 'ticket__BAG__315152',
+ 'ticket__BAG__315153',
+ 'ticket__BAG__315154',
+ 'ticket__BAG__31921',
+ 'ticket__BAG__32302',
+ 'ticket__BAG__3235',
+ 'ticket__BAG__323592',
+ 'ticket__BAG__3236',
+ 'ticket__BAG__323951',
+ 'ticket__BAG__329944',
+ 'ticket__BAG__330844',
+ 'ticket__BAG__330877',
+ 'ticket__BAG__330910',
+ 'ticket__BAG__330911',
+ 'ticket__BAG__330919',
+ 'ticket__BAG__330920',
+ 'ticket__BAG__330923',
+ 'ticket__BAG__330924',
+ 'ticket__BAG__330931',
+ 'ticket__BAG__330932',
+ 'ticket__BAG__330935',
+ 'ticket__BAG__330958',
+ 'ticket__BAG__330959',
+ 'ticket__BAG__330968',
+ 'ticket__BAG__330971',
+ 'ticket__BAG__330972',
+ 'ticket__BAG__330979',
+ 'ticket__BAG__330980',
+ 'ticket__BAG__33111',
+ 'ticket__BAG__33112',
+ 'ticket__BAG__3337',
+ 'ticket__BAG__3338',
+ 'ticket__BAG__334914',
+ 'ticket__BAG__334915',
+ 'ticket__BAG__335097',
+ 'ticket__BAG__335677',
+ 'ticket__BAG__33595',
+ 'ticket__BAG__33638',
+ 'ticket__BAG__336439',
+ 'ticket__BAG__3381',
+ 'ticket__BAG__34068',
+ 'ticket__BAG__3410',
+ 'ticket__BAG__3411',
+ 'ticket__BAG__34218',
+ 'ticket__BAG__34244',
+ 'ticket__BAG__342441',
+ 'ticket__BAG__34260',
+ 'ticket__BAG__342684',
+ 'ticket__BAG__342712',
+ 'ticket__BAG__342826',
+ 'ticket__BAG__343095',
+ 'ticket__BAG__343120',
+ 'ticket__BAG__343271',
+ 'ticket__BAG__345364',
+ 'ticket__BAG__345498',
+ 'ticket__BAG__345501',
+ 'ticket__BAG__345572',
+ 'ticket__BAG__345763',
+ 'ticket__BAG__345764',
+ 'ticket__BAG__345765',
+ 'ticket__BAG__345767',
+ 'ticket__BAG__345769',
+ 'ticket__BAG__345771',
+ 'ticket__BAG__345773',
+ 'ticket__BAG__345775',
+ 'ticket__BAG__345777',
+ 'ticket__BAG__345778',
+ 'ticket__BAG__345779',
+ 'ticket__BAG__345780',
+ 'ticket__BAG__345783',
+ 'ticket__BAG__3460',
+ 'ticket__BAG__3464',
+ 'ticket__BAG__34644',
+ 'ticket__BAG__34651',
+ 'ticket__BAG__3470',
+ 'ticket__BAG__347054',
+ 'ticket__BAG__347060',
+ 'ticket__BAG__347062',
+ 'ticket__BAG__347063',
+ 'ticket__BAG__347064',
+ 'ticket__BAG__347065',
+ 'ticket__BAG__347066',
+ 'ticket__BAG__347067',
+ 'ticket__BAG__347068',
+ 'ticket__BAG__347069',
+ 'ticket__BAG__347071',
+ 'ticket__BAG__347072',
+ 'ticket__BAG__347074',
+ 'ticket__BAG__347075',
+ 'ticket__BAG__347076',
+ 'ticket__BAG__347077',
+ 'ticket__BAG__347078',
+ 'ticket__BAG__347080',
+ 'ticket__BAG__347082',
+ 'ticket__BAG__347085',
+ 'ticket__BAG__347086',
+ 'ticket__BAG__347087',
+ 'ticket__BAG__347088',
+ 'ticket__BAG__347089',
+ 'ticket__BAG__347090',
+ 'ticket__BAG__3474',
+ 'ticket__BAG__347464',
+ 'ticket__BAG__347467',
+ 'ticket__BAG__347469',
+ 'ticket__BAG__347742',
+ 'ticket__BAG__347743',
+ 'ticket__BAG__348121',
+ 'ticket__BAG__348122',
+ 'ticket__BAG__348123',
+ 'ticket__BAG__348125',
+ 'ticket__BAG__349202',
+ 'ticket__BAG__349204',
+ 'ticket__BAG__349205',
+ 'ticket__BAG__349207',
+ 'ticket__BAG__349210',
+ 'ticket__BAG__349211',
+ 'ticket__BAG__349212',
+ 'ticket__BAG__349213',
+ 'ticket__BAG__349214',
+ 'ticket__BAG__349215',
+ 'ticket__BAG__349216',
+ 'ticket__BAG__349217',
+ 'ticket__BAG__349219',
+ 'ticket__BAG__349220',
+ 'ticket__BAG__349221',
+ 'ticket__BAG__349223',
+ 'ticket__BAG__349224',
+ 'ticket__BAG__349225',
+ 'ticket__BAG__349226',
+ 'ticket__BAG__349227',
+ 'ticket__BAG__349228',
+ 'ticket__BAG__349229',
+ 'ticket__BAG__349230',
+ 'ticket__BAG__349231',
+ 'ticket__BAG__349232',
+ 'ticket__BAG__349234',
+ 'ticket__BAG__349235',
+ 'ticket__BAG__349236',
+ 'ticket__BAG__349237',
+ 'ticket__BAG__349238',
+ 'ticket__BAG__349239',
+ 'ticket__BAG__349240',
+ 'ticket__BAG__349241',
+ 'ticket__BAG__349243',
+ 'ticket__BAG__349244',
+ 'ticket__BAG__349245',
+ 'ticket__BAG__349246',
+ 'ticket__BAG__349247',
+ 'ticket__BAG__349248',
+ 'ticket__BAG__349249',
+ 'ticket__BAG__349250',
+ 'ticket__BAG__349251',
+ 'ticket__BAG__349252',
+ 'ticket__BAG__349253',
+ 'ticket__BAG__349254',
+ 'ticket__BAG__349255',
+ 'ticket__BAG__349256',
+ 'ticket__BAG__349257',
+ 'ticket__BAG__349909',
+ 'ticket__BAG__349910',
+ 'ticket__BAG__349912',
+ 'ticket__BAG__350026',
+ 'ticket__BAG__350029',
+ 'ticket__BAG__350033',
+ 'ticket__BAG__350034',
+ 'ticket__BAG__350035',
+ 'ticket__BAG__350036',
+ 'ticket__BAG__350042',
+ 'ticket__BAG__350048',
+ 'ticket__BAG__350052',
+ 'ticket__BAG__350053',
+ 'ticket__BAG__350054',
+ 'ticket__BAG__350060',
+ 'ticket__BAG__350404',
+ 'ticket__BAG__350405',
+ 'ticket__BAG__350406',
+ 'ticket__BAG__350407',
+ 'ticket__BAG__350408',
+ 'ticket__BAG__350409',
+ 'ticket__BAG__350410',
+ 'ticket__BAG__35273',
+ 'ticket__BAG__35281',
+ 'ticket__BAG__3540',
+ 'ticket__BAG__35851',
+ 'ticket__BAG__358585',
+ 'ticket__BAG__359306',
+ 'ticket__BAG__3594',
+ 'ticket__BAG__36209',
+ 'ticket__BAG__362316',
+ 'ticket__BAG__363291',
+ 'ticket__BAG__363592',
+ 'ticket__BAG__363611',
+ 'ticket__BAG__364498',
+ 'ticket__BAG__364499',
+ 'ticket__BAG__364500',
+ 'ticket__BAG__364511',
+ 'ticket__BAG__364512',
+ 'ticket__BAG__364516',
+ 'ticket__BAG__364846',
+ 'ticket__BAG__364848',
+ 'ticket__BAG__364849',
+ 'ticket__BAG__364850',
+ 'ticket__BAG__364851',
+ 'ticket__BAG__364856',
+ 'ticket__BAG__364858',
+ 'ticket__BAG__364859',
+ 'ticket__BAG__365222',
+ 'ticket__BAG__365226',
+ 'ticket__BAG__365235',
+ 'ticket__BAG__365237',
+ 'ticket__BAG__36568',
+ 'ticket__BAG__366713',
+ 'ticket__BAG__367226',
+ 'ticket__BAG__367227',
+ 'ticket__BAG__367228',
+ 'ticket__BAG__367229',
+ 'ticket__BAG__367230',
+ 'ticket__BAG__367231',
+ 'ticket__BAG__367232',
+ 'ticket__BAG__367655',
+ 'ticket__BAG__368323',
+ 'ticket__BAG__368364',
+ 'ticket__BAG__368402',
+ 'ticket__BAG__368573',
+ 'ticket__BAG__36864',
+ 'ticket__BAG__36865',
+ 'ticket__BAG__36866',
+ 'ticket__BAG__368702',
+ 'ticket__BAG__368703',
+ 'ticket__BAG__368783',
+ 'ticket__BAG__36928',
+ 'ticket__BAG__36947',
+ 'ticket__BAG__36963',
+ 'ticket__BAG__36967',
+ 'ticket__BAG__36973',
+ 'ticket__BAG__369943',
+ 'ticket__BAG__3701',
+ 'ticket__BAG__370129',
+ 'ticket__BAG__370365',
+ 'ticket__BAG__370368',
+ 'ticket__BAG__370370',
+ 'ticket__BAG__370371',
+ 'ticket__BAG__370372',
+ 'ticket__BAG__370373',
+ 'ticket__BAG__370374',
+ 'ticket__BAG__370375',
+ 'ticket__BAG__370376',
+ 'ticket__BAG__371060',
+ 'ticket__BAG__371109',
+ 'ticket__BAG__371110',
+ 'ticket__BAG__371362',
+ 'ticket__BAG__372622',
+ 'ticket__BAG__373450',
+ 'ticket__BAG__374746',
+ 'ticket__BAG__374887',
+ 'ticket__BAG__374910',
+ 'ticket__BAG__376563',
+ 'ticket__BAG__376564',
+ 'ticket__BAG__376566',
+ 'ticket__BAG__37671',
+ 'ticket__BAG__382650',
+ 'ticket__BAG__382651',
+ 'ticket__BAG__382652',
+ 'ticket__BAG__382653',
+ 'ticket__BAG__383121',
+ 'ticket__BAG__383162',
+ 'ticket__BAG__384461',
+ 'ticket__BAG__386525',
+ 'ticket__BAG__3902',
+ 'ticket__BAG__39186',
+ 'ticket__BAG__392076',
+ 'ticket__BAG__392078',
+ 'ticket__BAG__392083',
+ 'ticket__BAG__392087',
+ 'ticket__BAG__392091',
+ 'ticket__BAG__392095',
+ 'ticket__BAG__392096',
+ 'ticket__BAG__394140',
+ 'ticket__BAG__39886',
+ 'ticket__BAG__4001',
+ 'ticket__BAG__4133',
+ 'ticket__BAG__4135',
+ 'ticket__BAG__4137',
+ 'ticket__BAG__4138',
+ 'ticket__BAG__42795',
+ 'ticket__BAG__4348',
+ 'ticket__BAG__45380',
+ 'ticket__BAG__4579',
+ 'ticket__BAG__48871',
+ 'ticket__BAG__48873',
+ 'ticket__BAG__49867',
+ 'ticket__BAG__54510',
+ 'ticket__BAG__54636',
+ 'ticket__BAG__5547',
+ 'ticket__BAG__5727',
+ 'ticket__BAG__5734',
+ 'ticket__BAG__5735',
+ 'ticket__BAG__6212',
+ 'ticket__BAG__65303',
+ 'ticket__BAG__65305',
+ 'ticket__BAG__65306',
+ 'ticket__BAG__6563',
+ 'ticket__BAG__6607',
+ 'ticket__BAG__6608',
+ 'ticket__BAG__6609',
+ 'ticket__BAG__680',
+ 'ticket__BAG__693',
+ 'ticket__BAG__695',
+ 'ticket__BAG__7075',
+ 'ticket__BAG__7076',
+ 'ticket__BAG__7077',
+ 'ticket__BAG__751',
+ 'ticket__BAG__752',
+ 'ticket__BAG__7534',
+ 'ticket__BAG__7538',
+ 'ticket__BAG__7540',
+ 'ticket__BAG__7545',
+ 'ticket__BAG__7548',
+ 'ticket__BAG__7552',
+ 'ticket__BAG__7598',
+ 'ticket__BAG__7935',
+ 'ticket__BAG__8471',
+ 'ticket__BAG__851',
+ 'ticket__BAG__9232',
+ 'ticket__BAG__9234',
+ 'ticket__BAG__9549',
+ 'ticket__BAG__a4',
+ 'ticket__BAG__ah',
+ 'ticket__BAG__aq',
+ 'ticket__BAG__ca',
+ 'ticket__BAG__fa',
+ 'ticket__BAG__line',
+ 'ticket__BAG__lp',
+ 'ticket__BAG__o2',
+ 'ticket__BAG__oq',
+ 'ticket__BAG__paris',
+ 'ticket__BAG__pc',
+ 'ticket__BAG__pp',
+ 'ticket__BAG__sc',
+ 'ticket__BAG__sco',
+ 'ticket__BAG__soton',
+ 'ticket__BAG__ston',
+ 'ticket__BAG__sw',
+ 'ticket__BAG__we',
+ 'name__BAG__aaron',
+ 'name__BAG__abbing',
+ 'name__BAG__abbott',
+ 'name__BAG__abelseth',
+ 'name__BAG__abelson',
+ 'name__BAG__abi',
+ 'name__BAG__abraham',
+ 'name__BAG__abrahim',
+ 'name__BAG__achille',
+ 'name__BAG__ada',
+ 'name__BAG__adahl',
+ 'name__BAG__addie',
+ 'name__BAG__adele',
+ 'name__BAG__adelia',
+ 'name__BAG__adola',
+ 'name__BAG__adolf',
+ 'name__BAG__adolfina',
+ 'name__BAG__adolphe',
+ 'name__BAG__adrian',
+ 'name__BAG__agatha',
+ 'name__BAG__agda',
+ 'name__BAG__agnes',
+ 'name__BAG__ahmed',
+ 'name__BAG__aijo',
+ 'name__BAG__aina',
+ 'name__BAG__akar',
+ 'name__BAG__aks',
+ 'name__BAG__albert',
+ 'name__BAG__albimona',
+ 'name__BAG__albin',
+ 'name__BAG__albina',
+ 'name__BAG__alden',
+ 'name__BAG__aldworth',
+ 'name__BAG__alexander',
+ 'name__BAG__alexandra',
+ 'name__BAG__alexanteri',
+ 'name__BAG__alexenia',
+ 'name__BAG__alfons',
+ 'name__BAG__alfonzo',
+ 'name__BAG__alfred',
+ 'name__BAG__alfrida',
+ 'name__BAG__algernon',
+ 'name__BAG__ali',
+ 'name__BAG__alice',
+ 'name__BAG__aline',
+ 'name__BAG__allen',
+ 'name__BAG__allis',
+ 'name__BAG__allison',
+ 'name__BAG__allum',
+ 'name__BAG__alma',
+ 'name__BAG__aloisia',
+ 'name__BAG__amanda',
+ 'name__BAG__ambrose',
+ 'name__BAG__amelia',
+ 'name__BAG__amelie',
+ 'name__BAG__amenia',
+ 'name__BAG__amy',
+ 'name__BAG__anders',
+ 'name__BAG__andersen',
+ 'name__BAG__anderson',
+ 'name__BAG__andersson',
+ 'name__BAG__andre',
+ 'name__BAG__andreas',
+ 'name__BAG__andree',
+ 'name__BAG__andrew',
+ 'name__BAG__andrews',
+ 'name__BAG__andy',
+ 'name__BAG__angheloff',
+ 'name__BAG__angle',
+ 'name__BAG__ann',
+ 'name__BAG__anna',
+ 'name__BAG__anne',
+ 'name__BAG__annie',
+ 'name__BAG__anthony',
+ 'name__BAG__antino',
+ 'name__BAG__antoinette',
+ 'name__BAG__anton',
+ 'name__BAG__antoni',
+ 'name__BAG__antti',
+ 'name__BAG__apostolos',
+ 'name__BAG__appleton',
+ 'name__BAG__archibald',
+ 'name__BAG__argenia',
+ 'name__BAG__arne',
+ 'name__BAG__arnold',
+ 'name__BAG__artagaveytia',
+ 'name__BAG__arthur',
+ 'name__BAG__artur',
+ 'name__BAG__arvid',
+ 'name__BAG__ashby',
+ 'name__BAG__asim',
+ 'name__BAG__asplund',
+ 'name__BAG__assad',
+ 'name__BAG__assaf',
+ 'name__BAG__assam',
+ 'name__BAG__assi',
+ 'name__BAG__astor',
+ 'name__BAG__asuncion',
+ 'name__BAG__atkinson',
+ 'name__BAG__attalah',
+ 'name__BAG__aubart',
+ 'name__BAG__august',
+ 'name__BAG__augusta',
+ 'name__BAG__augustus',
+ 'name__BAG__aurora',
+ 'name__BAG__austen',
+ 'name__BAG__ayoub',
+ 'name__BAG__baccos',
+ 'name__BAG__backstrom',
+ 'name__BAG__baclini',
+ 'name__BAG__badman',
+ 'name__BAG__badt',
+ 'name__BAG__bailey',
+ 'name__BAG__baimbrigge',
+ 'name__BAG__baird',
+ 'name__BAG__balkic',
+ 'name__BAG__ball',
+ 'name__BAG__banfield',
+ 'name__BAG__banoura',
+ 'name__BAG__baptist',
+ 'name__BAG__barah',
+ 'name__BAG__barbara',
+ 'name__BAG__barber',
+ 'name__BAG__barkworth',
+ 'name__BAG__baron',
+ 'name__BAG__barrett',
+ 'name__BAG__barron',
+ 'name__BAG__barry',
+ 'name__BAG__bartol',
+ 'name__BAG__bateman',
+ 'name__BAG__baumgardner',
+ 'name__BAG__baxter',
+ 'name__BAG__bazzani',
+ 'name__BAG__beane',
+ 'name__BAG__beatrice',
+ 'name__BAG__beattie',
+ 'name__BAG__beavan',
+ 'name__BAG__bechstein',
+ 'name__BAG__becker',
+ 'name__BAG__beesley',
+ 'name__BAG__behr',
+ 'name__BAG__beila',
+ 'name__BAG__bengt',
+ 'name__BAG__bengtsson',
+ 'name__BAG__benjamin',
+ 'name__BAG__bentham',
+ 'name__BAG__berg',
+ 'name__BAG__berglund',
+ 'name__BAG__berk',
+ 'name__BAG__bernard',
+ 'name__BAG__bernhardina',
+ 'name__BAG__bernt',
+ 'name__BAG__berriman',
+ 'name__BAG__berta',
+ 'name__BAG__bertha',
+ 'name__BAG__bertram',
+ 'name__BAG__bessie',
+ 'name__BAG__betros',
+ 'name__BAG__bidois',
+ 'name__BAG__billiard',
+ 'name__BAG__bing',
+ 'name__BAG__birger',
+ 'name__BAG__birkeland',
+ 'name__BAG__birkhardt',
+ 'name__BAG__bishop',
+ 'name__BAG__bissette',
+ 'name__BAG__bjorklund',
+ 'name__BAG__bjornstrom',
+ 'name__BAG__blackwell',
+ 'name__BAG__blank',
+ 'name__BAG__bloomfield',
+ 'name__BAG__blumer',
+ 'name__BAG__blun',
+ ...]
+
+ +
+ +
+
+ +
+
+
+
In [5]:
+
+
+
import pandas as pd
+importance = pd.Series(gpipeline.models["rf"].feature_importances_, index=random_forest_features)
+importance.sort_values(ascending=False, inplace=True)
+importance.head(20)
+
+ +
+
+
+ +
+
+ + +
+ +
Out[5]:
+ + + + +
+
boat____null__         0.206115
+name__BAG__mr          0.062559
+sex__male              0.051571
+sex__female            0.046927
+name__BAG__mrs         0.022630
+boat__15               0.021434
+fare                   0.018856
+age                    0.017897
+boat__13               0.016285
+name__BAG__miss        0.015394
+cabin____null__        0.014628
+pclass                 0.013752
+cabin____default__     0.013450
+boat__C                0.011137
+body_isnull            0.010993
+parch                  0.010555
+boat____default__      0.010102
+home_dest____null__    0.008842
+sibsp                  0.008263
+boat__14               0.007434
+dtype: float64
+
+ +
+ +
+
+ +
+
+
+
+

So we can retrieve the importance given by the RandomForest and use the name of the features to see what it corresponds to

+ +
+
+
+
+
+
+

Subpipeline

Sometime it can be helpful to create another pipeline containing only the preprocessing. We have used that technics for two reasons :

+
    +
  • combining GraphPipeline and shapley values to have a good explanation of a model
  • +
  • in a clustering framework, remove the clustering algorithm and keep the pre-processing and use that to create a distance
  • +
+ +
+
+
+
+
+
+

We'll include a 'PassThrough' node (transformers that does nothing) in the pipeline so that we can use it to stop our preprocessing.

+

In more recent version of aikit you won't have to manually include that model, you'll be able to stop before or after any node.

+ +
+
+
+
+
+
In [6]:
+
+
+
from aikit.transformers import PassThrough
+
+gpipeline = GraphPipeline(models = {
+    "sel":ColumnsSelector(columns_to_use=non_text_cols),
+    "enc":NumericalEncoder(columns_to_use="object"),
+    "imp":NumImputer(),
+    "vect":CountVectorizerWrapper(analyzer="word",columns_to_use=text_cols),
+    "pt": PassThrough(),
+    "rf":RandomForestClassifier(n_estimators=100, random_state=123)
+                       },
+              edges = [("sel","enc","imp","pt", "rf"),("vect","pt", "rf")])
+
+gpipeline.graphviz
+
+ +
+
+
+ +
+
+ + +
+ +
Out[6]:
+ + + +
+ + + + + + +%3 + + +enc + +enc + + +imp + +imp + + +enc->imp + + + + +pt + +pt + + +imp->pt + + + + +sel + +sel + + +sel->enc + + + + +rf + +rf + + +pt->rf + + + + +vect + +vect + + +vect->pt + + + + + + +
+ +
+ +
+
+ +
+
+
+
In [7]:
+
+
+
gpipeline.fit(Xtrain, ytrain)
+
+ +
+
+
+ +
+
+ + +
+ +
Out[7]:
+ + + + +
+
GraphPipeline(edges=[('sel', 'enc', 'imp', 'pt', 'rf'), ('vect', 'pt', 'rf')],
+              models={'enc': NumericalEncoder(columns_to_use='object'),
+                      'imp': NumImputer(), 'pt': PassThrough(),
+                      'rf': RandomForestClassifier(random_state=123),
+                      'sel': ColumnsSelector(columns_to_use=['pclass', 'sex',
+                                                             'age', 'sibsp',
+                                                             'parch', 'fare',
+                                                             'cabin',
+                                                             'embarked', 'boat',
+                                                             'body',
+                                                             'home_dest']),
+                      'vect': CountVectorizerWrapper(columns_to_use=['ticket',
+                                                                     'name'])})
+
+ +
+ +
+
+ +
+
+
+
In [8]:
+
+
+
sub_pipeline = gpipeline.get_subpipeline(end_node="pt")
+sub_pipeline.graphviz
+
+ +
+
+
+ +
+
+ + +
+ +
Out[8]:
+ + + +
+ + + + + + +%3 + + +enc + +enc + + +imp + +imp + + +enc->imp + + + + +pt + +pt + + +imp->pt + + + + +sel + +sel + + +sel->enc + + + + +vect + +vect + + +vect->pt + + + + + + +
+ +
+ +
+
+ +
+
+
+
In [9]:
+
+
+
Xtrain_before_rf = sub_pipeline.transform(Xtrain)
+Xtrain_before_rf.head(10)
+
+ +
+
+
+ +
+
+ + +
+ +
Out[9]:
+ + + +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
pclassagesibspparchfarebodysex__malesex__femalecabin____null__cabin____default__...name__BAG__youseffname__BAG__yousifname__BAG__youssefname__BAG__yousseffname__BAG__yroisname__BAG__zabourname__BAG__zakarianname__BAG__zebleyname__BAG__zenniname__BAG__zillah
0154.0000000051.8625175.0000001001...0000000000
1164.00000014263.0000159.0918371001...0000000000
2124.0000000069.3000159.0918370101...0000000000
332.0000004139.6875159.0918371010...0000000000
4116.0000000086.5000159.0918370101...0000000000
5322.000000009.0000159.0918371010...0000000000
6329.841488007.2500159.0918371010...0000000000
7148.0000000025.9292159.0918370101...0000000000
8118.0000001060.0000159.0918370101...0000000000
9133.0000000026.5500109.0000001010...0000000000
+

10 rows × 2478 columns

+
+
+ +
+ +
+
+ +
+
+
+
+

So by doing that we retrieve the features entering the random forest algorithm.

+ +
+
+
+
+
+
+

We can check that the predictions are the same

+ +
+
+
+
+
+
In [10]:
+
+
+
import numpy as np
+rf = gpipeline.models["rf"]
+
+probas1 = rf.predict_proba(Xtrain_before_rf)
+probas2 = gpipeline.predict_proba(Xtrain)
+
+np.abs(probas1 - probas2).max()
+
+ +
+
+
+ +
+
+ + +
+ +
Out[10]:
+ + + + +
+
0.0
+
+ +
+ +
+
+ +
+
+
+
+

Example using Shap

+
+
+
+
+
+
In [11]:
+
+
+
import shap
+rf = gpipeline.models["rf"]
+explainer = shap.TreeExplainer(rf)
+explainer
+
+choosen_instance = Xtrain_before_rf.loc[[0]]
+shap_values = explainer.shap_values(choosen_instance, check_additivity=False)
+shap_values
+
+shap.initjs()
+shap.force_plot(explainer.expected_value[1], shap_values[1], choosen_instance)
+
+ +
+
+
+ +
+
+ + +
+ +
+ + + +
+
+
+ +
+ +
+ +
Out[11]:
+ + + +
+ +
+
+ Visualization omitted, Javascript library not loaded!
+ Have you run `initjs()` in this notebook? If this notebook was from another + user you must also trust this notebook (File -> Trust notebook). If you are viewing + this notebook on github the Javascript has been stripped for security. If you are using + JupyterLab this error is because a JupyterLab extension has not yet been written. +
+ +
+ +
+ +
+
+ +
+
+
+
+

Cutting a clustering

Here we'll do something a little different to illustrate another possible use-case, we'll do a clustering of the text column. By doing the following steps :

+
    +
  • bag of char encoding
  • +
  • truncated svd
  • +
  • kmeans
  • +
+

And in a second step we'll retrieve the features entering the KMeans algorithm. +If our clustering pipeline gives good result, an euclidian distance on the features entering the KMeans should be a good distance

+ +
+
+
+
+
+
In [12]:
+
+
+
from aikit.transformers import TruncatedSVDWrapper
+from sklearn.cluster import KMeans
+
+clustering_pipeline = GraphPipeline(models={"vect":CountVectorizerWrapper(analyzer="char",
+                                                                          ngram_range=(1, 4),
+                                                                          columns_to_use=["name"]),
+                                            "svd":TruncatedSVDWrapper(n_components=200, random_state=123),
+                                            "km":KMeans(n_clusters=5, random_state=123)
+}, edges=[("vect","svd","km")]
+)
+
+clustering_pipeline.graphviz
+
+ +
+
+
+ +
+
+ + +
+ +
Out[12]:
+ + + +
+ + + + + + +%3 + + +svd + +svd + + +km + +km + + +svd->km + + + + +vect + +vect + + +vect->svd + + + + + + +
+ +
+ +
+
+ +
+
+
+
In [13]:
+
+
+
clustering_pipeline.fit(Xtrain)
+
+ +
+
+
+ +
+
+ + +
+ +
Out[13]:
+ + + + +
+
GraphPipeline(edges=[('vect', 'svd', 'km')],
+              models={'km': KMeans(n_clusters=5, random_state=123),
+                      'svd': TruncatedSVDWrapper(n_components=200,
+                                                 random_state=123),
+                      'vect': CountVectorizerWrapper(analyzer='char',
+                                                     columns_to_use=['name'],
+                                                     ngram_range=(1, 4))})
+
+ +
+ +
+
+ +
+
+
+
+

The goal here is not to do a tutorial on clustering, and clustering names doesn't make much sense...

+

But we can still look at the clusters (for fun)

+ +
+
+
+
+
+
In [14]:
+
+
+
clusters = clustering_pipeline.predict(Xtrain)
+
+for cl in range(5):
+    print(f"cluster {cl}, nb of observations {np.sum(clusters == cl)}")
+    for name in Xtrain.loc[clusters == cl, "name"].head(10):
+        print(name)
+        
+    print("")
+
+ +
+
+
+ +
+
+ + +
+ +
+ + +
+
cluster 0, nb of observations 73
+Cavendish, Mr. Tyrell William
+Ware, Mr. William Jeffery
+Gill, Mr. John William
+Frauenthal, Dr. Henry William
+Gilbert, Mr. William
+Campbell, Mr. William
+Warren, Mr. Charles William
+Tornquist, Mr. William Henry
+Dulles, Mr. William Crothers
+Carter, Mr. William Ernest
+
+cluster 1, nb of observations 349
+McCarthy, Mr. Timothy J
+Fortune, Mr. Mark
+Meo, Mr. Alfonzo
+Elias, Mr. Dibo
+Reynaldo, Ms. Encarnacion
+Khalil, Mr. Betros
+Lennon, Mr. Denis
+Johansson, Mr. Gustaf Joel
+O'Connor, Mr. Patrick
+McMahon, Mr. Martin
+
+cluster 2, nb of observations 269
+Sagesser, Mlle. Emma
+Panula, Master. Urho Abraham
+Waelens, Mr. Achille
+Reed, Mr. James George
+Rowe, Mr. Alfred G
+Abbott, Mr. Rossmore Edward
+de Pelsmaeker, Mr. Alfons
+Asplund, Master. Carl Edgar
+Lockyer, Mr. Edward
+Davies, Mr. Charles Henry
+
+cluster 3, nb of observations 141
+Swift, Mrs. Frederick Joel (Margaret Welles Barron)
+Smith, Mrs. Lucien Philip (Mary Eloise Hughes)
+Thorneycroft, Mrs. Percival (Florence Kate White)
+Chambers, Mrs. Norman Campbell (Bertha Griggs)
+O'Brien, Mrs. Thomas (Johanna 'Hannah' Godfrey)
+Marvin, Mrs. Daniel Warner (Mary Graham Carmichael Farquarson)
+Lemore, Mrs. (Amelia Milley)
+Warren, Mrs. Frank Manley (Anna Sophia Atkinson)
+McNamee, Mrs. Neal (Eileen O'Leary)
+Lindell, Mrs. Edvard Bengtsson (Elin Gerda Persson)
+
+cluster 4, nb of observations 216
+Maioni, Miss. Roberta
+Daniels, Miss. Sarah
+Ford, Miss. Robina Maggie 'Ruby'
+Harper, Miss. Annie Jessie 'Nina'
+Fleming, Miss. Margaret
+Riihivouri, Miss. Susanna Juhantytar 'Sanni'
+Cacic, Miss. Marija
+Dowdell, Miss. Elizabeth
+Wick, Miss. Mary Natalie
+Fortune, Miss. Alice Elizabeth
+
+
+
+
+ +
+
+ +
+
+
+
+
    +
  • So clusters 0 corresponds to 'William'
  • +
  • clusters 1 and 2 : more 'Mr'
  • +
  • cluster 3 seems to have found the name containing 'Mrs.' (and maybe longer name)
  • +
  • cluster 4 seem to have more 'miss'
  • +
+ +
+
+
+
+
+
+

Now let's stop before the clustering :

+ +
+
+
+
+
+
In [15]:
+
+
+
subpipeline = clustering_pipeline.get_subpipeline("svd")
+subpipeline.graphviz
+
+ +
+
+
+ +
+
+ + +
+ +
Out[15]:
+ + + +
+ + + + + + +%3 + + +vect + +vect + + +svd + +svd + + +vect->svd + + + + + + +
+ +
+ +
+
+ +
+
+
+
In [16]:
+
+
+
Xbefore_svd = subpipeline.transform(Xtrain)
+Xbefore_svd.shape
+
+ +
+
+
+ +
+
+ + +
+ +
Out[16]:
+ + + + +
+
(1048, 200)
+
+ +
+ +
+
+ +
+
+
+
+

This matrix can be used to compute euclidian distance on. Since KMeans uses Euclidian distance we can expect that :

+
    +
  • if your KMeans make sense ...
  • +
  • ... then our distance would also make sense
  • +
+

Remark: in our experience this type of clustering (Bag Of Char, SVD, KMeans) work relatively well on 'dirty' text (where the notion of words doesn't make much sense). It can be a relatively simple way to test some clustering algorithms

+ +
+
+
+
+
+
+

other functionnalities

    +
  • add nodes to an existing pipeline
  • +
  • take a sklearn pipeline and convert it to an aikit one
  • +
+ +
+
+
+
+
+
+

sklearn to aikit

You can take an existing sklearn pipeline and convert it to an aikit one (this work on an un-fitted pipeline and a fitted one)

+ +
+
+
+
+
+
+

... un-fitted pipeline

+
+
+
+
+
+
In [17]:
+
+
+
from sklearn.pipeline import Pipeline
+
+clustering_pipeline = Pipeline([("vect", CountVectorizerWrapper(analyzer="char",
+                                                                          ngram_range=(1, 4),
+                                                                          columns_to_use=["name"])),
+                                 ("svd", TruncatedSVDWrapper(n_components=200, random_state=123)),
+                                 ("km", KMeans(n_clusters=5, random_state=123))])
+clustering_pipeline
+
+ +
+
+
+ +
+
+ + +
+ +
Out[17]:
+ + + + +
+
Pipeline(steps=[('vect',
+                 CountVectorizerWrapper(analyzer='char',
+                                        columns_to_use=['name'],
+                                        ngram_range=(1, 4))),
+                ('svd',
+                 TruncatedSVDWrapper(n_components=200, random_state=123)),
+                ('km', KMeans(n_clusters=5, random_state=123))])
+
+ +
+ +
+
+ +
+
+
+
In [18]:
+
+
+
aikit_pipeline = GraphPipeline.from_sklearn(clustering_pipeline)
+aikit_pipeline
+
+ +
+
+
+ +
+
+ + +
+ +
Out[18]:
+ + + + +
+
GraphPipeline(edges=[('vect', 'svd', 'km')],
+              models={'km': KMeans(n_clusters=5, random_state=123),
+                      'svd': TruncatedSVDWrapper(n_components=200,
+                                                 random_state=123),
+                      'vect': CountVectorizerWrapper(analyzer='char',
+                                                     columns_to_use=['name'],
+                                                     ngram_range=(1, 4))})
+
+ +
+ +
+
+ +
+
+
+
+

... fitted pipeline

+
+
+
+
+
+
In [19]:
+
+
+
clustering_pipeline = Pipeline([("vect", CountVectorizerWrapper(analyzer="char",
+                                                                          ngram_range=(1, 4),
+                                                                          columns_to_use=["name"])),
+                                 ("svd", TruncatedSVDWrapper(n_components=200, random_state=123)),
+                                 ("km", KMeans(n_clusters=5, random_state=123))])
+clustering_pipeline.fit(Xtrain)
+
+ +
+
+
+ +
+
+ + +
+ +
Out[19]:
+ + + + +
+
Pipeline(steps=[('vect',
+                 CountVectorizerWrapper(analyzer='char',
+                                        columns_to_use=['name'],
+                                        ngram_range=(1, 4))),
+                ('svd',
+                 TruncatedSVDWrapper(n_components=200, random_state=123)),
+                ('km', KMeans(n_clusters=5, random_state=123))])
+
+ +
+ +
+
+ +
+
+
+
In [20]:
+
+
+
aikit_pipeline = GraphPipeline.from_sklearn(clustering_pipeline)
+aikit_pipeline
+
+ +
+
+
+ +
+
+ + +
+ +
Out[20]:
+ + + + +
+
GraphPipeline(edges=[('vect', 'svd', 'km')],
+              models={'km': KMeans(n_clusters=5, random_state=123),
+                      'svd': TruncatedSVDWrapper(n_components=200,
+                                                 random_state=123),
+                      'vect': CountVectorizerWrapper(analyzer='char',
+                                                     columns_to_use=['name'],
+                                                     ngram_range=(1, 4))})
+
+ +
+ +
+
+ +
+
+
+
+

prediction are the sames:

+ +
+
+
+
+
+
In [21]:
+
+
+
pred_from_aikit=aikit_pipeline.predict(Xtrain)
+pred_from_sklearn=clustering_pipeline.predict(Xtrain)
+(pred_from_aikit == pred_from_sklearn).all()
+
+ +
+
+
+ +
+
+ + +
+ +
Out[21]:
+ + + + +
+
True
+
+ +
+ +
+
+ +
+
+
+
In [ ]:
+
+
+
 
+
+ +
+
+
+ +
+
+
+ + + + + + diff --git a/docs/notebooks/GraphPipelineAdvanced.ipynb b/docs/notebooks/GraphPipelineAdvanced.ipynb index 230a669..10c1c87 100644 --- a/docs/notebooks/GraphPipelineAdvanced.ipynb +++ b/docs/notebooks/GraphPipelineAdvanced.ipynb @@ -297,31 +297,70 @@ " warnings.warn(message, FutureWarning)\n" ] }, - { - "ename": "ExecutableNotFound", - "evalue": "failed to execute ['dot', '-Tsvg'], make sure the Graphviz executables are on your systems' PATH", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\graphviz\\backend.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(cmd, input, capture_output, check, quiet, **kwargs)\u001b[0m\n\u001b[0;32m 158\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 159\u001b[1;33m \u001b[0mproc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msubprocess\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mPopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcmd\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstartupinfo\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mget_startupinfo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 160\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mOSError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\Anaconda3\\lib\\subprocess.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, args, bufsize, executable, stdin, stdout, stderr, preexec_fn, close_fds, shell, cwd, env, universal_newlines, startupinfo, creationflags, restore_signals, start_new_session, pass_fds, encoding, errors, text)\u001b[0m\n\u001b[0;32m 774\u001b[0m \u001b[0merrread\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0merrwrite\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 775\u001b[1;33m restore_signals, start_new_session)\n\u001b[0m\u001b[0;32m 776\u001b[0m \u001b[1;32mexcept\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\Anaconda3\\lib\\subprocess.py\u001b[0m in \u001b[0;36m_execute_child\u001b[1;34m(self, args, executable, preexec_fn, close_fds, pass_fds, cwd, env, startupinfo, creationflags, shell, p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite, unused_restore_signals, unused_start_new_session)\u001b[0m\n\u001b[0;32m 1177\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfspath\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcwd\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mcwd\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1178\u001b[1;33m startupinfo)\n\u001b[0m\u001b[0;32m 1179\u001b[0m \u001b[1;32mfinally\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mFileNotFoundError\u001b[0m: [WinError 2] Le fichier spécifié est introuvable", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[1;31mExecutableNotFound\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\IPython\\core\\formatters.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, obj)\u001b[0m\n\u001b[0;32m 343\u001b[0m \u001b[0mmethod\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_real_method\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprint_method\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 344\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 345\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mmethod\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 346\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 347\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\graphviz\\files.py\u001b[0m in \u001b[0;36m_repr_svg_\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 111\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_repr_svg_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 113\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpipe\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'svg'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_encoding\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 114\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 115\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mpipe\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquiet\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\graphviz\\files.py\u001b[0m in \u001b[0;36mpipe\u001b[1;34m(self, format, renderer, formatter, quiet)\u001b[0m\n\u001b[0;32m 136\u001b[0m out = backend.pipe(self._engine, format, data,\n\u001b[0;32m 137\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mformatter\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 138\u001b[1;33m quiet=quiet)\n\u001b[0m\u001b[0;32m 139\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 140\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\graphviz\\backend.py\u001b[0m in \u001b[0;36mpipe\u001b[1;34m(engine, format, data, renderer, formatter, quiet)\u001b[0m\n\u001b[0;32m 227\u001b[0m \"\"\"\n\u001b[0;32m 228\u001b[0m \u001b[0mcmd\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcommand\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mengine\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 229\u001b[1;33m \u001b[0mout\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcmd\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcapture_output\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcheck\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mquiet\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mquiet\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 230\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 231\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\graphviz\\backend.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(cmd, input, capture_output, check, quiet, **kwargs)\u001b[0m\n\u001b[0;32m 160\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mOSError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 161\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0merrno\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0merrno\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mENOENT\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 162\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mExecutableNotFound\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcmd\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 163\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 164\u001b[0m \u001b[1;32mraise\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mExecutableNotFound\u001b[0m: failed to execute ['dot', '-Tsvg'], make sure the Graphviz executables are on your systems' PATH" - ] - }, { "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "%3\r\n", + "\r\n", + "\r\n", + "enc\r\n", + "\r\n", + "enc\r\n", + "\r\n", + "\r\n", + "imp\r\n", + "\r\n", + "imp\r\n", + "\r\n", + "\r\n", + "enc->imp\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "rf\r\n", + "\r\n", + "rf\r\n", + "\r\n", + "\r\n", + "imp->rf\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "sel\r\n", + "\r\n", + "sel\r\n", + "\r\n", + "\r\n", + "sel->enc\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "vect\r\n", + "\r\n", + "vect\r\n", + "\r\n", + "\r\n", + "vect->rf\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n" + ], "text/plain": [ - "" + "" ] }, "execution_count": 2, @@ -418,21 +457,21 @@ " 'embarked__Q',\n", " 'boat____null__',\n", " 'boat__13',\n", - " 'boat__15',\n", " 'boat__C',\n", + " 'boat__15',\n", " 'boat__14',\n", " 'boat__4',\n", " 'boat__10',\n", " 'boat__5',\n", " 'boat__6',\n", + " 'boat__11',\n", " 'boat__8',\n", + " 'boat__12',\n", " 'boat__16',\n", " 'boat__7',\n", - " 'boat__11',\n", - " 'boat__12',\n", " 'boat__3',\n", - " 'boat__9',\n", " 'boat__D',\n", + " 'boat__9',\n", " 'boat____default__',\n", " 'home_dest____null__',\n", " 'home_dest__New York, NY',\n", @@ -1425,26 +1464,26 @@ { "data": { "text/plain": [ - "boat____null__ 0.189853\n", - "sex__male 0.058451\n", - "sex__female 0.054584\n", - "name__BAG__mr 0.045457\n", - "name__BAG__mrs 0.023080\n", - "fare 0.021042\n", - "name__BAG__miss 0.018376\n", - "age 0.017365\n", - "boat__15 0.017300\n", - "boat__13 0.016186\n", - "cabin____default__ 0.015086\n", - "pclass 0.014003\n", - "boat__C 0.013554\n", - "boat____default__ 0.012829\n", - "cabin____null__ 0.011726\n", - "body_isnull 0.010226\n", - "boat__14 0.009290\n", - "parch 0.008864\n", - "sibsp 0.008804\n", - "home_dest____null__ 0.008692\n", + "boat____null__ 0.206115\n", + "name__BAG__mr 0.062559\n", + "sex__male 0.051571\n", + "sex__female 0.046927\n", + "name__BAG__mrs 0.022630\n", + "boat__15 0.021434\n", + "fare 0.018856\n", + "age 0.017897\n", + "boat__13 0.016285\n", + "name__BAG__miss 0.015394\n", + "cabin____null__ 0.014628\n", + "pclass 0.013752\n", + "cabin____default__ 0.013450\n", + "boat__C 0.011137\n", + "body_isnull 0.010993\n", + "parch 0.010555\n", + "boat____default__ 0.010102\n", + "home_dest____null__ 0.008842\n", + "sibsp 0.008263\n", + "boat__14 0.007434\n", "dtype: float64" ] }, @@ -1481,7 +1520,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We'll include a 'PassThrough' node (transformers that does nothing) in the pipeline so that we can use it to stop our preprocessing" + "We'll include a 'PassThrough' node (transformers that does nothing) in the pipeline so that we can use it to stop our preprocessing.\n", + "\n", + "In more recent version of aikit you won't have to manually include that model, you'll be able to stop before or after any node.\n" ] }, { @@ -1491,20 +1532,78 @@ "outputs": [ { "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "%3\r\n", + "\r\n", + "\r\n", + "enc\r\n", + "\r\n", + "enc\r\n", + "\r\n", + "\r\n", + "imp\r\n", + "\r\n", + "imp\r\n", + "\r\n", + "\r\n", + "enc->imp\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "pt\r\n", + "\r\n", + "pt\r\n", + "\r\n", + "\r\n", + "imp->pt\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "sel\r\n", + "\r\n", + "sel\r\n", + "\r\n", + "\r\n", + "sel->enc\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "rf\r\n", + "\r\n", + "rf\r\n", + "\r\n", + "\r\n", + "pt->rf\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "vect\r\n", + "\r\n", + "vect\r\n", + "\r\n", + "\r\n", + "vect->pt\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n" + ], "text/plain": [ - "GraphPipeline(edges=[('sel', 'enc', 'imp', 'pt', 'rf'), ('vect', 'pt', 'rf')],\n", - " models={'enc': NumericalEncoder(columns_to_use='object'),\n", - " 'imp': NumImputer(), 'pt': PassThrough(),\n", - " 'rf': RandomForestClassifier(random_state=123),\n", - " 'sel': ColumnsSelector(columns_to_use=['pclass', 'sex',\n", - " 'age', 'sibsp',\n", - " 'parch', 'fare',\n", - " 'cabin',\n", - " 'embarked', 'boat',\n", - " 'body',\n", - " 'home_dest']),\n", - " 'vect': CountVectorizerWrapper(columns_to_use=['ticket',\n", - " 'name'])})" + "" ] }, "execution_count": 6, @@ -1524,9 +1623,8 @@ " \"rf\":RandomForestClassifier(n_estimators=100, random_state=123)\n", " },\n", " edges = [(\"sel\",\"enc\",\"imp\",\"pt\", \"rf\"),(\"vect\",\"pt\", \"rf\")])\n", - "#gpipeline.graphviz\n", "\n", - "gpipeline.fit(Xtrain, ytrain)" + "gpipeline.graphviz" ] }, { @@ -1537,10 +1635,10 @@ { "data": { "text/plain": [ - "GraphPipeline(edges=[('sel', 'enc'), ('enc', 'imp'), ('imp', 'pt'),\n", - " ('vect', 'pt')],\n", + "GraphPipeline(edges=[('sel', 'enc', 'imp', 'pt', 'rf'), ('vect', 'pt', 'rf')],\n", " models={'enc': NumericalEncoder(columns_to_use='object'),\n", " 'imp': NumImputer(), 'pt': PassThrough(),\n", + " 'rf': RandomForestClassifier(random_state=123),\n", " 'sel': ColumnsSelector(columns_to_use=['pclass', 'sex',\n", " 'age', 'sibsp',\n", " 'parch', 'fare',\n", @@ -1558,14 +1656,94 @@ } ], "source": [ - "sub_pipeline = gpipeline.get_subpipeline(end_node=\"pt\")\n", - "sub_pipeline" + "gpipeline.fit(Xtrain, ytrain)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "%3\r\n", + "\r\n", + "\r\n", + "enc\r\n", + "\r\n", + "enc\r\n", + "\r\n", + "\r\n", + "imp\r\n", + "\r\n", + "imp\r\n", + "\r\n", + "\r\n", + "enc->imp\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "pt\r\n", + "\r\n", + "pt\r\n", + "\r\n", + "\r\n", + "imp->pt\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "sel\r\n", + "\r\n", + "sel\r\n", + "\r\n", + "\r\n", + "sel->enc\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "vect\r\n", + "\r\n", + "vect\r\n", + "\r\n", + "\r\n", + "vect->pt\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sub_pipeline = gpipeline.get_subpipeline(end_node=\"pt\")\n", + "sub_pipeline.graphviz" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, "outputs": [ { "data": { @@ -1921,7 +2099,7 @@ "[10 rows x 2478 columns]" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -1947,7 +2125,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -1956,7 +2134,7 @@ "0.0" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -1980,16 +2158,9 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 11, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 2478)\n" - ] - }, { "data": { "text/html": [ @@ -2036,7 +2207,7 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", @@ -2046,16 +2217,16 @@ "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 27, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -2066,9 +2237,7 @@ "explainer = shap.TreeExplainer(rf)\n", "explainer\n", "\n", - "print(choosen_instance.shape)\n", - "\n", - "choosen_instance = Xtrain_before_rf.loc[[421]]\n", + "choosen_instance = Xtrain_before_rf.loc[[0]]\n", "shap_values = explainer.shap_values(choosen_instance, check_additivity=False)\n", "shap_values\n", "\n", @@ -2092,22 +2261,56 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "%3\r\n", + "\r\n", + "\r\n", + "svd\r\n", + "\r\n", + "svd\r\n", + "\r\n", + "\r\n", + "km\r\n", + "\r\n", + "km\r\n", + "\r\n", + "\r\n", + "svd->km\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "vect\r\n", + "\r\n", + "vect\r\n", + "\r\n", + "\r\n", + "vect->svd\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n" + ], "text/plain": [ - "GraphPipeline(edges=[('vect', 'svd', 'km')],\n", - " models={'km': KMeans(n_clusters=5, random_state=123),\n", - " 'svd': TruncatedSVDWrapper(n_components=200,\n", - " random_state=123),\n", - " 'vect': CountVectorizerWrapper(analyzer='char',\n", - " columns_to_use=['name'],\n", - " ngram_range=(1, 4))})" + "" ] }, - "execution_count": 36, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -2123,12 +2326,13 @@ " \"km\":KMeans(n_clusters=5, random_state=123)\n", "}, edges=[(\"vect\",\"svd\",\"km\")]\n", ")\n", - "clustering_pipeline" + "\n", + "clustering_pipeline.graphviz" ] }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -2143,7 +2347,7 @@ " ngram_range=(1, 4))})" ] }, - "execution_count": 37, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -2163,23 +2367,14 @@ }, { "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [], - "source": [ - "clusters = clustering_pipeline.predict(Xtrain)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 45, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "cluster 0\n", + "cluster 0, nb of observations 73\n", "Cavendish, Mr. Tyrell William\n", "Ware, Mr. William Jeffery\n", "Gill, Mr. John William\n", @@ -2191,7 +2386,7 @@ "Dulles, Mr. William Crothers\n", "Carter, Mr. William Ernest\n", "\n", - "cluster 1\n", + "cluster 1, nb of observations 349\n", "McCarthy, Mr. Timothy J\n", "Fortune, Mr. Mark\n", "Meo, Mr. Alfonzo\n", @@ -2203,7 +2398,7 @@ "O'Connor, Mr. Patrick\n", "McMahon, Mr. Martin\n", "\n", - "cluster 2\n", + "cluster 2, nb of observations 269\n", "Sagesser, Mlle. Emma\n", "Panula, Master. Urho Abraham\n", "Waelens, Mr. Achille\n", @@ -2215,7 +2410,7 @@ "Lockyer, Mr. Edward\n", "Davies, Mr. Charles Henry\n", "\n", - "cluster 3\n", + "cluster 3, nb of observations 141\n", "Swift, Mrs. Frederick Joel (Margaret Welles Barron)\n", "Smith, Mrs. Lucien Philip (Mary Eloise Hughes)\n", "Thorneycroft, Mrs. Percival (Florence Kate White)\n", @@ -2227,7 +2422,7 @@ "McNamee, Mrs. Neal (Eileen O'Leary)\n", "Lindell, Mrs. Edvard Bengtsson (Elin Gerda Persson)\n", "\n", - "cluster 4\n", + "cluster 4, nb of observations 216\n", "Maioni, Miss. Roberta\n", "Daniels, Miss. Sarah\n", "Ford, Miss. Robina Maggie 'Ruby'\n", @@ -2243,8 +2438,10 @@ } ], "source": [ + "clusters = clustering_pipeline.predict(Xtrain)\n", + "\n", "for cl in range(5):\n", - " print(f\"cluster {cl}\")\n", + " print(f\"cluster {cl}, nb of observations {np.sum(clusters == cl)}\")\n", " for name in Xtrain.loc[clusters == cl, \"name\"].head(10):\n", " print(name)\n", " \n", @@ -2255,27 +2452,87 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now let's stop before the clustering\n" + "* So clusters 0 corresponds to 'William' \n", + "* clusters 1 and 2 : more 'Mr'\n", + "* cluster 3 seems to have found the name containing 'Mrs.' (and maybe longer name)\n", + "* cluster 4 seem to have more 'miss'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's stop before the clustering :" ] }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { + "image/svg+xml": [ + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "%3\r\n", + "\r\n", + "\r\n", + "vect\r\n", + "\r\n", + "vect\r\n", + "\r\n", + "\r\n", + "svd\r\n", + "\r\n", + "svd\r\n", + "\r\n", + "\r\n", + "vect->svd\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n", + "\r\n" + ], "text/plain": [ - "(1048, 200)" + "" ] }, - "execution_count": 46, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "subpipeline = clustering_pipeline.get_subpipeline(\"svd\")\n", + "subpipeline.graphviz" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1048, 200)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ "Xbefore_svd = subpipeline.transform(Xtrain)\n", "Xbefore_svd.shape" ] @@ -2284,11 +2541,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This matrix can be used to compute euclidian distance on. This KMeans uses Euclidian distance we can expect that :\n", + "This matrix can be used to compute euclidian distance on. Since KMeans uses Euclidian distance we can expect that :\n", " * if your KMeans make sense ...\n", " * ... then our distance would also make sense\n", "\n", - "Remark : in our experience this type of clustering (Bag Of Char, SVD, KMeans) work relatively well on 'dirty' text (where the notion of words doesn't make much sense). It can be a relatively simple way to test some clustering algorithms" + "Remark: in our experience this type of clustering (Bag Of Char, SVD, KMeans) work relatively well on 'dirty' text (where the notion of words doesn't make much sense). It can be a relatively simple way to test some clustering algorithms" ] }, { @@ -2303,7 +2560,171 @@ { "cell_type": "markdown", "metadata": {}, - "source": [] + "source": [ + "## sklearn to aikit\n", + "You can take an existing sklearn pipeline and convert it to an aikit one (this work on an un-fitted pipeline and a fitted one)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### ... un-fitted pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Pipeline(steps=[('vect',\n", + " CountVectorizerWrapper(analyzer='char',\n", + " columns_to_use=['name'],\n", + " ngram_range=(1, 4))),\n", + " ('svd',\n", + " TruncatedSVDWrapper(n_components=200, random_state=123)),\n", + " ('km', KMeans(n_clusters=5, random_state=123))])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.pipeline import Pipeline\n", + "\n", + "clustering_pipeline = Pipeline([(\"vect\", CountVectorizerWrapper(analyzer=\"char\",\n", + " ngram_range=(1, 4),\n", + " columns_to_use=[\"name\"])),\n", + " (\"svd\", TruncatedSVDWrapper(n_components=200, random_state=123)),\n", + " (\"km\", KMeans(n_clusters=5, random_state=123))])\n", + "clustering_pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GraphPipeline(edges=[('vect', 'svd', 'km')],\n", + " models={'km': KMeans(n_clusters=5, random_state=123),\n", + " 'svd': TruncatedSVDWrapper(n_components=200,\n", + " random_state=123),\n", + " 'vect': CountVectorizerWrapper(analyzer='char',\n", + " columns_to_use=['name'],\n", + " ngram_range=(1, 4))})" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aikit_pipeline = GraphPipeline.from_sklearn(clustering_pipeline)\n", + "aikit_pipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### ... fitted pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Pipeline(steps=[('vect',\n", + " CountVectorizerWrapper(analyzer='char',\n", + " columns_to_use=['name'],\n", + " ngram_range=(1, 4))),\n", + " ('svd',\n", + " TruncatedSVDWrapper(n_components=200, random_state=123)),\n", + " ('km', KMeans(n_clusters=5, random_state=123))])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "clustering_pipeline = Pipeline([(\"vect\", CountVectorizerWrapper(analyzer=\"char\",\n", + " ngram_range=(1, 4),\n", + " columns_to_use=[\"name\"])),\n", + " (\"svd\", TruncatedSVDWrapper(n_components=200, random_state=123)),\n", + " (\"km\", KMeans(n_clusters=5, random_state=123))])\n", + "clustering_pipeline.fit(Xtrain)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GraphPipeline(edges=[('vect', 'svd', 'km')],\n", + " models={'km': KMeans(n_clusters=5, random_state=123),\n", + " 'svd': TruncatedSVDWrapper(n_components=200,\n", + " random_state=123),\n", + " 'vect': CountVectorizerWrapper(analyzer='char',\n", + " columns_to_use=['name'],\n", + " ngram_range=(1, 4))})" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aikit_pipeline = GraphPipeline.from_sklearn(clustering_pipeline)\n", + "aikit_pipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "prediction are the sames:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred_from_aikit=aikit_pipeline.predict(Xtrain)\n", + "pred_from_sklearn=clustering_pipeline.predict(Xtrain)\n", + "(pred_from_aikit == pred_from_sklearn).all()" + ] }, { "cell_type": "code",