Skip to content

Commit

Permalink
MS: Fix type safety warning
Browse files Browse the repository at this point in the history
  • Loading branch information
EltonCN committed Sep 20, 2024
1 parent 054b59a commit 1f52cf6
Showing 1 changed file with 149 additions and 52 deletions.
201 changes: 149 additions & 52 deletions dev/memory_storage_codelet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -66,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -87,25 +87,30 @@
" self._pubsub = self._client.pubsub()\n",
" self._pubsub_thread : redis.client.PubSubWorkerThread = self._pubsub.run_in_thread()\n",
"\n",
" if node_name is None:\n",
" node_number = self._client.scard(f\"{mind_name}:nodes\")\n",
" base_name = node_name\n",
" if base_name is None:\n",
" base_name = \"node\"\n",
"\n",
" node_name = f\"node{node_number}\"\n",
" \n",
" if self._client.sismember(f\"{mind_name}:nodes\", node_name):\n",
" node_number = self._client.scard(f\"{mind_name}:nodes\")\n",
" node_name = base_name+str(node_number)\n",
" while self._client.sismember(f\"{mind_name}:nodes\", node_name):\n",
" node_number += 1\n",
" node_name = f\"node{node_number}\"\n",
" node_name = base_name+str(node_number)\n",
" \n",
"\n",
" self._node_name = cast(str, node_name)\n",
"\n",
" self._client.sadd(f\"{mind_name}:nodes\", node_name)\n",
"\n",
" transfer_service_addr = f\"{self._mind_name}:nodes:{node_name}:transfer_memory\"\n",
" self._pubsub.subscribe(**{transfer_service_addr:self.transfer_memory})\n",
" self._pubsub.subscribe(**{transfer_service_addr:self._handler_transfer_memory})\n",
"\n",
" transfer_done_addr = f\"{self._mind_name}:nodes:{node_name}:transfer_done\"\n",
" self._pubsub.subscribe(**{transfer_done_addr:self.notify_transfer})\n",
" self._pubsub.subscribe(**{transfer_done_addr:self._handler_notify_transfer})\n",
"\n",
" self._last_update : dict[str, float] = {}\n",
" self._last_update : dict[str, int] = {}\n",
" self._waiting_retrieve : set[str] = set()\n",
" \n",
" self._retrieve_executor = ThreadPoolExecutor(3)\n",
Expand Down Expand Up @@ -141,13 +146,13 @@
" self._memories[memory_name] = memory\n",
"\n",
" if self._client.exists(f\"{self._mind_name}:memories:{memory_name}\"):\n",
" self._retrieve_executor.submit(self.retrieve_memory, memory)\n",
" self._retrieve_executor.submit(self._retrieve_memory, memory)\n",
" \n",
" else: #Send impostor with owner\n",
" memory_impostor = {\"name\":memory.get_name(),\n",
" \"evaluation\" : 0.0,\n",
" \"I\": \"\",\n",
" \"id\" : \"0.0\",\n",
" \"id\" : 0,\n",
" \"owner\": self._node_name}\n",
" \n",
" self._client.hset(f\"{self._mind_name}:memories:{memory_name}\", mapping=memory_impostor)\n",
Expand All @@ -166,27 +171,25 @@
" if memory.get_timestamp() > self._last_update[memory_name]:\n",
" self.update_memory(memory_name)\n",
"\n",
" def transfer_memory(self, message) -> None:\n",
" request = json.loads(message[\"data\"])\n",
" \n",
" memory_name = request[\"memory_name\"]\n",
" requesting_node = request[\"node\"]\n",
" def update_memory(self, memory_name:str) -> None:\n",
" print(self._node_name, \"Updating memory\", memory_name)\n",
"\n",
" print(self._node_name, \"Tranfering\", memory_name)\n",
" if memory_name not in self._memories:\n",
" self._pubsub.unsubscribe(f\"{self._mind_name}:memories:{memory_name}:update\")\n",
"\n",
" if memory_name in self._memories:\n",
" memory = self._memories[memory_name]\n",
" else:\n",
" memory = cst.MemoryObject()\n",
" memory.set_name(memory_name)\n",
" timestamp = float(self._client.hget(f\"{self._mind_name}:memories:{memory_name}\", \"timestamp\"))\n",
" memory = self._memories[memory_name]\n",
" memory_timestamp = memory.get_timestamp()\n",
" \n",
" self.send_memory(memory)\n",
" if memory_timestamp < timestamp:\n",
" self._retrieve_executor.submit(self._retrieve_memory, memory)\n",
"\n",
" response_addr = f\"{self._mind_name}:nodes:{requesting_node}:transfer_done\"\n",
" self._client.publish(response_addr, memory_name)\n",
" elif memory_timestamp> timestamp:\n",
" self._send_memory(memory)\n",
"\n",
" self._last_update[memory_name] = memory.get_timestamp()\n",
"\n",
" def send_memory(self, memory:Memory) -> None:\n",
" def _send_memory(self, memory:Memory) -> None:\n",
" memory_name = memory.get_name()\n",
" print(self._node_name, \"Send memory\", memory_name)\n",
" \n",
Expand All @@ -200,25 +203,8 @@
"\n",
" self._last_update[memory_name] = memory.get_timestamp()\n",
" \n",
" def update_memory(self, memory_name:str) -> None:\n",
" print(self._node_name, \"Updating memory\", memory_name)\n",
"\n",
" if memory_name not in self._memories:\n",
" self._pubsub.unsubscribe(f\"{self._mind_name}:memories:{memory_name}:update\")\n",
"\n",
" timestamp = float(self._client.hget(f\"{self._mind_name}:memories:{memory_name}\", \"timestamp\"))\n",
" memory = self._memories[memory_name]\n",
" memory_timestamp = memory.get_timestamp()\n",
" \n",
" if memory_timestamp < timestamp:\n",
" self._retrieve_executor.submit(self.retrieve_memory, memory)\n",
"\n",
" elif memory_timestamp> timestamp:\n",
" self.send_memory(memory)\n",
"\n",
" self._last_update[memory_name] = memory.get_timestamp()\n",
"\n",
" def retrieve_memory(self, memory:Memory) -> None:\n",
" def _retrieve_memory(self, memory:Memory) -> None:\n",
" memory_name = memory.get_name()\n",
"\n",
" print(self._node_name, \"Retrieve\", memory_name)\n",
Expand All @@ -232,18 +218,18 @@
" if memory_dict[\"owner\"] != \"\":\n",
" event = threading.Event()\n",
" self._waiting_request_events[memory_name] = event\n",
" self.request_memory(memory_name, memory_dict[\"owner\"])\n",
" self._request_memory(memory_name, memory_dict[\"owner\"])\n",
"\n",
" if not event.wait(timeout=self._request_timeout):\n",
" print(self._node_name, \"Request failed\", memory_name)\n",
" #Request failed\n",
" self.send_memory(memory)\n",
" self._send_memory(memory)\n",
" return \n",
" \n",
" memory_dict = self._client.hgetall(f\"{self._mind_name}:memories:{memory_name}\")\n",
"\n",
" memory.set_evaluation(float(memory_dict[\"evaluation\"]))\n",
" memory.set_id(float(memory_dict[\"id\"]))\n",
" memory.set_id(int(memory_dict[\"id\"]))\n",
"\n",
" info_json = memory_dict[\"I\"]\n",
" info = json.loads(info_json)\n",
Expand All @@ -256,7 +242,7 @@
"\n",
" self._waiting_retrieve.remove(memory_name)\n",
"\n",
" def request_memory(self, memory_name:str, owner_name:str):\n",
" def _request_memory(self, memory_name:str, owner_name:str) -> None:\n",
" print(self._node_name, \"Requesting\", memory_name)\n",
"\n",
" request_addr = f\"{self._mind_name}:nodes:{owner_name}:transfer_memory\"\n",
Expand All @@ -265,21 +251,74 @@
" request = json.dumps(request_dict)\n",
" self._client.publish(request_addr, request)\n",
"\n",
" def notify_transfer(self, message:str) -> None:\n",
" def _handler_notify_transfer(self, message:str) -> None:\n",
" memory_name = message[\"data\"]\n",
" if memory_name in self._waiting_request_events:\n",
" event = self._waiting_request_events[memory_name]\n",
" event.set()\n",
" del self._waiting_request_events[memory_name]\n",
"\n",
" def _handler_transfer_memory(self, message) -> None:\n",
" request = json.loads(message[\"data\"])\n",
" \n",
" memory_name = request[\"memory_name\"]\n",
" requesting_node = request[\"node\"]\n",
"\n",
" print(self._node_name, \"Tranfering\", memory_name)\n",
"\n",
" if memory_name in self._memories:\n",
" memory = self._memories[memory_name]\n",
" else:\n",
" memory = cst.MemoryObject()\n",
" memory.set_name(memory_name)\n",
" \n",
" self._send_memory(memory)\n",
"\n",
" response_addr = f\"{self._mind_name}:nodes:{requesting_node}:transfer_done\"\n",
" self._client.publish(response_addr, memory_name)\n",
"\n",
" def __del__(self) -> None:\n",
" self._pubsub_thread.stop()\n",
" self._retrieve_executor.shutdown(cancel_futures=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```mermaid\n",
"flowchart LR\n",
"\n",
"update[Update Memory]\n",
"send[Send Memory]\n",
"retrieve[Retrieve Memory]\n",
"request[Request Memory]\n",
"handler_notify_transfer[Handler: Notify Transfer]\n",
"handler_transfer_memory[Handler: Transfer Memory]\n",
"\n",
"\n",
"update --> |\"timestamp(MS) < timestamp(local)\"| send\n",
"update --> |\"timestamp(MS) > timestamp(local)\"| retrieve\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"handler_transfer_memory --> send\n",
"\n",
"subgraph retrieveContext\n",
"retrieve --> |owner is not empty| request\n",
"\n",
"request -.->|Wait transfer event| handler_notify_transfer\n",
"\n",
"end\n",
"\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -289,9 +328,18 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"node03 Retrieve Memory1\n",
"node03 INFO INFO \"INFO\"\n"
]
}
],
"source": [
"ms_codelet = MemoryStorageCodelet(mind, \"node0\")\n",
"ms_codelet.time_step = 100\n",
Expand All @@ -300,6 +348,55 @@
"mind.start()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MemoryObject [idmemoryobject=1, timestamp=1726858916840, evaluation=0.0, I=INFO, name=Memory1]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"memory1"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-1"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"node03 Updating memory Memory1\n",
"node03 Send memory Memory1\n",
"node03 Updating memory Memory1\n"
]
}
],
"source": [
"memory1.set_info(1.0)"
]
},
{
"cell_type": "code",
"execution_count": 7,
Expand Down

0 comments on commit 1f52cf6

Please sign in to comment.