From 9f1ebc6d09efe48d44de720dba24db975628bb3a Mon Sep 17 00:00:00 2001 From: Kade Heckel Date: Sat, 4 Nov 2023 22:54:36 +0000 Subject: [PATCH] rsnn borked --- .../nir/rsnn/braille_subtract_spyx.ipynb | 248 ++++++++---------- spyx/nir.py | 4 +- 2 files changed, 108 insertions(+), 144 deletions(-) diff --git a/docs/examples/nir/rsnn/braille_subtract_spyx.ipynb b/docs/examples/nir/rsnn/braille_subtract_spyx.ipynb index 915a7d4..8b46db3 100644 --- a/docs/examples/nir/rsnn/braille_subtract_spyx.ipynb +++ b/docs/examples/nir/rsnn/braille_subtract_spyx.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 19, + "execution_count": 1, "id": "b256e6d3", "metadata": {}, "outputs": [], @@ -12,6 +12,7 @@ "\n", "import jax\n", "import jax.numpy as jnp\n", + "import numpy as np\n", "\n", "import nir\n", "# for loading dataset\n", @@ -22,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 2, "id": "6cab3896", "metadata": {}, "outputs": [], @@ -32,57 +33,7 @@ }, { "cell_type": "code", - "execution_count": 21, - "id": "56d6ab66", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['fc1', 'fc2', 'input', 'lif1.lif', 'lif1.w_rec', 'lif2', 'output'])" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ng.nodes.keys()" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "0af61bbe", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['fc1', 'fc2', 'input', 'lif1.lif', 'lif1.w_rec', 'lif2', 'output'])" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ng.nodes.keys()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43b91141", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 23, + "execution_count": 3, "id": "e47b4a4c", "metadata": {}, "outputs": [], @@ -92,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 4, "id": "28708916", "metadata": {}, "outputs": [], @@ -102,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 5, "id": "56fd624a", "metadata": {}, "outputs": [], @@ -112,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 6, "id": "aa50f16d", "metadata": {}, "outputs": [], @@ -122,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 7, "id": "89122ee8", "metadata": {}, "outputs": [ @@ -132,7 +83,7 @@ "(140, 256, 12)" ] }, - "execution_count": 27, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -143,38 +94,7 @@ }, { "cell_type": "code", - "execution_count": 44, - "id": "03c8d15e", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "b08d25b2", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'output': array([55])}" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ng.nodes[\"lif1.lif\"].output_type" - ] - }, - { - "cell_type": "code", - "execution_count": 48, + "execution_count": 25, "id": "65c5c5e4", "metadata": {}, "outputs": [], @@ -188,13 +108,13 @@ " ('lif1.w_rec', 'lif1.lif'),\n", " ('lif1.lif', 'output')\n", "]\n", - "subgraph.nodes[\"output\"].output_type['output'] = np.array([55])\n", - "subgraph = nir.NIRGraph(subgraph_nodes, subgraph_edges)" + "subgraph = nir.NIRGraph(subgraph_nodes, subgraph_edges)\n", + "subgraph.nodes[\"output\"].output_type['output'] = np.array([55])" ] }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 26, "id": "f3b086ce", "metadata": { "scrolled": true @@ -207,16 +127,6 @@ "[INFO] found RNN subgraph, replacing with NIRGraph node\n", "[INFO] subgraph edges: ('lif1.lif', 'lif1.w_rec'), ('lif1.w_rec', 'lif1.lif')\n", "found subgraph, trying to parse as RNN\n", - "HERE: [0.00016667 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667\n", - " 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667\n", - " 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667\n", - " 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667\n", - " 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667\n", - " 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667\n", - " 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667\n", - " 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667\n", - " 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667 0.00016667\n", - " 0.00016667]\n", "found subgraph, trying to parse as RNN\n" ] } @@ -227,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 27, "id": "1979a46e", "metadata": { "scrolled": true @@ -239,7 +149,7 @@ "dict_keys(['linear', 'RCuBaLIF'])" ] }, - "execution_count": 50, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -250,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 28, "id": "0701a1b6", "metadata": {}, "outputs": [ @@ -268,25 +178,27 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 30, "id": "ca2e5c15", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 57, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAABqCAYAAABOFLF1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAU60lEQVR4nO3deXAc1Z3A8e+ve3RbsrEtXzosCwtjG4xPSRYhUHZIwBh8kcQQEgNJTAWRcC4FRS3LbpElIWCzIQm2uAKBLMtljiyVDThkvZzBgDHYjiz5tiVLxpJsodOaefvHjGTJmkPHzLQ0+n2qpma6+033r9+8+Wn0uvu1GGNQSik1uFlOB6CUUqr/NJkrpVQM0GSulFIxQJO5UkrFAE3mSikVAzSZK6VUDOhXMheRi0SkVETKReSOcAWllFKqd6Sv55mLiA3sBC4EDgIfAVcYY7aHLzyllFI90Z9f5vlAuTFmtzGmFXgOWBKesJRSSvWGqx/vzQAOdJo+CBQEe0O8JJhEUvqxSaWUGnrqqf3SGJMerEx/krn4mdetz0ZEVgOrARJJpkAW9mOTSik19LxlXtwXqkx/ulkOAlmdpjOBilMLGWNKjDFzjTFz40jox+aUUkoF0p9k/hGQJyKTRCQeWAm8Fp6wlFJK9Uafu1mMMW0icgPwP4ANPGGM2Ra2yJRSSvVYf/rMMca8AbwRpliUUkr1kV4BqpRSMUCTuVJKxQBN5kopFQM0mSulVAzQZB6CKyebnY/kY592mtOhKKVUQJrMQ/AMT+EHRe8iiXrBk1Jq4OrXqYlDgeezHXxwThxw2OlQlFIqIP1lrpRSMUCTuVJKxYDBkcwLZ+C+YDZYNo3LC3BlTHA6oj5pXpyPnZfrnRChcVkBrqxMZ4PyMUXn4DlvVse0+4LZUDijY/rEN+dizZzmRGjO6vQ5ubIyaVxWAOIdMNTOy6X50nwAWi+aR8PlBR0PKzkZmXsWbQvndF+ly0XDigLssWOwJ0/qWEcg9uhRNKwoQOLiAXBlTOgShys3p8s67Gln0LJont91NS3JxzVpIq7x42hcXgCW7fezllnTO6ZbFs3DnpoXuH5830lXTjZNSzvFMWUyLZf4jyPSOuo4veuosa0XzcM660xHYoq0AZfMrcRExNW1K3/fLQbX3VVYiQmUrFlL7XnZDkXXPz9b8xwHlo0DQFxxPLTmYY4szArxrggSwUpOBqDitjba7q7pWJR4TyV7bjpZdNZ9n1D6o9QoB+i8zp/TkYVZPLTmYcQVB8CBZeO4ec0fAbjgl+/y+tq1HQ+ZMJbS4gSyf76z2zqtYSn8Yc2DNOTncPCy8dy65pmOxOxP88wcXlj7INbIEQDUfD2bdWsfwkrwHpQ/eOkE7lz7VMc6dl0xim8/8OfuK7Js/m3NY1RcnEHdeTmUrFmLlZSI6+4q9t1ycvTqGfdtoXT1yfsOrHzgDXav9D+UtpWU1PGdPPytDO5dUwKWDcDe74zh2gc3BNyvSLKGp/Hc2gdpnJvTZf7C+/+PslWxeWZan28b1xdpMtKEGs988bZafvfCJWTf817HPCsxESwLT2MjVmoqnoZG8LgjHW7YWampmOYWzInWk9NNTZi2NkfiaVs4h8ee+A+K85djjh3vqGPAm+Q9HjzNzd7plBRMWxumpcWRWJ3U/jkBSFISnvp67+u4eCQxAU99PVZKCmLbHe9x19d7k22nOu3MTkvD/VUDYtsd6wgcgI09LAX38ePe7bpcXeM4dTouHomPw9PQEHBfjMdgpSR7Yw/xWVspKZjWEx3t1t86PQ2NiCUB68cJ7XXcOVcM1nb8lnnxY2PM3GBlBlwyb1xewLDyY3i2/iNKUQ1ddno6tReezvCXPh10jVupoaQnyXzAnZqY/PKHeJwOYohwHzlC2h+PdL89lFJq0BlwfeZKKaV6T5O5UkrFAE3mSikVA0ImcxHJEpG3RWSHiGwTkRt980eKyJsiUuZ7js3zfZRSahDoyS/zNuBWY8xUoBAoFpFpwB3ARmNMHrDRNz1k2Hm5WCkpoQsqpWKaKzeny6iqdl4uVmoqVkrKyYsEoyBkMjfGVBpjPvG9rgd2ABnAEuApX7GngKURinHAkYQE1m18mqOXzwhdWCkV067/y5/Ze8NUwHvO/0Nv/YHqK8+iZvkMSjY+7b1OJgp6dZ65iOQAm4CzgP3GmBGdltUaY4J2tQQ7z9yeMpkVr7zDS0vOxb1zV49jcoo1cxqy9xDuumNOh6KUcpA140ykqgZ3VbV3euY05MBhcLsxORl4PtsB/byeJ6znmYvIMOAl4CZjzHEJcvnxKe9bDawGSCQ5cMG64/zi9WXk1Zb3NCRHebZsdzoEpdQAcOoFjl1yw5bo/djrUTIXkTi8ifxZY8zLvtlVIjLeGFMpIuOBan/vNcaUACXg/WUeaBvuqmpyb69m8F2kr5RSzuvJ2SwCPA7sMMas6bToNWCV7/Uq4NXwh6eUUqonenI2y7nA94EFIrLF91gE/AK4UETKgAt90+ElwuJttRz98Xy/i+2xY7h91+ddhm4d6Hauy6f1zYlOhxExTUvzWb1zd8dojLFi3/Nns+/5s7vNd03M4q7dW5C5ZwEgs6Zz1+4tuCYF/ozttDR+Vv4PWi7uPjzs/n8pIu+j7rcoFJeLFTuqqb2663fh2BuT2flE0K5UADznzeL2XZ9jjx3jd/nRH8/nsu1HO0ZePHhnETM+Cd2VaiUm8sOde2hYUUDNNfNZtv1Ix6iJnU3ZHMf+u4tCri/c2t7KZue64EMMx4qQ3SzGmHeAQJ9q8FGz+ssYHl9/CeM2H/M7foip/4qb1l1H9t79ODPuYO9l/wkaR04gnn1OhxIRw7Z9yb+uv4oJrX93OpSwGvWi/9NQPTV1FK+7nuyDu2gDrIojFK+7nqyjXwRcl6epmTvWXUt2aUW3dpv5tybePzyb0bzfZb5xu/ntuqVM+KSm69hFz44mqyH0aEbxe49w07rryKzf4nf56M3HeXTdpYw13tFKMzY1srGukPRT4ui2L60nuHfd98j8ohqTGMcj65cwznR/zzuPzSVzR1PIOMPtq6czyK4ZGp23A27URKWUUl0NylETVewQlwt8Y3yblhbvnXKMB9PWhiQk9GnY3c7r6FH5hARMq3ccbomP7xaHUv0lcfFg+em8cLuj2sZ0bBYVMc1vZFJStpGfbvsMLJusd+Io/fVs7Kl5lJRt7NPtu0b/bzKlv5ndo7KSkMA/73if41cUUHN1If+0fTPicjHib8MofWTwHGdRA9vyrQcpKdvY7XHs9egeG9NuFhUxLRfPo2GcC7vVMPzZD2hcVkBSVTOuL/ZwZMV00jds7/VFV01L80k80oq8uyV0YRHqripk9DsVGNuipnAsw5/9kKbL5pFwtBXrnR6sQ6kQ6lcWciKp+y/zYZVtxP/5o7BsY1DeaUgppVRXPUnm2s2ilFIxQJO5UkrFgIGdzEWQv2Z4+6S+MYdR754WtRHITtX2VjZ1P/BesCFx8aRsSqd5ceiLESpvLaLylamRDm9A2lkyj12/8n/BV6S5L5hN+nsjwj5MsZ2Wxtj30/B8bSYArowJZH+Ygj19CodvHrqftXLewE7mQGnZBBJq24g73soHZblEs4+/s93l40hsv/jAePi0fCLxda0h35dUbWjcNTzC0Q1MSfvjSKno2YBs4eaqb+G9slxwh/eCEeN2807ZZFxfeT9709LC2+VnIE0tJFV7aNg9ND9r5Tw9AKqUUgOcHgBVSqkhQpN5EPvvLorpQbGUUrFDL+cPYtR2N1XuDLJidFAspVTsGDS/zO0Rw73DjPbgDkfWOVNxZUwIvr4zTseePClomZQXPyTr5+/5XSZzpne5iWtPuSZNxJ52Rq/fFw3B6tg660xcWZkORDV4uDIzsGZ0HaLASk1F5p3td1hYe2pe0KFyI0FmTccePcr/MpcLmXc2VkoKrvHjsM7xf2aOzJqOnZ7eZZ49ZbL3+2TZHetQ0TVoknndt6byzIb1WMOGhSx71fN/oeynwb8k7kdaqFob16dYJCGBkg3r+fKy3o8tUnrvSDKePNSn7UZazSVTefLldVhJSd2WFT67le13Bf8DOdSV3pzF4ufe7TKv6bwzeXnD49ijRnYrP/bJw5TdF8WzXyyb+zc8TsUVU/wuttNH8/orv+dE/hTKr5/ENS+80b2QCD9/+UkOfS+v66rXN3B4TTyuMaN59ZUnaDlXT9GMtkFzNouVnIxMGIu7fE/Isq6cbMyxety1tYHLZEwAt5u2w1V9iseVm4On6giehoZevc8eOwaJi6Pt4MBL6MHq2JWViWlqwv3lUQciGxzsUSORYSm07TvQMc9KTEQyx+PetbfbTX1dmRmYEyc6bgQcDa7cHDxHjuKpr+++0LKxc7PxHKjwtoXhqbTt3d99HZMmYo7W4j5+/OS89u9TVTX26Tl4DlT0aVRM5Z+OzaKUUjFAT01USqkhosfJXERsEflURP7kmx4pIm+KSJnvufdHA5VSYSUul9+Drd0LivemCkPIqfsba/vfm1/mNwI7Ok3fAWw0xuQBG33TSikHTXwvjvI13W8UfaqGFfn8c+kHQ+asE9ekify6/O2TN96eM51fl78d9bOJIqlHyVxEMoFLgMc6zV4CPOV7/RSwNKyRKaV67YtfzeD0F5pDlhu+uZKb7i3G0xS6bCzwVH/JlffehrWnAgBrbyVX3nsbnuovHY4sfHp0AFREXgTuA1KB24wxi0WkzhgzolOZWmNMt64WEVkNrAZIJHnO12RRuGJXSqkhISwHQEVkMVBtjPm4L0EYY0qMMXONMXPjSOjLKpRSSoXQk8v5zwUuE5FFQCKQJiLPAFUiMt4YUyki44HonSyrlFKqi5C/zI0xdxpjMo0xOcBK4K/GmKuA14BVvmKrgFcjFqVSSqmg+nOe+S+AC0WkDLjQNx3TWi6Zx/ErC50OY9Co/24hTUvysZKTqbylCFfGBMz8c7x3X7q1yO8l7v1VfX0RMmd6t/n29ClU/bQIRKj7/nwqby3iyE+C3wXJlZPN4ZuL/J7CduyqQloWnTxrpOba+bgvmN3/HQihcVkBX32nEHG5OHxTEfbkScis6VQXFwFQe/V82hbOiXgcYWXZHL6xCHvK5KDFjv5oPp7zZ+EaN5bKW4uwUlN7vamjP4zO5+SEXiVzY8zfjDGLfa+PGmMWGmPyfM81kQlx4Ni3BBKuPux0GINGy1U1HLr8BJI6jN/c8Dtac8dScX4KJcUPU1L8MIxLD72S3hBhVfEbVM3vPt7J0TkjuaX4ecS2Oe2a/ZQUP8zlP/lr0NU1ThnDr4ofxUrpPlZN6jWH2HfZyQHJZl63lf3fjPwxocPfbqHxe3VIUhK/vOFxGqamU12YxrXX/zeIMPnH/2DfxX0bc8gpEufi7uJnqJs5Omi5wus+4cCCJNxZY3ik+DdYI0f0eltzrtvCgW/E5rE7vZxfRYfIybFJ2kdljETb67ydcMQQbH1OO3V/BmqcPdHb+Af7/vaSXs4/iLkyJnD+1qaAw5CeqmlpPlM/diEJCez9rxmUry3ENX4c529t8tvt4E/zpfnM+EQic9Pszl88Y/r8Rax8ZSp7/j1I94if9e58cg4718/rWwwDOWGcuj+DWW/jHwD7a087gwWfN+DKyWbPffNZ8HlDx6O9i/Gsjy0kwf9/ApW3FsHG8A0rrTenGKBMYxOPbbqAqXUVeHpQPvlQE69vmkueezPxm4eRWmMwTb511BymrQfrSKps5OVNBeS1be5v+BHj/vA0TjvYuy/y8E8TkJ5UgFK9IMcbWL9pAWc27Oa07bA+bUHHsskHm7Fa2nhlUz55bv/fpxG72tj7XhY5HAxPPNrNopRSA5t2syil1BChyVwppWKAJnOllIoBmsyVUioGaDJXSqkYoMlcKaVigCZzpZTqpdqr57OzxDs2T9nv51D/3UKsmdOo2DANOy3NkZgGVzK3bOq/W4grM8PpSAYUz/mzaFswyAZXigH2lMk0rChwOgzlgLYkiEtrASA5rRl3gmDibDKGHwNLQrw7MDs9neNXFPbp/qSDKplbSYm88sCD1Hw9y+lQBpRjt3+FdZcOJx9t+5aP4d77H3U6DOWA9EfeZ9LKrQBkLN/GiKffx3z0OWbBIdx1x/q83uZzstn0wG+xR3W7aVtIUb0CVESOAA1A7Nx4r29Go3WgdeCl9aB10C5YPUw0xgQdZjSqyRxARDaHuiw11mkdaB2003rQOmjX33oYVN0sSiml/NNkrpRSMcCJZF7iwDYHGq0DrYN2Wg9aB+36VQ9R7zNXSikVftrNopRSMSBqyVxELhKRUhEpF5E7orVdp4nIXhH5XES2iMhm37yRIvKmiJT5nnt/UukAJyJPiEi1iHzRaV7A/RaRO31to1REvuVM1OEVoA7uEZFDvvawRUQWdVoWi3WQJSJvi8gOEdkmIjf65g+1thCoHsLXHowxEX8ANrALyAXigc+AadHYttMPYC8w+pR59wN3+F7fAfzS6TgjsN9fB2YDX4Tab2Car00kAJN8bcV2eh8iVAf3ALf5KRurdTAemO17nQrs9O3rUGsLgeohbO0hWr/M84FyY8xuY0wr8BywJErbHoiWAE/5Xj8FLHUulMgwxmwCak6ZHWi/lwDPGWNajDF7gHK8bWZQC1AHgcRqHVQaYz7xva4HdgAZDL22EKgeAul1PUQrmWcABzpNHyT4jsQSA/xFRD4WkdW+eWONMZXg/ZCBMY5FF12B9nuotY8bRGSrrxumvXsh5utARHKAWcCHDOG2cEo9QJjaQ7SSub+RZ4bKaTTnGmNmAxcDxSLydacDGoCGUvt4BDgdmAlUAg/65sd0HYjIMOAl4CZjzPFgRf3Mi+V6CFt7iFYyPwh0Hh0rE6iI0rYdZYyp8D1XAxvw/qtUJSLjAXzPQ2WUrED7PWTahzGmyhjjNsZ4gEc5+a9zzNaBiMThTWDPGmNe9s0ecm3BXz2Esz1EK5l/BOSJyCQRiQdWAq9FaduOEZEUEUltfw18E/gC776v8hVbBbzqTIRRF2i/XwNWikiCiEwC8oC/OxBfxLUnMJ9leNsDxGgdiIgAjwM7jDFrOi0aUm0hUD2EtT1E8WjuIrxHcHcBdzl9dDlK+5yL94j0Z8C29v0GRgEbgTLf80inY43Avv8n3n8bT+D9lfHDYPsN3OVrG6XAxU7HH8E6+APwObDV94UdH+N18DW83QNbgS2+x6Ih2BYC1UPY2oNeAaqUUjFArwBVSqkYoMlcKaVigCZzpZSKAZrMlVIqBmgyV0qpGKDJXCmlYoAmc6WUigGazJVSKgb8P4mTrmZeCntVAAAAAElFTkSuQmCC\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAr8AAACxCAYAAADNjNW3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPHklEQVR4nO3db4hld3kH8O/TJCZtNJitZtn8aSNlCxWpa1kSIVDSpjGplW76IpJAZQuBLVSL0pYafaMtFEJppe8KKYZuqf9CNGQp4ppsFfGNZtdu1Rg1waZxu0u2moixL9JEn76YO2S6zp25M/fPzM75fGC59/zuuXOenfuc3z77u79zftXdAQCAIfiZrQ4AAAAWRfELAMBgKH4BABgMxS8AAIOh+AUAYDAUvwAADMZUxW9V3VpV36qqJ6vq7lkFBQAA81Cbvc9vVV2Q5NtJbk5yKsmjSe7s7m/MLjwAAJidC6d473VJnuzu7yRJVX08yYEkY4vfV9TFfUkuneKQAACwvufz3Pe6+7Xntk9T/F6V5Lsrtk8luf7cnarqUJJDSXJJfi7X101THBIAANb3SD/wn6u1TzPnt1Zp+6k5FN19b3fv7+79F+XiKQ4HAADTmab4PZXkmhXbVyc5PV04AAAwP9MUv48m2VtVr6uqVyS5I8mR2YQFAACzt+k5v939UlW9K8nRJBckua+7H5tZZAAAMGPTXPCW7v50kk/PKBYAAJgrK7wBADAYil8AAAZD8QsAwGAofgHYFo6ePpmjp09udRjADqf4BQBgMBS/AAAMxlS3OgOAWbnlyn1bHQIwAEZ+AQAYDMUvAACDofgFAGAwFL8AAAyG4hcAgMFQ/AIAMBiKXwAABkPxCwCww1gufDzFLwAAg6H4BQBgMNZd3riq7kvytiRnu/sNo7ZdST6R5NokTyV5e3c/N78wAQCYlOXCx5tk5Pcfk9x6TtvdSY51994kx0bbAACwra1b/Hb3F5I8e07zgSSHR88PJ7lttmEBAMDsbXbO7+7uPpMko8crxu1YVYeq6nhVHX8xL2zycAAAML25X/DW3fd29/7u3n9RLp734QAAYKzNFr/PVNWeJBk9np1dSAAAMB+bLX6PJDk4en4wyUOzCQcAAOZnkludfSzJjUleU1WnknwgyT1J7q+qu5I8neT2eQYJQ7RyZZ5Jb1mz3ns28zNhp1k+D5wDLJo+eHtYt/jt7jvHvHTTjGMBAIC5ssIbAACDse7I73bkawNY3Xrng/OFaeyUvvd8jp3zm9zbHoz8AgAwGIpfAAAGo7p7YQe7rHb19eU6OQAA5uuRfuBEd+8/t93ILwAAg6H4BQBgMBS/AAAMhuIXAIDBUPwCADAYil8AAAZD8QsAwGAofgEAGAzFLwAAg6H4BQBgMBS/AAAMhuIXAIDBUPzy/xw9fTJHT5/c6jAAAOZi3eK3qq6pqs9V1eNV9VhVvXvUvquqHq6qJ0aPl88/XAAA2LxJRn5fSvKn3f0rSd6c5J1V9fokdyc51t17kxwbbc+VUcn5u+XKfbnlyn1bHQYAwFysW/x295nu/sro+fNJHk9yVZIDSQ6Pdjuc5LY5xQgAADOxoTm/VXVtkjcl+VKS3d19JlkqkJNcMeY9h6rqeFUdfzEvTBkuAABs3oWT7lhVr0zyySTv6e4fVtVE7+vue5PcmySX1a7eTJDLfB0PAMA0Jhr5raqLslT4fqS7PzVqfqaq9oxe35Pk7HxCBACA2Zjkbg+V5MNJHu/uD6146UiSg6PnB5M8NPvwAABgdiaZ9nBDknck+VpVnRy1vT/JPUnur6q7kjyd5Pa5RMiGrLwbxnaZJrIc08p4Vrtrx3aJFxZlrXNjKOfDan3WovqxcceZ9K5Cm/nctuKORfP4HQ4tTxflfPk3fK39Jtl3q61b/Hb3F5OMm+B702zDAQCA+anuqa5B25DLaldfXzunXj6f/pcD40ybx86DnWlRI3uT5s8i83Sef/dZnS+b/Tmb+bvN8xuKefQf27FP2m7n01A80g+c6O7957Zb3hgAgMFQ/AIAMBimPQAAsOOY9gAAwOApfgEAGAzFLwAAg6H4BQBgMCZZ4Q1gQzayateiV4pa5H0wrYIFnO92Yj9m5BcAgMFQ/AIAMBju8wsAwI7jPr8AAAye4hcAgMFQ/AIAMBiKXwAABkPxCwDAYKxb/FbVJVX15ar696p6rKr+YtS+q6oerqonRo+Xzz9cAADYvElGfl9I8pvd/cYk+5LcWlVvTnJ3kmPdvTfJsdE2AABsW+sub9xLNwL+0WjzotGfTnIgyY2j9sNJPp/kvTOPcANWW4JvUcvybWbJ1Fkts7qRn7PW72hl+8q2lWb1e9yJyyWebzbzGWwmf3byZzzp72O72MgS0+P6gHPNI3/mTf+zva2XE7PMGbkwTBPN+a2qC6rqZJKzSR7u7i8l2d3dZ5Jk9HjFmPceqqrjVXX8xbwwo7ABAGDjNrTCW1W9OsmDSf44yRe7+9UrXnuuu9ec92uFNwA2YtKRue084n4+Wm3kf9pRfli0mazw1t0/yNL0hluTPFNVe5Jk9Hh2+jABAGB+Jrnbw2tHI76pqp9N8ltJvpnkSJKDo90OJnloTjECAMBMrDvtoap+NUsXtF2QpWL5/u7+y6r6+ST3J/mFJE8nub27n13rZ5n2AADAIoyb9jDJ3R6+muRNq7R/P4lKFgCA88a6xS/bnws9AAAmY3ljAAAGQ/ELAMBgmPawA5jqAAAwGSO/AAAMhuIXAIDBOO+nPbjTAUMwq6VCnS8ADJ2RXwAABmPdFd5myQpvAAAswrgV3oz8AgAwGIpfAAAGQ/ELAMBgKH4BABgMxS8AAIOh+AUAYDAUvwAADIbiFwCAwZi4+K2qC6rq36rqX0bbu6rq4ap6YvR4+fzCBACA6W1k5PfdSR5fsX13kmPdvTfJsdE2AABsWxdOslNVXZ3kd5L8VZI/GTUfSHLj6PnhJJ9P8t7ZhgcAw3T09MkkyS1X7hv72mZfnyaeWf5M2AqTjvz+XZI/T/KTFW27u/tMkower1jtjVV1qKqOV9XxF/PCNLECAMBUqrvX3qHqbUne2t1/VFU3Jvmz7n5bVf2gu1+9Yr/nunvNeb+X1a6+vm6aPmoAAFjDI/3Aie7ef277JNMebkjyu1X11iSXJLmsqv45yTNVtae7z1TVniRnZxsyAADM1rrTHrr7fd19dXdfm+SOJP/a3b+f5EiSg6PdDiZ5aG5RAgDADExzn997ktxcVU8kuXm0DQAA29ZEd3tY1t2fz9JdHdLd309iAi8AAOcNK7wBADAYil8AAAZD8QsAwGBsaM4vwLlWW/VpI6tPrbXfdrfWClzA9Oa9qtyiz2Gr5G0PRn4BABgMxS8AAIOx7vLGs2R5YwAAFmHc8sZGfgEAGAwXvLGjzerignE/Z1EXbrmwanbmccGJi1iAnWDSf2sW1efN6zhGfgEAGAzFLwAAg+GCNwAAdhwXvAEAMHiKXwAABkPxCwDAYCh+AQAYDMUvAACDMdEiF1X1VJLnk/w4yUvdvb+qdiX5RJJrkzyV5O3d/dx8wgQAgOltZOT3N7p734pbRtyd5Fh3701ybLQNAADb1jTTHg4kOTx6fjjJbVNHAwAAczRp8dtJPltVJ6rq0Khtd3efSZLR4xWrvbGqDlXV8ao6/mJemD5iAADYpInm/Ca5obtPV9UVSR6uqm9OeoDuvjfJvcnSCm+biBEAAGZiopHf7j49ejyb5MEk1yV5pqr2JMno8ey8ggQAgFlYt/itqkur6lXLz5O8JcnXkxxJcnC028EkD80rSAAAmIVJpj3sTvJgVS3v/9Hu/kxVPZrk/qq6K8nTSW6fX5gAADC9dYvf7v5Okjeu0v79JDfNIygAAJgHK7wBADAYk97tgR3i6OmTSZJbrtw38evLbWu9D/hpmzmf1jtHAWZpZZ80jfOpzzLyCwDAYFT34m69e1nt6uvLNGGA7WSt0Wbf/LATyOPz07Sf2yP9wInu3n9uu5FfAAAGQ/ELAMBgLHTaQ1X9d5L/SfK9hR2U7ew1kQsskQsskwsk8oCXTZMLv9jdrz23caHFb5JU1fHV5l8wPHKBZXKBZXKBRB7wsnnkgmkPAAAMhuIXAIDB2Iri994tOCbbk1xgmVxgmVwgkQe8bOa5sPA5vwAAsFVMewAAYDAUvwAADMZCi9+qurWqvlVVT1bV3Ys8Nlurqp6qqq9V1cmqOj5q21VVD1fVE6PHy7c6Tmavqu6rqrNV9fUVbWM/+6p636iP+FZV3bI1UTMPY3Lhg1X1X6O+4WRVvXXFa3Jhh6qqa6rqc1X1eFU9VlXvHrXrGwZkjTyYa7+wsDm/VXVBkm8nuTnJqSSPJrmzu7+xkADYUlX1VJL93f29FW1/neTZ7r5n9J+hy7v7vVsVI/NRVb+e5EdJ/qm73zBqW/Wzr6rXJ/lYkuuSXJnkkSS/3N0/3qLwmaExufDBJD/q7r85Z1+5sINV1Z4ke7r7K1X1qiQnktyW5A+ibxiMNfLg7Zljv7DIkd/rkjzZ3d/p7v9N8vEkBxZ4fLafA0kOj54fzlLCs8N09xeSPHtO87jP/kCSj3f3C939H0mezFLfwQ4wJhfGkQs7WHef6e6vjJ4/n+TxJFdF3zAoa+TBODPJg0UWv1cl+e6K7VNZ+y/IztJJPltVJ6rq0Khtd3efSZZOgCRXbFl0LNq4z14/MUzvqqqvjqZFLH/NLRcGoqquTfKmJF+KvmGwzsmDZI79wiKL31qlzX3WhuOG7v61JL+d5J2jrz/hXPqJ4fn7JL+UZF+SM0n+dtQuFwagql6Z5JNJ3tPdP1xr11Xa5MMOsUoezLVfWGTxeyrJNSu2r05yeoHHZwt19+nR49kkD2bpa4pnRvN9luf9nN26CFmwcZ+9fmJguvuZ7v5xd/8kyT/k5a8w5cIOV1UXZang+Uh3f2rUrG8YmNXyYN79wiKL30eT7K2q11XVK5LckeTIAo/PFqmqS0cT2VNVlyZ5S5KvZ+nzPzja7WCSh7YmQrbAuM/+SJI7quriqnpdkr1JvrwF8bEgy4XOyO9lqW9I5MKOVlWV5MNJHu/uD614Sd8wIOPyYN79woWbD3ljuvulqnpXkqNJLkhyX3c/tqjjs6V2J3lwKcdzYZKPdvdnqurRJPdX1V1Jnk5y+xbGyJxU1ceS3JjkNVV1KskHktyTVT777n6squ5P8o0kLyV5p6u5d44xuXBjVe3L0leXTyX5w0QuDMANSd6R5GtVdXLU9v7oG4ZmXB7cOc9+wfLGAAAMhhXeAAAYDMUvAACDofgFAGAwFL8AAAyG4hcAgMFQ/AIAMBiKXwAABuP/ACpom/wRae/2AAAAAElFTkSuQmCC\n", "text/plain": [ - "
" + "
" ] }, "metadata": { @@ -296,42 +208,96 @@ } ], "source": [ - "plt.imshow(a[0].T)" + "plt.figure(figsize=(12,4))\n", + "plt.imshow(a[0].T, interpolation=\"none\")" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "0afbbb38", + "metadata": {}, + "outputs": [], + "source": [ + "ng = nir.read(\"braille_subtract.nir\")" ] }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 69, + "id": "c52b6549", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] found RNN subgraph, replacing with NIRGraph node\n", + "[INFO] subgraph edges: ('lif1.lif', 'lif1.w_rec'), ('lif1.w_rec', 'lif1.lif')\n", + "found subgraph, trying to parse as RNN\n", + "found subgraph, trying to parse as RNN\n" + ] + } + ], + "source": [ + "SNN, params = spyx.nir.from_nir(ng, x, dt=1e-4, return_all_states=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "91f7cab8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "found subgraph, trying to parse as RNN\n" + ] + } + ], + "source": [ + "a, b = SNN.apply(params, x)" + ] + }, + { + "cell_type": "code", + "execution_count": 71, "id": "1ef835a1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Array([12, 46, 12, 46, 35, 12, 12, 46, 46, 35, 46, 35, 35, 35, 35, 22, 46,\n", - " 46, 35, 12, 35, 46, 46, 35, 46, 46, 46, 12, 46, 46, 12, 46, 30, 13,\n", - " 46, 35, 12, 35, 46, 35, 12, 46, 35, 46, 44, 46, 35, 35, 35, 12, 46,\n", - " 35, 35, 46, 46, 46, 35, 35, 46, 35, 35, 35, 35, 35, 35, 46, 12, 46,\n", - " 35, 12, 35, 35, 35, 46, 35, 46, 46, 35, 12, 35, 46, 35, 46, 12, 35,\n", - " 12, 46, 35, 46, 35, 35, 35, 46, 35, 35, 35, 46, 35, 12, 35, 35, 35,\n", - " 35, 35, 46, 12, 35, 12, 46, 46, 46, 46, 35, 46, 35, 12, 46, 35, 46,\n", - " 35, 12, 12, 46, 35, 35, 46, 35, 46, 12, 35, 46, 35, 46, 35, 46, 12,\n", - " 12, 46, 46, 12], dtype=int32)" + "" ] }, - "execution_count": 53, + "execution_count": 71, "metadata": {}, "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAArkAAADWCAYAAADVY5oSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAO6klEQVR4nO3db4xld1kH8O9jd7u1hQYa/oT+iaABEkJ0SyYlpglBqmxBI5qoaRMJGJL1BRhITBR8I74jRom+MCQrVDECDQEaCUGWihBCoqXbskL/ADa1kXUrhTSmFGNLy+OLvdMuy8zOnZlzZmZ/+/kkm71z7tlznnt+z/3Nd8+cuae6OwAAMJKf2O0CAABgakIuAADDEXIBABiOkAsAwHCEXAAAhiPkAgAwnH1zbPTCOtAX5ZI5Nr1lL/nZ/02SfPOrF+9yJecux3B+Gx1jY8BO03PATlqdc5Ll5p3/y/fzeD9Waz1Xc3xO7qV1Wb+yrpt8u9tx9OTxJMmhyw/uah3nMsdwfhsdY2PATtNzwE5anXOS5ead2/pzeaQfXjPkulwBAIDhCLkAAAxHyAUAYDhCLgAAwxFyAQAYjpALAMBwhFwAAIazVMitquur6htVdV9VvXPuogAAYDs2DLlVdUGSv0ryuiQvS3JjVb1s7sIAAGCrljmTe02S+7r7/u5+PMnNSd4wb1kAALB1y4TcK5J867SvTyyWAQDAnrRviXXWuh9w/9hKVYeTHE6Si3LxNssCAICtW+ZM7okkV5329ZVJTp65Uncf6e6V7l7ZnwNT1QcAAJu2TMi9PcmLq+pFVXVhkhuSfHLesgAAYOs2vFyhu5+oqrclOZrkgiQ3dffds1cGAABbtMw1uenuTyf59My1AADAJNzxDACA4Qi5AAAMR8gFAGA4Qi4AAMMRcgEAGI6QCwDAcIRcAACGU909+UYvrcv6lXXdpNs8evL4U48PXX5w2+vNXcduW61zoxqnfD1n2+d6+1m2TnaHcTu3bDQuaz1/tmVn29Zm6tnMPqeah07f1kavZzP/Zit1LnuMp6p9M6aqYyv73EyfnsvOldxwrrqtP5dH+uFa6zlncgEAGI6QCwDAcIRcAACGI+QCADAcIRcAgOEIuQAADEfIBQBgOEIuAADDEXIBABjOhiG3qm6qqoeq6q6dKAgAALZrmTO5f5vk+pnrAACAyWwYcrv7i0ke3oFaAABgEq7JBQBgOPum2lBVHU5yOEkuysVTbRYAADZtsjO53X2ku1e6e2V/Dky1WQAA2DSXKwAAMJxlPkLsI0n+JclLq+pEVb1l/rIAAGDrNrwmt7tv3IlCAABgKi5XAABgOEIuAADDEXIBABiOkAsAwHCEXAAAhiPkAgAwHCEXAIDhCLkAAAxnw5tB7BWHLj846XpT1HH05PEd2eeZ+1tvn9t9fru2ss2dOnZz2Onx3w3rvbbV5XP31F6z1dd7tl6Z8hhu9O+XHbepxnIz25ljn3O8R3f62Kz1ek5fPkfP7Ob3uXPRTr2f5jbiWDuTCwDAcIRcAACGI+QCADAcIRcAgOEIuQAADEfIBQBgOEIuAADDEXIBABiOkAsAwHA2DLlVdVVVfb6q7q2qu6vq7TtRGAAAbNUyt/V9Isnvd/edVfXMJHdU1a3dfc/MtQEAwJZseCa3ux/s7jsXj7+X5N4kV8xdGAAAbNUyZ3KfUlUvTHJ1ktvWeO5wksNJclEunqI2AADYkqV/8ayqnpHk40ne0d2PnPl8dx/p7pXuXtmfA1PWCAAAm7JUyK2q/TkVcD/U3Z+YtyQAANieZT5doZJ8IMm93f3e+UsCAIDtWeZM7rVJ3pjkNVV1fPHn9TPXBQAAW7bhL55195eS1A7UAgAAk3DHMwAAhiPkAgAwHCEXAIDhCLkAAAxHyAUAYDhCLgAAwxFyAQAYTnX35Bu9tC7rV9Z1k2+X7Tl68vhTjw9dfnDX6hjF6vGc41ju9bE622vfTO079Trn2M9a29xoP3t9XNcyZ5/vhmXH6Hx8XzOPvTzum5nH9upccFt/Lo/0w2vez8GZXAAAhiPkAgAwHCEXAIDhCLkAAAxHyAUAYDhCLgAAwxFyAQAYjpALAMBwhFwAAIazYcitqouq6stV9W9VdXdV/clOFAYAAFu1b4l1Hkvymu5+tKr2J/lSVf1jd//rzLUBAMCWbBhyu7uTPLr4cv/iT89ZFAAAbMdS1+RW1QVVdTzJQ0lu7e7b1ljncFUdq6pjP8hjE5cJAADLWyrkdveT3X0wyZVJrqmql6+xzpHuXunulf05MHGZAACwvE19ukJ3/0+SLyS5fo5iAABgCst8usJzq+pZi8c/meQXk3x95roAAGDLlvl0hRck+WBVXZBTofij3f2pecsCAICtW+bTFb6a5OodqAUAACbhjmcAAAxHyAUAYDhCLgAAwxFyAQAYjpALAMBwhFwAAIYj5AIAMJxlbgbBBo6ePP7U40OXH9yz+5mztt029xisbn+3j+FeqWOq/Z/LYzXVe3Cn5o+t2os1bcdujsHp2152nxutt9f753xgDPYuZ3IBABiOkAsAwHCEXAAAhiPkAgAwHCEXAIDhCLkAAAxHyAUAYDhCLgAAwxFyAQAYztIht6ouqKqvVNWn5iwIAAC2azNnct+e5N65CgEAgKksFXKr6sokv5zk/fOWAwAA27fsmdy/SPIHSX643gpVdbiqjlXVsR/ksSlqAwCALdkw5FbVryR5qLvvONt63X2ku1e6e2V/DkxWIAAAbNYyZ3KvTfKrVfVAkpuTvKaq/n7WqgAAYBs2DLnd/a7uvrK7X5jkhiT/3N2/PXtlAACwRT4nFwCA4ezbzMrd/YUkX5ilEgAAmIgzuQAADEfIBQBgOEIuAADDEXIBABiOkAsAwHCEXAAAhiPkAgAwnE19Ti4/6ujJ40mSQ5cf3JH9rbefZetYXW+Zdc81c7+etbY/5z7X2/ZeGbet9P5OvV/mGKuNat/Oe3Cq2jazrZHngtPN0XPb2eZmjvuy+5lj/Obuj6neT1Pucy/YjfflZvazl4/depzJBQBgOEIuAADDEXIBABiOkAsAwHCEXAAAhiPkAgAwHCEXAIDhCLkAAAxHyAUAYDhL3fGsqh5I8r0kTyZ5ortX5iwKAAC2YzO39f2F7v7ubJUAAMBEXK4AAMBwlg25neSzVXVHVR2esyAAANiuZS9XuLa7T1bV85LcWlVf7+4vnr7CIvweTpKLcvHEZQIAwPKWOpPb3ScXfz+U5JYk16yxzpHuXunulf05MG2VAACwCRuG3Kq6pKqeufo4yWuT3DV3YQAAsFXLXK7w/CS3VNXq+h/u7s/MWhUAAGzDhiG3u+9P8nM7UAsAAEzCR4gBADAcIRcAgOEIuQAADEfIBQBgOEIuAADDEXIBABiOkAsAwHCEXAAAhrPMHc/2hKMnjz/1+NDlB3etjtOda3XslXpHsdqTjuv4RhvjzbyeOfp8N987y34vmeN7znrbWet4rD7eqI4p6zzbuOzk9+A5tn+217Td/c3dHzttr9QxBWdyAQAYjpALAMBwhFwAAIYj5AIAMBwhFwCA4Qi5AAAMR8gFAGA4Qi4AAMNZKuRW1bOq6mNV9fWqureqfn7uwgAAYKuWvePZXyb5THf/RlVdmOTiGWsCAIBt2TDkVtWlSV6V5M1J0t2PJ3l83rIAAGDrlrlc4aeTfCfJ31TVV6rq/VV1ycx1AQDAli0TcvcleUWS93X31Um+n+SdZ65UVYer6lhVHftBHpu4TAAAWN4yIfdEkhPdfdvi64/lVOj9Ed19pLtXuntlfw5MWSMAAGzKhiG3u/87ybeq6qWLRdcluWfWqgAAYBuW/XSF30vyocUnK9yf5HfmKwkAALZnqZDb3ceTrMxbCgAATMMdzwAAGI6QCwDAcIRcAACGI+QCADAcIRcAgOEIuQAADEfIBQBgOEIuAADDqe6efqNV30ny/STfnXzjnIueE72APuBpeoFVeoFVW+2Fn+ru5671xCwhN0mq6lh3u0saeoEk+oCn6QVW6QVWzdELLlcAAGA4Qi4AAMOZM+QemXHbnFv0Aok+4Gl6gVV6gVWT98Js1+QCAMBucbkCAADDmTzkVtX1VfWNqrqvqt459fbZ26rqgar6WlUdr6pji2WXVdWtVfXvi7+fvdt1Mr2quqmqHqqqu05btu7YV9W7FvPEN6rq0O5UzRzW6YV3V9V/LeaG41X1+tOe0wsDqqqrqurzVXVvVd1dVW9fLDcvnGfO0guzzguTXq5QVRck+WaSX0pyIsntSW7s7nsm2wl7WlU9kGSlu7972rI/TfJwd79n8R+fZ3f3H+5Wjcyjql6V5NEkf9fdL18sW3Psq+plST6S5Joklyf5pyQv6e4nd6l8JrROL7w7yaPd/WdnrKsXBlVVL0jygu6+s6qemeSOJL+W5M0xL5xXztILv5UZ54Wpz+Rek+S+7r6/ux9PcnOSN0y8D849b0jywcXjD+ZUYzOY7v5ikofPWLze2L8hyc3d/Vh3/0eS+3Jq/mAA6/TCevTCoLr7we6+c/H4e0nuTXJFzAvnnbP0wnom6YWpQ+4VSb512tcncvYXwXg6yWer6o6qOrxY9vzufjA51ehJnrdr1bHT1ht7c8X56W1V9dXF5QyrP6LWC+eBqnphkquT3BbzwnntjF5IZpwXpg65tcYyH99wfrm2u1+R5HVJ3rr4sSWcyVxx/nlfkp9JcjDJg0n+fLFcLwyuqp6R5ONJ3tHdj5xt1TWW6YWBrNELs84LU4fcE0muOu3rK5OcnHgf7GHdfXLx90NJbsmpHy98e3E9zup1OQ/tXoXssPXG3lxxnunub3f3k939wyR/nad/9KgXBlZV+3Mq1Hyouz+xWGxeOA+t1QtzzwtTh9zbk7y4ql5UVRcmuSHJJyfeB3tUVV2yuKA8VXVJktcmuSuneuBNi9XelOQfdqdCdsF6Y//JJDdU1YGqelGSFyf58i7Uxw5ZDTULv55Tc0OiF4ZVVZXkA0nu7e73nvaUeeE8s14vzD0v7Nt6yT+uu5+oqrclOZrkgiQ3dffdU+6DPe35SW451cvZl+TD3f2Zqro9yUer6i1J/jPJb+5ijcykqj6S5NVJnlNVJ5L8cZL3ZI2x7+67q+qjSe5J8kSSt/oN6nGs0wuvrqqDOfUjxweS/G6iFwZ3bZI3JvlaVR1fLPujmBfOR+v1wo1zzgvueAYAwHDc8QwAgOEIuQAADEfIBQBgOEIuAADDEXIBABiOkAsAwHCEXAAAhiPkAgAwnP8Hlb0TWNB5/6IAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" } ], "source": [ - "preds = jnp.argmax(jnp.sum(a,axis=1), axis=1)\n", - "preds" + "plt.figure(figsize=(12,4))\n", + "plt.imshow(a[0].T, aspect=10, interpolation=\"none\")" ] }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 72, "id": "23e259bc", "metadata": { "scrolled": true @@ -349,7 +315,7 @@ " 3, 6, 3, 1, 6, 4, 3, 1], dtype=int32)" ] }, - "execution_count": 54, + "execution_count": 72, "metadata": {}, "output_type": "execute_result" } @@ -360,26 +326,24 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 73, "id": "464c5cb8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(Array(0., dtype=float32),\n", - " Array([12, 46, 12, 46, 35, 12, 12, 46, 46, 35, 46, 35, 35, 35, 35, 22, 46,\n", - " 46, 35, 12, 35, 46, 46, 35, 46, 46, 46, 12, 46, 46, 12, 46, 30, 13,\n", - " 46, 35, 12, 35, 46, 35, 12, 46, 35, 46, 44, 46, 35, 35, 35, 12, 46,\n", - " 35, 35, 46, 46, 46, 35, 35, 46, 35, 35, 35, 35, 35, 35, 46, 12, 46,\n", - " 35, 12, 35, 35, 35, 46, 35, 46, 46, 35, 12, 35, 46, 35, 46, 12, 35,\n", - " 12, 46, 35, 46, 35, 35, 35, 46, 35, 35, 35, 46, 35, 12, 35, 35, 35,\n", - " 35, 35, 46, 12, 35, 12, 46, 46, 46, 46, 35, 46, 35, 12, 46, 35, 46,\n", - " 35, 12, 12, 46, 35, 35, 46, 35, 46, 12, 35, 46, 35, 46, 35, 46, 12,\n", - " 12, 46, 46, 12], dtype=int32))" + "(Array(0.47857144, dtype=float32),\n", + " Array([1, 4, 4, 2, 5, 1, 2, 4, 4, 0, 4, 0, 5, 0, 0, 4, 4, 4, 2, 2, 0, 4,\n", + " 4, 0, 4, 2, 4, 1, 4, 4, 4, 4, 3, 0, 4, 0, 1, 0, 4, 0, 1, 4, 0, 4,\n", + " 5, 4, 0, 0, 0, 5, 4, 0, 0, 4, 4, 4, 5, 0, 4, 0, 5, 0, 0, 0, 0, 4,\n", + " 1, 4, 6, 1, 0, 0, 5, 4, 0, 4, 4, 0, 4, 5, 4, 0, 4, 1, 0, 1, 4, 0,\n", + " 4, 0, 0, 0, 4, 0, 5, 4, 2, 0, 1, 6, 0, 0, 5, 0, 4, 1, 6, 6, 4, 4,\n", + " 4, 4, 0, 4, 0, 1, 4, 0, 4, 0, 1, 1, 2, 6, 0, 4, 0, 4, 1, 6, 4, 0,\n", + " 4, 5, 4, 2, 5, 4, 4, 5], dtype=int32))" ] }, - "execution_count": 55, + "execution_count": 73, "metadata": {}, "output_type": "execute_result" } @@ -390,7 +354,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 28, "id": "5de1f45f", "metadata": {}, "outputs": [ @@ -400,7 +364,7 @@ "(280, 256, 55)" ] }, - "execution_count": 18, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } diff --git a/spyx/nir.py b/spyx/nir.py index 13cdce0..7f82c78 100644 --- a/spyx/nir.py +++ b/spyx/nir.py @@ -192,7 +192,7 @@ def _nir_node_to_spyx_node(node_pair: nir.NIRNode): if isinstance(lif_node, nir.IF): return RIF(lif_size, threshold=lif_node.v_leak) elif isinstance(lif_node, nir.LIF): - return RLIF(lif_nod.tau.shape, threshold=lif_node.v_threshold) + return RLIF(lif_node.tau.shape, threshold=lif_node.v_threshold) else: return RCuBaLIF(lif_node.tau_syn.shape, threshold=lif_node.v_threshold) @@ -327,7 +327,7 @@ def _nir_node_to_spyx_params(node_pair: nir.NIRNode, dt: float): return { "w": jnp.array(wrec_node.weight)*w_scale, "b": jnp.array(wrec_node.bias)*w_scale, - "beta": 1 - (dt / lif_node.tau_mem) + "beta": 1 - (dt / lif_node.tau) } else: # RCuBaLIF w_scale = dt / lif_node.tau_syn