Skip to content

Commit

Permalink
Move MS and MemoryEncoder to module
Browse files Browse the repository at this point in the history
  • Loading branch information
EltonCN committed Sep 27, 2024
1 parent 1f52cf6 commit b565114
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 261 deletions.
274 changes: 13 additions & 261 deletions dev/memory_storage_codelet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import weakref\n",
"import json\n",
"import threading\n",
"import time\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"from typing import Optional, cast, List\n",
"\n",
"import redis\n",
"\n",
"import cst_python as cst\n",
"from cst_python.core.entities import Memory, Mind"
"from cst_python.memory_storage import MemoryStorageCodelet"
]
},
{
Expand All @@ -41,247 +35,6 @@
"client.flushall()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class MemoryEncoder(json.JSONEncoder):\n",
" def default(self, memory:cst.core.entities.Memory):\n",
" return MemoryEncoder.to_dict(memory)\n",
" \n",
" @staticmethod\n",
" def to_dict(memory:cst.core.entities.Memory):\n",
" data = {\n",
" \"timestamp\": memory.get_timestamp(),\n",
" \"evaluation\": memory.get_evaluation(),\n",
" \"I\": memory.get_info(),\n",
" \"name\": memory.get_name(),\n",
" \"id\": memory.get_id()\n",
" }\n",
"\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class MemoryStorageCodelet(cst.Codelet):\n",
" def __init__(self, mind:Mind, node_name:Optional[str]=None, mind_name:Optional[str]=None, request_timeout:float=500e-3) -> None:\n",
" super().__init__()\n",
" \n",
" self._mind = mind\n",
" self._request_timeout = request_timeout\n",
" \n",
" if mind_name is None:\n",
" mind_name = \"default_mind\"\n",
" self._mind_name = cast(str, mind_name)\n",
" \n",
" self._memories : weakref.WeakValueDictionary[str, Memory] = weakref.WeakValueDictionary()\n",
" \n",
" self._client = redis.Redis(decode_responses=True)\n",
" self._pubsub = self._client.pubsub()\n",
" self._pubsub_thread : redis.client.PubSubWorkerThread = self._pubsub.run_in_thread()\n",
"\n",
" base_name = node_name\n",
" if base_name is None:\n",
" base_name = \"node\"\n",
"\n",
" \n",
" if self._client.sismember(f\"{mind_name}:nodes\", node_name):\n",
" node_number = self._client.scard(f\"{mind_name}:nodes\")\n",
" node_name = base_name+str(node_number)\n",
" while self._client.sismember(f\"{mind_name}:nodes\", node_name):\n",
" node_number += 1\n",
" node_name = base_name+str(node_number)\n",
" \n",
"\n",
" self._node_name = cast(str, node_name)\n",
"\n",
" self._client.sadd(f\"{mind_name}:nodes\", node_name)\n",
"\n",
" transfer_service_addr = f\"{self._mind_name}:nodes:{node_name}:transfer_memory\"\n",
" self._pubsub.subscribe(**{transfer_service_addr:self._handler_transfer_memory})\n",
"\n",
" transfer_done_addr = f\"{self._mind_name}:nodes:{node_name}:transfer_done\"\n",
" self._pubsub.subscribe(**{transfer_done_addr:self._handler_notify_transfer})\n",
"\n",
" self._last_update : dict[str, int] = {}\n",
" self._waiting_retrieve : set[str] = set()\n",
" \n",
" self._retrieve_executor = ThreadPoolExecutor(3)\n",
"\n",
" self._waiting_request_events : dict[str, threading.Event] = {}\n",
"\n",
" self._request = None\n",
"\n",
" def calculate_activation(self) -> None:\n",
" pass\n",
"\n",
" def access_memory_objects(self) -> None:\n",
" pass\n",
"\n",
" def proc(self) -> None:\n",
" \n",
" #Check new memories\n",
"\n",
" mind_memories = {}\n",
" for memory in self._mind.raw_memory.all_memories:\n",
" if memory.get_name() == \"\": #No name -> No MS\n",
" continue\n",
"\n",
" mind_memories[memory.get_name()] = memory\n",
"\n",
" mind_memories_names = set(mind_memories.keys())\n",
" memories_names = set(self._memories.keys())\n",
"\n",
" #Check only not here (memories_names not in mind should be garbage collected)\n",
" difference = mind_memories_names - memories_names\n",
" for memory_name in difference:\n",
" memory : Memory = mind_memories[memory_name]\n",
" self._memories[memory_name] = memory\n",
"\n",
" if self._client.exists(f\"{self._mind_name}:memories:{memory_name}\"):\n",
" self._retrieve_executor.submit(self._retrieve_memory, memory)\n",
" \n",
" else: #Send impostor with owner\n",
" memory_impostor = {\"name\":memory.get_name(),\n",
" \"evaluation\" : 0.0,\n",
" \"I\": \"\",\n",
" \"id\" : 0,\n",
" \"owner\": self._node_name}\n",
" \n",
" self._client.hset(f\"{self._mind_name}:memories:{memory_name}\", mapping=memory_impostor)\n",
"\n",
" subscribe_func = lambda message : self.update_memory(memory_name)\n",
" self._pubsub.subscribe(**{f\"{self._mind_name}:memories:{memory_name}:update\":subscribe_func})\n",
"\n",
" #Update memories\n",
" to_update = self._last_update.keys()\n",
" for memory_name in to_update:\n",
" if memory_name not in self._memories:\n",
" del self._last_update[memory_name]\n",
" continue\n",
"\n",
" memory = self._memories[memory_name]\n",
" if memory.get_timestamp() > self._last_update[memory_name]:\n",
" self.update_memory(memory_name)\n",
"\n",
" def update_memory(self, memory_name:str) -> None:\n",
" print(self._node_name, \"Updating memory\", memory_name)\n",
"\n",
" if memory_name not in self._memories:\n",
" self._pubsub.unsubscribe(f\"{self._mind_name}:memories:{memory_name}:update\")\n",
"\n",
" timestamp = float(self._client.hget(f\"{self._mind_name}:memories:{memory_name}\", \"timestamp\"))\n",
" memory = self._memories[memory_name]\n",
" memory_timestamp = memory.get_timestamp()\n",
" \n",
" if memory_timestamp < timestamp:\n",
" self._retrieve_executor.submit(self._retrieve_memory, memory)\n",
"\n",
" elif memory_timestamp> timestamp:\n",
" self._send_memory(memory)\n",
"\n",
" self._last_update[memory_name] = memory.get_timestamp()\n",
"\n",
" def _send_memory(self, memory:Memory) -> None:\n",
" memory_name = memory.get_name()\n",
" print(self._node_name, \"Send memory\", memory_name)\n",
" \n",
" memory_dict = MemoryEncoder.to_dict(memory)\n",
" memory_dict[\"I\"] = json.dumps(memory_dict[\"I\"])\n",
" memory_dict[\"owner\"] = \"\"\n",
"\n",
"\n",
" self._client.hset(f\"{self._mind_name}:memories:{memory_name}\", mapping=memory_dict)\n",
" self._client.publish(f\"{self._mind_name}:memories:{memory_name}:update\", \"\")\n",
"\n",
" self._last_update[memory_name] = memory.get_timestamp()\n",
" \n",
"\n",
" def _retrieve_memory(self, memory:Memory) -> None:\n",
" memory_name = memory.get_name()\n",
"\n",
" print(self._node_name, \"Retrieve\", memory_name)\n",
"\n",
" if memory_name in self._waiting_retrieve:\n",
" return\n",
" self._waiting_retrieve.add(memory_name)\n",
"\n",
" memory_dict = self._client.hgetall(f\"{self._mind_name}:memories:{memory_name}\")\n",
"\n",
" if memory_dict[\"owner\"] != \"\":\n",
" event = threading.Event()\n",
" self._waiting_request_events[memory_name] = event\n",
" self._request_memory(memory_name, memory_dict[\"owner\"])\n",
"\n",
" if not event.wait(timeout=self._request_timeout):\n",
" print(self._node_name, \"Request failed\", memory_name)\n",
" #Request failed\n",
" self._send_memory(memory)\n",
" return \n",
" \n",
" memory_dict = self._client.hgetall(f\"{self._mind_name}:memories:{memory_name}\")\n",
"\n",
" memory.set_evaluation(float(memory_dict[\"evaluation\"]))\n",
" memory.set_id(int(memory_dict[\"id\"]))\n",
"\n",
" info_json = memory_dict[\"I\"]\n",
" info = json.loads(info_json)\n",
"\n",
" print(self._node_name, \"INFO\", info, info_json)\n",
"\n",
" memory.set_info(info)\n",
"\n",
" self._last_update[memory_name] = memory.get_timestamp()\n",
"\n",
" self._waiting_retrieve.remove(memory_name)\n",
"\n",
" def _request_memory(self, memory_name:str, owner_name:str) -> None:\n",
" print(self._node_name, \"Requesting\", memory_name)\n",
"\n",
" request_addr = f\"{self._mind_name}:nodes:{owner_name}:transfer_memory\"\n",
" \n",
" request_dict = {\"memory_name\":memory_name, \"node\":self._node_name}\n",
" request = json.dumps(request_dict)\n",
" self._client.publish(request_addr, request)\n",
"\n",
" def _handler_notify_transfer(self, message:str) -> None:\n",
" memory_name = message[\"data\"]\n",
" if memory_name in self._waiting_request_events:\n",
" event = self._waiting_request_events[memory_name]\n",
" event.set()\n",
" del self._waiting_request_events[memory_name]\n",
"\n",
" def _handler_transfer_memory(self, message) -> None:\n",
" request = json.loads(message[\"data\"])\n",
" \n",
" memory_name = request[\"memory_name\"]\n",
" requesting_node = request[\"node\"]\n",
"\n",
" print(self._node_name, \"Tranfering\", memory_name)\n",
"\n",
" if memory_name in self._memories:\n",
" memory = self._memories[memory_name]\n",
" else:\n",
" memory = cst.MemoryObject()\n",
" memory.set_name(memory_name)\n",
" \n",
" self._send_memory(memory)\n",
"\n",
" response_addr = f\"{self._mind_name}:nodes:{requesting_node}:transfer_done\"\n",
" self._client.publish(response_addr, memory_name)\n",
"\n",
" def __del__(self) -> None:\n",
" self._pubsub_thread.stop()\n",
" self._retrieve_executor.shutdown(cancel_futures=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -318,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -328,15 +81,14 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"node03 Retrieve Memory1\n",
"node03 INFO INFO \"INFO\"\n"
"node03 Retrieve Memory1\n"
]
}
],
Expand All @@ -350,16 +102,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MemoryObject [idmemoryobject=1, timestamp=1726858916840, evaluation=0.0, I=INFO, name=Memory1]"
"MemoryObject [idmemoryobject=1, timestamp=1727456263799, evaluation=0.0, I=[1, 1, 1], name=Memory1]"
]
},
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -370,7 +122,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -379,22 +131,22 @@
"-1"
]
},
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"node03 Updating memory Memory1\n",
"node03 Send memory Memory1\n",
"node03 Updating memory Memory1\n"
"node02 Updating memory Memory1\n",
"node02 Send memory Memory1\n",
"node02 Updating memory Memory1\n"
]
}
],
"source": [
"memory1.set_info(1.0)"
"memory1.set_info([1,1,1])"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions src/cst_python/memory_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .memory_storage import MemoryStorageCodelet
32 changes: 32 additions & 0 deletions src/cst_python/memory_storage/memory_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json
from typing import Any

from cst_python.core.entities import Memory

class MemoryEncoder(json.JSONEncoder):
def default(self, memory:Memory):
return MemoryEncoder.to_dict(memory)

@staticmethod
def to_dict(memory:Memory, jsonify_info:bool=False):
data = {
"timestamp": memory.get_timestamp(),
"evaluation": memory.get_evaluation(),
"I": memory.get_info(),
"name": memory.get_name(),
"id": memory.get_id()
}

if jsonify_info:
data["I"] = json.dumps(data["I"])

return data

def load_memory(memory:Memory, memory_dict:dict[str,Any], load_json:bool=True):
memory.set_evaluation(float(memory_dict["evaluation"]))
memory.set_id(int(memory_dict["id"]))

info_json = memory_dict["I"]
info = json.loads(info_json)

memory.set_info(info)
Loading

0 comments on commit b565114

Please sign in to comment.