diff --git a/dist/strtree-0.1.0-py3-none-any.whl b/dist/strtree-0.1.0-py3-none-any.whl deleted file mode 100644 index 559e11a..0000000 Binary files a/dist/strtree-0.1.0-py3-none-any.whl and /dev/null differ diff --git a/dist/strtree-0.1.0.tar.gz b/dist/strtree-0.1.0.tar.gz deleted file mode 100644 index 76fbd8d..0000000 Binary files a/dist/strtree-0.1.0.tar.gz and /dev/null differ diff --git a/dist/strtree-0.2.0-py3-none-any.whl b/dist/strtree-0.2.0-py3-none-any.whl new file mode 100644 index 0000000..a650d90 Binary files /dev/null and b/dist/strtree-0.2.0-py3-none-any.whl differ diff --git a/dist/strtree-0.2.0.tar.gz b/dist/strtree-0.2.0.tar.gz new file mode 100644 index 0000000..533cc7b Binary files /dev/null and b/dist/strtree-0.2.0.tar.gz differ diff --git a/dummy_example.ipynb b/dummy_example.ipynb index da73f78..1c7244f 100644 --- a/dummy_example.ipynb +++ b/dummy_example.ipynb @@ -10,6 +10,14 @@ "from src import strtree" ] }, + { + "cell_type": "markdown", + "id": "b2705c25-c6c5-4a22-87e7-e32951dbd225", + "metadata": {}, + "source": [ + "# One-class classification" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -17,7 +25,7 @@ "metadata": {}, "outputs": [], "source": [ - "STRINGS = ['Samsung X-500', 'Samsung SM-10', 'Samsung X-1100', 'Samsung F-10', 'Samsung X-2200',\n", + "strings = ['Samsung X-500', 'Samsung SM-10', 'Samsung X-1100', 'Samsung F-10', 'Samsung X-2200',\n", " 'AB Nokia 1', 'DG Nokia 2', 'THGF Nokia 3', 'SFSD Nokia 4', 'Nokia XG', 'Nokia YO']" ] }, @@ -28,7 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "TARGET = [1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0]" + "labels = [1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0]" ] }, { @@ -61,7 +69,15 @@ ], "source": [ "tree = strtree.StringTree()\n", - "tree.build(STRINGS, TARGET, min_precision=0.9, min_token_length=1, verbose=True)" + "tree.build(strings, labels, min_precision=0.9, min_token_length=1, verbose=True)" + ] + }, + { + "cell_type": "markdown", + "id": "e4f21130-9628-4549-a752-7a2d260684bf", + "metadata": {}, + "source": [ + "All found patterns:" ] }, { @@ -73,8 +89,8 @@ { "data": { "text/plain": [ - "[PatternNode(\".+ .+a.+\\d$\", right=None, left=PatternNode(.+0.+), n_strings=11, precision=1.0, recall=0.5714285714285714),\n", - " PatternNode(\".+0.+\", right=None, left=None, n_strings=7, precision=1.0, recall=1.0)]" + "[PatternNode(\".+ .+a.+\\d$\", right=None, left=PatternNode(.+0.+), n_strings=11, n_matches=4, precision=1.0, recall=0.5714285714285714),\n", + " PatternNode(\".+0.+\", right=None, left=None, n_strings=7, n_matches=3, precision=1.0, recall=1.0)]" ] }, "execution_count": 5, @@ -86,16 +102,24 @@ "tree.leaves" ] }, + { + "cell_type": "markdown", + "id": "6fc92f86-1ada-4832-8329-a245db593626", + "metadata": {}, + "source": [ + "Filter out strings not matching the tree:" + ] + }, { "cell_type": "code", "execution_count": 6, - "id": "32990a3a", + "id": "281681b2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "PatternNode(\".+0.+\", right=None, left=None, n_strings=7, precision=1.0, recall=1.0)" + "['Nokia A-100']" ] }, "execution_count": 6, @@ -104,19 +128,27 @@ } ], "source": [ - "tree.root.left" + "tree.filter(['Nokia A-100', 'String Outside Of Dataset'])" + ] + }, + { + "cell_type": "markdown", + "id": "8ce2654f-c781-4597-8353-fde8949a2777", + "metadata": {}, + "source": [ + "Get the matching flags for each string:" ] }, { "cell_type": "code", "execution_count": 7, - "id": "281681b2", + "id": "e93d04e3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['Nokia A-100']" + "[1, 0]" ] }, "execution_count": 7, @@ -125,19 +157,27 @@ } ], "source": [ - "tree.filter(['Nokia A-100'])" + "tree.match(['Nokia A-100', 'String Outside Of Dataset'])" + ] + }, + { + "cell_type": "markdown", + "id": "74f2f146-19fa-46ff-9b6f-16dbdd97d8e7", + "metadata": {}, + "source": [ + "Get the precision score for given strings and labels:" ] }, { "cell_type": "code", "execution_count": 8, - "id": "e93d04e3", + "id": "a336107b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[1]" + "1.0" ] }, "execution_count": 8, @@ -146,13 +186,21 @@ } ], "source": [ - "tree.match(['Nokia A-100'])" + "tree.precision_score(strings, labels)" + ] + }, + { + "cell_type": "markdown", + "id": "ddbff48a-9deb-41d2-812d-6f59fba47e50", + "metadata": {}, + "source": [ + "Get the recall score for given strings and labels:" ] }, { "cell_type": "code", "execution_count": 9, - "id": "a336107b", + "id": "3c0f67c2", "metadata": {}, "outputs": [ { @@ -167,19 +215,27 @@ } ], "source": [ - "tree.precision_score(STRINGS, TARGET)" + "tree.recall_score(strings, labels)" + ] + }, + { + "cell_type": "markdown", + "id": "93d48b32-5240-4f1b-8eb9-a5ef153594e8", + "metadata": {}, + "source": [ + "Predict labels for given strings:" ] }, { "cell_type": "code", "execution_count": 10, - "id": "3c0f67c2", + "id": "091441b8-de27-4564-84c7-3638e27e2f3b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "1.0" + "[1, None]" ] }, "execution_count": 10, @@ -188,7 +244,249 @@ } ], "source": [ - "tree.recall_score(STRINGS, TARGET)" + "tree.predict_label(['Nokia A-100', 'String Outside Of Dataset'])" + ] + }, + { + "cell_type": "markdown", + "id": "c4fb1679-4746-40c9-aa7a-96059894d1c5", + "metadata": {}, + "source": [ + "Find all regular expressions for a given label:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "30089819-1327-4f24-8085-13ba1b16c65f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([PatternNode(\".+ .+a.+\\d$\", right=None, left=PatternNode(.+0.+), n_strings=11, n_matches=4, precision=1.0, recall=0.5714285714285714),\n", + " PatternNode(\".+0.+\", right=None, left=None, n_strings=7, n_matches=3, precision=1.0, recall=1.0)],\n", + " dtype=object)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tree.get_nodes_by_label(1)" + ] + }, + { + "cell_type": "markdown", + "id": "4c31f778-c390-4aa5-8e79-ac7b8e033d27", + "metadata": {}, + "source": [ + "# Multi-class classification" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8f510b21-b65e-4ae9-b34b-e5622c5a2d23", + "metadata": {}, + "outputs": [], + "source": [ + "strings = ['Admiral', 'Apple', 'Age',\n", + " 'Bee', 'Bubble', 'Butter',\n", + " 'Color', 'Climate', 'CPU']\n", + "\n", + "labels = [0, 0, 0,\n", + " 1, 1, 1,\n", + " 2, 2, 2]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "5583620a-e3b6-432e-98a6-ff741b707e0d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total: 9 strings with 9 positive labels\n", + "\n", + "Start processing another 9 of strings with 3 classes.\n", + "Current pattern=\"\". N matches: 9, Precision=[0.3333333333333333, 0.3333333333333333, 0.3333333333333333], Recall=[1.0, 1.0, 1.0]\n", + "Best pattern=\"^A.+\". N matches: 3, Precision=[1.0, 0.0, 0.0], Recall=[1.0, 0.0, 0.0]\n", + "Last pattern was saved\n", + "\n", + "Start processing another 6 of strings with 3 classes.\n", + "Current pattern=\"\". N matches: 6, Precision=[0.0, 0.5, 0.5], Recall=[0.0, 1.0, 1.0]\n", + "Best pattern=\"^B.+\". N matches: 3, Precision=[0.0, 1.0, 0.0], Recall=[0.0, 1.0, 0.0]\n", + "Last pattern was saved\n", + "\n", + "Start processing another 3 of strings with 3 classes.\n", + "Current pattern=\"\". N matches: 3, Precision=[0.0, 0.0, 1.0], Recall=[0.0, 0.0, 1.0]\n", + "Best pattern=\"^C.+\". N matches: 3, Precision=[0.0, 0.0, 1.0], Recall=[0.0, 0.0, 1.0]\n", + "Last pattern was saved\n", + "\n", + "Finished\n" + ] + } + ], + "source": [ + "tree = strtree.StringTree()\n", + "tree.build(strings, labels, min_precision=0.9, min_token_length=1, verbose=True)" + ] + }, + { + "cell_type": "markdown", + "id": "b9f42c12-c4cd-4a06-96a0-29255f64e302", + "metadata": {}, + "source": [ + "All found patterns:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5aada727-5d23-460f-9eec-016289fe8d83", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[PatternNode(\"^A.+\", right=None, left=PatternNode(^B.+), n_strings=9, n_matches=3, precision=[1.0, 0.0, 0.0], recall=[1.0, 0.0, 0.0]),\n", + " PatternNode(\"^B.+\", right=None, left=PatternNode(^C.+), n_strings=6, n_matches=3, precision=[0.0, 1.0, 0.0], recall=[0.0, 1.0, 0.0]),\n", + " PatternNode(\"^C.+\", right=None, left=None, n_strings=3, n_matches=3, precision=[0.0, 0.0, 1.0], recall=[0.0, 0.0, 1.0])]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tree.leaves" + ] + }, + { + "cell_type": "markdown", + "id": "37b778a0-3366-4388-903a-509512c4a5ff", + "metadata": {}, + "source": [ + "Filter out strings not matching the tree:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "39c86229-50d0-4fb9-ab5d-15b3ef6644af", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Ananas']" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tree.filter(['Ananas', 'Zeta'])" + ] + }, + { + "cell_type": "markdown", + "id": "5ec33c8c-4eef-4cba-b930-451133d43a58", + "metadata": {}, + "source": [ + "Get the matching flags for each string (with nodes where a match was found):" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "dad75f40-d4c5-4351-a211-cae9b1ac1902", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([1, 0],\n", + " [PatternNode(\"^A.+\", right=None, left=PatternNode(^B.+), n_strings=9, n_matches=3, precision=[1.0, 0.0, 0.0], recall=[1.0, 0.0, 0.0]),\n", + " None])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tree.match(['Ananas', 'Zeta'], return_nodes=True)" + ] + }, + { + "cell_type": "markdown", + "id": "9b0dcb3d-28a9-4ede-8e7a-b6d2cfcb2ed3", + "metadata": {}, + "source": [ + "Predict labels for given strings:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "b1637f1b-f74a-451f-b39b-d858f4d4244e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([0],\n", + " [PatternNode(\"^A.+\", right=None, left=PatternNode(^B.+), n_strings=9, n_matches=3, precision=[1.0, 0.0, 0.0], recall=[1.0, 0.0, 0.0])])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tree.predict_label(['Ananas'], return_nodes=True)" + ] + }, + { + "cell_type": "markdown", + "id": "97cfc0ac-4dd2-4ae1-a716-8e72a9396967", + "metadata": {}, + "source": [ + "Find all regular expressions for a given label:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "2501b421-6438-4635-8a66-31507b68e689", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([PatternNode(\"^A.+\", right=None, left=PatternNode(^B.+), n_strings=9, n_matches=3, precision=[1.0, 0.0, 0.0], recall=[1.0, 0.0, 0.0])],\n", + " dtype=object)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tree.get_nodes_by_label(0)" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 4c4ea2d..fe55f1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "strtree" -version = "0.1.2" +version = "0.2.0" authors = [ { name="Anton Saroka", email="anton.soroka.1313@gmail.com" }, ] diff --git a/site/assets/_mkdocstrings.css b/site/assets/_mkdocstrings.css index 049a254..4b7d98b 100644 --- a/site/assets/_mkdocstrings.css +++ b/site/assets/_mkdocstrings.css @@ -26,39 +26,84 @@ float: right; } -/* Keep headings consistent. */ -h1.doc-heading, -h2.doc-heading, -h3.doc-heading, -h4.doc-heading, -h5.doc-heading, -h6.doc-heading { - font-weight: 400; - line-height: 1.5; - color: inherit; - text-transform: none; +/* Symbols in Navigation and ToC. */ +:root, +[data-md-color-scheme="default"] { + --doc-symbol-attribute-fg-color: #953800; + --doc-symbol-function-fg-color: #8250df; + --doc-symbol-method-fg-color: #8250df; + --doc-symbol-class-fg-color: #0550ae; + --doc-symbol-module-fg-color: #5cad0f; + + --doc-symbol-attribute-bg-color: #9538001a; + --doc-symbol-function-bg-color: #8250df1a; + --doc-symbol-method-bg-color: #8250df1a; + --doc-symbol-class-bg-color: #0550ae1a; + --doc-symbol-module-bg-color: #5cad0f1a; +} + +[data-md-color-scheme="slate"] { + --doc-symbol-attribute-fg-color: #ffa657; + --doc-symbol-function-fg-color: #d2a8ff; + --doc-symbol-method-fg-color: #d2a8ff; + --doc-symbol-class-fg-color: #79c0ff; + --doc-symbol-module-fg-color: #baff79; + + --doc-symbol-attribute-bg-color: #ffa6571a; + --doc-symbol-function-bg-color: #d2a8ff1a; + --doc-symbol-method-bg-color: #d2a8ff1a; + --doc-symbol-class-bg-color: #79c0ff1a; + --doc-symbol-module-bg-color: #baff791a; +} + +code.doc-symbol { + border-radius: .1rem; + font-size: .85em; + padding: 0 .3em; + font-weight: bold; +} + +code.doc-symbol-attribute { + color: var(--doc-symbol-attribute-fg-color); + background-color: var(--doc-symbol-attribute-bg-color); +} + +code.doc-symbol-attribute::after { + content: "attr"; +} + +code.doc-symbol-function { + color: var(--doc-symbol-function-fg-color); + background-color: var(--doc-symbol-function-bg-color); +} + +code.doc-symbol-function::after { + content: "func"; } -h1.doc-heading { - font-size: 1.6rem; +code.doc-symbol-method { + color: var(--doc-symbol-method-fg-color); + background-color: var(--doc-symbol-method-bg-color); } -h2.doc-heading { - font-size: 1.2rem; +code.doc-symbol-method::after { + content: "meth"; } -h3.doc-heading { - font-size: 1.15rem; +code.doc-symbol-class { + color: var(--doc-symbol-class-fg-color); + background-color: var(--doc-symbol-class-bg-color); } -h4.doc-heading { - font-size: 1.10rem; +code.doc-symbol-class::after { + content: "class"; } -h5.doc-heading { - font-size: 1.05rem; +code.doc-symbol-module { + color: var(--doc-symbol-module-fg-color); + background-color: var(--doc-symbol-module-bg-color); } -h6.doc-heading { - font-size: 1rem; +code.doc-symbol-module::after { + content: "mod"; } \ No newline at end of file diff --git a/site/index.html b/site/index.html index 74333ce..e153b1c 100644 --- a/site/index.html +++ b/site/index.html @@ -242,5 +242,5 @@ diff --git a/site/objects.inv b/site/objects.inv index 89461c1..4f3d862 100644 Binary files a/site/objects.inv and b/site/objects.inv differ diff --git a/site/pattern_reference/index.html b/site/pattern_reference/index.html index 5d9d56c..3bb0010 100644 --- a/site/pattern_reference/index.html +++ b/site/pattern_reference/index.html @@ -89,15 +89,15 @@