diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..84a569b5af --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,51 @@ +# Contributing to mem0 + +Let us make contribution easy, collaborative and fun. + +## Submit your Contribution through PR + +To make a contribution, follow these steps: + +1. Fork and clone this repository +2. Do the changes on your fork with dedicated feature branch `feature/f1` +3. If you modified the code (new feature or bug-fix), please add tests for it +4. Include proper documentation / docstring and examples to run the feature +5. Ensure that all tests pass +6. Submit a pull request + +For more details about pull requests, please read [GitHub's guides](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request). + + +### 📦 Package manager + +We use `poetry` as our package manager. You can install poetry by following the instructions [here](https://python-poetry.org/docs/#installation). + +Please DO NOT use pip or conda to install the dependencies. Instead, use poetry: + +```bash +make install_all + +#activate + +poetry shell +``` + +### 📌 Pre-commit + +To ensure our standards, make sure to install pre-commit before starting to contribute. + +```bash +pre-commit install +``` + +### 🧪 Testing + +We use `pytest` to test our code. You can run the tests by running the following command: + +```bash +poetry run pytest +``` + +Several packages have been removed from Poetry to make the package lighter. Therefore, it is recommended to run `make install_all` to install the remaining packages and ensure all tests pass. Make sure that all tests pass before submitting a pull request. + +We look forward to your pull requests and can't wait to see your contributions! \ No newline at end of file diff --git a/Makefile b/Makefile index 032be34648..965a719301 100644 --- a/Makefile +++ b/Makefile @@ -12,19 +12,19 @@ install: install_all: poetry install - poetry run pip install groq together boto3 litellm ollama + poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers # Format code with ruff format: - poetry run ruff check . --fix $(RUFF_OPTIONS) + poetry run ruff format mem0/ # Sort imports with isort sort: - poetry run isort . $(ISORT_OPTIONS) + poetry run isort mem0/ # Lint code with ruff lint: - poetry run ruff . + poetry run ruff check mem0/ docs: cd docs && mintlify dev diff --git a/README.md b/README.md index 0364e1ec52..13f27727d1 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,9 @@ Mem0 - The Memory Layer for Personalized AI +

Launch

+ +

Learn more · @@ -157,7 +160,7 @@ history = m.history(memory_id=) ### Graph Memory To initialize Graph Memory you'll need to set up your configuration with graph store providers. -Currently, we support Neo4j as a graph store provider. You can setup [Neo4j](https://neo4j.com/) locally or use the hosted [Neo4j AuraDB](https://neo4j.com/product/auradb/). +Currently, we support FalkorDB and Neo4j as a graph store providers. You can set up [FalkorDB](https://www.falkordb.com/) or [Neo4j](https://neo4j.com/) locally or use the hosted [FalkorDB Cloud](https://app.falkordb.cloud/) or [Neo4j AuraDB](https://neo4j.com/product/auradb/). Moreover, you also need to set the version to `v1.1` (*prior versions are not supported*). Here's how you can do it: @@ -166,11 +169,12 @@ from mem0 import Memory config = { "graph_store": { - "provider": "neo4j", + "provider": "falkordb", "config": { - "url": "neo4j+s://xxx", - "username": "neo4j", - "password": "xxx" + "host": "---" + "username": "---", + "password": "---", + "port": "---" } }, "version": "v1.1" @@ -208,4 +212,4 @@ We value and appreciate the contributions of our community. Special thanks to ou ## License -This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details. \ No newline at end of file +This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details. diff --git a/cookbooks/add_memory_using_qdrant_cloud.py b/cookbooks/add_memory_using_qdrant_cloud.py index d714275224..0ca02e52df 100644 --- a/cookbooks/add_memory_using_qdrant_cloud.py +++ b/cookbooks/add_memory_using_qdrant_cloud.py @@ -7,27 +7,21 @@ # Loading OpenAI API Key load_dotenv() -OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") USER_ID = "test" -quadrant_host="xx.gcp.cloud.qdrant.io" +quadrant_host = "xx.gcp.cloud.qdrant.io" # creating the config attributes -collection_name="memory" # this is the collection I created in QDRANT cloud -api_key=os.environ.get("QDRANT_API_KEY") # Getting the QDRANT api KEY -host=quadrant_host -port=6333 #Default port for QDRANT cloud +collection_name = "memory" # this is the collection I created in QDRANT cloud +api_key = os.environ.get("QDRANT_API_KEY") # Getting the QDRANT api KEY +host = quadrant_host +port = 6333 # Default port for QDRANT cloud # Creating the config dict config = { "vector_store": { "provider": "qdrant", - "config": { - "collection_name": collection_name, - "host": host, - "port": port, - "path": None, - "api_key":api_key - } + "config": {"collection_name": collection_name, "host": host, "port": port, "path": None, "api_key": api_key}, } } diff --git a/cookbooks/mem0-multion.ipynb b/cookbooks/mem0-multion.ipynb index 3cd3fc97d8..98e304568f 100644 --- a/cookbooks/mem0-multion.ipynb +++ b/cookbooks/mem0-multion.ipynb @@ -1,189 +1,189 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "y4bKPPa7DXNs" - }, - "outputs": [], - "source": [ - "%pip install mem0ai multion" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pe4htqUmDdmS" - }, - "source": [ - "## Setup and Configuration\n", - "\n", - "First, we'll import the necessary libraries and set up our configurations.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "fsZwK7eLDh3I" - }, - "outputs": [], - "source": [ - "import os\n", - "from mem0 import Memory\n", - "from multion.client import MultiOn\n", - "\n", - "# Configuration\n", - "OPENAI_API_KEY = 'sk-xxx' # Replace with your actual OpenAI API key\n", - "MULTION_API_KEY = 'your-multion-key' # Replace with your actual MultiOn API key\n", - "USER_ID = \"deshraj\"\n", - "\n", - "# Set up OpenAI API key\n", - "os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY\n", - "\n", - "# Initialize Mem0 and MultiOn\n", - "memory = Memory()\n", - "multion = MultiOn(api_key=MULTION_API_KEY)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HTGVhGwaDl-1" - }, - "source": [ - "## Add memories to Mem0\n", - "\n", - "Next, we'll define our user data and add it to Mem0." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "xB3tm0_pDm6e", - "outputId": "aeab370c-8679-4d39-faaa-f702146d2fc4" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "User data added to memory.\n" - ] - } - ], - "source": [ - "# Define user data\n", - "USER_DATA = \"\"\"\n", - "About me\n", - "- I'm Deshraj Yadav, Co-founder and CTO at Mem0 (f.k.a Embedchain). I am broadly interested in the field of Artificial Intelligence and Machine Learning Infrastructure.\n", - "- Previously, I was Senior Autopilot Engineer at Tesla Autopilot where I led the Autopilot's AI Platform which helped the Tesla Autopilot team to track large scale training and model evaluation experiments, provide monitoring and observability into jobs and training cluster issues.\n", - "- I had built EvalAI as my masters thesis at Georgia Tech, which is an open-source platform for evaluating and comparing machine learning and artificial intelligence algorithms at scale.\n", - "- Outside of work, I am very much into cricket and play in two leagues (Cricbay and NACL) in San Francisco Bay Area.\n", - "\"\"\"\n", - "\n", - "# Add user data to memory\n", - "memory.add(USER_DATA, user_id=USER_ID)\n", - "print(\"User data added to memory.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZCPUJf0TDqUK" - }, - "source": [ - "## Retrieving Relevant Memories\n", - "\n", - "Now, we'll define our search command and retrieve relevant memories from Mem0." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "s0PwAhNVDrIv", - "outputId": "59cbb767-b468-4139-8d0c-fa763918dbb0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Relevant memories:\n", - "Name: Deshraj Yadav - Co-founder and CTO at Mem0 (formerly known as Embedchain) - Interested in Artificial Intelligence and Machine Learning Infrastructure - Previous role: Senior Autopilot Engineer at Tesla Autopilot - Led the Autopilot's AI Platform at Tesla, focusing on large scale training, model evaluation, monitoring, and observability - Built EvalAI as a master's thesis at Georgia Tech, an open-source platform for evaluating and comparing machine learning algorithms - Enjoys cricket - Plays in two cricket leagues: Cricbay and NACL in the San Francisco Bay Area\n" - ] - } - ], - "source": [ - "# Define search command and retrieve relevant memories\n", - "command = \"Find papers on arxiv that I should read based on my interests.\"\n", - "\n", - "relevant_memories = memory.search(command, user_id=USER_ID, limit=3)\n", - "relevant_memories_text = '\\n'.join(mem['memory'] for mem in relevant_memories)\n", - "print(f\"Relevant memories:\")\n", - "print(relevant_memories_text)" - ] + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "y4bKPPa7DXNs" + }, + "outputs": [], + "source": [ + "%pip install mem0ai multion" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pe4htqUmDdmS" + }, + "source": [ + "## Setup and Configuration\n", + "\n", + "First, we'll import the necessary libraries and set up our configurations.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "fsZwK7eLDh3I" + }, + "outputs": [], + "source": [ + "import os\n", + "from mem0 import Memory\n", + "from multion.client import MultiOn\n", + "\n", + "# Configuration\n", + "OPENAI_API_KEY = \"sk-xxx\" # Replace with your actual OpenAI API key\n", + "MULTION_API_KEY = \"your-multion-key\" # Replace with your actual MultiOn API key\n", + "USER_ID = \"deshraj\"\n", + "\n", + "# Set up OpenAI API key\n", + "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n", + "\n", + "# Initialize Mem0 and MultiOn\n", + "memory = Memory()\n", + "multion = MultiOn(api_key=MULTION_API_KEY)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HTGVhGwaDl-1" + }, + "source": [ + "## Add memories to Mem0\n", + "\n", + "Next, we'll define our user data and add it to Mem0." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "xB3tm0_pDm6e", + "outputId": "aeab370c-8679-4d39-faaa-f702146d2fc4" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "jdge78_VDtgv" - }, - "source": [ - "## Browsing arXiv\n", - "\n", - "Finally, we'll use MultiOn to browse arXiv based on our command and relevant memories." - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "User data added to memory.\n" + ] + } + ], + "source": [ + "# Define user data\n", + "USER_DATA = \"\"\"\n", + "About me\n", + "- I'm Deshraj Yadav, Co-founder and CTO at Mem0 (f.k.a Embedchain). I am broadly interested in the field of Artificial Intelligence and Machine Learning Infrastructure.\n", + "- Previously, I was Senior Autopilot Engineer at Tesla Autopilot where I led the Autopilot's AI Platform which helped the Tesla Autopilot team to track large scale training and model evaluation experiments, provide monitoring and observability into jobs and training cluster issues.\n", + "- I had built EvalAI as my masters thesis at Georgia Tech, which is an open-source platform for evaluating and comparing machine learning and artificial intelligence algorithms at scale.\n", + "- Outside of work, I am very much into cricket and play in two leagues (Cricbay and NACL) in San Francisco Bay Area.\n", + "\"\"\"\n", + "\n", + "# Add user data to memory\n", + "memory.add(USER_DATA, user_id=USER_ID)\n", + "print(\"User data added to memory.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZCPUJf0TDqUK" + }, + "source": [ + "## Retrieving Relevant Memories\n", + "\n", + "Now, we'll define our search command and retrieve relevant memories from Mem0." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "s0PwAhNVDrIv", + "outputId": "59cbb767-b468-4139-8d0c-fa763918dbb0" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4T_tLURTDvS-", - "outputId": "259ff32f-5d42-44e6-f2ef-c3557a8e9da6" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "message=\"Summarizing the relevant papers found so far that align with Deshraj Yadav's interests in Artificial Intelligence and Machine Learning Infrastructure.\\n\\n1. **Urban Waterlogging Detection: A Challenging Benchmark and Large-Small Model Co-Adapter**\\n - Authors: Suqi Song, Chenxu Zhang, Peng Zhang, Pengkun Li, Fenglong Song, Lei Zhang\\n - Abstract: Urban waterlogging poses a major risk to public safety. Conventional methods using water-level sensors need high-maintenance to hardly achieve full coverage. Recent advances employ surveillance camera imagery and deep learning for detection, yet these struggle amidst scarce data and adverse environments.\\n - Date: 10 July, 2024\\n\\n2. **Intercepting Unauthorized Aerial Robots in Controlled Airspace Using Reinforcement Learning**\\n - Authors: Francisco Giral, Ignacio Gómez, Soledad Le Clainche\\n - Abstract: Ensuring the safe and efficient operation of airspace, particularly in urban environments and near critical infrastructure, necessitates effective methods to intercept unauthorized or non-cooperative UAVs. This work addresses the critical need for robust, adaptive systems capable of managing such scenarios.\\n - Date: 9 July, 2024\\n\\n3. **Efficient Materials Informatics between Rockets and Electrons**\\n - Authors: Adam M. Krajewski\\n - Abstract: This paper discusses the distinct efforts existing at three general scales of abstractions of what a material is - atomistic, physical, and design. At each, an efficient materials informatics is being built from the ground up based on the fundamental understanding of the underlying prior knowledge, including the data.\\n - Date: 5 July, 2024\\n\\n4. **ObfuscaTune: Obfuscated Offsite Fine-tuning and Inference of Proprietary LLMs on Private Datasets**\\n - Authors: Ahmed Frikha, Nassim Walha, Ricardo Mendes, Krishna Kanth Nakka, Xue Jiang, Xuebing Zhou\\n - Abstract: This paper proposes ObfuscaTune, a novel, efficient, and fully utility-preserving approach that combines a simple yet effective method to ensure the confidentiality of both the model and the data during offsite fine-tuning on a third-party cloud provider.\\n - Date: 3 July, 2024\\n\\n5. **MG-Verilog: Multi-grained Dataset Towards Enhanced LLM-assisted Verilog Generation**\\n - Authors: Yongan Zhang, Zhongzhi Yu, Yonggan Fu, Cheng Wan, Yingyan Celine Lin\\n - Abstract: This paper discusses the necessity of providing domain-specific data during inference, fine-tuning, or pre-training to effectively leverage LLMs in hardware design. Existing publicly available hardware datasets are often limited in size, complexity, or detail, which hinders the effectiveness of LLMs in this domain.\\n - Date: 1 July, 2024\\n\\n6. **The Future of Aerial Communications: A Survey of IRS-Enhanced UAV Communication Technologies**\\n - Authors: Zina Chkirbene, Ala Gouissem, Ridha Hamila, Devrim Unal\\n - Abstract: The advent of Reflecting Surfaces (IRS) and Unmanned Aerial Vehicles (UAVs) is setting a new benchmark in the field of wireless communications. IRS, with their groundbreaking ability to manipulate electromagnetic waves, have opened avenues for substantial enhancements in signal quality, network efficiency, and spectral usage.\\n - Date: 2 June, 2024\\n\\n7. **Scalable and RISC-V Programmable Near-Memory Computing Architectures for Edge Nodes**\\n - Authors: Michele Caon, Clément Choné, Pasquale Davide Schiavone, Alexandre Levisse, Guido Masera, Maurizio Martina, David Atienza\\n - Abstract: The widespread adoption of data-centric algorithms, particularly AI and ML, has exposed the limitations of centralized processing, driving the need for scalable and programmable near-memory computing architectures for edge nodes.\\n - Date: 20 June, 2024\\n\\n8. **Enhancing robustness of data-driven SHM models: adversarial training with circle loss**\\n - Authors: Xiangli Yang, Xijie Deng, Hanwei Zhang, Yang Zou, Jianxi Yang\\n - Abstract: Structural health monitoring (SHM) is critical to safeguarding the safety and reliability of aerospace, civil, and mechanical infrastructures. This paper discusses the use of adversarial training with circle loss to enhance the robustness of data-driven SHM models.\\n - Date: 20 June, 2024\\n\\n9. **Understanding Pedestrian Movement Using Urban Sensing Technologies: The Promise of Audio-based Sensors**\\n - Authors: Chaeyeon Han, Pavan Seshadri, Yiwei Ding, Noah Posner, Bon Woo Koo, Animesh Agrawal, Alexander Lerch, Subhrajit Guhathakurta\\n - Abstract: Understanding pedestrian volumes and flows is essential for designing safer and more attractive pedestrian infrastructures. This study discusses a new approach to scale up urban sensing of people with the help of novel audio-based technology.\\n - Date: 14 June, 2024\\n\\nASK_USER_HELP: Deshraj, I have found several papers that might be of interest to you. Would you like to proceed with any specific papers from the list above, or should I refine the search further?\\n\" status='NOT_SURE' url='https://arxiv.org/search/?query=Artificial+Intelligence+Machine+Learning+Infrastructure&searchtype=all&source=header' screenshot='' session_id='ff2ee9ef-60d4-4436-bc36-a81d94e0f410' metadata=Metadata(step_count=9, processing_time=66, temperature=0.2)\n" - ] - } - ], - "source": [ - "# Create prompt and browse arXiv\n", - "prompt = f\"{command}\\n My past memories: {relevant_memories_text}\"\n", - "browse_result = multion.browse(cmd=prompt, url=\"https://arxiv.org/\")\n", - "print(browse_result)" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Relevant memories:\n", + "Name: Deshraj Yadav - Co-founder and CTO at Mem0 (formerly known as Embedchain) - Interested in Artificial Intelligence and Machine Learning Infrastructure - Previous role: Senior Autopilot Engineer at Tesla Autopilot - Led the Autopilot's AI Platform at Tesla, focusing on large scale training, model evaluation, monitoring, and observability - Built EvalAI as a master's thesis at Georgia Tech, an open-source platform for evaluating and comparing machine learning algorithms - Enjoys cricket - Plays in two cricket leagues: Cricbay and NACL in the San Francisco Bay Area\n" + ] } - ], - "metadata": { + ], + "source": [ + "# Define search command and retrieve relevant memories\n", + "command = \"Find papers on arxiv that I should read based on my interests.\"\n", + "\n", + "relevant_memories = memory.search(command, user_id=USER_ID, limit=3)\n", + "relevant_memories_text = \"\\n\".join(mem[\"memory\"] for mem in relevant_memories)\n", + "print(f\"Relevant memories:\")\n", + "print(relevant_memories_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jdge78_VDtgv" + }, + "source": [ + "## Browsing arXiv\n", + "\n", + "Finally, we'll use MultiOn to browse arXiv based on our command and relevant memories." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { "colab": { - "provenance": [] + "base_uri": "https://localhost:8080/" }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" + "id": "4T_tLURTDvS-", + "outputId": "259ff32f-5d42-44e6-f2ef-c3557a8e9da6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "message=\"Summarizing the relevant papers found so far that align with Deshraj Yadav's interests in Artificial Intelligence and Machine Learning Infrastructure.\\n\\n1. **Urban Waterlogging Detection: A Challenging Benchmark and Large-Small Model Co-Adapter**\\n - Authors: Suqi Song, Chenxu Zhang, Peng Zhang, Pengkun Li, Fenglong Song, Lei Zhang\\n - Abstract: Urban waterlogging poses a major risk to public safety. Conventional methods using water-level sensors need high-maintenance to hardly achieve full coverage. Recent advances employ surveillance camera imagery and deep learning for detection, yet these struggle amidst scarce data and adverse environments.\\n - Date: 10 July, 2024\\n\\n2. **Intercepting Unauthorized Aerial Robots in Controlled Airspace Using Reinforcement Learning**\\n - Authors: Francisco Giral, Ignacio Gómez, Soledad Le Clainche\\n - Abstract: Ensuring the safe and efficient operation of airspace, particularly in urban environments and near critical infrastructure, necessitates effective methods to intercept unauthorized or non-cooperative UAVs. This work addresses the critical need for robust, adaptive systems capable of managing such scenarios.\\n - Date: 9 July, 2024\\n\\n3. **Efficient Materials Informatics between Rockets and Electrons**\\n - Authors: Adam M. Krajewski\\n - Abstract: This paper discusses the distinct efforts existing at three general scales of abstractions of what a material is - atomistic, physical, and design. At each, an efficient materials informatics is being built from the ground up based on the fundamental understanding of the underlying prior knowledge, including the data.\\n - Date: 5 July, 2024\\n\\n4. **ObfuscaTune: Obfuscated Offsite Fine-tuning and Inference of Proprietary LLMs on Private Datasets**\\n - Authors: Ahmed Frikha, Nassim Walha, Ricardo Mendes, Krishna Kanth Nakka, Xue Jiang, Xuebing Zhou\\n - Abstract: This paper proposes ObfuscaTune, a novel, efficient, and fully utility-preserving approach that combines a simple yet effective method to ensure the confidentiality of both the model and the data during offsite fine-tuning on a third-party cloud provider.\\n - Date: 3 July, 2024\\n\\n5. **MG-Verilog: Multi-grained Dataset Towards Enhanced LLM-assisted Verilog Generation**\\n - Authors: Yongan Zhang, Zhongzhi Yu, Yonggan Fu, Cheng Wan, Yingyan Celine Lin\\n - Abstract: This paper discusses the necessity of providing domain-specific data during inference, fine-tuning, or pre-training to effectively leverage LLMs in hardware design. Existing publicly available hardware datasets are often limited in size, complexity, or detail, which hinders the effectiveness of LLMs in this domain.\\n - Date: 1 July, 2024\\n\\n6. **The Future of Aerial Communications: A Survey of IRS-Enhanced UAV Communication Technologies**\\n - Authors: Zina Chkirbene, Ala Gouissem, Ridha Hamila, Devrim Unal\\n - Abstract: The advent of Reflecting Surfaces (IRS) and Unmanned Aerial Vehicles (UAVs) is setting a new benchmark in the field of wireless communications. IRS, with their groundbreaking ability to manipulate electromagnetic waves, have opened avenues for substantial enhancements in signal quality, network efficiency, and spectral usage.\\n - Date: 2 June, 2024\\n\\n7. **Scalable and RISC-V Programmable Near-Memory Computing Architectures for Edge Nodes**\\n - Authors: Michele Caon, Clément Choné, Pasquale Davide Schiavone, Alexandre Levisse, Guido Masera, Maurizio Martina, David Atienza\\n - Abstract: The widespread adoption of data-centric algorithms, particularly AI and ML, has exposed the limitations of centralized processing, driving the need for scalable and programmable near-memory computing architectures for edge nodes.\\n - Date: 20 June, 2024\\n\\n8. **Enhancing robustness of data-driven SHM models: adversarial training with circle loss**\\n - Authors: Xiangli Yang, Xijie Deng, Hanwei Zhang, Yang Zou, Jianxi Yang\\n - Abstract: Structural health monitoring (SHM) is critical to safeguarding the safety and reliability of aerospace, civil, and mechanical infrastructures. This paper discusses the use of adversarial training with circle loss to enhance the robustness of data-driven SHM models.\\n - Date: 20 June, 2024\\n\\n9. **Understanding Pedestrian Movement Using Urban Sensing Technologies: The Promise of Audio-based Sensors**\\n - Authors: Chaeyeon Han, Pavan Seshadri, Yiwei Ding, Noah Posner, Bon Woo Koo, Animesh Agrawal, Alexander Lerch, Subhrajit Guhathakurta\\n - Abstract: Understanding pedestrian volumes and flows is essential for designing safer and more attractive pedestrian infrastructures. This study discusses a new approach to scale up urban sensing of people with the help of novel audio-based technology.\\n - Date: 14 June, 2024\\n\\nASK_USER_HELP: Deshraj, I have found several papers that might be of interest to you. Would you like to proceed with any specific papers from the list above, or should I refine the search further?\\n\" status='NOT_SURE' url='https://arxiv.org/search/?query=Artificial+Intelligence+Machine+Learning+Infrastructure&searchtype=all&source=header' screenshot='' session_id='ff2ee9ef-60d4-4436-bc36-a81d94e0f410' metadata=Metadata(step_count=9, processing_time=66, temperature=0.2)\n" + ] } + ], + "source": [ + "# Create prompt and browse arXiv\n", + "prompt = f\"{command}\\n My past memories: {relevant_memories_text}\"\n", + "browse_result = multion.browse(cmd=prompt, url=\"https://arxiv.org/\")\n", + "print(browse_result)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 0 + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/cookbooks/mem0_graph_memory.py b/cookbooks/mem0_graph_memory.py new file mode 100644 index 0000000000..b29547d47a --- /dev/null +++ b/cookbooks/mem0_graph_memory.py @@ -0,0 +1,40 @@ +# This example shows how to use graph config to use falkordb graph databese +import os +from mem0 import Memory +from dotenv import load_dotenv + +# Loading OpenAI API Key +load_dotenv() +OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') +USER_ID = "test" + +# Creating the config dict from the environment variables +config = { + "llm": { # This is the language model configuration, use your carditionals + "provider": "openai", + "config": { + "model": "gpt-4o-mini", + "temperature": 0 + } + }, + "graph_store": { # See https://app.falkordb.cloud/ for the carditionals + "provider": "falkordb", + "config": { + "host": os.environ['HOST'], + "username": os.environ['USERNAME'], + "password": os.environ['PASSWORD'], + "port": os.environ['PORT'] + } + }, + "version": "v1.1" +} + +# Create the memory class using from config +memory = Memory.from_config(config_dict=config) + +# Use the Mem0 to add and search memories +memory.add("I like painting", user_id=USER_ID) +memory.add("I hate playing badminton", user_id=USER_ID) +print(memory.get_all(user_id=USER_ID)) +memory.add("My friend name is john and john has a dog named tommy", user_id=USER_ID) +print(memory.search("What I like to do", user_id=USER_ID)) diff --git a/cookbooks/multion_travel_agent.ipynb b/cookbooks/multion_travel_agent.ipynb index 196337077f..f9211da1b5 100644 --- a/cookbooks/multion_travel_agent.ipynb +++ b/cookbooks/multion_travel_agent.ipynb @@ -1,306 +1,296 @@ { - "cells": [ - { - "cell_type": "code", - "source": [ - "!pip install mem0ai" - ], - "metadata": { - "id": "fu3euPKZsbaC" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "U2VC_0FElQid" - }, - "outputs": [], - "source": [ - "import os\n", - "from openai import OpenAI\n", - "from mem0 import MemoryClient\n", - "from multion.client import MultiOn\n", - "\n", - "# Configuration\n", - "OPENAI_API_KEY = 'sk-xxx' # Replace with your actual OpenAI API key\n", - "MULTION_API_KEY = 'xx' # Replace with your actual MultiOn API key\n", - "MEM0_API_KEY = 'xx' # Replace with your actual Mem0 API key\n", - "USER_ID = \"test_travel_agent\"\n", - "\n", - "# Set up OpenAI API key\n", - "os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY\n", - "\n", - "# Initialize Mem0 and MultiOn\n", - "memory = MemoryClient(api_key=MEM0_API_KEY)\n", - "multion = MultiOn(api_key=MULTION_API_KEY)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sq-OdPHKlQie", - "outputId": "1d605222-0bf5-4ac9-99b9-6059b502c20b" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'message': 'Memory added successfully!'}" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Add conversation to Mem0\n", - "conversation = [\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": \"What are the best travel destinations in the world?\"\n", - " },\n", - " {\n", - " \"role\": \"assistant\",\n", - " \"content\": \"Could you please specify your interests or the type of travel information you are looking for? This will help me find the most relevant information for you.\"\n", - " },\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": \"Sure, I want to travel to San Francisco.\"\n", - " },\n", - " {\n", - " \"role\": \"assistant\",\n", - " \"content\": \"\"\"\n", - " Based on the information gathered from TripAdvisor, here are some popular attractions, activities, and travel tips for San Francisco:\n", - "\n", - " 1. **Golden Gate Bridge**: A must-see iconic landmark.\n", - " 2. **Alcatraz Island**: Famous former prison offering tours.\n", - " 3. **Fisherman's Wharf**: Popular tourist area with shops, restaurants, and sea lions.\n", - " 4. **Chinatown**: The largest Chinatown outside of Asia.\n", - " 5. **Golden Gate Park**: Large urban park with gardens, museums, and recreational activities.\n", - " 6. **Cable Cars**: Historic streetcars offering a unique way to see the city.\n", - " 7. **Exploratorium**: Interactive science museum.\n", - " 8. **San Francisco Museum of Modern Art (SFMOMA)**: Modern and contemporary art museum.\n", - " 9. **Lombard Street**: Known for its steep, one-block section with eight hairpin turns.\n", - " 10. **Union Square**: Major shopping and cultural hub.\n", - "\n", - " Travel Tips:\n", - " - **Weather**: San Francisco has a mild climate, but it can be foggy and windy. Dress in layers.\n", - " - **Transportation**: Use public transportation like BART, Muni, and cable cars to get around.\n", - " - **Safety**: Be aware of your surroundings, especially in crowded tourist areas.\n", - " - **Dining**: Try local specialties like sourdough bread, seafood, and Mission-style burritos.\n", - " \"\"\"\n", - " },\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": \"Show me hotels around Golden Gate Bridge.\"\n", - " },\n", - " {\n", - " \"role\": \"assistant\",\n", - " \"content\": \"\"\"\n", - " The search results for hotels around Golden Gate Bridge in San Francisco include:\n", - "\n", - " 1. Hilton Hotels In San Francisco - Hotel Near Fishermans Wharf (hilton.com)\n", - " 2. The 10 Closest Hotels to Golden Gate Bridge (tripadvisor.com)\n", - " 3. Hotels near Golden Gate Bridge (expedia.com)\n", - " 4. Hotels near Golden Gate Bridge (hotels.com)\n", - " 5. Holiday Inn Express & Suites San Francisco Fishermans Wharf, an IHG Hotel $146 (1.8K) 3-star hotel Golden Gate Bridge • 3.5 mi DEAL 19% less than usual\n", - " 6. Holiday Inn San Francisco-Golden Gateway, an IHG Hotel $151 (3.5K) 3-star hotel Golden Gate Bridge • 3.7 mi Casual hotel with dining, a bar & a pool\n", - " 7. Hotel Zephyr San Francisco $159 (3.8K) 4-star hotel Golden Gate Bridge • 3.7 mi Nautical-themed lodging with bay views\n", - " 8. Lodge at the Presidio\n", - " 9. The Inn Above Tide\n", - " 10. Cavallo Point\n", - " 11. Casa Madrona Hotel and Spa\n", - " 12. Cow Hollow Inn and Suites\n", - " 13. Samesun San Francisco\n", - " 14. Inn on Broadway\n", - " 15. Coventry Motor Inn\n", - " 16. HI San Francisco Fisherman's Wharf Hostel\n", - " 17. Loews Regency San Francisco Hotel\n", - " 18. Fairmont Heritage Place Ghirardelli Square\n", - " 19. Hotel Drisco Pacific Heights\n", - " 20. Travelodge by Wyndham Presidio San Francisco\n", - " \"\"\"\n", - " }\n", - "]\n", - "\n", - "memory.add(conversation, user_id=USER_ID)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hO8z9aNTlQif" - }, - "outputs": [], - "source": [ - "def get_travel_info(question, use_memory=True):\n", - " \"\"\"\n", - " Get travel information based on user's question and optionally their preferences from memory.\n", - "\n", - " \"\"\"\n", - " if use_memory:\n", - " previous_memories = memory.search(question, user_id=USER_ID)\n", - " relevant_memories_text = \"\"\n", - " if previous_memories:\n", - " print(\"Using previous memories to enhance the search...\")\n", - " relevant_memories_text = '\\n'.join(mem[\"memory\"] for mem in previous_memories)\n", - "\n", - " command = \"Find travel information based on my interests:\"\n", - " prompt = f\"{command}\\n Question: {question} \\n My preferences: {relevant_memories_text}\"\n", - " else:\n", - " command = \"Find travel information based on my interests:\"\n", - " prompt = f\"{command}\\n Question: {question}\"\n", - "\n", - "\n", - " print(\"Searching for travel information...\")\n", - " browse_result = multion.browse(cmd=prompt)\n", - " return browse_result.message" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Wp2xpzMrlQig" - }, - "source": [ - "## Example 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bPRPwqsplQig" - }, - "outputs": [], - "source": [ - "question = \"Show me flight details for it.\"\n", - "answer_without_memory = get_travel_info(question, use_memory=False)\n", - "answer_with_memory = get_travel_info(question, use_memory=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "a76ifa2HlQig" - }, - "source": [ - "| Without Memory | With Memory |\n", - "|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", - "| I have performed a Google search for \"flight details\" and reviewed the search results. Here are some relevant links and information: | Memorizing the following information: Flight details for San Francisco: |\n", - "| 1. **FlightStats Global Flight Tracker** - Track the real-time flight status of your flight. See if your flight has been delayed or cancelled and track the live status.
[Flight Tracker - FlightStats](https://www.flightstats.com/flight-tracker/search) | 1. Prices from $232. Depart Thursday, August 22. Return Thursday, August 29.
2. Prices from $216. Depart Friday, August 23. Return Friday, August 30.
3. Prices from $236. Depart Saturday, August 24. Return Saturday, August 31.
4. Prices from $215. Depart Sunday, August 25. Return Sunday, September 1. |\n", - "| 2. **FlightAware - Flight Tracker** - Track live flights worldwide, see flight cancellations, and browse by airport.
[FlightAware - Flight Tracker](https://www.flightaware.com) | 5. Prices from $218. Depart Monday, August 26. Return Monday, September 2.
6. Prices from $211. Depart Tuesday, August 27. Return Tuesday, September 3.
7. Prices from $198. Depart Wednesday, August 28. Return Wednesday, September 4.
8. Prices from $218. Depart Thursday, August 29. Return Thursday, September 5. |\n", - "| 3. **Google Flights** - Show flights based on your search.
[Google Flights](https://www.google.com/flights) | 9. Prices from $194. Depart Friday, August 30. Return Friday, September 6.
10. Prices from $218. Depart Saturday, August 31. Return Saturday, September 7.
11. Prices from $212. Depart Sunday, September 1. Return Sunday, September 8.
12. Prices from $247. Depart Monday, September 2. Return Monday, September 9. |\n", - "| | 13. Prices from $212. Depart Tuesday, September 3. Return Tuesday, September 10.
14. Prices from $203. Depart Wednesday, September 4. Return Wednesday, September 11.
15. Prices from $242. Depart Thursday, September 5. Return Thursday, September 12.
16. Prices from $191. Depart Friday, September 6. Return Friday, September 13. |\n", - "| | 17. Prices from $215. Depart Saturday, September 7. Return Saturday, September 14.
18. Prices from $229. Depart Sunday, September 8. Return Sunday, September 15.
19. Prices from $183. Depart Monday, September 9. Return Monday, September 16.
65. Prices from $194. Depart Friday, October 25. Return Friday, November 1. |\n", - "| | 66. Prices from $205. Depart Saturday, October 26. Return Saturday, November 2.
67. Prices from $241. Depart Sunday, October 27. Return Sunday, November 3. |\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0cXpiAwMlQig" - }, - "source": [ - "## Example 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LpprKfpslQih" - }, - "outputs": [], - "source": [ - "question = \"What places to visit there?\"\n", - "answer_without_memory = get_travel_info(question, use_memory=False)\n", - "answer_with_memory = get_travel_info(question, use_memory=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kpfjeY1_lQih" - }, - "source": [ - "| Without Memory | With Memory |\n", - "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", - "| Based on the information gathered, here are some top travel destinations to consider visiting: | Based on the information gathered, here are some top places to visit in San Francisco: |\n", - "| 1. **Paris**: Known for its iconic attractions like the Eiffel Tower and the Louvre, Paris offers quaint cafes, trendy shopping districts, and beautiful Haussmann architecture. It's a city where you can always discover something new with each visit. | 1. **Golden Gate Bridge** - An iconic symbol of San Francisco, perfect for walking, biking, or simply enjoying the view.
2. **Alcatraz Island** - The historic former prison offers tours and insights into its storied past.
3. **Fisherman's Wharf** - A bustling waterfront area known for its seafood, shopping, and attractions like Pier 39.
4. **Golden Gate Park** - A large urban park with gardens, museums, and recreational activities.
5. **Chinatown San Francisco** - One of the oldest and most famous Chinatowns in North America, offering unique shops and delicious food.
6. **Coit Tower** - Offers panoramic views of the city and murals depicting San Francisco's history.
7. **Lands End** - A beautiful coastal trail with stunning views of the Pacific Ocean and the Golden Gate Bridge.
8. **Palace of Fine Arts** - A picturesque structure and park, perfect for a leisurely stroll or photo opportunities.
9. **Crissy Field & The Presidio Tunnel Tops** - Great for outdoor activities and scenic views of the bay. |\n", - "| 2. **Bora Bora**: This small island in French Polynesia is famous for its stunning turquoise waters, luxurious overwater bungalows, and vibrant coral reefs. It's a popular destination for honeymooners and those seeking a tropical paradise. | |\n", - "| 3. **Glacier National Park**: Located in Montana, USA, this park is known for its breathtaking landscapes, including rugged mountains, pristine lakes, and diverse wildlife. It's a haven for outdoor enthusiasts and hikers. | |\n", - "| 4. **Rome**: The capital of Italy, Rome is rich in history and culture, featuring landmarks such as the Colosseum, the Vatican, and the Pantheon. It's a city where ancient history meets modern life. | |\n", - "| 5. **Swiss Alps**: Renowned for their stunning natural beauty, the Swiss Alps offer opportunities for skiing, hiking, and enjoying picturesque mountain villages. | |\n", - "| 6. **Maui**: One of Hawaii's most popular islands, Maui is known for its beautiful beaches, lush rainforests, and the scenic Hana Highway. It's a great destination for both relaxation and adventure. | |\n", - "| 7. **London, England**: A vibrant city with a mix of historical landmarks like the Tower of London and modern attractions such as the London Eye. London offers diverse cultural experiences, world-class museums, and a bustling nightlife. | |\n", - "| 8. **Maldives**: This tropical paradise in the Indian Ocean is famous for its crystal-clear waters, luxurious resorts, and abundant marine life. It's an ideal destination for snorkeling, diving, and relaxation. | |\n", - "| 9. **Turks & Caicos**: Known for its pristine beaches and turquoise waters, this Caribbean destination is perfect for water sports, beach lounging, and exploring coral reefs. | |\n", - "| 10. **Tokyo**: Japan's bustling capital offers a unique blend of traditional and modern attractions, from ancient temples to futuristic skyscrapers. Tokyo is also known for its vibrant food scene and shopping districts. | |\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XdpkcMrclQih" - }, - "source": [ - "## Example 3" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Nntl2FxulQih" - }, - "outputs": [], - "source": [ - "question = \"What the weather there?\"\n", - "answer_without_memory = get_travel_info(question, use_memory=False)\n", - "answer_with_memory = get_travel_info(question, use_memory=True)" - ] - }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fu3euPKZsbaC" + }, + "outputs": [], + "source": [ + "!pip install mem0ai" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "U2VC_0FElQid" + }, + "outputs": [], + "source": [ + "import os\n", + "from openai import OpenAI\n", + "from mem0 import MemoryClient\n", + "from multion.client import MultiOn\n", + "\n", + "# Configuration\n", + "OPENAI_API_KEY = \"sk-xxx\" # Replace with your actual OpenAI API key\n", + "MULTION_API_KEY = \"xx\" # Replace with your actual MultiOn API key\n", + "MEM0_API_KEY = \"xx\" # Replace with your actual Mem0 API key\n", + "USER_ID = \"test_travel_agent\"\n", + "\n", + "# Set up OpenAI API key\n", + "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n", + "\n", + "# Initialize Mem0 and MultiOn\n", + "memory = MemoryClient(api_key=MEM0_API_KEY)\n", + "multion = MultiOn(api_key=MULTION_API_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sq-OdPHKlQie", + "outputId": "1d605222-0bf5-4ac9-99b9-6059b502c20b" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "yt2pj1irlQih" - }, - "source": [ - "| Without Memory | With Memory |\n", - "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", - "| The current weather in Paris is light rain with a temperature of 67°F. The precipitation is at 50%, humidity is 95%, and the wind speed is 5 mph. | The current weather in San Francisco is as follows:
- **Temperature**: 59°F
- **Condition**: Clear with periodic clouds
- **Precipitation**: 3%
- **Humidity**: 87%
- **Wind**: 12 mph |\n" + "data": { + "text/plain": [ + "{'message': 'Memory added successfully!'}" ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - }, - "colab": { - "provenance": [] - } + ], + "source": [ + "# Add conversation to Mem0\n", + "conversation = [\n", + " {\"role\": \"user\", \"content\": \"What are the best travel destinations in the world?\"},\n", + " {\n", + " \"role\": \"assistant\",\n", + " \"content\": \"Could you please specify your interests or the type of travel information you are looking for? This will help me find the most relevant information for you.\",\n", + " },\n", + " {\"role\": \"user\", \"content\": \"Sure, I want to travel to San Francisco.\"},\n", + " {\n", + " \"role\": \"assistant\",\n", + " \"content\": \"\"\"\n", + " Based on the information gathered from TripAdvisor, here are some popular attractions, activities, and travel tips for San Francisco:\n", + "\n", + " 1. **Golden Gate Bridge**: A must-see iconic landmark.\n", + " 2. **Alcatraz Island**: Famous former prison offering tours.\n", + " 3. **Fisherman's Wharf**: Popular tourist area with shops, restaurants, and sea lions.\n", + " 4. **Chinatown**: The largest Chinatown outside of Asia.\n", + " 5. **Golden Gate Park**: Large urban park with gardens, museums, and recreational activities.\n", + " 6. **Cable Cars**: Historic streetcars offering a unique way to see the city.\n", + " 7. **Exploratorium**: Interactive science museum.\n", + " 8. **San Francisco Museum of Modern Art (SFMOMA)**: Modern and contemporary art museum.\n", + " 9. **Lombard Street**: Known for its steep, one-block section with eight hairpin turns.\n", + " 10. **Union Square**: Major shopping and cultural hub.\n", + "\n", + " Travel Tips:\n", + " - **Weather**: San Francisco has a mild climate, but it can be foggy and windy. Dress in layers.\n", + " - **Transportation**: Use public transportation like BART, Muni, and cable cars to get around.\n", + " - **Safety**: Be aware of your surroundings, especially in crowded tourist areas.\n", + " - **Dining**: Try local specialties like sourdough bread, seafood, and Mission-style burritos.\n", + " \"\"\",\n", + " },\n", + " {\"role\": \"user\", \"content\": \"Show me hotels around Golden Gate Bridge.\"},\n", + " {\n", + " \"role\": \"assistant\",\n", + " \"content\": \"\"\"\n", + " The search results for hotels around Golden Gate Bridge in San Francisco include:\n", + "\n", + " 1. Hilton Hotels In San Francisco - Hotel Near Fishermans Wharf (hilton.com)\n", + " 2. The 10 Closest Hotels to Golden Gate Bridge (tripadvisor.com)\n", + " 3. Hotels near Golden Gate Bridge (expedia.com)\n", + " 4. Hotels near Golden Gate Bridge (hotels.com)\n", + " 5. Holiday Inn Express & Suites San Francisco Fishermans Wharf, an IHG Hotel $146 (1.8K) 3-star hotel Golden Gate Bridge • 3.5 mi DEAL 19% less than usual\n", + " 6. Holiday Inn San Francisco-Golden Gateway, an IHG Hotel $151 (3.5K) 3-star hotel Golden Gate Bridge • 3.7 mi Casual hotel with dining, a bar & a pool\n", + " 7. Hotel Zephyr San Francisco $159 (3.8K) 4-star hotel Golden Gate Bridge • 3.7 mi Nautical-themed lodging with bay views\n", + " 8. Lodge at the Presidio\n", + " 9. The Inn Above Tide\n", + " 10. Cavallo Point\n", + " 11. Casa Madrona Hotel and Spa\n", + " 12. Cow Hollow Inn and Suites\n", + " 13. Samesun San Francisco\n", + " 14. Inn on Broadway\n", + " 15. Coventry Motor Inn\n", + " 16. HI San Francisco Fisherman's Wharf Hostel\n", + " 17. Loews Regency San Francisco Hotel\n", + " 18. Fairmont Heritage Place Ghirardelli Square\n", + " 19. Hotel Drisco Pacific Heights\n", + " 20. Travelodge by Wyndham Presidio San Francisco\n", + " \"\"\",\n", + " },\n", + "]\n", + "\n", + "memory.add(conversation, user_id=USER_ID)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hO8z9aNTlQif" + }, + "outputs": [], + "source": [ + "def get_travel_info(question, use_memory=True):\n", + " \"\"\"\n", + " Get travel information based on user's question and optionally their preferences from memory.\n", + "\n", + " \"\"\"\n", + " if use_memory:\n", + " previous_memories = memory.search(question, user_id=USER_ID)\n", + " relevant_memories_text = \"\"\n", + " if previous_memories:\n", + " print(\"Using previous memories to enhance the search...\")\n", + " relevant_memories_text = \"\\n\".join(mem[\"memory\"] for mem in previous_memories)\n", + "\n", + " command = \"Find travel information based on my interests:\"\n", + " prompt = f\"{command}\\n Question: {question} \\n My preferences: {relevant_memories_text}\"\n", + " else:\n", + " command = \"Find travel information based on my interests:\"\n", + " prompt = f\"{command}\\n Question: {question}\"\n", + "\n", + " print(\"Searching for travel information...\")\n", + " browse_result = multion.browse(cmd=prompt)\n", + " return browse_result.message" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Wp2xpzMrlQig" + }, + "source": [ + "## Example 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bPRPwqsplQig" + }, + "outputs": [], + "source": [ + "question = \"Show me flight details for it.\"\n", + "answer_without_memory = get_travel_info(question, use_memory=False)\n", + "answer_with_memory = get_travel_info(question, use_memory=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a76ifa2HlQig" + }, + "source": [ + "| Without Memory | With Memory |\n", + "|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| I have performed a Google search for \"flight details\" and reviewed the search results. Here are some relevant links and information: | Memorizing the following information: Flight details for San Francisco: |\n", + "| 1. **FlightStats Global Flight Tracker** - Track the real-time flight status of your flight. See if your flight has been delayed or cancelled and track the live status.
[Flight Tracker - FlightStats](https://www.flightstats.com/flight-tracker/search) | 1. Prices from $232. Depart Thursday, August 22. Return Thursday, August 29.
2. Prices from $216. Depart Friday, August 23. Return Friday, August 30.
3. Prices from $236. Depart Saturday, August 24. Return Saturday, August 31.
4. Prices from $215. Depart Sunday, August 25. Return Sunday, September 1. |\n", + "| 2. **FlightAware - Flight Tracker** - Track live flights worldwide, see flight cancellations, and browse by airport.
[FlightAware - Flight Tracker](https://www.flightaware.com) | 5. Prices from $218. Depart Monday, August 26. Return Monday, September 2.
6. Prices from $211. Depart Tuesday, August 27. Return Tuesday, September 3.
7. Prices from $198. Depart Wednesday, August 28. Return Wednesday, September 4.
8. Prices from $218. Depart Thursday, August 29. Return Thursday, September 5. |\n", + "| 3. **Google Flights** - Show flights based on your search.
[Google Flights](https://www.google.com/flights) | 9. Prices from $194. Depart Friday, August 30. Return Friday, September 6.
10. Prices from $218. Depart Saturday, August 31. Return Saturday, September 7.
11. Prices from $212. Depart Sunday, September 1. Return Sunday, September 8.
12. Prices from $247. Depart Monday, September 2. Return Monday, September 9. |\n", + "| | 13. Prices from $212. Depart Tuesday, September 3. Return Tuesday, September 10.
14. Prices from $203. Depart Wednesday, September 4. Return Wednesday, September 11.
15. Prices from $242. Depart Thursday, September 5. Return Thursday, September 12.
16. Prices from $191. Depart Friday, September 6. Return Friday, September 13. |\n", + "| | 17. Prices from $215. Depart Saturday, September 7. Return Saturday, September 14.
18. Prices from $229. Depart Sunday, September 8. Return Sunday, September 15.
19. Prices from $183. Depart Monday, September 9. Return Monday, September 16.
65. Prices from $194. Depart Friday, October 25. Return Friday, November 1. |\n", + "| | 66. Prices from $205. Depart Saturday, October 26. Return Saturday, November 2.
67. Prices from $241. Depart Sunday, October 27. Return Sunday, November 3. |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0cXpiAwMlQig" + }, + "source": [ + "## Example 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LpprKfpslQih" + }, + "outputs": [], + "source": [ + "question = \"What places to visit there?\"\n", + "answer_without_memory = get_travel_info(question, use_memory=False)\n", + "answer_with_memory = get_travel_info(question, use_memory=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kpfjeY1_lQih" + }, + "source": [ + "| Without Memory | With Memory |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| Based on the information gathered, here are some top travel destinations to consider visiting: | Based on the information gathered, here are some top places to visit in San Francisco: |\n", + "| 1. **Paris**: Known for its iconic attractions like the Eiffel Tower and the Louvre, Paris offers quaint cafes, trendy shopping districts, and beautiful Haussmann architecture. It's a city where you can always discover something new with each visit. | 1. **Golden Gate Bridge** - An iconic symbol of San Francisco, perfect for walking, biking, or simply enjoying the view.
2. **Alcatraz Island** - The historic former prison offers tours and insights into its storied past.
3. **Fisherman's Wharf** - A bustling waterfront area known for its seafood, shopping, and attractions like Pier 39.
4. **Golden Gate Park** - A large urban park with gardens, museums, and recreational activities.
5. **Chinatown San Francisco** - One of the oldest and most famous Chinatowns in North America, offering unique shops and delicious food.
6. **Coit Tower** - Offers panoramic views of the city and murals depicting San Francisco's history.
7. **Lands End** - A beautiful coastal trail with stunning views of the Pacific Ocean and the Golden Gate Bridge.
8. **Palace of Fine Arts** - A picturesque structure and park, perfect for a leisurely stroll or photo opportunities.
9. **Crissy Field & The Presidio Tunnel Tops** - Great for outdoor activities and scenic views of the bay. |\n", + "| 2. **Bora Bora**: This small island in French Polynesia is famous for its stunning turquoise waters, luxurious overwater bungalows, and vibrant coral reefs. It's a popular destination for honeymooners and those seeking a tropical paradise. | |\n", + "| 3. **Glacier National Park**: Located in Montana, USA, this park is known for its breathtaking landscapes, including rugged mountains, pristine lakes, and diverse wildlife. It's a haven for outdoor enthusiasts and hikers. | |\n", + "| 4. **Rome**: The capital of Italy, Rome is rich in history and culture, featuring landmarks such as the Colosseum, the Vatican, and the Pantheon. It's a city where ancient history meets modern life. | |\n", + "| 5. **Swiss Alps**: Renowned for their stunning natural beauty, the Swiss Alps offer opportunities for skiing, hiking, and enjoying picturesque mountain villages. | |\n", + "| 6. **Maui**: One of Hawaii's most popular islands, Maui is known for its beautiful beaches, lush rainforests, and the scenic Hana Highway. It's a great destination for both relaxation and adventure. | |\n", + "| 7. **London, England**: A vibrant city with a mix of historical landmarks like the Tower of London and modern attractions such as the London Eye. London offers diverse cultural experiences, world-class museums, and a bustling nightlife. | |\n", + "| 8. **Maldives**: This tropical paradise in the Indian Ocean is famous for its crystal-clear waters, luxurious resorts, and abundant marine life. It's an ideal destination for snorkeling, diving, and relaxation. | |\n", + "| 9. **Turks & Caicos**: Known for its pristine beaches and turquoise waters, this Caribbean destination is perfect for water sports, beach lounging, and exploring coral reefs. | |\n", + "| 10. **Tokyo**: Japan's bustling capital offers a unique blend of traditional and modern attractions, from ancient temples to futuristic skyscrapers. Tokyo is also known for its vibrant food scene and shopping districts. | |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XdpkcMrclQih" + }, + "source": [ + "## Example 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nntl2FxulQih" + }, + "outputs": [], + "source": [ + "question = \"What the weather there?\"\n", + "answer_without_memory = get_travel_info(question, use_memory=False)\n", + "answer_with_memory = get_travel_info(question, use_memory=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yt2pj1irlQih" + }, + "source": [ + "| Without Memory | With Memory |\n", + "|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n", + "| The current weather in Paris is light rain with a temperature of 67°F. The precipitation is at 50%, humidity is 95%, and the wind speed is 5 mph. | The current weather in San Francisco is as follows:
- **Temperature**: 59°F
- **Condition**: Clear with periodic clouds
- **Precipitation**: 3%
- **Humidity**: 87%
- **Wind**: 12 mph |\n" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 0 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } \ No newline at end of file diff --git a/docs/api-reference/organization/add-org-member.mdx b/docs/api-reference/organization/add-org-member.mdx new file mode 100644 index 0000000000..53056658e5 --- /dev/null +++ b/docs/api-reference/organization/add-org-member.mdx @@ -0,0 +1,4 @@ +--- +title: 'Add Member' +openapi: post /api/v1/orgs/organizations/{org_id}/members/ +--- \ No newline at end of file diff --git a/docs/api-reference/organization/create-org.mdx b/docs/api-reference/organization/create-org.mdx new file mode 100644 index 0000000000..48f38b4b02 --- /dev/null +++ b/docs/api-reference/organization/create-org.mdx @@ -0,0 +1,4 @@ +--- +title: 'Create Organization' +openapi: post /api/v1/orgs/organizations/ +--- \ No newline at end of file diff --git a/docs/api-reference/organization/delete-org-member.mdx b/docs/api-reference/organization/delete-org-member.mdx new file mode 100644 index 0000000000..4e4c45e45f --- /dev/null +++ b/docs/api-reference/organization/delete-org-member.mdx @@ -0,0 +1,4 @@ +--- +title: 'Delete Member' +openapi: delete /api/v1/orgs/organizations/{org_id}/members/ +--- \ No newline at end of file diff --git a/docs/api-reference/organization/delete-org.mdx b/docs/api-reference/organization/delete-org.mdx new file mode 100644 index 0000000000..bcc0c3d1da --- /dev/null +++ b/docs/api-reference/organization/delete-org.mdx @@ -0,0 +1,4 @@ +--- +title: 'Delete Organization' +openapi: delete /api/v1/orgs/organizations/{org_id}/ +--- \ No newline at end of file diff --git a/docs/api-reference/organization/get-org-members.mdx b/docs/api-reference/organization/get-org-members.mdx new file mode 100644 index 0000000000..003f093494 --- /dev/null +++ b/docs/api-reference/organization/get-org-members.mdx @@ -0,0 +1,4 @@ +--- +title: 'Get Members' +openapi: get /api/v1/orgs/organizations/{org_id}/members/ +--- \ No newline at end of file diff --git a/docs/api-reference/organization/get-org.mdx b/docs/api-reference/organization/get-org.mdx new file mode 100644 index 0000000000..232312365e --- /dev/null +++ b/docs/api-reference/organization/get-org.mdx @@ -0,0 +1,4 @@ +--- +title: 'Get Organization' +openapi: get /api/v1/orgs/organizations/{org_id}/ +--- \ No newline at end of file diff --git a/docs/api-reference/organization/get-orgs.mdx b/docs/api-reference/organization/get-orgs.mdx new file mode 100644 index 0000000000..ddb6594fea --- /dev/null +++ b/docs/api-reference/organization/get-orgs.mdx @@ -0,0 +1,4 @@ +--- +title: 'Get Organizations' +openapi: get /api/v1/orgs/organizations/ +--- \ No newline at end of file diff --git a/docs/api-reference/organization/update-org-member.mdx b/docs/api-reference/organization/update-org-member.mdx new file mode 100644 index 0000000000..126787fa9d --- /dev/null +++ b/docs/api-reference/organization/update-org-member.mdx @@ -0,0 +1,4 @@ +--- +title: 'Update Member' +openapi: put /api/v1/orgs/organizations/{org_id}/members/ +--- \ No newline at end of file diff --git a/docs/api-reference/overview.mdx b/docs/api-reference/overview.mdx index 44305f71f7..99b393fce7 100644 --- a/docs/api-reference/overview.mdx +++ b/docs/api-reference/overview.mdx @@ -23,6 +23,32 @@ Our API is organized into several main categories: All API requests require authentication using HTTP Basic Auth. Ensure you include your API key in the Authorization header of each request. +## Organizations and projects (optional) + +For users who belong to multiple organizations or are working on multiple projects, you can specify the organization and project for an API request. This is done by initializing the Mem0 client with the appropriate parameters. Usage from these API requests will be attributed to the specified organization and project. + +Example with the mem0 Python package: + +```python +from mem0 import MemoryClient + +client = MemoryClient( + organization_name='YOUR_ORG_NAME', + project_name='YOUR_PROJECT_NAME', +) +``` + +Example with the mem0 Node.js package: + +```javascript +import { MemoryClient } from "mem0ai"; + +const client = new MemoryClient({ + organization: "YOUR_ORG_NAME", + project: "YOUR_PROJECT_NAME" +}); +``` + ## Getting Started To begin using the Mem0 API, you'll need to: diff --git a/docs/api-reference/project/add-project-member.mdx b/docs/api-reference/project/add-project-member.mdx new file mode 100644 index 0000000000..721fd45f32 --- /dev/null +++ b/docs/api-reference/project/add-project-member.mdx @@ -0,0 +1,4 @@ +--- +title: 'Add Member' +openapi: post /api/v1/orgs/organizations/{org_id}/projects/{project_id}/members/ +--- \ No newline at end of file diff --git a/docs/api-reference/project/create-project.mdx b/docs/api-reference/project/create-project.mdx new file mode 100644 index 0000000000..24f18f5586 --- /dev/null +++ b/docs/api-reference/project/create-project.mdx @@ -0,0 +1,4 @@ +--- +title: 'Create Project' +openapi: post /api/v1/orgs/organizations/{org_id}/projects/ +--- \ No newline at end of file diff --git a/docs/api-reference/project/delete-project-member.mdx b/docs/api-reference/project/delete-project-member.mdx new file mode 100644 index 0000000000..3099cae726 --- /dev/null +++ b/docs/api-reference/project/delete-project-member.mdx @@ -0,0 +1,4 @@ +--- +title: 'Delete Member' +openapi: delete /api/v1/orgs/organizations/{org_id}/projects/{project_id}/members/ +--- \ No newline at end of file diff --git a/docs/api-reference/project/delete-project.mdx b/docs/api-reference/project/delete-project.mdx new file mode 100644 index 0000000000..96fb20da34 --- /dev/null +++ b/docs/api-reference/project/delete-project.mdx @@ -0,0 +1,4 @@ +--- +title: 'Delete Project' +openapi: delete /api/v1/orgs/organizations/{org_id}/projects/{project_id}/ +--- \ No newline at end of file diff --git a/docs/api-reference/project/get-project-members.mdx b/docs/api-reference/project/get-project-members.mdx new file mode 100644 index 0000000000..42171dce83 --- /dev/null +++ b/docs/api-reference/project/get-project-members.mdx @@ -0,0 +1,4 @@ +--- +title: 'Get Members' +openapi: get /api/v1/orgs/organizations/{org_id}/projects/{project_id}/members/ +--- \ No newline at end of file diff --git a/docs/api-reference/project/get-project.mdx b/docs/api-reference/project/get-project.mdx new file mode 100644 index 0000000000..219f2a215b --- /dev/null +++ b/docs/api-reference/project/get-project.mdx @@ -0,0 +1,4 @@ +--- +title: 'Get Project' +openapi: get /api/v1/orgs/organizations/{org_id}/projects/{project_id}/ +--- \ No newline at end of file diff --git a/docs/api-reference/project/get-projects.mdx b/docs/api-reference/project/get-projects.mdx new file mode 100644 index 0000000000..f484adf4bb --- /dev/null +++ b/docs/api-reference/project/get-projects.mdx @@ -0,0 +1,4 @@ +--- +title: 'Get Projects' +openapi: get /api/v1/orgs/organizations/{org_id}/projects/ +--- \ No newline at end of file diff --git a/docs/api-reference/project/update-project-member.mdx b/docs/api-reference/project/update-project-member.mdx new file mode 100644 index 0000000000..de438d92e2 --- /dev/null +++ b/docs/api-reference/project/update-project-member.mdx @@ -0,0 +1,4 @@ +--- +title: 'Update Member' +openapi: put /api/v1/orgs/organizations/{org_id}/projects/{project_id}/members/ +--- \ No newline at end of file diff --git a/docs/components/embedders/config.mdx b/docs/components/embedders/config.mdx index 91731be882..95860b8078 100644 --- a/docs/components/embedders/config.mdx +++ b/docs/components/embedders/config.mdx @@ -53,6 +53,7 @@ Here's a comprehensive list of all parameters that can be used across different | `model_kwargs` | Key-Value arguments for the Huggingface embedding model | | `azure_kwargs` | Key-Value arguments for the AzureOpenAI embedding model | | `openai_base_url` | Base URL for OpenAI API | OpenAI | +| `vertex_credentials_json` | Path to the Google Cloud credentials JSON file for VertexAI | ## Supported Embedding Models diff --git a/docs/components/embedders/models/vertexai.mdx b/docs/components/embedders/models/vertexai.mdx new file mode 100644 index 0000000000..1fe8b95eef --- /dev/null +++ b/docs/components/embedders/models/vertexai.mdx @@ -0,0 +1,35 @@ +### Vertex AI + +To use Google Cloud's Vertex AI for text embedding models, set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to point to the path of your service account's credentials JSON file. These credentials can be created in the [Google Cloud Console](https://console.cloud.google.com/). + +### Usage + +```python +import os +from mem0 import Memory + +# Set the path to your Google Cloud credentials JSON file +os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/path/to/your/credentials.json" + +config = { + "embedder": { + "provider": "vertexai", + "config": { + "model": "text-embedding-004" + } + } +} + +m = Memory.from_config(config) +m.add("I'm visiting Paris", user_id="john") +``` + +### Config + +Here are the parameters available for configuring the Vertex AI embedder: + +| Parameter | Description | Default Value | +| ------------------------- | ------------------------------------------------ | -------------------- | +| `model` | The name of the Vertex AI embedding model to use | `text-embedding-004` | +| `vertex_credentials_json` | Path to the Google Cloud credentials JSON file | `None` | +| `embedding_dims` | Dimensions of the embedding model | `256` | diff --git a/docs/components/embedders/overview.mdx b/docs/components/embedders/overview.mdx index 2cd78c1282..f1d7d7e584 100644 --- a/docs/components/embedders/overview.mdx +++ b/docs/components/embedders/overview.mdx @@ -13,6 +13,7 @@ See the list of supported embedders below. + ## Usage diff --git a/docs/components/vectordbs/config.mdx b/docs/components/vectordbs/config.mdx index fe0f1fcdda..7ddff7c9c0 100644 --- a/docs/components/vectordbs/config.mdx +++ b/docs/components/vectordbs/config.mdx @@ -6,7 +6,7 @@ Config in mem0 is a dictionary that specifies the settings for your vector datab The config is defined as a Python dictionary with two main keys: - `vector_store`: Specifies the vector database provider and its configuration - - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant") + - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus") - `config`: A nested dictionary containing provider-specific settings ## How to Use Config diff --git a/docs/components/vectordbs/dbs/milvus.mdx b/docs/components/vectordbs/dbs/milvus.mdx new file mode 100644 index 0000000000..12193f462b --- /dev/null +++ b/docs/components/vectordbs/dbs/milvus.mdx @@ -0,0 +1,35 @@ +[Milvus](https://milvus.io/) Milvus is an open-source vector database that suits AI applications of every size from running a demo chatbot in Jupyter notebook to building web-scale search that serves billions of users. + +### Usage + +```python +import os +from mem0 import Memory + +config = { + "vector_store": { + "provider": "milvus", + "config": { + "collection_name": "test", + "embedding_model_dims": "123", + "url": "127.0.0.1", + "token": "8e4b8ca8cf2c67", + } + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` + +### Config + +Here's the parameters available for configuring Milvus Database: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `url` | Full URL/Uri for Milvus/Zilliz server | `http://localhost:19530` | +| `token` | Token for Zilliz server / for local setup defaults to None. | `None` | +| `collection_name` | The name of the collection | `mem0` | +| `embedding_model_dims` | Dimensions of the embedding model | `1536` | +| `metric_type` | Metric type for similarity search | `L2` | diff --git a/docs/features/custom-prompts.mdx b/docs/features/custom-prompts.mdx new file mode 100644 index 0000000000..ff6fe0e67f --- /dev/null +++ b/docs/features/custom-prompts.mdx @@ -0,0 +1,109 @@ +--- +title: Custom Prompts +description: 'Enhance your product experience by adding custom prompts tailored to your needs' +--- + +## Introduction to Custom Prompts + +Custom prompts allow you to tailor the behavior of your Mem0 instance to specific use cases or domains. +By defining a custom prompt, you can control how information is extracted, processed, and stored in your memory system. + +To create an effective custom prompt: +1. Be specific about the information to extract. +2. Provide few-shot examples to guide the LLM. +3. Ensure examples follow the format shown below. + +Example of a custom prompt: + +```python +custom_prompt = """ +Please only extract entities containing customer support information, order details, and user information. +Here are some few shot examples: + +Input: Hi. +Output: {{"facts" : []}} + +Input: The weather is nice today. +Output: {{"facts" : []}} + +Input: My order #12345 hasn't arrived yet. +Output: {{"facts" : ["Order #12345 not received"]}} + +Input: I'm John Doe, and I'd like to return the shoes I bought last week. +Output: {{"facts" : ["Customer name: John Doe", "Wants to return shoes", "Purchase made last week"]}} + +Input: I ordered a red shirt, size medium, but received a blue one instead. +Output: {{"facts" : ["Ordered red shirt, size medium", "Received blue shirt instead"]}} + +Return the facts and customer information in a json format as shown above. +""" + +``` + +Here we initialize the custom prompt in the config. + +```python +from mem0 import Memory + +config = { + "llm": { + "provider": "openai", + "config": { + "model": "gpt-4o", + "temperature": 0.2, + "max_tokens": 1500, + } + }, + "custom_prompt": custom_prompt, + "version": "v1.1" +} + +m = Memory.from_config(config_dict=config, user_id="alice") +``` + +### Example 1 + +In this example, we are adding a memory of a user ordering a laptop. As seen in the output, the custom prompt is used to extract the relevant information from the user's message. + + +```python Code +m.add("Yesterday, I ordered a laptop, the order id is 12345", user_id="alice") +``` + +```json Output +{ + "results": [ + { + "memory": "Ordered a laptop", + "event": "ADD" + }, + { + "memory": "Order ID: 12345", + "event": "ADD" + }, + { + "memory": "Order placed yesterday", + "event": "ADD" + } + ], + "relations": [] +} +``` + + +### Example 2 + +In this example, we are adding a memory of a user liking to go on hikes. This add message is not specific to the use-case mentioned in the custom prompt. +Hence, the memory is not added. + +```python Code +m.add("I like going to hikes", user_id="alice") +``` + +```json Output +{ + "results": [], + "relations": [] +} +``` + diff --git a/docs/mint.json b/docs/mint.json index 481b79286a..4832bf604b 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -108,7 +108,8 @@ "pages": [ "components/vectordbs/dbs/chroma", "components/vectordbs/dbs/pgvector", - "components/vectordbs/dbs/qdrant" + "components/vectordbs/dbs/qdrant", + "components/vectordbs/dbs/milvus" ] } ] @@ -131,7 +132,7 @@ }, { "group": "Features", - "pages": ["features/openai_compatibility"] + "pages": ["features/openai_compatibility", "features/custom-prompts"] } ] }, @@ -159,6 +160,42 @@ "api-reference/entities/get-users", "api-reference/entities/delete-user" ] + }, + { + "group": "Organizations APIs", + "pages": [ + "api-reference/organization/get-orgs", + "api-reference/organization/get-org", + "api-reference/organization/create-org", + "api-reference/organization/delete-org", + { + "group": "Members APIs", + "pages": [ + "api-reference/organization/get-org-members", + "api-reference/organization/add-org-member", + "api-reference/organization/update-org-member", + "api-reference/organization/delete-org-member" + ] + } + ] + }, + { + "group": "Projects APIs", + "pages": [ + "api-reference/project/get-projects", + "api-reference/project/get-project", + "api-reference/project/create-project", + "api-reference/project/delete-project", + { + "group": "Members APIs", + "pages":[ + "api-reference/project/get-project-members", + "api-reference/project/add-project-member", + "api-reference/project/update-project-member", + "api-reference/project/delete-project-member" + ] + } + ] } ] }, diff --git a/docs/open-source/graph_memory/features.mdx b/docs/open-source/graph_memory/features.mdx index dd2cc3e352..8cc35d3c30 100644 --- a/docs/open-source/graph_memory/features.mdx +++ b/docs/open-source/graph_memory/features.mdx @@ -19,11 +19,13 @@ from mem0 import Memory config = { "graph_store": { - "provider": "neo4j", + "provider": "falkordb", "config": { - "url": "neo4j+s://xxx", - "username": "neo4j", - "password": "xxx" + "Database": "falkordb", + "host": "---" + "username": "---", + "password": "---", + "port": "---" }, "custom_prompt": "Please only extract entities containing sports related relationships and nothing else.", }, diff --git a/docs/open-source/graph_memory/overview.mdx b/docs/open-source/graph_memory/overview.mdx index 75bf94bf43..127acc8146 100644 --- a/docs/open-source/graph_memory/overview.mdx +++ b/docs/open-source/graph_memory/overview.mdx @@ -3,11 +3,11 @@ title: Overview description: 'Enhance your memory system with graph-based knowledge representation and retrieval' --- -Mem0 now supports **Graph Memory**. +Mem0 now supports **Graph Memory**. With Graph Memory, users can now create and utilize complex relationships between pieces of information, allowing for more nuanced and context-aware responses. This integration enables users to leverage the strengths of both vector-based and graph-based approaches, resulting in more accurate and comprehensive information retrieval and generation. -Try Graph Memory on Google Colab. +Try Graph Memory on Google Colab. Open In Colab @@ -26,9 +26,11 @@ allowfullscreen ## Initialize Graph Memory To initialize Graph Memory you'll need to set up your configuration with graph store providers. -Currently, we support Neo4j as a graph store provider. You can setup [Neo4j](https://neo4j.com/) locally or use the hosted [Neo4j AuraDB](https://neo4j.com/product/auradb/). +Currently, we support FalkorDB and Neo4j as a graph store providers. You can set up [FalkorDB](https://www.falkordb.com/) or [Neo4j](https://neo4j.com/) locally or use the hosted [FalkorDB Cloud](https://app.falkordb.cloud/) or [Neo4j AuraDB](https://neo4j.com/product/auradb/). Moreover, you also need to set the version to `v1.1` (*prior versions are not supported*). +If you are using Neo4j locally, then you need to install [APOC plugins](https://neo4j.com/labs/apoc/4.1/installation/). + User can also customize the LLM for Graph Memory from the [Supported LLM list](https://docs.mem0.ai/components/llms/overview) with three levels of configuration: 1. **Main Configuration**: If `llm` is set in the main config, it will be used for all graph operations. @@ -44,11 +46,13 @@ from mem0 import Memory config = { "graph_store": { - "provider": "neo4j", + "provider": "falkordb", "config": { - "url": "neo4j+s://xxx", - "username": "neo4j", - "password": "xxx" + "Database": "falkordb", + "host": "---" + "username": "---", + "password": "---", + "port": "---" } }, "version": "v1.1" @@ -70,11 +74,13 @@ config = { } }, "graph_store": { - "provider": "neo4j", + "provider": "falkordb", "config": { - "url": "neo4j+s://xxx", - "username": "neo4j", - "password": "xxx" + "Database": "falkordb", + "host": "---" + "username": "---", + "password": "---", + "port": "---" }, "llm" : { "provider": "openai", diff --git a/docs/open-source/quickstart.mdx b/docs/open-source/quickstart.mdx index f62060f3cc..d7cb014ee2 100644 --- a/docs/open-source/quickstart.mdx +++ b/docs/open-source/quickstart.mdx @@ -63,11 +63,13 @@ from mem0 import Memory config = { "graph_store": { - "provider": "neo4j", + "provider": "falkordb", "config": { - "url": "neo4j+s://---", - "username": "neo4j", - "password": "---" + "Database": "falkordb", + "host": "---" + "username": "---", + "password": "---", + "port": "---" } }, "version": "v1.1" diff --git a/docs/openapi.json b/docs/openapi.json index 038cf3dae7..abdd0015f8 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -94,6 +94,24 @@ "entities" ], "operationId": "entities_list", + "parameters": [ + { + "name": "org_name", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Filter entities by organization name" + }, + { + "name": "project_name", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Filter entities by project name" + } + ], "responses": { "200": { "description": "", @@ -183,6 +201,24 @@ "entities" ], "operationId": "entities_read", + "parameters": [ + { + "name": "org_name", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Filter entities by organization name" + }, + { + "name": "project_name", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Filter entities by project name" + } + ], "responses": { "200": { "description": "", @@ -239,6 +275,22 @@ } } } + }, + "400": { + "description": "Invalid entity type", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Invalid entity type" + } + } + } + } + } } } } @@ -310,6 +362,22 @@ "description": "Filter memories by metadata (JSON string)", "style": "deepObject", "explode": true + }, + { + "name": "org_name", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Filter memories by organization name" + }, + { + "name": "project_name", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Filter memories by project name" } ], "responses": { @@ -493,6 +561,22 @@ "description": "Filter memories by metadata (JSON string)", "style": "deepObject", "explode": true + }, + { + "name": "org_name", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Filter memories by organization name" + }, + { + "name": "project_name", + "in": "query", + "schema": { + "type": "string" + }, + "description": "Filter memories by project name" } ], "responses": { @@ -1203,138 +1287,1524 @@ }, "x-codegen-request-body-name": "data" } - } - }, - "components": { - "schemas": { - "CreateAgent": { - "required": [ - "agent_id" + }, + "/api/v1/orgs/organizations/": { + "get": { + "tags": [ + "organizations" ], - "type": "object", - "properties": { - "agent_id": { - "title": "Agent id", - "minLength": 1, - "type": "string" - }, - "name": { - "title": "Name", - "minLength": 1, - "type": "string" - }, - "metadata": { - "title": "Metadata", - "type": "object", - "properties": { - + "operationId": "organizations_read", + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "Unique identifier for the organization" + }, + "org_id": { + "type": "string", + "description": "Organization's unique string identifier" + }, + "name": { + "type": "string", + "description": "Name of the organization" + }, + "description": { + "type": "string", + "description": "Brief description of the organization" + }, + "address": { + "type": "string", + "description": "Physical address of the organization" + }, + "contact_email": { + "type": "string", + "description": "Primary contact email for the organization" + }, + "phone_number": { + "type": "string", + "description": "Contact phone number for the organization" + }, + "website": { + "type": "string", + "description": "Official website URL of the organization" + }, + "on_paid_plan": { + "type": "boolean", + "description": "Indicates whether the organization is on a paid plan" + }, + "created_at": { + "type": "string", + "format": "date-time", + "description": "Timestamp of when the organization was created" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "description": "Timestamp of when the organization was last updated" + }, + "owner": { + "type": "integer", + "description": "Identifier of the organization's owner" + }, + "members": { + "type": "array", + "items": { + "type": "integer" + }, + "description": "List of member identifiers belonging to the organization" + } + } + } + } + } } } } }, - "CreateApp": { - "required": [ - "app_id" + "post": { + "tags": [ + "organizations" ], - "type": "object", - "properties": { - "app_id": { - "title": "App id", - "minLength": 1, - "type": "string" - }, - "name": { - "title": "Name", - "minLength": 1, - "type": "string" - }, - "metadata": { - "title": "Metadata", - "type": "object", - "properties": { - + "description": "Create a new organization.", + "operationId": "create_organization", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the new organization" + } + }, + "required": ["name"] + } } } - } - }, - "MemoryInput": { - "type": "object", - "properties": { - "messages": { - "description": "An array of message objects representing the content of the memory. Each message object typically contains 'role' and 'content' fields, where 'role' indicates the sender (e.g., 'user', 'assistant', 'system') and 'content' contains the actual message text. This structure allows for the representation of conversations or multi-part memories.", - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "type": "string", - "nullable": true + }, + "responses": { + "201": { + "description": "Successfully created a new organization", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization created successfully." + } + } + } } } }, - "agent_id": { - "description": "The unique identifier of the agent associated with this memory.", - "title": "Agent id", - "type": "string", - "nullable": true - }, - "user_id": { - "description": "The unique identifier of the user associated with this memory.", - "title": "User id", - "type": "string", - "nullable": true - }, - "app_id": { - "description": "The unique identifier of the application associated with this memory.", - "title": "App id", - "type": "string", - "nullable": true - }, - "run_id": { - "description": "The unique identifier of the run associated with this memory.", - "title": "Run id", - "type": "string", - "nullable": true - }, - "metadata": { - "description": "Additional metadata associated with the memory, which can be used to store any additional information or context about the memory.", - "title": "Metadata", - "type": "object", - "properties": { - - }, - "nullable": true - }, - "includes": { - "description": "String to include the specific preferences in the memory.", - "title": "Includes", - "minLength": 1, - "type": "string", - "nullable": true - }, - "excludes": { - "description": "String to exclude the specific preferences in the memory.", - "title": "Excludes", - "minLength": 1, - "type": "string", - "nullable": true - }, - "infer": { - "description": "Wether to infer the memories or directly store the messages.", - "title": "Infer", - "type": "boolean", - "default": true - }, - "custom_categories": { - "description": "A list of categories with category name and it's description.", - "title": "Custom categories", - "type": "object", - "properties": { - - }, - "nullable": true + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "errors": { + "type": "object", + "description": "Errors found in the payload", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/v1/orgs/organizations/{org_id}/": { + "get": { + "tags": [ + "organizations" + ], + "description": "Get a organization.", + "operationId": "get_organization", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "The unique identifier of the organization", + "schema": { + "type": "string", + "format": "uuid" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "Unique identifier for the organization" + }, + "org_id": { + "type": "string", + "description": "Unique organization ID" + }, + "name": { + "type": "string", + "description": "Name of the organization" + }, + "description": { + "type": "string", + "description": "Description of the organization" + }, + "address": { + "type": "string", + "description": "Address of the organization" + }, + "contact_email": { + "type": "string", + "format": "email", + "description": "Contact email for the organization" + }, + "phone_number": { + "type": "string", + "description": "Phone number of the organization" + }, + "website": { + "type": "string", + "format": "uri", + "description": "Website of the organization" + }, + "on_paid_plan": { + "type": "boolean", + "description": "Indicates if the organization is on a paid plan" + }, + "created_at": { + "type": "string", + "format": "date-time", + "description": "Timestamp of when the organization was created" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "description": "Timestamp of when the organization was last updated" + }, + "owner": { + "type": "integer", + "description": "Identifier of the organization's owner" + }, + "members": { + "type": "array", + "items": { + "type": "integer" + }, + "description": "List of member identifiers belonging to the organization" + } + } + } + } + } + }, + "404": { + "description": "Organization not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization not found" + } + } + } + } + } + } + } + }, + "delete": { + "tags": [ + "organizations" + ], + "summary": "Delete an organization", + "description": "Delete an organization by its ID.", + "operationId": "delete_organization", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization to delete", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Organization deleted successfully!", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization deleted successfully!" + } + } + } + } + } + }, + "403": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Unauthorized" + } + } + } + } + } + }, + "404": { + "description": "Organization not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization not found" + } + } + } + } + } + } + } + } + }, + "/api/v1/orgs/organizations/{org_id}/members/": { + "get": { + "tags": [ + "organizations" + ], + "summary": "Get organization members", + "description": "Retrieve a list of members for a specific organization.", + "operationId": "get_organization_members", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "members": { + "type": "array", + "items": { + "type": "object", + "properties": { + "user_id": { + "type": "string", + "description": "Unique identifier of the member" + }, + "role": { + "type": "string", + "description": "Role of the member in the organization" + } + } + }, + "description": "List of members belonging to the organization" + } + } + } + } + } + }, + "404": { + "description": "Organization not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization not found" + } + } + } + } + } + } + } + }, + "put": { + "tags": [ + "organizations" + ], + "summary": "Update organization member role", + "description": "Update the role of an existing member in a specific organization.", + "operationId": "update_organization_member_role", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "username", + "role" + ], + "properties": { + "username": { + "type": "string", + "description": "Username of the member whose role is to be updated" + }, + "role": { + "type": "string", + "description": "New role of the member in the organization" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "User role updated successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "User role updated successfully" + } + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "errors": { + "type": "object", + "description": "Errors found in the payload", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + } + } + } + } + } + }, + "404": { + "description": "Organization not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization not found" + } + } + } + } + } + } + } + }, + "post": { + "tags": [ + "organizations" + ], + "summary": "Add organization member", + "description": "Add a new member to a specific organization.", + "operationId": "add_organization_member", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "username", + "role" + ], + "properties": { + "username": { + "type": "string", + "description": "Username of the member to be added" + }, + "role": { + "type": "string", + "description": "Role of the member in the organization" + } + } + } + } + } + }, + "responses": { + "201": { + "description": "Member added successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "User added to the organization." + } + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "errors": { + "type": "object", + "description": "Errors found in the payload", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + } + } + } + } + } + }, + "404": { + "description": "Organization not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization not found" + } + } + } + } + } + } + } + }, + "delete": { + "tags": [ + "organizations" + ], + "summary": "Remove a member from the organization", + "operationId": "remove_organization_member", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "username" + ], + "properties": { + "username": { + "type": "string", + "description": "Username of the member to be removed" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Member removed successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "User removed from organization." + } + } + } + } + } + }, + "404": { + "description": "Organization not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization not found" + } + } + } + } + } + } + } + } + }, + "/api/v1/orgs/organizations/{org_id}/projects/": { + "get": { + "tags": [ + "projects" + ], + "summary": "Get projects", + "description": "Retrieve a list of projects for a specific organization.", + "operationId": "get_projects", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "Unique numeric identifier of the project" + }, + "project_id": { + "type": "string", + "description": "Unique string identifier of the project" + }, + "name": { + "type": "string", + "description": "Name of the project" + }, + "description": { + "type": "string", + "description": "Description of the project" + }, + "created_at": { + "type": "string", + "format": "date-time", + "description": "Timestamp of when the project was created" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "description": "Timestamp of when the project was last updated" + }, + "members": { + "type": "array", + "items": { + "type": "object", + "properties": { + "username": { + "type": "string", + "description": "Username of the project member" + }, + "role": { + "type": "string", + "description": "Role of the member in the project" + } + } + }, + "description": "List of members belonging to the project" + } + } + } + } + } + } + } + } + }, + "post": { + "tags": [ + "projects" + ], + "summary": "Create project", + "description": "Create a new project within an organization.", + "operationId": "create_project", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "name" + ], + "properties": { + "name": { + "type": "string", + "description": "Name of the project to be created" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Project created successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Project created successfully." + } + } + } + } + } + }, + "403": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Unauthorized to create projects in this organization." + } + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Project could not be created." + } + } + } + } + } + } + } + } + }, + "/api/v1/orgs/organizations/{org_id}/projects/{project_id}/": { + "get": { + "tags": [ + "projects" + ], + "summary": "Get project details", + "description": "Retrieve details of a specific project within an organization.", + "operationId": "get_project", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + }, + { + "name": "project_id", + "in": "path", + "required": true, + "description": "Unique identifier of the project", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "Unique numeric identifier of the project" + }, + "project_id": { + "type": "string", + "description": "Unique string identifier of the project" + }, + "name": { + "type": "string", + "description": "Name of the project" + }, + "description": { + "type": "string", + "description": "Description of the project" + }, + "created_at": { + "type": "string", + "format": "date-time", + "description": "Timestamp of when the project was created" + }, + "updated_at": { + "type": "string", + "format": "date-time", + "description": "Timestamp of when the project was last updated" + }, + "members": { + "type": "array", + "items": { + "type": "object", + "properties": { + "username": { + "type": "string", + "description": "Username of the project member" + }, + "role": { + "type": "string", + "description": "Role of the member in the project" + } + } + }, + "description": "List of members belonging to the project" + } + } + } + } + } + }, + "404": { + "description": "Organization or project not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization or project not found" + } + } + } + } + } + } + } + }, + "delete": { + "tags": [ + "projects" + ], + "summary": "Delete Project", + "description": "Delete a specific project and its related data.", + "operationId": "delete_project", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + }, + { + "name": "project_id", + "in": "path", + "required": true, + "description": "Unique identifier of the project to be deleted", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Project and related data deleted successfully.", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Project and related data deleted successfully." + } + } + } + } + } + }, + "404": { + "description": "Organization or project not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization or project not found" + } + } + } + } + } + }, + "403": { + "description": "Unauthorized to modify this project", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Unauthorized to modify this project." + } + } + } + } + } + } + } + } + }, + "/api/v1/orgs/organizations/{org_id}/projects/{project_id}/members/": { + "get": { + "tags": [ + "projects" + ], + "summary": "Get Project Members", + "description": "Retrieve a list of members for a specific project.", + "operationId": "get_project_members", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + }, + { + "name": "project_id", + "in": "path", + "required": true, + "description": "Unique identifier of the project", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successfully retrieved project members", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "members": { + "type": "array", + "items": { + "type": "object", + "properties": { + "username": { + "type": "string" + }, + "role": { + "type": "string" + } + } + } + } + } + } + } + } + }, + "404": { + "description": "Organization or project not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization or project not found" + } + } + } + } + } + } + } + }, + "post": { + "tags": [ + "projects" + ], + "summary": "Add member to project", + "description": "Add a new member to a specific project within an organization.", + "operationId": "add_project_member", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + }, + { + "name": "project_id", + "in": "path", + "required": true, + "description": "Unique identifier of the project", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "username", + "role" + ], + "properties": { + "username": { + "type": "string", + "description": "Username of the member to be added" + }, + "role": { + "type": "string", + "description": "Role of the member in the project" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "User added to the project successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "User added to the project successfully." + } + } + } + } + } + }, + "403": { + "description": "Unauthorized to modify project members", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Unauthorized to modify project members." + } + } + } + } + } + }, + "404": { + "description": "Organization or project not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization or project not found" + } + } + } + } + } + } + } + }, + "put": { + "tags": [ + "projects" + ], + "summary": "Update project member role", + "description": "Update the role of a member in a specific project within an organization.", + "operationId": "update_project_member", + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + }, + { + "name": "project_id", + "in": "path", + "required": true, + "description": "Unique identifier of the project", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "username", + "role" + ], + "properties": { + "username": { + "type": "string", + "description": "Username of the member to be updated" + }, + "role": { + "type": "string", + "description": "New role of the member in the project" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "User role updated successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "User role updated successfully." + } + } + } + } + } + }, + "403": { + "description": "Unauthorized to modify project members", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Unauthorized to modify project members." + } + } + } + } + } + }, + "404": { + "description": "Organization or project not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization or project not found" + } + } + } + } + } + } + } + }, + "delete": { + "summary": "Delete Project Member", + "operationId": "deleteProjectMember", + "tags": ["Project"], + "parameters": [ + { + "name": "org_id", + "in": "path", + "required": true, + "description": "Unique identifier of the organization", + "schema": { + "type": "string" + } + }, + { + "name": "project_id", + "in": "path", + "required": true, + "description": "Unique identifier of the project", + "schema": { + "type": "string" + } + }, + { + "name": "username", + "in": "query", + "required": true, + "description": "Username of the member to be removed", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Member removed from the project successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Member removed from the project" + } + } + } + } + } + }, + "403": { + "description": "Unauthorized to modify project members", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Unauthorized to modify project members." + } + } + } + } + } + }, + "404": { + "description": "Organization or project not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Organization or project not found" + } + } + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "CreateAgent": { + "required": [ + "agent_id" + ], + "type": "object", + "properties": { + "agent_id": { + "title": "Agent id", + "minLength": 1, + "type": "string" + }, + "name": { + "title": "Name", + "minLength": 1, + "type": "string" + }, + "metadata": { + "title": "Metadata", + "type": "object", + "properties": { + + } + } + } + }, + "CreateApp": { + "required": [ + "app_id" + ], + "type": "object", + "properties": { + "app_id": { + "title": "App id", + "minLength": 1, + "type": "string" + }, + "name": { + "title": "Name", + "minLength": 1, + "type": "string" + }, + "metadata": { + "title": "Metadata", + "type": "object", + "properties": { + + } } } }, + "MemoryInput": { + "type": "object", + "properties": { + "messages": { + "description": "An array of message objects representing the content of the memory. Each message object typically contains 'role' and 'content' fields, where 'role' indicates the sender (e.g., 'user', 'assistant', 'system') and 'content' contains the actual message text. This structure allows for the representation of conversations or multi-part memories.", + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "type": "string", + "nullable": true + } + } + }, + "agent_id": { + "description": "The unique identifier of the agent associated with this memory.", + "title": "Agent id", + "type": "string", + "nullable": true + }, + "user_id": { + "description": "The unique identifier of the user associated with this memory.", + "title": "User id", + "type": "string", + "nullable": true + }, + "app_id": { + "description": "The unique identifier of the application associated with this memory.", + "title": "App id", + "type": "string", + "nullable": true + }, + "run_id": { + "description": "The unique identifier of the run associated with this memory.", + "title": "Run id", + "type": "string", + "nullable": true + }, + "metadata": { + "description": "Additional metadata associated with the memory, which can be used to store any additional information or context about the memory.", + "title": "Metadata", + "type": "object", + "properties": { + + }, + "nullable": true + }, + "includes": { + "description": "String to include the specific preferences in the memory.", + "title": "Includes", + "minLength": 1, + "type": "string", + "nullable": true + }, + "excludes": { + "description": "String to exclude the specific preferences in the memory.", + "title": "Excludes", + "minLength": 1, + "type": "string", + "nullable": true + }, + "infer": { + "description": "Wether to infer the memories or directly store the messages.", + "title": "Infer", + "type": "boolean", + "default": true + }, + "custom_categories": { + "description": "A list of categories with category name and it's description.", + "title": "Custom categories", + "type": "object", + "properties": { + + }, + "nullable": true + }, + "org_name": { + "description": "The name of the organization associated with this memory.", + "title": "Organization name", + "type": "string", + "nullable": true + }, + "project_name": { + "description": "The name of the project associated with this memory.", + "title": "Project name", + "type": "string", + "nullable": true + } + } + }, "MemorySearchInput": { "required": [ "query" @@ -1399,6 +2869,18 @@ "type": "boolean", "default": false, "description": "Whether to rerank the memories." + }, + "org_name": { + "title": "Organization Name", + "type": "string", + "nullable": true, + "description": "The name of the organization associated with the memory." + }, + "project_name": { + "title": "Project Name", + "type": "string", + "nullable": true, + "description": "The name of the project associated with the memory." } } }, @@ -1441,6 +2923,18 @@ "type": "boolean", "default": false, "description": "Whether to rerank the memories." + }, + "org_name": { + "title": "Organization Name", + "type": "string", + "nullable": true, + "description": "The name of the organization associated with the memory." + }, + "project_name": { + "title": "Project Name", + "type": "string", + "nullable": true, + "description": "The name of the project associated with the memory." } } }, diff --git a/docs/platform/quickstart.mdx b/docs/platform/quickstart.mdx index 95d34605af..6732253f68 100644 --- a/docs/platform/quickstart.mdx +++ b/docs/platform/quickstart.mdx @@ -130,10 +130,10 @@ messages = [ ] # The default output_format is v1.0 -client.add(messages, user_id="alex123", session_id="trip-planning-2024", output_format="v1.0") +client.add(messages, user_id="alex123", run_id="trip-planning-2024", output_format="v1.0") # To use the latest output_format, set the output_format parameter to "v1.1" -client.add(messages, user_id="alex123", session_id="trip-planning-2024", output_format="v1.1") +client.add(messages, user_id="alex123", run_id="trip-planning-2024", output_format="v1.1") ``` ```javascript JavaScript @@ -143,7 +143,7 @@ const messages = [ {"role": "user", "content": "Yes, please! Especially in Tokyo."}, {"role": "assistant", "content": "Great! I'll remember that you're interested in vegetarian restaurants in Tokyo for your upcoming trip. I'll prepare a list for you in our next interaction."} ]; -client.add(messages, { user_id: "alex123", session_id: "trip-planning-2024", output_format: "v1.1" }) +client.add(messages, { user_id: "alex123", run_id: "trip-planning-2024", output_format: "v1.1" }) .then(response => console.log(response)) .catch(error => console.error(error)); ``` @@ -160,7 +160,7 @@ curl -X POST "https://api.mem0.ai/v1/memories/" \ {"role": "assistant", "content": "Great! I'll remember that you're interested in vegetarian restaurants in Tokyo for your upcoming trip. I'll prepare a list for you in our next interaction."} ], "user_id": "alex123", - "session_id": "trip-planning-2024", + "run_id": "trip-planning-2024", "output_format": "v1.1" }' ``` @@ -186,7 +186,9 @@ curl -X POST "https://api.mem0.ai/v1/memories/" \ - + + Please use `run_id` instead of `session_id`. The `session_id` parameter is deprecated and will be removed in version 0.1.20. + #### Long-term memory for agents @@ -527,7 +529,7 @@ curl -X POST "https://api.mem0.ai/v1/memories/search/?version=v2" \ ### 4.3 Get All Users -Get all users, agents, and sessions which have memories associated with them. +Get all users, agents, and runs which have memories associated with them. @@ -579,7 +581,7 @@ curl -X GET "https://api.mem0.ai/v1/entities/" \ ### 4.4 Get All Memories -Fetch all memories for a user, agent, or session using the getAll() method. +Fetch all memories for a user, agent, or run using the getAll() method. The `get_all` method supports two output formats: `v1.0` (default) and `v1.1`. To use the latest format, which provides more detailed information about each memory operation, set the `output_format` parameter to `v1.1`: @@ -735,20 +737,20 @@ curl -X GET "https://api.mem0.ai/v1/memories/?agent_id=travel-assistant&output_f ```python Python # The default output_format is v1.0 -short_term_memories = client.get_all(user_id="alex123", session_id="trip-planning-2024", output_format="v1.0") +short_term_memories = client.get_all(user_id="alex123", run_id="trip-planning-2024", output_format="v1.0") # To use the latest output_format (v1.1), set the output_format parameter to "v1.1" -short_term_memories = client.get_all(user_id="alex123", session_id="trip-planning-2024", output_format="v1.1") +short_term_memories = client.get_all(user_id="alex123", run_id="trip-planning-2024", output_format="v1.1") ``` ```javascript JavaScript -client.getAll({ user_id: "alex123", session_id: "trip-planning-2024", output_format: "v1.1" }) +client.getAll({ user_id: "alex123", run_id: "trip-planning-2024", output_format: "v1.1" }) .then(memories => console.log(memories)) .catch(error => console.error(error)); ``` ```bash cURL -curl -X GET "https://api.mem0.ai/v1/memories/?user_id=alex123&session_id=trip-planning-2024&output_format=v1.1" \ +curl -X GET "https://api.mem0.ai/v1/memories/?user_id=alex123&run_id=trip-planning-2024&output_format=v1.1" \ -H "Authorization: Token your-api-key" ``` @@ -1050,7 +1052,7 @@ client.delete_users() ``` ```json Output -{'message': 'All users, agents, and sessions deleted.'} +{'message': 'All users, agents, and runs deleted.'} ``` diff --git a/mem0/client/main.py b/mem0/client/main.py index 737c96e809..e94c6f836e 100644 --- a/mem0/client/main.py +++ b/mem0/client/main.py @@ -1,5 +1,6 @@ import logging import os +import warnings from functools import wraps from typing import Any, Dict, List, Optional, Union @@ -9,6 +10,11 @@ from mem0.memory.telemetry import capture_client_event logger = logging.getLogger(__name__) +warnings.filterwarnings( + "always", + category=DeprecationWarning, + message="The 'session_id' parameter is deprecated. User 'run_id' instead.", +) # Setup user config setup_config() @@ -80,14 +86,10 @@ def _validate_api_key(self): response = self.client.get("/v1/memories/", params={"user_id": "test"}) response.raise_for_status() except httpx.HTTPStatusError: - raise ValueError( - "Invalid API Key. Please get a valid API Key from https://app.mem0.ai" - ) + raise ValueError("Invalid API Key. Please get a valid API Key from https://app.mem0.ai") @api_error_handler - def add( - self, messages: Union[str, List[Dict[str, str]]], **kwargs - ) -> Dict[str, Any]: + def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]: """Add a new memory. Args: @@ -251,9 +253,7 @@ def delete_users(self) -> Dict[str, str]: """Delete all users, agents, or sessions.""" entities = self.users() for entity in entities["results"]: - response = self.client.delete( - f"/v1/entities/{entity['type']}/{entity['id']}/" - ) + response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/") response.raise_for_status() capture_client_event("client.delete_users", self) @@ -303,6 +303,17 @@ def _prepare_payload( payload["messages"] = [{"role": "user", "content": messages}] elif isinstance(messages, list): payload["messages"] = messages + + # Handle session_id deprecation + if "session_id" in kwargs: + warnings.warn( + "The 'session_id' parameter is deprecated and will be removed in version 0.1.20. " + "Use 'run_id' instead.", + DeprecationWarning, + stacklevel=2, + ) + kwargs["run_id"] = kwargs.pop("session_id") + payload.update({k: v for k, v in kwargs.items() if v is not None}) return payload @@ -315,4 +326,15 @@ def _prepare_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: Returns: A dictionary containing the prepared parameters. """ + + # Handle session_id deprecation + if "session_id" in kwargs: + warnings.warn( + "The 'session_id' parameter is deprecated and will be removed in version 0.1.20. " + "Use 'run_id' instead.", + DeprecationWarning, + stacklevel=2, + ) + kwargs["run_id"] = kwargs.pop("session_id") + return {k: v for k, v in kwargs.items() if v is not None} diff --git a/mem0/configs/base.py b/mem0/configs/base.py index 42a0a2b248..55e09f2757 100644 --- a/mem0/configs/base.py +++ b/mem0/configs/base.py @@ -17,18 +17,10 @@ class MemoryItem(BaseModel): ) # TODO After prompt changes from platform, update this hash: Optional[str] = Field(None, description="The hash of the memory") # The metadata value can be anything and not just string. Fix it - metadata: Optional[Dict[str, Any]] = Field( - None, description="Additional metadata for the text data" - ) - score: Optional[float] = Field( - None, description="The score associated with the text data" - ) - created_at: Optional[str] = Field( - None, description="The timestamp when the memory was created" - ) - updated_at: Optional[str] = Field( - None, description="The timestamp when the memory was updated" - ) + metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data") + score: Optional[float] = Field(None, description="The score associated with the text data") + created_at: Optional[str] = Field(None, description="The timestamp when the memory was created") + updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated") class MemoryConfig(BaseModel): @@ -56,7 +48,11 @@ class MemoryConfig(BaseModel): description="The version of the API", default="v1.0", ) - + custom_prompt: Optional[str] = Field( + description="Custom prompt for the memory", + default=None, + ) + class AzureConfig(BaseModel): """ @@ -69,7 +65,10 @@ class AzureConfig(BaseModel): api_version (str): The version of the Azure API being used. """ - api_key: str = Field(description="The API key used for authenticating with the Azure service.", default=None) - azure_deployment : str = Field(description="The name of the Azure deployment.", default=None) - azure_endpoint : str = Field(description="The endpoint URL for the Azure service.", default=None) - api_version : str = Field(description="The version of the Azure API being used.", default=None) + api_key: str = Field( + description="The API key used for authenticating with the Azure service.", + default=None, + ) + azure_deployment: str = Field(description="The name of the Azure deployment.", default=None) + azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None) + api_version: str = Field(description="The version of the Azure API being used.", default=None) diff --git a/mem0/configs/embeddings/base.py b/mem0/configs/embeddings/base.py index f4659dce3d..6324587297 100644 --- a/mem0/configs/embeddings/base.py +++ b/mem0/configs/embeddings/base.py @@ -60,6 +60,6 @@ def __init__( # Huggingface specific self.model_kwargs = model_kwargs or {} - + # AzureOpenAI specific self.azure_kwargs = AzureConfig(**azure_kwargs) or {} diff --git a/mem0/configs/prompts.py b/mem0/configs/prompts.py index d0e07863e8..d9192129a4 100644 --- a/mem0/configs/prompts.py +++ b/mem0/configs/prompts.py @@ -32,16 +32,16 @@ Output: {{"facts" : []}} Input: Hi, I am looking for a restaurant in San Francisco. -Output: {{"facts" : ['Looking for a restaurant in San Francisco']}} +Output: {{"facts" : ["Looking for a restaurant in San Francisco"]}} Input: Yesterday, I had a meeting with John at 3pm. We discussed the new project. -Output: {{"facts" : ['Had a meeting with John at 3pm', 'Discussed the new project']}} +Output: {{"facts" : ["Had a meeting with John at 3pm", "Discussed the new project"]}} Input: Hi, my name is John. I am a software engineer. -Output: {{"facts" : ['Name is John', 'Is a Software engineer']}} +Output: {{"facts" : ["Name is John", "Is a Software engineer"]}} Input: Me favourite movies are Inception and Interstellar. -Output: {{"facts" : ['Favourite movies are Inception and Interstellar']}} +Output: {{"facts" : ["Favourite movies are Inception and Interstellar"]}} Return the facts and preferences in a json format as shown above. @@ -59,6 +59,7 @@ If you do not find anything relevant facts, user memories, and preferences in the below conversation, you can return an empty list corresponding to the "facts" key. """ + def get_update_memory_messages(retrieved_old_memory_dict, response_content): return f"""You are a smart memory manager which controls the memory of a system. You can perform four operations: (1) add into the memory, (2) update the memory, (3) delete from the memory, and (4) no change. @@ -82,7 +83,7 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content): "text" : "User is a software engineer" }} ] - - Retrieved facts: ['Name is John'] + - Retrieved facts: ["Name is John"] - New Memory: {{ "memory" : [ @@ -123,7 +124,7 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content): "text" : "User likes to play cricket" }} ] - - Retrieved facts: ['Loves chicken pizza', 'Loves to play cricket with friends'] + - Retrieved facts: ["Loves chicken pizza", "Loves to play cricket with friends"] - New Memory: {{ "memory" : [ @@ -161,7 +162,7 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content): "text" : "Loves cheese pizza" }} ] - - Retrieved facts: ['Dislikes cheese pizza'] + - Retrieved facts: ["Dislikes cheese pizza"] - New Memory: {{ "memory" : [ @@ -191,7 +192,7 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content): "text" : "Loves cheese pizza" }} ] - - Retrieved facts: ['Name is John'] + - Retrieved facts: ["Name is John"] - New Memory: {{ "memory" : [ diff --git a/mem0/configs/vector_stores/chroma.py b/mem0/configs/vector_stores/chroma.py index ba58b3ac45..afff8a8f7b 100644 --- a/mem0/configs/vector_stores/chroma.py +++ b/mem0/configs/vector_stores/chroma.py @@ -9,23 +9,11 @@ class ChromaDbConfig(BaseModel): try: from chromadb.api.client import Client except ImportError: - user_input: Any = input("The 'chromadb' library is required. Install it now? [y/N]: ") - if user_input.lower() == 'y': - try: - subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"]) - from chromadb.api.client import Client - except subprocess.CalledProcessError: - print("Failed to install 'chromadb'. Please install it manually using 'pip install chromadb'.") - sys.exit(1) - else: - print("The required 'chromadb' library is not installed.") - sys.exit(1) + raise ImportError("The 'chromadb' library is required. Please install it using 'pip install chromadb'.") Client: ClassVar[type] = Client collection_name: str = Field("mem0", description="Default name for the collection") - client: Optional[Client] = Field( - None, description="Existing ChromaDB client instance" - ) + client: Optional[Client] = Field(None, description="Existing ChromaDB client instance") path: Optional[str] = Field(None, description="Path to the database directory") host: Optional[str] = Field(None, description="Database connection remote host") port: Optional[int] = Field(None, description="Database connection remote port") diff --git a/mem0/configs/vector_stores/milvus.py b/mem0/configs/vector_stores/milvus.py new file mode 100644 index 0000000000..7578c6fcef --- /dev/null +++ b/mem0/configs/vector_stores/milvus.py @@ -0,0 +1,43 @@ +from enum import Enum +from typing import Any, Dict + +from pydantic import BaseModel, Field, model_validator + + +class MetricType(str, Enum): + """ + Metric Constant for milvus/ zilliz server. + """ + + def __str__(self) -> str: + return str(self.value) + + L2 = "L2" + IP = "IP" + COSINE = "COSINE" + HAMMING = "HAMMING" + JACCARD = "JACCARD" + + +class MilvusDBConfig(BaseModel): + url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server") + token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.") + collection_name: str = Field("mem0", description="Name of the collection") + embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model") + metric_type: str = Field("L2", description="Metric type for similarity search") + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) + return values + + model_config = { + "arbitrary_types_allowed": True, + } diff --git a/mem0/configs/vector_stores/pgvector.py b/mem0/configs/vector_stores/pgvector.py index df8dabf4c4..b81ed9859d 100644 --- a/mem0/configs/vector_stores/pgvector.py +++ b/mem0/configs/vector_stores/pgvector.py @@ -4,12 +4,9 @@ class PGVectorConfig(BaseModel): - dbname: str = Field("postgres", description="Default name for the database") collection_name: str = Field("mem0", description="Default name for the collection") - embedding_model_dims: Optional[int] = Field( - 1536, description="Dimensions of the embedding model" - ) + embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") user: Optional[str] = Field(None, description="Database user") password: Optional[str] = Field(None, description="Database password") host: Optional[str] = Field(None, description="Database host. Default is localhost") diff --git a/mem0/configs/vector_stores/qdrant.py b/mem0/configs/vector_stores/qdrant.py index 10951db8b1..f8628d332e 100644 --- a/mem0/configs/vector_stores/qdrant.py +++ b/mem0/configs/vector_stores/qdrant.py @@ -9,17 +9,11 @@ class QdrantConfig(BaseModel): QdrantClient: ClassVar[type] = QdrantClient collection_name: str = Field("mem0", description="Name of the collection") - embedding_model_dims: Optional[int] = Field( - 1536, description="Dimensions of the embedding model" - ) - client: Optional[QdrantClient] = Field( - None, description="Existing Qdrant client instance" - ) + embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") + client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance") host: Optional[str] = Field(None, description="Host address for Qdrant server") port: Optional[int] = Field(None, description="Port for Qdrant server") - path: Optional[str] = Field( - "/tmp/qdrant", description="Path for local Qdrant database" - ) + path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database") url: Optional[str] = Field(None, description="Full URL for Qdrant server") api_key: Optional[str] = Field(None, description="API key for Qdrant server") on_disk: Optional[bool] = Field(False, description="Enables persistent storage") @@ -35,9 +29,7 @@ def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]: values.get("api_key"), ) if not path and not (host and port) and not (url and api_key): - raise ValueError( - "Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided." - ) + raise ValueError("Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided.") return values @model_validator(mode="before") diff --git a/mem0/embeddings/azure_openai.py b/mem0/embeddings/azure_openai.py index 8e801ccd8e..d25cc00e45 100644 --- a/mem0/embeddings/azure_openai.py +++ b/mem0/embeddings/azure_openai.py @@ -15,14 +15,14 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None): azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT") azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT") api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION") - + self.client = AzureOpenAI( - azure_deployment=azure_deployment, + azure_deployment=azure_deployment, azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key, - http_client=self.config.http_client - ) + http_client=self.config.http_client, + ) def embed(self, text): """ @@ -35,8 +35,4 @@ def embed(self, text): list: The embedding vector. """ text = text.replace("\n", " ") - return ( - self.client.embeddings.create(input=[text], model=self.config.model) - .data[0] - .embedding - ) + return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding diff --git a/mem0/embeddings/configs.py b/mem0/embeddings/configs.py index 73aa9b30ad..213493440b 100644 --- a/mem0/embeddings/configs.py +++ b/mem0/embeddings/configs.py @@ -8,14 +8,12 @@ class EmbedderConfig(BaseModel): description="Provider of the embedding model (e.g., 'ollama', 'openai')", default="openai", ) - config: Optional[dict] = Field( - description="Configuration for the specific embedding model", default={} - ) + config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={}) @field_validator("config") def validate_config(cls, v, values): provider = values.data.get("provider") - if provider in ["openai", "ollama", "huggingface", "azure_openai"]: + if provider in ["openai", "ollama", "huggingface", "azure_openai", "vertexai"]: return v else: raise ValueError(f"Unsupported embedding provider: {provider}") diff --git a/mem0/embeddings/ollama.py b/mem0/embeddings/ollama.py index 2e7f375879..ae00368e01 100644 --- a/mem0/embeddings/ollama.py +++ b/mem0/embeddings/ollama.py @@ -9,7 +9,7 @@ from ollama import Client except ImportError: user_input = input("The 'ollama' library is required. Install it now? [y/N]: ") - if user_input.lower() == 'y': + if user_input.lower() == "y": try: subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"]) from ollama import Client diff --git a/mem0/embeddings/openai.py b/mem0/embeddings/openai.py index be9195bf6b..b68b8ffc09 100644 --- a/mem0/embeddings/openai.py +++ b/mem0/embeddings/openai.py @@ -29,8 +29,4 @@ def embed(self, text): list: The embedding vector. """ text = text.replace("\n", " ") - return ( - self.client.embeddings.create(input=[text], model=self.config.model) - .data[0] - .embedding - ) + return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding diff --git a/mem0/embeddings/vertexai.py b/mem0/embeddings/vertexai.py new file mode 100644 index 0000000000..bcdaaab284 --- /dev/null +++ b/mem0/embeddings/vertexai.py @@ -0,0 +1,40 @@ +import os +from typing import Optional + +from vertexai.language_models import TextEmbeddingModel + +from mem0.configs.embeddings.base import BaseEmbedderConfig +from mem0.embeddings.base import EmbeddingBase + + +class VertexAI(EmbeddingBase): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config) + + self.config.model = self.config.model or "text-embedding-004" + self.config.embedding_dims = self.config.embedding_dims or 256 + + credentials_path = self.config.vertex_credentials_json + + if credentials_path: + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_path + elif not os.getenv("GOOGLE_APPLICATION_CREDENTIALS"): + raise ValueError( + "Google application credentials JSON is not provided. Please provide a valid JSON path or set the 'GOOGLE_APPLICATION_CREDENTIALS' environment variable." + ) + + self.model = TextEmbeddingModel.from_pretrained(self.config.model) + + def embed(self, text): + """ + Get the embedding for the given text using Vertex AI. + + Args: + text (str): The text to embed. + + Returns: + list: The embedding vector. + """ + embeddings = self.model.get_embeddings(texts=[text], output_dimensionality=self.config.embedding_dims) + + return embeddings[0].values diff --git a/mem0/graphs/configs.py b/mem0/graphs/configs.py index 033637c32e..be5a2bae8d 100644 --- a/mem0/graphs/configs.py +++ b/mem0/graphs/configs.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from pydantic import BaseModel, Field, field_validator, model_validator @@ -18,28 +18,36 @@ def check_host_port_or_path(cls, values): values.get("password"), ) if not url or not username or not password: + raise ValueError("Please provide 'url', 'username' and 'password'.") + return values + +class FalkorDBConfig(BaseModel): + host: Optional[str] = Field(None, description="Host address for the graph database") + username: Optional[str] = Field(None, description="Username for the graph database") + password: Optional[str] = Field(None, description="Password for the graph database") + port: Optional[int] = Field(None, description="Port for the graph database") + # Default database name is mandatory in langchain + database: str = "_default_" + + @model_validator(mode="before") + def check_host_port_or_path(cls, values): + host, port = ( + values.get("host"), + values.get("port"), + ) + if not host or not port: raise ValueError( - "Please provide 'url', 'username' and 'password'." + "Please provide 'host' and 'port'." ) return values class GraphStoreConfig(BaseModel): - provider: str = Field( - description="Provider of the data store (e.g., 'neo4j')", - default="neo4j" - ) - config: Neo4jConfig = Field( - description="Configuration for the specific data store", - default=None - ) - llm: Optional[LlmConfig] = Field( - description="LLM configuration for querying the graph store", - default=None - ) + provider: str = Field(description="Provider of the data store (e.g., 'falkordb', 'neo4j')", default="falkordb") + config: Union[FalkorDBConfig, Neo4jConfig] = Field(description="Configuration for the specific data store", default=None) + llm: Optional[LlmConfig] = Field(description="LLM configuration for querying the graph store", default=None) custom_prompt: Optional[str] = Field( - description="Custom prompt to fetch entities from the given text", - default=None + description="Custom prompt to fetch entities from the given text", default=None ) @field_validator("config") @@ -47,6 +55,11 @@ def validate_config(cls, v, values): provider = values.data.get("provider") if provider == "neo4j": return Neo4jConfig(**v.model_dump()) + elif provider == "falkordb": + config = v.model_dump() + # In case the user try to use diffrent database name + config["database"] = "_default_" + + return FalkorDBConfig(**config) else: raise ValueError(f"Unsupported graph store provider: {provider}") - diff --git a/mem0/graphs/tools.py b/mem0/graphs/tools.py index d727924298..1fdbe91faf 100644 --- a/mem0/graphs/tools.py +++ b/mem0/graphs/tools.py @@ -1,4 +1,3 @@ - UPDATE_MEMORY_TOOL_GRAPH = { "type": "function", "function": { @@ -9,21 +8,21 @@ "properties": { "source": { "type": "string", - "description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph." + "description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph.", }, "destination": { "type": "string", - "description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph." + "description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph.", }, "relationship": { "type": "string", - "description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected." - } + "description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.", + }, }, "required": ["source", "destination", "relationship"], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, } ADD_MEMORY_TOOL_GRAPH = { @@ -36,29 +35,35 @@ "properties": { "source": { "type": "string", - "description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created." + "description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created.", }, "destination": { "type": "string", - "description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created." + "description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created.", }, "relationship": { "type": "string", - "description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected." + "description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.", }, "source_type": { "type": "string", - "description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph." + "description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph.", }, "destination_type": { "type": "string", - "description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph." - } + "description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph.", + }, }, - "required": ["source", "destination", "relationship", "source_type", "destination_type"], - "additionalProperties": False - } - } + "required": [ + "source", + "destination", + "relationship", + "source_type", + "destination_type", + ], + "additionalProperties": False, + }, + }, } @@ -71,9 +76,9 @@ "type": "object", "properties": {}, "required": [], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, } @@ -94,17 +99,23 @@ "source_type": {"type": "string"}, "relation": {"type": "string"}, "destination_node": {"type": "string"}, - "destination_type": {"type": "string"} + "destination_type": {"type": "string"}, }, - "required": ["source_node", "source_type", "relation", "destination_node", "destination_type"], - "additionalProperties": False - } + "required": [ + "source_node", + "source_type", + "relation", + "destination_node", + "destination_type", + ], + "additionalProperties": False, + }, } }, "required": ["entities"], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, } @@ -118,23 +129,19 @@ "properties": { "nodes": { "type": "array", - "items": { - "type": "string" - }, - "description": "List of nodes to search for." + "items": {"type": "string"}, + "description": "List of nodes to search for.", }, "relations": { "type": "array", - "items": { - "type": "string" - }, - "description": "List of relations to search for." - } + "items": {"type": "string"}, + "description": "List of relations to search for.", + }, }, "required": ["nodes", "relations"], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, } UPDATE_MEMORY_STRUCT_TOOL_GRAPH = { @@ -148,21 +155,21 @@ "properties": { "source": { "type": "string", - "description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph." + "description": "The identifier of the source node in the relationship to be updated. This should match an existing node in the graph.", }, "destination": { "type": "string", - "description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph." + "description": "The identifier of the destination node in the relationship to be updated. This should match an existing node in the graph.", }, "relationship": { "type": "string", - "description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected." - } + "description": "The new or updated relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.", + }, }, "required": ["source", "destination", "relationship"], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, } ADD_MEMORY_STRUCT_TOOL_GRAPH = { @@ -176,29 +183,35 @@ "properties": { "source": { "type": "string", - "description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created." + "description": "The identifier of the source node in the new relationship. This can be an existing node or a new node to be created.", }, "destination": { "type": "string", - "description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created." + "description": "The identifier of the destination node in the new relationship. This can be an existing node or a new node to be created.", }, "relationship": { "type": "string", - "description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected." + "description": "The type of relationship between the source and destination nodes. This should be a concise, clear description of how the two nodes are connected.", }, "source_type": { "type": "string", - "description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph." + "description": "The type or category of the source node. This helps in classifying and organizing nodes in the graph.", }, "destination_type": { "type": "string", - "description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph." - } + "description": "The type or category of the destination node. This helps in classifying and organizing nodes in the graph.", + }, }, - "required": ["source", "destination", "relationship", "source_type", "destination_type"], - "additionalProperties": False - } - } + "required": [ + "source", + "destination", + "relationship", + "source_type", + "destination_type", + ], + "additionalProperties": False, + }, + }, } @@ -212,9 +225,9 @@ "type": "object", "properties": {}, "required": [], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, } @@ -236,17 +249,23 @@ "source_type": {"type": "string"}, "relation": {"type": "string"}, "destination_node": {"type": "string"}, - "destination_type": {"type": "string"} + "destination_type": {"type": "string"}, }, - "required": ["source_node", "source_type", "relation", "destination_node", "destination_type"], - "additionalProperties": False - } + "required": [ + "source_node", + "source_type", + "relation", + "destination_node", + "destination_type", + ], + "additionalProperties": False, + }, } }, "required": ["entities"], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, } @@ -261,21 +280,17 @@ "properties": { "nodes": { "type": "array", - "items": { - "type": "string" - }, - "description": "List of nodes to search for." + "items": {"type": "string"}, + "description": "List of nodes to search for.", }, "relations": { "type": "array", - "items": { - "type": "string" - }, - "description": "List of relations to search for." - } + "items": {"type": "string"}, + "description": "List of relations to search for.", + }, }, "required": ["nodes", "relations"], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, } diff --git a/mem0/graphs/utils.py b/mem0/graphs/utils.py index e9ed827eea..db637fe6d6 100644 --- a/mem0/graphs/utils.py +++ b/mem0/graphs/utils.py @@ -1,4 +1,3 @@ - UPDATE_GRAPH_PROMPT = """ You are an AI expert specializing in graph memory management and optimization. Your task is to analyze existing graph memories alongside new information, and update the relationships in the memory list to ensure the most accurate, current, and coherent representation of knowledge. @@ -54,11 +53,57 @@ Adhere strictly to these guidelines to ensure high-quality knowledge graph extraction.""" - +FALKORDB_QUERY = """ +MATCH (n) +WHERE n.embedding IS NOT NULL AND n.user_id = $user_id +WITH n, + reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))) AS similarity +WHERE similarity >= $threshold +MATCH (n)-[r]->(m) +RETURN n.name AS source, Id(n) AS source_id, type(r) AS relation, Id(r) AS relation_id, m.name AS destination, Id(m) AS destination_id, similarity +UNION +MATCH (n) +WHERE n.embedding IS NOT NULL AND n.user_id = $user_id +WITH n, + reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))) AS similarity +WHERE similarity >= $threshold +MATCH (m)-[r]->(n) +RETURN m.name AS source, Id(m) AS source_id, type(r) AS relation, Id(r) AS relation_id, n.name AS destination, Id(n) AS destination_id, similarity +ORDER BY similarity DESC +""" + + +NEO4J_QUERY = """ +MATCH (n) +WHERE n.embedding IS NOT NULL AND n.user_id = $user_id +WITH n, + round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity +WHERE similarity >= $threshold +MATCH (n)-[r]->(m) +RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity +UNION +MATCH (n) +WHERE n.embedding IS NOT NULL AND n.user_id = $user_id +WITH n, + round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity +WHERE similarity >= $threshold +MATCH (m)-[r]->(n) +RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity +ORDER BY similarity DESC +""" def get_update_memory_prompt(existing_memories, memory, template): return template.format(existing_memories=existing_memories, memory=memory) + def get_update_memory_messages(existing_memories, memory): return [ { diff --git a/mem0/llms/anthropic.py b/mem0/llms/anthropic.py index fb390348cd..5f004ae8b6 100644 --- a/mem0/llms/anthropic.py +++ b/mem0/llms/anthropic.py @@ -4,7 +4,7 @@ try: import anthropic except ImportError: - raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.") + raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.") from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase @@ -43,8 +43,8 @@ def generate_response( system_message = "" filtered_messages = [] for message in messages: - if message['role'] == 'system': - system_message = message['content'] + if message["role"] == "system": + system_message = message["content"] else: filtered_messages.append(message) @@ -56,7 +56,7 @@ def generate_response( "max_tokens": self.config.max_tokens, "top_p": self.config.top_p, } - if tools: # TODO: Remove tools if no issues found with new memory addition logic + if tools: # TODO: Remove tools if no issues found with new memory addition logic params["tools"] = tools params["tool_choice"] = tool_choice diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py index 5e7969c13e..2bc963c2b2 100644 --- a/mem0/llms/aws_bedrock.py +++ b/mem0/llms/aws_bedrock.py @@ -125,9 +125,7 @@ def _prepare_input( }, } input_body["textGenerationConfig"] = { - k: v - for k, v in input_body["textGenerationConfig"].items() - if v is not None + k: v for k, v in input_body["textGenerationConfig"].items() if v is not None } return input_body @@ -161,9 +159,7 @@ def _convert_tool_format(self, original_tools): } } - for prop, details in ( - function["parameters"].get("properties", {}).items() - ): + for prop, details in function["parameters"].get("properties", {}).items(): new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = { "type": details.get("type", "string"), "description": details.get("description", ""), @@ -216,9 +212,7 @@ def generate_response( # Use invoke_model method when no tools are provided prompt = self._format_messages(messages) provider = self.model.split(".")[0] - input_body = self._prepare_input( - provider, self.config.model, prompt, **self.model_kwargs - ) + input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs) body = json.dumps(input_body) response = self.client.invoke_model( diff --git a/mem0/llms/azure_openai.py b/mem0/llms/azure_openai.py index f093284b47..f1fe6863a7 100644 --- a/mem0/llms/azure_openai.py +++ b/mem0/llms/azure_openai.py @@ -15,20 +15,20 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): # Model name should match the custom deployment name chosen for it. if not self.config.model: self.config.model = "gpt-4o" - + api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY") azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT") azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT") api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION") self.client = AzureOpenAI( - azure_deployment=azure_deployment, + azure_deployment=azure_deployment, azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key, - http_client=self.config.http_client - ) - + http_client=self.config.http_client, + ) + def _parse_response(self, response, tools): """ Process the response based on whether tools are used or not. @@ -87,7 +87,7 @@ def generate_response( } if response_format: params["response_format"] = response_format - if tools: # TODO: Remove tools if no issues found with new memory addition logic + if tools: # TODO: Remove tools if no issues found with new memory addition logic params["tools"] = tools params["tool_choice"] = tool_choice diff --git a/mem0/llms/azure_openai_structured.py b/mem0/llms/azure_openai_structured.py index 091f92e316..729523d85d 100644 --- a/mem0/llms/azure_openai_structured.py +++ b/mem0/llms/azure_openai_structured.py @@ -1,11 +1,11 @@ -import os import json +import os from typing import Dict, List, Optional from openai import AzureOpenAI -from mem0.llms.base import LLMBase from mem0.configs.llms.base import BaseLlmConfig +from mem0.llms.base import LLMBase class AzureOpenAIStructuredLLM(LLMBase): @@ -15,21 +15,21 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): # Model name should match the custom deployment name chosen for it. if not self.config.model: self.config.model = "gpt-4o-2024-08-06" - + api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version # Can display a warning if API version is of model and api-version - + self.client = AzureOpenAI( - azure_deployment=azure_deployment, + azure_deployment=azure_deployment, azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key, - http_client=self.config.http_client - ) - + http_client=self.config.http_client, + ) + def _parse_response(self, response, tools): """ Process the response based on whether tools are used or not. diff --git a/mem0/llms/configs.py b/mem0/llms/configs.py index fb6dccbfd1..dcd5b8c7ac 100644 --- a/mem0/llms/configs.py +++ b/mem0/llms/configs.py @@ -4,12 +4,8 @@ class LlmConfig(BaseModel): - provider: str = Field( - description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai" - ) - config: Optional[dict] = Field( - description="Configuration for the specific LLM", default={} - ) + provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai") + config: Optional[dict] = Field(description="Configuration for the specific LLM", default={}) @field_validator("config") def validate_config(cls, v, values): @@ -23,7 +19,7 @@ def validate_config(cls, v, values): "litellm", "azure_openai", "openai_structured", - "azure_openai_structured" + "azure_openai_structured", ): return v else: diff --git a/mem0/llms/litellm.py b/mem0/llms/litellm.py index bfe951303b..d5896ff80b 100644 --- a/mem0/llms/litellm.py +++ b/mem0/llms/litellm.py @@ -67,9 +67,7 @@ def generate_response( str: The generated response. """ if not litellm.supports_function_calling(self.config.model): - raise ValueError( - f"Model '{self.config.model}' in litellm does not support function calling." - ) + raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.") params = { "model": self.config.model, @@ -80,7 +78,7 @@ def generate_response( } if response_format: params["response_format"] = response_format - if tools: # TODO: Remove tools if no issues found with new memory addition logic + if tools: # TODO: Remove tools if no issues found with new memory addition logic params["tools"] = tools params["tool_choice"] = tool_choice diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index b988c33bb3..89bef986d4 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -22,7 +22,7 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): ) else: api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") - base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") + base_url = os.getenv("OPENAI_API_BASE") or self.config.openai_base_url self.client = OpenAI(api_key=api_key, base_url=base_url) def _parse_response(self, response, tools): @@ -100,7 +100,7 @@ def generate_response( if response_format: params["response_format"] = response_format - if tools: # TODO: Remove tools if no issues found with new memory addition logic + if tools: # TODO: Remove tools if no issues found with new memory addition logic params["tools"] = tools params["tool_choice"] = tool_choice diff --git a/mem0/llms/openai_structured.py b/mem0/llms/openai_structured.py index 0625c1e890..4060afb8b6 100644 --- a/mem0/llms/openai_structured.py +++ b/mem0/llms/openai_structured.py @@ -1,6 +1,5 @@ -import os import json - +import os from typing import Dict, List, Optional from openai import OpenAI @@ -20,7 +19,6 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") self.client = OpenAI(api_key=api_key, base_url=base_url) - def _parse_response(self, response, tools): """ Process the response based on whether tools are used or not. @@ -31,8 +29,8 @@ def _parse_response(self, response, tools): Returns: str or dict: The processed response. - """ - + """ + if tools: processed_response = { "content": response.choices[0].message.content, @@ -52,7 +50,6 @@ def _parse_response(self, response, tools): else: return response.choices[0].message.content - def generate_response( self, @@ -87,4 +84,4 @@ def generate_response( response = self.client.beta.chat.completions.parse(**params) - return self._parse_response(response, tools) \ No newline at end of file + return self._parse_response(response, tools) diff --git a/mem0/llms/together.py b/mem0/llms/together.py index 51ebac660c..922a30d224 100644 --- a/mem0/llms/together.py +++ b/mem0/llms/together.py @@ -20,7 +20,7 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY") self.client = Together(api_key=api_key) - + def _parse_response(self, response, tools): """ Process the response based on whether tools are used or not. @@ -79,7 +79,7 @@ def generate_response( } if response_format: params["response_format"] = response_format - if tools: # TODO: Remove tools if no issues found with new memory addition logic + if tools: # TODO: Remove tools if no issues found with new memory addition logic params["tools"] = tools params["tool_choice"] = tool_choice diff --git a/mem0/llms/utils/tools.py b/mem0/llms/utils/tools.py index fb4ff4a2a5..6857294f40 100644 --- a/mem0/llms/utils/tools.py +++ b/mem0/llms/utils/tools.py @@ -5,14 +5,11 @@ "function": { "name": "add_memory", "description": "Add a memory", - "strict": True, "parameters": { "type": "object", - "properties": { - "data": {"type": "string", "description": "Data to add to memory"} - }, + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, "required": ["data"], - "additionalProperties": False + "additionalProperties": False, }, }, } @@ -22,7 +19,6 @@ "function": { "name": "update_memory", "description": "Update memory provided ID and data", - "strict": True, "parameters": { "type": "object", "properties": { @@ -36,7 +32,7 @@ }, }, "required": ["memory_id", "data"], - "additionalProperties": False + "additionalProperties": False, }, }, } @@ -46,7 +42,6 @@ "function": { "name": "delete_memory", "description": "Delete memory by memory_id", - "strict": True, "parameters": { "type": "object", "properties": { @@ -56,7 +51,7 @@ } }, "required": ["memory_id"], - "additionalProperties": False + "additionalProperties": False, }, }, } diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index 7cdeb025bb..f3c4b198b0 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -1,32 +1,41 @@ import logging -from langchain_community.graphs import Neo4jGraph from rank_bm25 import BM25Okapi from mem0.graphs.tools import ( + ADD_MEMORY_STRUCT_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, + ADD_MESSAGE_STRUCT_TOOL, ADD_MESSAGE_TOOL, + NOOP_STRUCT_TOOL, NOOP_TOOL, + SEARCH_STRUCT_TOOL, SEARCH_TOOL, + UPDATE_MEMORY_STRUCT_TOOL_GRAPH, UPDATE_MEMORY_TOOL_GRAPH, - UPDATE_MEMORY_STRUCT_TOOL_GRAPH, - ADD_MEMORY_STRUCT_TOOL_GRAPH, - NOOP_STRUCT_TOOL, - ADD_MESSAGE_STRUCT_TOOL, - SEARCH_STRUCT_TOOL ) -from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages -from mem0.utils.factory import EmbedderFactory, LlmFactory +from mem0.graphs.utils import ( + EXTRACT_ENTITIES_PROMPT, + FALKORDB_QUERY, + NEO4J_QUERY, + get_update_memory_messages, +) +from mem0.utils.factory import EmbedderFactory, LlmFactory, GraphFactory + logger = logging.getLogger(__name__) + class MemoryGraph: def __init__(self, config): self.config = config - self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password) + self.graph = GraphFactory.create( + self.config.graph_store.provider, self.config.graph_store.config + ) self.embedding_model = EmbedderFactory.create( self.config.embedder.provider, self.config.embedder.config ) + self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config) self.llm_provider = "openai_structured" if self.config.llm.provider: @@ -51,15 +60,23 @@ def add(self, data, filters): search_output = self._search(data, filters) if self.config.graph_store.custom_prompt: - messages=[ - {"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")}, + messages = [ + { + "role": "system", + "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace( + "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}" + ), + }, {"role": "user", "content": data}, ] else: - messages=[ - {"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)}, + messages = [ + { + "role": "system", + "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id), + }, {"role": "user", "content": data}, - ] + ] _tools = [ADD_MESSAGE_TOOL] if self.llm_provider in ["azure_openai_structured", "openai_structured"]: @@ -67,11 +84,11 @@ def add(self, data, filters): extracted_entities = self.llm.generate_response( messages=messages, - tools = _tools, + tools=_tools, ) - if extracted_entities['tool_calls']: - extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities'] + if extracted_entities["tool_calls"]: + extracted_entities = extracted_entities["tool_calls"][0]["arguments"]["entities"] else: extracted_entities = [] @@ -79,9 +96,13 @@ def add(self, data, filters): update_memory_prompt = get_update_memory_messages(search_output, extracted_entities) - _tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL] - if self.llm_provider in ["azure_openai_structured","openai_structured"]: - _tools = [UPDATE_MEMORY_STRUCT_TOOL_GRAPH, ADD_MEMORY_STRUCT_TOOL_GRAPH, NOOP_STRUCT_TOOL] + _tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL] + if self.llm_provider in ["azure_openai_structured", "openai_structured"]: + _tools = [ + UPDATE_MEMORY_STRUCT_TOOL_GRAPH, + ADD_MEMORY_STRUCT_TOOL_GRAPH, + NOOP_STRUCT_TOOL, + ] memory_updates = self.llm.generate_response( messages=update_memory_prompt, @@ -90,28 +111,29 @@ def add(self, data, filters): to_be_added = [] - for item in memory_updates['tool_calls']: - if item['name'] == "add_graph_memory": - to_be_added.append(item['arguments']) - elif item['name'] == "update_graph_memory": - self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'], filters) - elif item['name'] == "noop": + for item in memory_updates["tool_calls"]: + if item["name"] == "add_graph_memory": + to_be_added.append(item["arguments"]) + elif item["name"] == "update_graph_memory": + self._update_relationship( + item["arguments"]["source"], + item["arguments"]["destination"], + item["arguments"]["relationship"], + filters, + ) + elif item["name"] == "noop": continue returned_entities = [] for item in to_be_added: - source = item['source'].lower().replace(" ", "_") - source_type = item['source_type'].lower().replace(" ", "_") - relation = item['relationship'].lower().replace(" ", "_") - destination = item['destination'].lower().replace(" ", "_") - destination_type = item['destination_type'].lower().replace(" ", "_") - - returned_entities.append({ - "source" : source, - "relationship" : relation, - "target" : destination - }) + source = item["source"].lower().replace(" ", "_") + source_type = item["source_type"].lower().replace(" ", "_") + relation = item["relationship"].lower().replace(" ", "_") + destination = item["destination"].lower().replace(" ", "_") + destination_type = item["destination_type"].lower().replace(" ", "_") + + returned_entities.append({"source": source, "relationship": relation, "target": destination}) # Create embeddings source_embedding = self.embedding_model.embed(source) @@ -135,10 +157,10 @@ def add(self, data, filters): "dest_name": destination, "source_embedding": source_embedding, "dest_embedding": dest_embedding, - "user_id": filters["user_id"] + "user_id": filters["user_id"], } - _ = self.graph.query(cypher, params=params) + _ = self.graph_query(cypher, params=params) logger.info(f"Added {len(to_be_added)} new memories to the graph") @@ -150,19 +172,22 @@ def _search(self, query, filters): _tools = [SEARCH_STRUCT_TOOL] search_results = self.llm.generate_response( messages=[ - {"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."}, + { + "role": "system", + "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities.", + }, {"role": "user", "content": query}, ], - tools = _tools + tools=_tools, ) node_list = [] relation_list = [] - for item in search_results['tool_calls']: - if item['name'] == "search": + for item in search_results["tool_calls"]: + if item["name"] == "search": try: - node_list.extend(item['arguments']['nodes']) + node_list.extend(item["arguments"]["nodes"]) except Exception as e: logger.error(f"Error in search tool: {e}") @@ -179,35 +204,23 @@ def _search(self, query, filters): for node in node_list: n_embedding = self.embedding_model.embed(node) - cypher_query = """ - MATCH (n) - WHERE n.embedding IS NOT NULL AND n.user_id = $user_id - WITH n, - round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / - (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * - sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity - WHERE similarity >= $threshold - MATCH (n)-[r]->(m) - RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity - UNION - MATCH (n) - WHERE n.embedding IS NOT NULL AND n.user_id = $user_id - WITH n, - round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / - (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * - sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity - WHERE similarity >= $threshold - MATCH (m)-[r]->(n) - RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity - ORDER BY similarity DESC - """ - params = {"n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"]} - ans = self.graph.query(cypher_query, params=params) + if self.config.graph_store.provider == "falkordb": + cypher_query = FALKORDB_QUERY + elif self.config.graph_store.provider == "neo4j": + cypher_query = NEO4J_QUERY + else: + raise ValueError("Unsupported graph database provider for querying") + + params = { + "n_embedding": n_embedding, + "threshold": self.threshold, + "user_id": filters["user_id"], + } + ans = self.graph_query(cypher_query, params=params) result_relations.extend(ans) return result_relations - def search(self, query, filters): """ Search for memories and related graph data. @@ -227,7 +240,7 @@ def search(self, query, filters): if not search_output: return [] - search_outputs_sequence = [[item["source"], item["relation"], item["destination"]] for item in search_output] + search_outputs_sequence = [[item[0], item[2], item[4]] for item in search_output] bm25 = BM25Okapi(search_outputs_sequence) tokenized_query = query.split(" ") @@ -235,26 +248,20 @@ def search(self, query, filters): search_results = [] for item in reranked_results: - search_results.append({ - "source": item[0], - "relationship": item[1], - "target": item[2] - }) + search_results.append({"source": item[0], "relationship": item[1], "target": item[2]}) logger.info(f"Returned {len(search_results)} search results") return search_results - def delete_all(self, filters): cypher = """ MATCH (n {user_id: $user_id}) DETACH DELETE n """ params = {"user_id": filters["user_id"]} - self.graph.query(cypher, params=params) + self.graph_query(cypher, params=params) - def get_all(self, filters): """ Retrieves all nodes and relationships from the graph database based on optional filtering criteria. @@ -272,21 +279,20 @@ def get_all(self, filters): MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id}) RETURN n.name AS source, type(r) AS relationship, m.name AS target """ - results = self.graph.query(query, params={"user_id": filters["user_id"]}) + results = self.graph_query(query, params={"user_id": filters["user_id"]}) final_results = [] for result in results: final_results.append({ - "source": result['source'], - "relationship": result['relationship'], - "target": result['target'] + "source": result[0], + "relationship": result[1], + "target": result[2] }) logger.info(f"Retrieved {len(final_results)} relationships") return final_results - - + def _update_relationship(self, source, target, relationship, filters): """ Update or create a relationship between two nodes in the graph. @@ -309,14 +315,20 @@ def _update_relationship(self, source, target, relationship, filters): MERGE (n1 {name: $source, user_id: $user_id}) MERGE (n2 {name: $target, user_id: $user_id}) """ - self.graph.query(check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]}) + self.graph_query( + check_and_create_query, + params={"source": source, "target": target, "user_id": filters["user_id"]}, + ) # Delete any existing relationship between the nodes delete_query = """ MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id}) DELETE r """ - self.graph.query(delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]}) + self.graph_query( + delete_query, + params={"source": source, "target": target, "user_id": filters["user_id"]}, + ) # Create the new relationship create_query = f""" @@ -324,7 +336,34 @@ def _update_relationship(self, source, target, relationship, filters): CREATE (n1)-[r:{relationship}]->(n2) RETURN n1, r, n2 """ - result = self.graph.query(create_query, params={"source": source, "target": target, "user_id": filters["user_id"]}) + result = self.graph_query( + create_query, + params={"source": source, "target": target, "user_id": filters["user_id"]}, + ) if not result: raise Exception(f"Failed to update or create relationship between {source} and {target}") + + def graph_query(self, query, params): + """ + Execute a Cypher query on the graph database. + FalkorDB supported multi-graph usage, the graphs is switched based on the user_id. + + Args: + query (str): The Cypher query to execute. + params (dict): A dictionary containing params to be applied during the query. + + Returns: + list: A list of dictionaries containing the results of the query. + """ + if self.config.graph_store.provider == "falkordb": + # TODO: Use langchain to switch graphs after the multi-graph feature is released + self.graph._graph = self.graph._driver.select_graph(params["user_id"]) + + query_output = self.graph.query(query, params=params) + + if self.config.graph_store.provider == "neo4j": + query_output = [list(d.values()) for d in query_output] + + + return query_output \ No newline at end of file diff --git a/mem0/memory/main.py b/mem0/memory/main.py index a3bb502477..8a0cc1ac11 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -10,14 +10,14 @@ import pytz from pydantic import ValidationError +from mem0.configs.base import MemoryConfig, MemoryItem from mem0.configs.prompts import get_update_memory_messages from mem0.memory.base import MemoryBase from mem0.memory.setup import setup_config from mem0.memory.storage import SQLiteManager from mem0.memory.telemetry import capture_event from mem0.memory.utils import get_fact_retrieval_messages, parse_messages -from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory -from mem0.configs.base import MemoryItem, MemoryConfig +from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory # Setup user config setup_config() @@ -28,9 +28,9 @@ class Memory(MemoryBase): def __init__(self, config: MemoryConfig = MemoryConfig()): self.config = config - self.embedding_model = EmbedderFactory.create( - self.config.embedder.provider, self.config.embedder.config - ) + + self.custom_prompt = self.config.custom_prompt + self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config) self.vector_store = VectorStoreFactory.create( self.config.vector_store.provider, self.config.vector_store.config ) @@ -43,12 +43,12 @@ def __init__(self, config: MemoryConfig = MemoryConfig()): if self.version == "v1.1" and self.config.graph_store.config: from mem0.memory.graph_memory import MemoryGraph + self.graph = MemoryGraph(self.config) self.enable_graph = True capture_event("mem0.init", self) - @classmethod def from_config(cls, config_dict: Dict[str, Any]): try: @@ -58,7 +58,6 @@ def from_config(cls, config_dict: Dict[str, Any]): raise return cls(config) - def add( self, messages, @@ -96,9 +95,7 @@ def add( filters["run_id"] = metadata["run_id"] = run_id if not any(key in filters for key in ("user_id", "agent_id", "run_id")): - raise ValueError( - "One of the filters: user_id, agent_id or run_id is required!" - ) + raise ValueError("One of the filters: user_id, agent_id or run_id is required!") if isinstance(messages, str): messages = [{"role": "user", "content": messages}] @@ -114,8 +111,8 @@ def add( if self.version == "v1.1": return { - "results" : vector_store_result, - "relations" : graph_result, + "results": vector_store_result, + "relations": graph_result, } else: warnings.warn( @@ -123,25 +120,29 @@ def add( "To use the latest format, set `api_version='v1.1'`. " "The current format will be removed in mem0ai 1.1.0 and later versions.", category=DeprecationWarning, - stacklevel=2 + stacklevel=2, ) return {"message": "ok"} - def _add_to_vector_store(self, messages, metadata, filters): parsed_messages = parse_messages(messages) - system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages) + if self.custom_prompt: + system_prompt = self.custom_prompt + user_prompt = f"Input: {parsed_messages}" + else: + system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages) response = self.llm.generate_response( - messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], response_format={"type": "json_object"}, ) try: - new_retrieved_facts = json.loads(response)[ - "facts" - ] + new_retrieved_facts = json.loads(response)["facts"] except Exception as e: logging.error(f"Error in new_retrieved_facts: {e}") new_retrieved_facts = [] @@ -172,24 +173,30 @@ def _add_to_vector_store(self, messages, metadata, filters): logging.info(resp) try: if resp["event"] == "ADD": - memory_id = self._create_memory(data=resp["text"], metadata=metadata) - returned_memories.append({ - "memory" : resp["text"], - "event" : resp["event"], - }) + _ = self._create_memory(data=resp["text"], metadata=metadata) + returned_memories.append( + { + "memory": resp["text"], + "event": resp["event"], + } + ) elif resp["event"] == "UPDATE": self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata) - returned_memories.append({ - "memory" : resp["text"], - "event" : resp["event"], - "previous_memory" : resp["old_memory"], - }) + returned_memories.append( + { + "memory": resp["text"], + "event": resp["event"], + "previous_memory": resp["old_memory"], + } + ) elif resp["event"] == "DELETE": self._delete_memory(memory_id=resp["id"]) - returned_memories.append({ - "memory" : resp["text"], - "event" : resp["event"], - }) + returned_memories.append( + { + "memory": resp["text"], + "event": resp["event"], + } + ) elif resp["event"] == "NONE": logging.info("NOOP for Memory.") except Exception as e: @@ -200,7 +207,6 @@ def _add_to_vector_store(self, messages, metadata, filters): capture_event("mem0.add", self) return returned_memories - def _add_to_graph(self, messages, filters): added_entities = [] @@ -214,7 +220,6 @@ def _add_to_graph(self, messages, filters): return added_entities - def get(self, memory_id): """ Retrieve a memory by ID. @@ -230,11 +235,7 @@ def get(self, memory_id): if not memory: return None - filters = { - key: memory.payload[key] - for key in ["user_id", "agent_id", "run_id"] - if memory.payload.get(key) - } + filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)} # Prepare base memory item memory_item = MemoryItem( @@ -255,9 +256,7 @@ def get(self, memory_id): "created_at", "updated_at", } - additional_metadata = { - k: v for k, v in memory.payload.items() if k not in excluded_keys - } + additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys} if additional_metadata: memory_item["metadata"] = additional_metadata @@ -265,7 +264,6 @@ def get(self, memory_id): return result - def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): """ List all memories. @@ -282,10 +280,12 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): filters["run_id"] = run_id capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit}) - + with concurrent.futures.ThreadPoolExecutor() as executor: future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) - future_graph_entities = executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None + future_graph_entities = ( + executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None + ) all_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None @@ -301,15 +301,22 @@ def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): "To use the latest format, set `api_version='v1.1'`. " "The current format will be removed in mem0ai 1.1.0 and later versions.", category=DeprecationWarning, - stacklevel=2 + stacklevel=2, ) return all_memories - def _get_all_from_vector_store(self, filters, limit): memories = self.vector_store.list(filters=filters, limit=limit) - excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} + excluded_keys = { + "user_id", + "agent_id", + "run_id", + "hash", + "data", + "created_at", + "updated_at", + } all_memories = [ { **MemoryItem( @@ -319,19 +326,9 @@ def _get_all_from_vector_store(self, filters, limit): created_at=mem.payload.get("created_at"), updated_at=mem.payload.get("updated_at"), ).model_dump(exclude={"score"}), - **{ - key: mem.payload[key] - for key in ["user_id", "agent_id", "run_id"] - if key in mem.payload - }, + **{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload}, **( - { - "metadata": { - k: v - for k, v in mem.payload.items() - if k not in excluded_keys - } - } + {"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}} if any(k for k in mem.payload if k not in excluded_keys) else {} ), @@ -340,10 +337,7 @@ def _get_all_from_vector_store(self, filters, limit): ] return all_memories - - def search( - self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None - ): + def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None): """ Search for memories. @@ -367,15 +361,21 @@ def search( filters["run_id"] = run_id if not any(key in filters for key in ("user_id", "agent_id", "run_id")): - raise ValueError( - "One of the filters: user_id, agent_id or run_id is required!" - ) + raise ValueError("One of the filters: user_id, agent_id or run_id is required!") - capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version}) + capture_event( + "mem0.search", + self, + {"filters": len(filters), "limit": limit, "version": self.version}, + ) with concurrent.futures.ThreadPoolExecutor() as executor: future_memories = executor.submit(self._search_vector_store, query, filters, limit) - future_graph_entities = executor.submit(self.graph.search, query, filters) if self.version == "v1.1" and self.enable_graph else None + future_graph_entities = ( + executor.submit(self.graph.search, query, filters) + if self.version == "v1.1" and self.enable_graph + else None + ) original_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None @@ -384,23 +384,20 @@ def search( if self.enable_graph: return {"results": original_memories, "relations": graph_entities} else: - return {"results" : original_memories} + return {"results": original_memories} else: warnings.warn( "The current get_all API output format is deprecated. " "To use the latest format, set `api_version='v1.1'`. " "The current format will be removed in mem0ai 1.1.0 and later versions.", category=DeprecationWarning, - stacklevel=2 + stacklevel=2, ) return original_memories - def _search_vector_store(self, query, filters, limit): embeddings = self.embedding_model.embed(query) - memories = self.vector_store.search( - query=embeddings, limit=limit, filters=filters - ) + memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters) excluded_keys = { "user_id", @@ -422,19 +419,9 @@ def _search_vector_store(self, query, filters, limit): updated_at=mem.payload.get("updated_at"), score=mem.score, ).model_dump(), - **{ - key: mem.payload[key] - for key in ["user_id", "agent_id", "run_id"] - if key in mem.payload - }, + **{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload}, **( - { - "metadata": { - k: v - for k, v in mem.payload.items() - if k not in excluded_keys - } - } + {"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}} if any(k for k in mem.payload if k not in excluded_keys) else {} ), @@ -444,7 +431,6 @@ def _search_vector_store(self, query, filters, limit): return original_memories - def update(self, memory_id, data): """ Update a memory by ID. @@ -460,7 +446,6 @@ def update(self, memory_id, data): self._update_memory(memory_id, data) return {"message": "Memory updated successfully!"} - def delete(self, memory_id): """ Delete a memory by ID. @@ -472,7 +457,6 @@ def delete(self, memory_id): self._delete_memory(memory_id) return {"message": "Memory deleted successfully!"} - def delete_all(self, user_id=None, agent_id=None, run_id=None): """ Delete all memories. @@ -505,8 +489,7 @@ def delete_all(self, user_id=None, agent_id=None, run_id=None): if self.version == "v1.1" and self.enable_graph: self.graph.delete_all(filters) - return {'message': 'Memories deleted successfully!'} - + return {"message": "Memories deleted successfully!"} def history(self, memory_id): """ @@ -521,7 +504,6 @@ def history(self, memory_id): capture_event("mem0.history", self, {"memory_id": memory_id}) return self.db.get_history(memory_id) - def _create_memory(self, data, metadata=None): logging.info(f"Creating memory with {data=}") embeddings = self.embedding_model.embed(data) @@ -536,12 +518,9 @@ def _create_memory(self, data, metadata=None): ids=[memory_id], payloads=[metadata], ) - self.db.add_history( - memory_id, None, data, "ADD", created_at=metadata["created_at"] - ) + self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"]) return memory_id - def _update_memory(self, memory_id, data, metadata=None): logger.info(f"Updating memory with {data=}") existing_memory = self.vector_store.get(vector_id=memory_id) @@ -551,9 +530,7 @@ def _update_memory(self, memory_id, data, metadata=None): new_metadata["data"] = data new_metadata["hash"] = existing_memory.payload.get("hash") new_metadata["created_at"] = existing_memory.payload.get("created_at") - new_metadata["updated_at"] = datetime.now( - pytz.timezone("US/Pacific") - ).isoformat() + new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat() if "user_id" in existing_memory.payload: new_metadata["user_id"] = existing_memory.payload["user_id"] @@ -578,7 +555,6 @@ def _update_memory(self, memory_id, data, metadata=None): updated_at=new_metadata["updated_at"], ) - def _delete_memory(self, memory_id): logging.info(f"Deleting memory with {memory_id=}") existing_memory = self.vector_store.get(vector_id=memory_id) @@ -586,7 +562,6 @@ def _delete_memory(self, memory_id): self.vector_store.delete(vector_id=memory_id) self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1) - def reset(self): """ Reset the memory store. @@ -596,6 +571,5 @@ def reset(self): self.db.reset() capture_event("mem0.reset", self) - def chat(self, query): raise NotImplementedError("Chat function not implemented yet.") diff --git a/mem0/memory/storage.py b/mem0/memory/storage.py index 126df85db4..87a256dc25 100644 --- a/mem0/memory/storage.py +++ b/mem0/memory/storage.py @@ -12,9 +12,7 @@ def _migrate_history_table(self): with self.connection: cursor = self.connection.cursor() - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='history'" - ) + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'") table_exists = cursor.fetchone() is not None if table_exists: @@ -62,7 +60,7 @@ def _migrate_history_table(self): INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted) SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted FROM old_history - """ + """ # noqa: E501 ) cursor.execute("DROP TABLE old_history") diff --git a/mem0/memory/telemetry.py b/mem0/memory/telemetry.py index 9b78d775be..6865b2feeb 100644 --- a/mem0/memory/telemetry.py +++ b/mem0/memory/telemetry.py @@ -1,7 +1,7 @@ import logging +import os import platform import sys -import os from posthog import Posthog @@ -15,8 +15,9 @@ if not isinstance(MEM0_TELEMETRY, bool): raise ValueError("MEM0_TELEMETRY must be a boolean value.") -logging.getLogger('posthog').setLevel(logging.CRITICAL + 1) -logging.getLogger('urllib3').setLevel(logging.CRITICAL + 1) +logging.getLogger("posthog").setLevel(logging.CRITICAL + 1) +logging.getLogger("urllib3").setLevel(logging.CRITICAL + 1) + class AnonymousTelemetry: def __init__(self, project_api_key, host): @@ -24,9 +25,8 @@ def __init__(self, project_api_key, host): # Call setup config to ensure that the user_id is generated setup_config() self.user_id = get_user_id() - # Optional - if not MEM0_TELEMETRY: - self.posthog.disabled = True + if not MEM0_TELEMETRY: + self.posthog.disabled = True def capture_event(self, event_name, properties=None): if properties is None: @@ -40,9 +40,7 @@ def capture_event(self, event_name, properties=None): "machine": platform.machine(), **properties, } - self.posthog.capture( - distinct_id=self.user_id, event=event_name, properties=properties - ) + self.posthog.capture(distinct_id=self.user_id, event=event_name, properties=properties) def identify_user(self, user_id, properties=None): if properties is None: @@ -65,6 +63,9 @@ def capture_event(event_name, memory_instance, additional_data=None): "collection": memory_instance.collection_name, "vector_size": memory_instance.embedding_model.config.embedding_dims, "history_store": "sqlite", + "graph_store": f"{memory_instance.graph.__class__.__module__}.{memory_instance.graph.__class__.__name__}" + if memory_instance.config.graph_store.config + else None, "vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}", "llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}", "embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}", @@ -76,7 +77,6 @@ def capture_event(event_name, memory_instance, additional_data=None): telemetry.capture_event(event_name, event_data) - def capture_client_event(event_name, instance, additional_data=None): event_data = { "function": f"{instance.__class__.__module__}.{instance.__class__.__name__}", diff --git a/mem0/memory/utils.py b/mem0/memory/utils.py index a0c82fedb5..a7e7bc3588 100644 --- a/mem0/memory/utils.py +++ b/mem0/memory/utils.py @@ -4,13 +4,14 @@ def get_fact_retrieval_messages(message): return FACT_RETRIEVAL_PROMPT, f"Input: {message}" + def parse_messages(messages): - response = "" - for msg in messages: - if msg["role"] == "system": - response += f"system: {msg['content']}\n" - if msg["role"] == "user": - response += f"user: {msg['content']}\n" - if msg["role"] == "assistant": - response += f"assistant: {msg['content']}\n" - return response + response = "" + for msg in messages: + if msg["role"] == "system": + response += f"system: {msg['content']}\n" + if msg["role"] == "user": + response += f"user: {msg['content']}\n" + if msg["role"] == "assistant": + response += f"assistant: {msg['content']}\n" + return response diff --git a/mem0/proxy/main.py b/mem0/proxy/main.py index bb614f4f37..b13c681ea8 100644 --- a/mem0/proxy/main.py +++ b/mem0/proxy/main.py @@ -10,7 +10,7 @@ import litellm except ImportError: user_input = input("The 'litellm' library is required. Install it now? [y/N]: ") - if user_input.lower() == 'y': + if user_input.lower() == "y": try: subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"]) import litellm @@ -105,16 +105,10 @@ def create( prepared_messages = self._prepare_messages(messages) if prepared_messages[-1]["role"] == "user": - self._async_add_to_memory( - messages, user_id, agent_id, run_id, metadata, filters - ) - relevant_memories = self._fetch_relevant_memories( - messages, user_id, agent_id, run_id, filters, limit - ) + self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters) + relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit) logger.debug(f"Retrieved {len(relevant_memories)} relevant memories") - prepared_messages[-1]["content"] = self._format_query_with_memories( - messages, relevant_memories - ) + prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories) response = litellm.completion( model=model, @@ -156,9 +150,7 @@ def _prepare_messages(self, messages: List[dict]) -> List[dict]: messages[0]["content"] = MEMORY_ANSWER_PROMPT return messages - def _async_add_to_memory( - self, messages, user_id, agent_id, run_id, metadata, filters - ): + def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters): def add_task(): logger.debug("Adding to memory asynchronously") self.mem0_client.add( @@ -172,13 +164,9 @@ def add_task(): threading.Thread(target=add_task, daemon=True).start() - def _fetch_relevant_memories( - self, messages, user_id, agent_id, run_id, filters, limit - ): + def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit): # Currently, only pass the last 6 messages to the search API to prevent long query - message_input = [ - f"{message['role']}: {message['content']}" for message in messages - ][-6:] + message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:] # TODO: Make it better by summarizing the past conversation return self.mem0_client.search( query="\n".join(message_input), diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index bdcc180678..1b6547886c 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -21,7 +21,7 @@ class LlmFactory: "azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM", "openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM", "anthropic": "mem0.llms.anthropic.AnthropicLLM", - "azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM" + "azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM", } @classmethod @@ -59,6 +59,7 @@ class VectorStoreFactory: "qdrant": "mem0.vector_stores.qdrant.Qdrant", "chroma": "mem0.vector_stores.chroma.ChromaDB", "pgvector": "mem0.vector_stores.pgvector.PGVector", + "milvus": "mem0.vector_stores.milvus.MilvusDB", } @classmethod @@ -71,3 +72,20 @@ def create(cls, provider_name, config): return vector_store_instance(**config) else: raise ValueError(f"Unsupported VectorStore provider: {provider_name}") + +class GraphFactory: + provider_to_class = { + "falkordb": "langchain_community.graphs.FalkorDBGraph", + "neo4j": "langchain_community.graphs.Neo4jGraph", + } + + @classmethod + def create(cls, provider_name, config): + class_type = cls.provider_to_class.get(provider_name) + if class_type: + if not isinstance(config, dict): + config = config.model_dump() + graph_instance = load_class(class_type) + return graph_instance(**config) + else: + raise ValueError(f"Unsupported graph provider: {provider_name}") \ No newline at end of file diff --git a/mem0/vector_stores/chroma.py b/mem0/vector_stores/chroma.py index 0dc97a3fcd..efb9fddb98 100644 --- a/mem0/vector_stores/chroma.py +++ b/mem0/vector_stores/chroma.py @@ -80,24 +80,14 @@ def _parse_output(self, data: Dict) -> List[OutputData]: values.append(value) ids, distances, metadatas = values - max_length = max( - len(v) for v in values if isinstance(v, list) and v is not None - ) + max_length = max(len(v) for v in values if isinstance(v, list) and v is not None) result = [] for i in range(max_length): entry = OutputData( id=ids[i] if isinstance(ids, list) and ids and i < len(ids) else None, - score=( - distances[i] - if isinstance(distances, list) and distances and i < len(distances) - else None - ), - payload=( - metadatas[i] - if isinstance(metadatas, list) and metadatas and i < len(metadatas) - else None - ), + score=(distances[i] if isinstance(distances, list) and distances and i < len(distances) else None), + payload=(metadatas[i] if isinstance(metadatas, list) and metadatas and i < len(metadatas) else None), ) result.append(entry) @@ -143,9 +133,7 @@ def insert( logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads) - def search( - self, query: List[list], limit: int = 5, filters: Optional[Dict] = None - ) -> List[OutputData]: + def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: """ Search for similar vectors. @@ -157,9 +145,7 @@ def search( Returns: List[OutputData]: Search results. """ - results = self.collection.query( - query_embeddings=query, where=filters, n_results=limit - ) + results = self.collection.query(query_embeddings=query, where=filters, n_results=limit) final_results = self._parse_output(results) return final_results @@ -225,9 +211,7 @@ def col_info(self) -> Dict: """ return self.client.get_collection(name=self.collection_name) - def list( - self, filters: Optional[Dict] = None, limit: int = 100 - ) -> List[OutputData]: + def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]: """ List all vectors in a collection. diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 2f052da5b5..65e55a5394 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -8,14 +8,13 @@ class VectorStoreConfig(BaseModel): description="Provider of the vector store (e.g., 'qdrant', 'chroma')", default="qdrant", ) - config: Optional[Dict] = Field( - description="Configuration for the specific vector store", default=None - ) + config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None) _provider_configs: Dict[str, str] = { "qdrant": "QdrantConfig", "chroma": "ChromaDbConfig", "pgvector": "PGVectorConfig", + "milvus": "MilvusDBConfig", } @model_validator(mode="after") diff --git a/mem0/vector_stores/milvus.py b/mem0/vector_stores/milvus.py new file mode 100644 index 0000000000..e1df3458e2 --- /dev/null +++ b/mem0/vector_stores/milvus.py @@ -0,0 +1,242 @@ +import logging +from typing import Dict, Optional + +from pydantic import BaseModel + +from mem0.configs.vector_stores.milvus import MetricType +from mem0.vector_stores.base import VectorStoreBase + +try: + import pymilvus # noqa: F401 +except ImportError: + raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.") + +from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: Optional[str] # memory id + score: Optional[float] # distance + payload: Optional[Dict] # metadata + + +class MilvusDB(VectorStoreBase): + def __init__( + self, + url: str, + token: str, + collection_name: str, + embedding_model_dims: int, + metric_type: MetricType, + ) -> None: + """Initialize the MilvusDB database. + + Args: + url (str): Full URL for Milvus/Zilliz server. + token (str): Token/api_key for Zilliz server / for local setup defaults to None. + collection_name (str): Name of the collection (defaults to mem0). + embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536). + metric_type (MetricType): Metric type for similarity search (defaults to L2). + """ + self.collection_name = collection_name + self.embedding_model_dims = embedding_model_dims + self.metric_type = metric_type + self.client = MilvusClient(uri=url, token=token) + self.create_col( + collection_name=self.collection_name, + vector_size=self.embedding_model_dims, + metric_type=self.metric_type, + ) + + def create_col( + self, + collection_name: str, + vector_size: str, + metric_type: MetricType = MetricType.COSINE, + ) -> None: + """Create a new collection with index_type AUTOINDEX. + + Args: + collection_name (str): Name of the collection (defaults to mem0). + vector_size (str): Dimensions of the embedding model (defaults to 1536). + metric_type (MetricType, optional): etric type for similarity search. Defaults to MetricType.COSINE. + """ + + if self.client.has_collection(collection_name): + logger.info(f"Collection {collection_name} already exists. Skipping creation.") + else: + fields = [ + FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512), + FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size), + FieldSchema(name="metadata", dtype=DataType.JSON), + ] + + schema = CollectionSchema(fields, enable_dynamic_field=True) + + index = self.client.prepare_index_params( + field_name="vectors", + metric_type=metric_type, + index_type="AUTOINDEX", + index_name="vector_index", + params={"nlist": 128}, + ) + self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index) + + def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]): + """Insert vectors into a collection. + + Args: + vectors (List[List[float]]): List of vectors to insert. + payloads (List[Dict], optional): List of payloads corresponding to vectors. + ids (List[str], optional): List of IDs corresponding to vectors. + """ + for idx, embedding, metadata in zip(ids, vectors, payloads): + data = {"id": idx, "vectors": embedding, "metadata": metadata} + self.client.insert(collection_name=self.collection_name, data=data, **kwargs) + + def _create_filter(self, filters: dict): + """Prepare filters for efficient query. + + Args: + filters (dict): filters [user_id, agent_id, run_id] + + Returns: + str: formated filter. + """ + operands = [] + for key, value in filters.items(): + if isinstance(value, str): + operands.append(f'(metadata["{key}"] == "{value}")') + else: + operands.append(f'(metadata["{key}"] == {value})') + + return " and ".join(operands) + + def _parse_output(self, data: list): + """ + Parse the output data. + + Args: + data (Dict): Output data. + + Returns: + List[OutputData]: Parsed output data. + """ + memory = [] + + for value in data: + uid, score, metadata = ( + value.get("id"), + value.get("distance"), + value.get("entity", {}).get("metadata"), + ) + + memory_obj = OutputData(id=uid, score=score, payload=metadata) + memory.append(memory_obj) + + return memory + + def search(self, query: list, limit: int = 5, filters: dict = None) -> list: + """ + Search for similar vectors. + + Args: + query (List[float]): Query vector. + limit (int, optional): Number of results to return. Defaults to 5. + filters (Dict, optional): Filters to apply to the search. Defaults to None. + + Returns: + list: Search results. + """ + query_filter = self._create_filter(filters) if filters else None + hits = self.client.search( + collection_name=self.collection_name, + data=[query], + limit=limit, + filter=query_filter, + output_fields=["*"], + ) + result = self._parse_output(data=hits[0]) + return result + + def delete(self, vector_id): + """ + Delete a vector by ID. + + Args: + vector_id (str): ID of the vector to delete. + """ + self.client.delete(collection_name=self.collection_name, ids=vector_id) + + def update(self, vector_id=None, vector=None, payload=None): + """ + Update a vector and its payload. + + Args: + vector_id (str): ID of the vector to update. + vector (List[float], optional): Updated vector. + payload (Dict, optional): Updated payload. + """ + schema = {"id": vector_id, "vectors": vector, "metadata": payload} + self.client.upsert(collection_name=self.collection_name, data=schema) + + def get(self, vector_id): + """ + Retrieve a vector by ID. + + Args: + vector_id (str): ID of the vector to retrieve. + + Returns: + OutputData: Retrieved vector. + """ + result = self.client.get(collection_name=self.collection_name, ids=vector_id) + output = OutputData( + id=result[0].get("id", None), + score=None, + payload=result[0].get("metadata", None), + ) + return output + + def list_cols(self): + """ + List all collections. + + Returns: + List[str]: List of collection names. + """ + return self.client.list_collections() + + def delete_col(self): + """Delete a collection.""" + return self.client.drop_collection(collection_name=self.collection_name) + + def col_info(self): + """ + Get information about a collection. + + Returns: + Dict[str, Any]: Collection information. + """ + return self.client.get_collection_stats(collection_name=self.collection_name) + + def list(self, filters: dict = None, limit: int = 100) -> list: + """ + List all vectors in a collection. + + Args: + filters (Dict, optional): Filters to apply to the list. + limit (int, optional): Number of vectors to return. Defaults to 100. + + Returns: + List[OutputData]: List of vectors. + """ + query_filter = self._create_filter(filters) if filters else None + result = self.client.query(collection_name=self.collection_name, filter=query_filter, limit=limit) + memories = [] + for data in result: + obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata")) + memories.append(obj) + return [memories] diff --git a/mem0/vector_stores/pgvector.py b/mem0/vector_stores/pgvector.py index f9ec3f9770..c8893e377b 100644 --- a/mem0/vector_stores/pgvector.py +++ b/mem0/vector_stores/pgvector.py @@ -14,6 +14,7 @@ logger = logging.getLogger(__name__) + class OutputData(BaseModel): id: Optional[str] score: Optional[float] @@ -22,7 +23,15 @@ class OutputData(BaseModel): class PGVector(VectorStoreBase): def __init__( - self, dbname, collection_name, embedding_model_dims, user, password, host, port, diskann + self, + dbname, + collection_name, + embedding_model_dims, + user, + password, + host, + port, + diskann, ): """ Initialize the PGVector database. @@ -40,9 +49,7 @@ def __init__( self.collection_name = collection_name self.use_diskann = diskann - self.conn = psycopg2.connect( - dbname=dbname, user=user, password=password, host=host, port=port - ) + self.conn = psycopg2.connect(dbname=dbname, user=user, password=password, host=host, port=port) self.cur = self.conn.cursor() collections = self.list_cols() @@ -73,7 +80,8 @@ def create_col(self, embedding_model_dims): self.cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'") if self.cur.fetchone(): # Create DiskANN index if extension is installed for faster search - self.cur.execute(f""" + self.cur.execute( + f""" CREATE INDEX IF NOT EXISTS {self.collection_name}_vector_idx ON {self.collection_name} USING diskann (vector); @@ -94,10 +102,7 @@ def insert(self, vectors, payloads=None, ids=None): logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") json_payloads = [json.dumps(payload) for payload in payloads] - data = [ - (id, vector, payload) - for id, vector, payload in zip(ids, vectors, json_payloads) - ] + data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)] execute_values( self.cur, f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s", @@ -125,9 +130,7 @@ def search(self, query, limit=5, filters=None): filter_conditions.append("payload->>%s = %s") filter_params.extend([k, str(v)]) - filter_clause = ( - "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" - ) + filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" self.cur.execute( f""" @@ -137,13 +140,11 @@ def search(self, query, limit=5, filters=None): ORDER BY distance LIMIT %s """, - (query, *filter_params, limit), + (query, *filter_params, limit), ) results = self.cur.fetchall() - return [ - OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results - ] + return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results] def delete(self, vector_id): """ @@ -152,9 +153,7 @@ def delete(self, vector_id): Args: vector_id (str): ID of the vector to delete. """ - self.cur.execute( - f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,) - ) + self.cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,)) self.conn.commit() def update(self, vector_id, vector=None, payload=None): @@ -204,9 +203,7 @@ def list_cols(self) -> List[str]: Returns: List[str]: List of collection names. """ - self.cur.execute( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" - ) + self.cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") return [row[0] for row in self.cur.fetchall()] def delete_col(self): @@ -254,9 +251,7 @@ def list(self, filters=None, limit=100): filter_conditions.append("payload->>%s = %s") filter_params.extend([k, str(v)]) - filter_clause = ( - "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" - ) + filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" query = f""" SELECT id, vector, payload diff --git a/mem0/vector_stores/qdrant.py b/mem0/vector_stores/qdrant.py index 3ecb93f9f3..708bf4fb4b 100644 --- a/mem0/vector_stores/qdrant.py +++ b/mem0/vector_stores/qdrant.py @@ -68,9 +68,7 @@ def __init__( self.collection_name = collection_name self.create_col(embedding_model_dims, on_disk) - def create_col( - self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE - ): + def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE): """ Create a new collection. @@ -83,16 +81,12 @@ def create_col( response = self.list_cols() for collection in response.collections: if collection.name == self.collection_name: - logging.debug( - f"Collection {self.collection_name} already exists. Skipping creation." - ) + logging.debug(f"Collection {self.collection_name} already exists. Skipping creation.") return self.client.create_collection( collection_name=self.collection_name, - vectors_config=VectorParams( - size=vector_size, distance=distance, on_disk=on_disk - ), + vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk), ) def insert(self, vectors: list, payloads: list = None, ids: list = None): @@ -128,15 +122,9 @@ def _create_filter(self, filters: dict) -> Filter: conditions = [] for key, value in filters.items(): if isinstance(value, dict) and "gte" in value and "lte" in value: - conditions.append( - FieldCondition( - key=key, range=Range(gte=value["gte"], lte=value["lte"]) - ) - ) + conditions.append(FieldCondition(key=key, range=Range(gte=value["gte"], lte=value["lte"]))) else: - conditions.append( - FieldCondition(key=key, match=MatchValue(value=value)) - ) + conditions.append(FieldCondition(key=key, match=MatchValue(value=value))) return Filter(must=conditions) if conditions else None def search(self, query: list, limit: int = 5, filters: dict = None) -> list: @@ -196,9 +184,7 @@ def get(self, vector_id: int) -> dict: Returns: dict: Retrieved vector. """ - result = self.client.retrieve( - collection_name=self.collection_name, ids=[vector_id], with_payload=True - ) + result = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id], with_payload=True) return result[0] if result else None def list_cols(self) -> list: diff --git a/poetry.lock b/poetry.lock index eab715d4eb..a0610ff687 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -372,6 +372,19 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "falkordb" +version = "1.0.8" +description = "Python client for interacting with FalkorDB database" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "falkordb-1.0.8.tar.gz", hash = "sha256:14a68ab9d684553caf8302602c18c8148c403a0d124a8a5f45de9ea43529b2c6"}, +] + +[package.dependencies] +redis = ">=5.0.1,<6.0.0" + [[package]] name = "frozenlist" version = "1.4.1" @@ -1600,6 +1613,24 @@ numpy = "*" [package.extras] dev = ["pytest"] +[[package]] +name = "redis" +version = "5.0.8" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.8-py3-none-any.whl", hash = "sha256:56134ee08ea909106090934adc36f65c9bcbbaecea5b21ba704ba6fb561f8eb4"}, + {file = "redis-5.0.8.tar.gz", hash = "sha256:0c5b10d387568dfe0698c6fad6615750c24170e548ca2deac10c649d463e9870"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} + +[package.extras] +hiredis = ["hiredis (>1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "requests" version = "2.32.3" @@ -1623,28 +1654,29 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "ruff" -version = "0.4.10" +version = "0.6.5" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.4.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c2c4d0859305ac5a16310eec40e4e9a9dec5dcdfbe92697acd99624e8638dac"}, - {file = "ruff-0.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a79489607d1495685cdd911a323a35871abfb7a95d4f98fc6f85e799227ac46e"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1dd1681dfa90a41b8376a61af05cc4dc5ff32c8f14f5fe20dba9ff5deb80cd6"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c75c53bb79d71310dc79fb69eb4902fba804a81f374bc86a9b117a8d077a1784"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18238c80ee3d9100d3535d8eb15a59c4a0753b45cc55f8bf38f38d6a597b9739"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d8f71885bce242da344989cae08e263de29752f094233f932d4f5cfb4ef36a81"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:330421543bd3222cdfec481e8ff3460e8702ed1e58b494cf9d9e4bf90db52b9d"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e9b6fb3a37b772628415b00c4fc892f97954275394ed611056a4b8a2631365e"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f54c481b39a762d48f64d97351048e842861c6662d63ec599f67d515cb417f6"}, - {file = "ruff-0.4.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:67fe086b433b965c22de0b4259ddfe6fa541c95bf418499bedb9ad5fb8d1c631"}, - {file = "ruff-0.4.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:acfaaab59543382085f9eb51f8e87bac26bf96b164839955f244d07125a982ef"}, - {file = "ruff-0.4.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3cea07079962b2941244191569cf3a05541477286f5cafea638cd3aa94b56815"}, - {file = "ruff-0.4.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:338a64ef0748f8c3a80d7f05785930f7965d71ca260904a9321d13be24b79695"}, - {file = "ruff-0.4.10-py3-none-win32.whl", hash = "sha256:ffe3cd2f89cb54561c62e5fa20e8f182c0a444934bf430515a4b422f1ab7b7ca"}, - {file = "ruff-0.4.10-py3-none-win_amd64.whl", hash = "sha256:67f67cef43c55ffc8cc59e8e0b97e9e60b4837c8f21e8ab5ffd5d66e196e25f7"}, - {file = "ruff-0.4.10-py3-none-win_arm64.whl", hash = "sha256:dd1fcee327c20addac7916ca4e2653fbbf2e8388d8a6477ce5b4e986b68ae6c0"}, - {file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"}, + {file = "ruff-0.6.5-py3-none-linux_armv6l.whl", hash = "sha256:7e4e308f16e07c95fc7753fc1aaac690a323b2bb9f4ec5e844a97bb7fbebd748"}, + {file = "ruff-0.6.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:932cd69eefe4daf8c7d92bd6689f7e8182571cb934ea720af218929da7bd7d69"}, + {file = "ruff-0.6.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3a8d42d11fff8d3143ff4da41742a98f8f233bf8890e9fe23077826818f8d680"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a50af6e828ee692fb10ff2dfe53f05caecf077f4210fae9677e06a808275754f"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:794ada3400a0d0b89e3015f1a7e01f4c97320ac665b7bc3ade24b50b54cb2972"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:381413ec47f71ce1d1c614f7779d88886f406f1fd53d289c77e4e533dc6ea200"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:52e75a82bbc9b42e63c08d22ad0ac525117e72aee9729a069d7c4f235fc4d276"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09c72a833fd3551135ceddcba5ebdb68ff89225d30758027280968c9acdc7810"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:800c50371bdcb99b3c1551d5691e14d16d6f07063a518770254227f7f6e8c178"}, + {file = "ruff-0.6.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e25ddd9cd63ba1f3bd51c1f09903904a6adf8429df34f17d728a8fa11174253"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7291e64d7129f24d1b0c947ec3ec4c0076e958d1475c61202497c6aced35dd19"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9ad7dfbd138d09d9a7e6931e6a7e797651ce29becd688be8a0d4d5f8177b4b0c"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:005256d977021790cc52aa23d78f06bb5090dc0bfbd42de46d49c201533982ae"}, + {file = "ruff-0.6.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:482c1e6bfeb615eafc5899127b805d28e387bd87db38b2c0c41d271f5e58d8cc"}, + {file = "ruff-0.6.5-py3-none-win32.whl", hash = "sha256:cf4d3fa53644137f6a4a27a2b397381d16454a1566ae5335855c187fbf67e4f5"}, + {file = "ruff-0.6.5-py3-none-win_amd64.whl", hash = "sha256:3e42a57b58e3612051a636bc1ac4e6b838679530235520e8f095f7c44f706ff9"}, + {file = "ruff-0.6.5-py3-none-win_arm64.whl", hash = "sha256:51935067740773afdf97493ba9b8231279e9beef0f2a8079188c4776c25688e0"}, + {file = "ruff-0.6.5.tar.gz", hash = "sha256:4d32d87fab433c0cf285c3683dd4dae63be05fd7a1d65b3f5bf7cdd05a6b96fb"}, ] [[package]] @@ -1743,7 +1775,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version < \"3.13\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} +greenlet = {version = "!=0.4.17", markers = "python_version < \"3.13\" and (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\")"} typing-extensions = ">=4.6.0" [package.extras] @@ -1966,4 +1998,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "458055aee51b5e75c8f189fc1b0fbd238b9bb0d8a8becced0bd62a6a59d8d428" +content-hash = "840670092b9935ccd81c72a7a4285fc72e3f4926d0c135c108f54054366c637f" diff --git a/pyproject.toml b/pyproject.toml index c5ab005423..e3d94da6af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mem0ai" -version = "0.1.11" +version = "0.1.14" description = "Long-term memory for AI Agents" authors = ["Mem0 "] exclude = [ @@ -25,18 +25,21 @@ sqlalchemy = "^2.0.31" langchain-community = "^0.2.12" neo4j = "^5.23.1" rank-bm25 = "^0.2.2" +falkordb = "^1.0.8" [tool.poetry.group.test.dependencies] pytest = "^8.2.2" [tool.poetry.group.dev.dependencies] -ruff = "^0.4.8" +ruff = "^0.6.5" isort = "^5.13.2" pytest = "^8.2.2" -[tool.poetry.group.optional.dependencies] - [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.ruff] +line-length = 120 +exclude = ["embedchain/"] diff --git a/tests/embeddings/test_azure_openai_embeddings.py b/tests/embeddings/test_azure_openai_embeddings.py new file mode 100644 index 0000000000..3425ea48a2 --- /dev/null +++ b/tests/embeddings/test_azure_openai_embeddings.py @@ -0,0 +1,46 @@ +import pytest +from unittest.mock import Mock, patch +from mem0.embeddings.azure_openai import AzureOpenAIEmbedding +from mem0.configs.embeddings.base import BaseEmbedderConfig + + +@pytest.fixture +def mock_openai_client(): + with patch("mem0.embeddings.azure_openai.AzureOpenAI") as mock_openai: + mock_client = Mock() + mock_openai.return_value = mock_client + yield mock_client + + +def test_embed_text(mock_openai_client): + config = BaseEmbedderConfig(model="text-embedding-ada-002") + embedder = AzureOpenAIEmbedding(config) + + mock_embedding_response = Mock() + mock_embedding_response.data = [Mock(embedding=[0.1, 0.2, 0.3])] + mock_openai_client.embeddings.create.return_value = mock_embedding_response + + text = "Hello, this is a test." + embedding = embedder.embed(text) + + mock_openai_client.embeddings.create.assert_called_once_with( + input=["Hello, this is a test."], model="text-embedding-ada-002" + ) + assert embedding == [0.1, 0.2, 0.3] + + +def test_embed_text_with_newlines(mock_openai_client): + config = BaseEmbedderConfig(model="text-embedding-ada-002") + embedder = AzureOpenAIEmbedding(config) + + mock_embedding_response = Mock() + mock_embedding_response.data = [Mock(embedding=[0.4, 0.5, 0.6])] + mock_openai_client.embeddings.create.return_value = mock_embedding_response + + text = "Hello,\nthis is a test\nwith newlines." + embedding = embedder.embed(text) + + mock_openai_client.embeddings.create.assert_called_once_with( + input=["Hello, this is a test with newlines."], model="text-embedding-ada-002" + ) + assert embedding == [0.4, 0.5, 0.6] diff --git a/tests/embeddings/test_huggingface_embeddings.py b/tests/embeddings/test_huggingface_embeddings.py new file mode 100644 index 0000000000..de6f5852e0 --- /dev/null +++ b/tests/embeddings/test_huggingface_embeddings.py @@ -0,0 +1,72 @@ +import pytest +from unittest.mock import Mock, patch +from mem0.embeddings.huggingface import HuggingFaceEmbedding +from mem0.configs.embeddings.base import BaseEmbedderConfig + + +@pytest.fixture +def mock_sentence_transformer(): + with patch("mem0.embeddings.huggingface.SentenceTransformer") as mock_transformer: + mock_model = Mock() + mock_transformer.return_value = mock_model + yield mock_model + + +def test_embed_default_model(mock_sentence_transformer): + config = BaseEmbedderConfig() + embedder = HuggingFaceEmbedding(config) + + mock_sentence_transformer.encode.return_value = [0.1, 0.2, 0.3] + result = embedder.embed("Hello world") + + mock_sentence_transformer.encode.assert_called_once_with("Hello world") + + assert result == [0.1, 0.2, 0.3] + + +def test_embed_custom_model(mock_sentence_transformer): + config = BaseEmbedderConfig(model="paraphrase-MiniLM-L6-v2") + embedder = HuggingFaceEmbedding(config) + + mock_sentence_transformer.encode.return_value = [0.4, 0.5, 0.6] + result = embedder.embed("Custom model test") + + mock_sentence_transformer.encode.assert_called_once_with("Custom model test") + + assert result == [0.4, 0.5, 0.6] + + +def test_embed_with_model_kwargs(mock_sentence_transformer): + config = BaseEmbedderConfig(model="all-MiniLM-L6-v2", model_kwargs={"device": "cuda"}) + embedder = HuggingFaceEmbedding(config) + + mock_sentence_transformer.encode.return_value = [0.7, 0.8, 0.9] + result = embedder.embed("Test with device") + + mock_sentence_transformer.encode.assert_called_once_with("Test with device") + + assert result == [0.7, 0.8, 0.9] + + +def test_embed_sets_embedding_dims(mock_sentence_transformer): + config = BaseEmbedderConfig() + + mock_sentence_transformer.get_sentence_embedding_dimension.return_value = 384 + embedder = HuggingFaceEmbedding(config) + + assert embedder.config.embedding_dims == 384 + mock_sentence_transformer.get_sentence_embedding_dimension.assert_called_once() + + +def test_embed_with_custom_embedding_dims(mock_sentence_transformer): + config = BaseEmbedderConfig(model="all-mpnet-base-v2", embedding_dims=768) + embedder = HuggingFaceEmbedding(config) + + mock_sentence_transformer.encode.return_value = [1.0, 1.1, 1.2] + result = embedder.embed("Custom embedding dims") + + mock_sentence_transformer.encode.assert_called_once_with("Custom embedding dims") + + assert embedder.config.embedding_dims == 768 + + assert result == [1.0, 1.1, 1.2] diff --git a/tests/embeddings/test_ollama_embeddings.py b/tests/embeddings/test_ollama_embeddings.py new file mode 100644 index 0000000000..0aa428b742 --- /dev/null +++ b/tests/embeddings/test_ollama_embeddings.py @@ -0,0 +1,41 @@ +import pytest +from unittest.mock import Mock, patch +from mem0.embeddings.ollama import OllamaEmbedding +from mem0.configs.embeddings.base import BaseEmbedderConfig + + +@pytest.fixture +def mock_ollama_client(): + with patch("mem0.embeddings.ollama.Client") as mock_ollama: + mock_client = Mock() + mock_client.list.return_value = {"models": [{"name": "nomic-embed-text"}]} + mock_ollama.return_value = mock_client + yield mock_client + + +def test_embed_text(mock_ollama_client): + config = BaseEmbedderConfig(model="nomic-embed-text", embedding_dims=512) + embedder = OllamaEmbedding(config) + + mock_response = {"embedding": [0.1, 0.2, 0.3, 0.4, 0.5]} + mock_ollama_client.embeddings.return_value = mock_response + + text = "Sample text to embed." + embedding = embedder.embed(text) + + mock_ollama_client.embeddings.assert_called_once_with(model="nomic-embed-text", prompt=text) + + assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5] + + +def test_ensure_model_exists(mock_ollama_client): + config = BaseEmbedderConfig(model="nomic-embed-text", embedding_dims=512) + embedder = OllamaEmbedding(config) + + mock_ollama_client.pull.assert_not_called() + + mock_ollama_client.list.return_value = {"models": []} + + embedder._ensure_model_exists() + + mock_ollama_client.pull.assert_called_once_with("nomic-embed-text") diff --git a/tests/embeddings/test_openai_embeddings.py b/tests/embeddings/test_openai_embeddings.py new file mode 100644 index 0000000000..113a8e64c9 --- /dev/null +++ b/tests/embeddings/test_openai_embeddings.py @@ -0,0 +1,84 @@ +import pytest +from unittest.mock import Mock, patch +from mem0.embeddings.openai import OpenAIEmbedding +from mem0.configs.embeddings.base import BaseEmbedderConfig + + +@pytest.fixture +def mock_openai_client(): + with patch("mem0.embeddings.openai.OpenAI") as mock_openai: + mock_client = Mock() + mock_openai.return_value = mock_client + yield mock_client + + +def test_embed_default_model(mock_openai_client): + config = BaseEmbedderConfig() + embedder = OpenAIEmbedding(config) + mock_response = Mock() + mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3])] + mock_openai_client.embeddings.create.return_value = mock_response + + result = embedder.embed("Hello world") + + mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small") + assert result == [0.1, 0.2, 0.3] + + +def test_embed_custom_model(mock_openai_client): + config = BaseEmbedderConfig(model="text-embedding-2-medium", embedding_dims=1024) + embedder = OpenAIEmbedding(config) + mock_response = Mock() + mock_response.data = [Mock(embedding=[0.4, 0.5, 0.6])] + mock_openai_client.embeddings.create.return_value = mock_response + + result = embedder.embed("Test embedding") + + mock_openai_client.embeddings.create.assert_called_once_with( + input=["Test embedding"], model="text-embedding-2-medium" + ) + assert result == [0.4, 0.5, 0.6] + + +def test_embed_removes_newlines(mock_openai_client): + config = BaseEmbedderConfig() + embedder = OpenAIEmbedding(config) + mock_response = Mock() + mock_response.data = [Mock(embedding=[0.7, 0.8, 0.9])] + mock_openai_client.embeddings.create.return_value = mock_response + + result = embedder.embed("Hello\nworld") + + mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small") + assert result == [0.7, 0.8, 0.9] + + +def test_embed_without_api_key_env_var(mock_openai_client): + config = BaseEmbedderConfig(api_key="test_key") + embedder = OpenAIEmbedding(config) + mock_response = Mock() + mock_response.data = [Mock(embedding=[1.0, 1.1, 1.2])] + mock_openai_client.embeddings.create.return_value = mock_response + + result = embedder.embed("Testing API key") + + mock_openai_client.embeddings.create.assert_called_once_with( + input=["Testing API key"], model="text-embedding-3-small" + ) + assert result == [1.0, 1.1, 1.2] + + +def test_embed_uses_environment_api_key(mock_openai_client, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "env_key") + config = BaseEmbedderConfig() + embedder = OpenAIEmbedding(config) + mock_response = Mock() + mock_response.data = [Mock(embedding=[1.3, 1.4, 1.5])] + mock_openai_client.embeddings.create.return_value = mock_response + + result = embedder.embed("Environment key test") + + mock_openai_client.embeddings.create.assert_called_once_with( + input=["Environment key test"], model="text-embedding-3-small" + ) + assert result == [1.3, 1.4, 1.5] diff --git a/tests/llms/test_azure_openai.py b/tests/llms/test_azure_openai.py index 63eb91b01d..e54d244fbd 100644 --- a/tests/llms/test_azure_openai.py +++ b/tests/llms/test_azure_openai.py @@ -1,4 +1,3 @@ - from unittest.mock import Mock, patch import httpx @@ -7,26 +6,28 @@ from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.azure_openai import AzureOpenAILLM -MODEL = "gpt-4o" # or your custom deployment name +MODEL = "gpt-4o" # or your custom deployment name TEMPERATURE = 0.7 MAX_TOKENS = 100 TOP_P = 1.0 + @pytest.fixture def mock_openai_client(): - with patch('mem0.llms.azure_openai.AzureOpenAI') as mock_openai: + with patch("mem0.llms.azure_openai.AzureOpenAI") as mock_openai: mock_client = Mock() mock_openai.return_value = mock_client yield mock_client + def test_generate_response_without_tools(mock_openai_client): config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P) llm = AzureOpenAILLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"} + {"role": "user", "content": "Hello, how are you?"}, ] - + mock_response = Mock() mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_openai_client.chat.completions.create.return_value = mock_response @@ -34,11 +35,7 @@ def test_generate_response_without_tools(mock_openai_client): response = llm.generate_response(messages) mock_openai_client.chat.completions.create.assert_called_once_with( - model=MODEL, - messages=messages, - temperature=TEMPERATURE, - max_tokens=MAX_TOKENS, - top_p=TOP_P + model=MODEL, messages=messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P ) assert response == "I'm doing well, thank you for asking!" @@ -48,7 +45,7 @@ def test_generate_response_with_tools(mock_openai_client): llm = AzureOpenAILLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Add a new memory: Today is a sunny day."} + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, ] tools = [ { @@ -58,23 +55,21 @@ def test_generate_response_with_tools(mock_openai_client): "description": "Add a memory", "parameters": { "type": "object", - "properties": { - "data": {"type": "string", "description": "Data to add to memory"} - }, + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, "required": ["data"], }, }, } ] - + mock_response = Mock() mock_message = Mock() mock_message.content = "I've added the memory for you." - + mock_tool_call = Mock() mock_tool_call.function.name = "add_memory" mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' - + mock_message.tool_calls = [mock_tool_call] mock_response.choices = [Mock(message=mock_message)] mock_openai_client.chat.completions.create.return_value = mock_response @@ -88,24 +83,33 @@ def test_generate_response_with_tools(mock_openai_client): max_tokens=MAX_TOKENS, top_p=TOP_P, tools=tools, - tool_choice="auto" + tool_choice="auto", ) - + assert response["content"] == "I've added the memory for you." assert len(response["tool_calls"]) == 1 assert response["tool_calls"][0]["name"] == "add_memory" - assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'} + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} + def test_generate_with_http_proxies(): mock_http_client = Mock(spec=httpx.Client) mock_http_client_instance = Mock(spec=httpx.Client) mock_http_client.return_value = mock_http_client_instance - with (patch("mem0.llms.azure_openai.AzureOpenAI") as mock_azure_openai, - patch("httpx.Client", new=mock_http_client) as mock_http_client): - config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P, - api_key="test", http_client_proxies="http://testproxy.mem0.net:8000", - azure_kwargs= {"api_key" : "test"}) + with ( + patch("mem0.llms.azure_openai.AzureOpenAI") as mock_azure_openai, + patch("httpx.Client", new=mock_http_client) as mock_http_client, + ): + config = BaseLlmConfig( + model=MODEL, + temperature=TEMPERATURE, + max_tokens=MAX_TOKENS, + top_p=TOP_P, + api_key="test", + http_client_proxies="http://testproxy.mem0.net:8000", + azure_kwargs={"api_key": "test"}, + ) _ = AzureOpenAILLM(config) @@ -114,6 +118,6 @@ def test_generate_with_http_proxies(): http_client=mock_http_client_instance, azure_deployment=None, azure_endpoint=None, - api_version=None + api_version=None, ) mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000") diff --git a/tests/llms/test_groq.py b/tests/llms/test_groq.py index e7d1f51c13..288b37f80c 100644 --- a/tests/llms/test_groq.py +++ b/tests/llms/test_groq.py @@ -8,7 +8,7 @@ @pytest.fixture def mock_groq_client(): - with patch('mem0.llms.groq.Groq') as mock_groq: + with patch("mem0.llms.groq.Groq") as mock_groq: mock_client = Mock() mock_groq.return_value = mock_client yield mock_client @@ -19,9 +19,9 @@ def test_generate_response_without_tools(mock_groq_client): llm = GroqLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"} + {"role": "user", "content": "Hello, how are you?"}, ] - + mock_response = Mock() mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_groq_client.chat.completions.create.return_value = mock_response @@ -29,11 +29,7 @@ def test_generate_response_without_tools(mock_groq_client): response = llm.generate_response(messages) mock_groq_client.chat.completions.create.assert_called_once_with( - model="llama3-70b-8192", - messages=messages, - temperature=0.7, - max_tokens=100, - top_p=1.0 + model="llama3-70b-8192", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" @@ -43,7 +39,7 @@ def test_generate_response_with_tools(mock_groq_client): llm = GroqLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Add a new memory: Today is a sunny day."} + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, ] tools = [ { @@ -53,23 +49,21 @@ def test_generate_response_with_tools(mock_groq_client): "description": "Add a memory", "parameters": { "type": "object", - "properties": { - "data": {"type": "string", "description": "Data to add to memory"} - }, + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, "required": ["data"], }, }, } ] - + mock_response = Mock() mock_message = Mock() mock_message.content = "I've added the memory for you." - + mock_tool_call = Mock() mock_tool_call.function.name = "add_memory" mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' - + mock_message.tool_calls = [mock_tool_call] mock_response.choices = [Mock(message=mock_message)] mock_groq_client.chat.completions.create.return_value = mock_response @@ -83,11 +77,10 @@ def test_generate_response_with_tools(mock_groq_client): max_tokens=100, top_p=1.0, tools=tools, - tool_choice="auto" + tool_choice="auto", ) - + assert response["content"] == "I've added the memory for you." assert len(response["tool_calls"]) == 1 assert response["tool_calls"][0]["name"] == "add_memory" - assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'} - \ No newline at end of file + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} diff --git a/tests/llms/test_litellm.py b/tests/llms/test_litellm.py index f4b265aacd..d7be93c9fe 100644 --- a/tests/llms/test_litellm.py +++ b/tests/llms/test_litellm.py @@ -8,14 +8,15 @@ @pytest.fixture def mock_litellm(): - with patch('mem0.llms.litellm.litellm') as mock_litellm: + with patch("mem0.llms.litellm.litellm") as mock_litellm: yield mock_litellm + def test_generate_response_with_unsupported_model(mock_litellm): config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1) llm = litellm.LiteLLM(config) messages = [{"role": "user", "content": "Hello"}] - + mock_litellm.supports_function_calling.return_value = False with pytest.raises(ValueError, match="Model 'unsupported-model' in litellm does not support function calling."): @@ -27,9 +28,9 @@ def test_generate_response_without_tools(mock_litellm): llm = litellm.LiteLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"} + {"role": "user", "content": "Hello, how are you?"}, ] - + mock_response = Mock() mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_litellm.completion.return_value = mock_response @@ -38,11 +39,7 @@ def test_generate_response_without_tools(mock_litellm): response = llm.generate_response(messages) mock_litellm.completion.assert_called_once_with( - model="gpt-4o", - messages=messages, - temperature=0.7, - max_tokens=100, - top_p=1.0 + model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" @@ -52,7 +49,7 @@ def test_generate_response_with_tools(mock_litellm): llm = litellm.LiteLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Add a new memory: Today is a sunny day."} + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, ] tools = [ { @@ -62,23 +59,21 @@ def test_generate_response_with_tools(mock_litellm): "description": "Add a memory", "parameters": { "type": "object", - "properties": { - "data": {"type": "string", "description": "Data to add to memory"} - }, + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, "required": ["data"], }, }, } ] - + mock_response = Mock() mock_message = Mock() mock_message.content = "I've added the memory for you." - + mock_tool_call = Mock() mock_tool_call.function.name = "add_memory" mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' - + mock_message.tool_calls = [mock_tool_call] mock_response.choices = [Mock(message=mock_message)] mock_litellm.completion.return_value = mock_response @@ -87,16 +82,10 @@ def test_generate_response_with_tools(mock_litellm): response = llm.generate_response(messages, tools=tools) mock_litellm.completion.assert_called_once_with( - model="gpt-4o", - messages=messages, - temperature=0.7, - max_tokens=100, - top_p=1, - tools=tools, - tool_choice="auto" + model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1, tools=tools, tool_choice="auto" ) - + assert response["content"] == "I've added the memory for you." assert len(response["tool_calls"]) == 1 assert response["tool_calls"][0]["name"] == "add_memory" - assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'} + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} diff --git a/tests/llms/test_ollama.py b/tests/llms/test_ollama.py index d99fd2bcc3..f815833515 100644 --- a/tests/llms/test_ollama.py +++ b/tests/llms/test_ollama.py @@ -9,61 +9,48 @@ @pytest.fixture def mock_ollama_client(): - with patch('mem0.llms.ollama.Client') as mock_ollama: + with patch("mem0.llms.ollama.Client") as mock_ollama: mock_client = Mock() mock_client.list.return_value = {"models": [{"name": "llama3.1:70b"}]} mock_ollama.return_value = mock_client yield mock_client + def test_generate_response_without_tools(mock_ollama_client): config = BaseLlmConfig(model="llama3.1:70b", temperature=0.7, max_tokens=100, top_p=1.0) llm = OllamaLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"} + {"role": "user", "content": "Hello, how are you?"}, ] - - mock_response = { - 'message': {"content": "I'm doing well, thank you for asking!"} - } + + mock_response = {"message": {"content": "I'm doing well, thank you for asking!"}} mock_ollama_client.chat.return_value = mock_response response = llm.generate_response(messages) mock_ollama_client.chat.assert_called_once_with( - model="llama3.1:70b", - messages=messages, - options={ - "temperature": 0.7, - "num_predict": 100, - "top_p": 1.0 - } + model="llama3.1:70b", messages=messages, options={"temperature": 0.7, "num_predict": 100, "top_p": 1.0} ) assert response == "I'm doing well, thank you for asking!" + def test_generate_response_with_tools(mock_ollama_client): config = BaseLlmConfig(model="llama3.1:70b", temperature=0.7, max_tokens=100, top_p=1.0) llm = OllamaLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Add a new memory: Today is a sunny day."} + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, ] tools = [ADD_MEMORY_TOOL] - + mock_response = { - 'message': { + "message": { "content": "I've added the memory for you.", - "tool_calls": [ - { - "function": { - "name": "add_memory", - "arguments": {"data": "Today is a sunny day."} - } - } - ] + "tool_calls": [{"function": {"name": "add_memory", "arguments": {"data": "Today is a sunny day."}}}], } } - + mock_ollama_client.chat.return_value = mock_response response = llm.generate_response(messages, tools=tools) @@ -71,16 +58,11 @@ def test_generate_response_with_tools(mock_ollama_client): mock_ollama_client.chat.assert_called_once_with( model="llama3.1:70b", messages=messages, - options={ - "temperature": 0.7, - "num_predict": 100, - "top_p": 1.0 - }, - tools=tools + options={"temperature": 0.7, "num_predict": 100, "top_p": 1.0}, + tools=tools, ) - + assert response["content"] == "I've added the memory for you." assert len(response["tool_calls"]) == 1 assert response["tool_calls"][0]["name"] == "add_memory" - assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'} - \ No newline at end of file + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index 204487c432..be2f6f954e 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -8,7 +8,7 @@ @pytest.fixture def mock_openai_client(): - with patch('mem0.llms.openai.OpenAI') as mock_openai: + with patch("mem0.llms.openai.OpenAI") as mock_openai: mock_client = Mock() mock_openai.return_value = mock_client yield mock_client @@ -19,9 +19,9 @@ def test_generate_response_without_tools(mock_openai_client): llm = OpenAILLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"} + {"role": "user", "content": "Hello, how are you?"}, ] - + mock_response = Mock() mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_openai_client.chat.completions.create.return_value = mock_response @@ -29,11 +29,7 @@ def test_generate_response_without_tools(mock_openai_client): response = llm.generate_response(messages) mock_openai_client.chat.completions.create.assert_called_once_with( - model="gpt-4o", - messages=messages, - temperature=0.7, - max_tokens=100, - top_p=1.0 + model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" @@ -43,7 +39,7 @@ def test_generate_response_with_tools(mock_openai_client): llm = OpenAILLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Add a new memory: Today is a sunny day."} + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, ] tools = [ { @@ -53,23 +49,21 @@ def test_generate_response_with_tools(mock_openai_client): "description": "Add a memory", "parameters": { "type": "object", - "properties": { - "data": {"type": "string", "description": "Data to add to memory"} - }, + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, "required": ["data"], }, }, } ] - + mock_response = Mock() mock_message = Mock() mock_message.content = "I've added the memory for you." - + mock_tool_call = Mock() mock_tool_call.function.name = "add_memory" mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' - + mock_message.tool_calls = [mock_tool_call] mock_response.choices = [Mock(message=mock_message)] mock_openai_client.chat.completions.create.return_value = mock_response @@ -77,17 +71,10 @@ def test_generate_response_with_tools(mock_openai_client): response = llm.generate_response(messages, tools=tools) mock_openai_client.chat.completions.create.assert_called_once_with( - model="gpt-4o", - messages=messages, - temperature=0.7, - max_tokens=100, - top_p=1.0, - tools=tools, - tool_choice="auto" + model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto" ) - + assert response["content"] == "I've added the memory for you." assert len(response["tool_calls"]) == 1 assert response["tool_calls"][0]["name"] == "add_memory" - assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'} - \ No newline at end of file + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} diff --git a/tests/llms/test_together.py b/tests/llms/test_together.py index f317d106ce..7c59ee4195 100644 --- a/tests/llms/test_together.py +++ b/tests/llms/test_together.py @@ -8,7 +8,7 @@ @pytest.fixture def mock_together_client(): - with patch('mem0.llms.together.Together') as mock_together: + with patch("mem0.llms.together.Together") as mock_together: mock_client = Mock() mock_together.return_value = mock_client yield mock_client @@ -19,9 +19,9 @@ def test_generate_response_without_tools(mock_together_client): llm = TogetherLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"} + {"role": "user", "content": "Hello, how are you?"}, ] - + mock_response = Mock() mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_together_client.chat.completions.create.return_value = mock_response @@ -29,11 +29,7 @@ def test_generate_response_without_tools(mock_together_client): response = llm.generate_response(messages) mock_together_client.chat.completions.create.assert_called_once_with( - model="mistralai/Mixtral-8x7B-Instruct-v0.1", - messages=messages, - temperature=0.7, - max_tokens=100, - top_p=1.0 + model="mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" @@ -43,7 +39,7 @@ def test_generate_response_with_tools(mock_together_client): llm = TogetherLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Add a new memory: Today is a sunny day."} + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, ] tools = [ { @@ -53,23 +49,21 @@ def test_generate_response_with_tools(mock_together_client): "description": "Add a memory", "parameters": { "type": "object", - "properties": { - "data": {"type": "string", "description": "Data to add to memory"} - }, + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, "required": ["data"], }, }, } ] - + mock_response = Mock() mock_message = Mock() mock_message.content = "I've added the memory for you." - + mock_tool_call = Mock() mock_tool_call.function.name = "add_memory" mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' - + mock_message.tool_calls = [mock_tool_call] mock_response.choices = [Mock(message=mock_message)] mock_together_client.chat.completions.create.return_value = mock_response @@ -83,11 +77,10 @@ def test_generate_response_with_tools(mock_together_client): max_tokens=100, top_p=1.0, tools=tools, - tool_choice="auto" + tool_choice="auto", ) - + assert response["content"] == "I've added the memory for you." assert len(response["tool_calls"]) == 1 assert response["tool_calls"][0]["name"] == "add_memory" - assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'} - \ No newline at end of file + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} diff --git a/tests/test_main.py b/tests/test_main.py index 16a672e395..8ed2224586 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,42 +4,39 @@ from mem0.memory.main import Memory from mem0.configs.base import MemoryConfig + @pytest.fixture(autouse=True) def mock_openai(): - os.environ['OPENAI_API_KEY'] = "123" - with patch('openai.OpenAI') as mock: + os.environ["OPENAI_API_KEY"] = "123" + with patch("openai.OpenAI") as mock: mock.return_value = Mock() yield mock + @pytest.fixture def memory_instance(): - with patch('mem0.utils.factory.EmbedderFactory') as mock_embedder, \ - patch('mem0.utils.factory.VectorStoreFactory') as mock_vector_store, \ - patch('mem0.utils.factory.LlmFactory') as mock_llm, \ - patch('mem0.memory.telemetry.capture_event'), \ - patch('mem0.memory.graph_memory.MemoryGraph'): + with patch("mem0.utils.factory.EmbedderFactory") as mock_embedder, patch( + "mem0.utils.factory.VectorStoreFactory" + ) as mock_vector_store, patch("mem0.utils.factory.LlmFactory") as mock_llm, patch( + "mem0.memory.telemetry.capture_event" + ), patch("mem0.memory.graph_memory.MemoryGraph"): mock_embedder.create.return_value = Mock() mock_vector_store.create.return_value = Mock() mock_llm.create.return_value = Mock() - + config = MemoryConfig(version="v1.1") config.graph_store.config = {"some_config": "value"} return Memory(config) -@pytest.mark.parametrize("version, enable_graph", [ - ("v1.0", False), - ("v1.1", True) -]) + +@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)]) def test_add(memory_instance, version, enable_graph): memory_instance.config.version = version memory_instance.enable_graph = enable_graph memory_instance._add_to_vector_store = Mock(return_value=[{"memory": "Test memory", "event": "ADD"}]) memory_instance._add_to_graph = Mock(return_value=[]) - result = memory_instance.add( - messages=[{"role": "user", "content": "Test message"}], - user_id="test_user" - ) + result = memory_instance.add(messages=[{"role": "user", "content": "Test message"}], user_id="test_user") assert "results" in result assert result["results"] == [{"memory": "Test memory", "event": "ADD"}] @@ -47,26 +44,27 @@ def test_add(memory_instance, version, enable_graph): assert result["relations"] == [] memory_instance._add_to_vector_store.assert_called_once_with( - [{"role": "user", "content": "Test message"}], - {"user_id": "test_user"}, - {"user_id": "test_user"} + [{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"} ) - + # Remove the conditional assertion for _add_to_graph memory_instance._add_to_graph.assert_called_once_with( - [{"role": "user", "content": "Test message"}], - {"user_id": "test_user"} + [{"role": "user", "content": "Test message"}], {"user_id": "test_user"} ) + def test_get(memory_instance): - mock_memory = Mock(id="test_id", payload={ - "data": "Test memory", - "user_id": "test_user", - "hash": "test_hash", - "created_at": "2023-01-01T00:00:00", - "updated_at": "2023-01-02T00:00:00", - "extra_field": "extra_value" - }) + mock_memory = Mock( + id="test_id", + payload={ + "data": "Test memory", + "user_id": "test_user", + "hash": "test_hash", + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-02T00:00:00", + "extra_field": "extra_value", + }, + ) memory_instance.vector_store.get = Mock(return_value=mock_memory) result = memory_instance.get("test_id") @@ -79,16 +77,14 @@ def test_get(memory_instance): assert result["updated_at"] == "2023-01-02T00:00:00" assert result["metadata"] == {"extra_field": "extra_value"} -@pytest.mark.parametrize("version, enable_graph", [ - ("v1.0", False), - ("v1.1", True) -]) + +@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)]) def test_search(memory_instance, version, enable_graph): memory_instance.config.version = version memory_instance.enable_graph = enable_graph mock_memories = [ Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"}, score=0.9), - Mock(id="2", payload={"data": "Memory 2", "user_id": "test_user"}, score=0.8) + Mock(id="2", payload={"data": "Memory 2", "user_id": "test_user"}, score=0.8), ] memory_instance.vector_store.search = Mock(return_value=mock_memories) memory_instance.embedding_model.embed = Mock(return_value=[0.1, 0.2, 0.3]) @@ -118,17 +114,16 @@ def test_search(memory_instance, version, enable_graph): assert result["results"][0]["score"] == 0.9 memory_instance.vector_store.search.assert_called_once_with( - query=[0.1, 0.2, 0.3], - limit=100, - filters={"user_id": "test_user"} + query=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"} ) memory_instance.embedding_model.embed.assert_called_once_with("test query") - + if enable_graph: memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"}) else: memory_instance.graph.search.assert_not_called() + def test_update(memory_instance): memory_instance._update_memory = Mock() @@ -137,6 +132,7 @@ def test_update(memory_instance): memory_instance._update_memory.assert_called_once_with("test_id", "Updated memory") assert result["message"] == "Memory updated successfully!" + def test_delete(memory_instance): memory_instance._delete_memory = Mock() @@ -145,10 +141,8 @@ def test_delete(memory_instance): memory_instance._delete_memory.assert_called_once_with("test_id") assert result["message"] == "Memory deleted successfully!" -@pytest.mark.parametrize("version, enable_graph", [ - ("v1.0", False), - ("v1.1", True) -]) + +@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)]) def test_delete_all(memory_instance, version, enable_graph): memory_instance.config.version = version memory_instance.enable_graph = enable_graph @@ -160,14 +154,15 @@ def test_delete_all(memory_instance, version, enable_graph): result = memory_instance.delete_all(user_id="test_user") assert memory_instance._delete_memory.call_count == 2 - + if enable_graph: memory_instance.graph.delete_all.assert_called_once_with({"user_id": "test_user"}) else: memory_instance.graph.delete_all.assert_not_called() - + assert result["message"] == "Memories deleted successfully!" + def test_reset(memory_instance): memory_instance.vector_store.delete_col = Mock() memory_instance.db.reset = Mock() @@ -177,22 +172,30 @@ def test_reset(memory_instance): memory_instance.vector_store.delete_col.assert_called_once() memory_instance.db.reset.assert_called_once() -@pytest.mark.parametrize("version, enable_graph, expected_result", [ - ("v1.0", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}), - ("v1.1", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}), - ("v1.1", True, { - "results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}], - "relations": [{"source": "entity1", "relationship": "rel", "target": "entity2"}] - }) -]) + +@pytest.mark.parametrize( + "version, enable_graph, expected_result", + [ + ("v1.0", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}), + ("v1.1", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}), + ( + "v1.1", + True, + { + "results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}], + "relations": [{"source": "entity1", "relationship": "rel", "target": "entity2"}], + }, + ), + ], +) def test_get_all(memory_instance, version, enable_graph, expected_result): memory_instance.config.version = version memory_instance.enable_graph = enable_graph mock_memories = [Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"})] memory_instance.vector_store.list = Mock(return_value=(mock_memories, None)) - memory_instance.graph.get_all = Mock(return_value=[ - {"source": "entity1", "relationship": "rel", "target": "entity2"} - ]) + memory_instance.graph.get_all = Mock( + return_value=[{"source": "entity1", "relationship": "rel", "target": "entity2"}] + ) result = memory_instance.get_all(user_id="test_user") @@ -204,7 +207,7 @@ def test_get_all(memory_instance, version, enable_graph, expected_result): assert result_item["id"] == expected_item["id"] assert result_item["memory"] == expected_item["memory"] assert result_item["user_id"] == expected_item["user_id"] - + if enable_graph: assert "relations" in result assert result["relations"] == expected_result["relations"] @@ -212,7 +215,7 @@ def test_get_all(memory_instance, version, enable_graph, expected_result): assert "relations" not in result memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100) - + if enable_graph: memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}) else: diff --git a/tests/test_memory.py b/tests/test_memory.py index 9c1c60039c..2659d06c92 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -7,6 +7,7 @@ def memory_store(): return Memory() + @pytest.mark.skip(reason="Not implemented") def test_create_memory(memory_store): data = "Name is John Doe." diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 8e7e58ec3e..8088f380ed 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -11,23 +11,26 @@ def mock_memory_client(): return Mock(spec=MemoryClient) + @pytest.fixture def mock_openai_embedding_client(): - with patch('mem0.embeddings.openai.OpenAI') as mock_openai: + with patch("mem0.embeddings.openai.OpenAI") as mock_openai: mock_client = Mock() mock_openai.return_value = mock_client yield mock_client + @pytest.fixture def mock_openai_llm_client(): - with patch('mem0.llms.openai.OpenAI') as mock_openai: + with patch("mem0.llms.openai.OpenAI") as mock_openai: mock_client = Mock() mock_openai.return_value = mock_client yield mock_client + @pytest.fixture def mock_litellm(): - with patch('mem0.proxy.main.litellm') as mock: + with patch("mem0.proxy.main.litellm") as mock: yield mock @@ -39,16 +42,16 @@ def test_mem0_initialization_with_api_key(mock_openai_embedding_client, mock_ope def test_mem0_initialization_with_config(): config = {"some_config": "value"} - with patch('mem0.Memory.from_config') as mock_from_config: + with patch("mem0.Memory.from_config") as mock_from_config: mem0 = Mem0(config=config) mock_from_config.assert_called_once_with(config) assert isinstance(mem0.chat, Chat) def test_mem0_initialization_without_params(mock_openai_embedding_client, mock_openai_llm_client): - mem0 = Mem0() - assert isinstance(mem0.mem0_client, Memory) - assert isinstance(mem0.chat, Chat) + mem0 = Mem0() + assert isinstance(mem0.mem0_client, Memory) + assert isinstance(mem0.chat, Chat) def test_chat_initialization(mock_memory_client): @@ -58,48 +61,37 @@ def test_chat_initialization(mock_memory_client): def test_completions_create(mock_memory_client, mock_litellm): completions = Completions(mock_memory_client) - - messages = [ - {"role": "user", "content": "Hello, how are you?"} - ] + + messages = [{"role": "user", "content": "Hello, how are you?"}] mock_memory_client.search.return_value = [{"memory": "Some relevant memory"}] mock_litellm.completion.return_value = {"choices": [{"message": {"content": "I'm doing well, thank you!"}}]} - - response = completions.create( - model="gpt-4o-mini", - messages=messages, - user_id="test_user", - temperature=0.7 - ) - + + response = completions.create(model="gpt-4o-mini", messages=messages, user_id="test_user", temperature=0.7) + mock_memory_client.add.assert_called_once() mock_memory_client.search.assert_called_once() - + mock_litellm.completion.assert_called_once() call_args = mock_litellm.completion.call_args[1] - assert call_args['model'] == "gpt-4o-mini" - assert len(call_args['messages']) == 2 - assert call_args['temperature'] == 0.7 - + assert call_args["model"] == "gpt-4o-mini" + assert len(call_args["messages"]) == 2 + assert call_args["temperature"] == 0.7 + assert response == {"choices": [{"message": {"content": "I'm doing well, thank you!"}}]} def test_completions_create_with_system_message(mock_memory_client, mock_litellm): completions = Completions(mock_memory_client) - + messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"} + {"role": "user", "content": "Hello, how are you?"}, ] mock_memory_client.search.return_value = [{"memory": "Some relevant memory"}] mock_litellm.completion.return_value = {"choices": [{"message": {"content": "I'm doing well, thank you!"}}]} - - completions.create( - model="gpt-4o-mini", - messages=messages, - user_id="test_user" - ) - + + completions.create(model="gpt-4o-mini", messages=messages, user_id="test_user") + call_args = mock_litellm.completion.call_args[1] - assert call_args['messages'][0]['role'] == "system" - assert call_args['messages'][0]['content'] == MEMORY_ANSWER_PROMPT + assert call_args["messages"][0]["role"] == "system" + assert call_args["messages"][0]["content"] == MEMORY_ANSWER_PROMPT diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index aa36f2f056..1d0b100f21 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -7,23 +7,28 @@ if isinstance(MEM0_TELEMETRY, str): MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes") + def use_telemetry(): - if os.getenv('MEM0_TELEMETRY', "true").lower() == "true": + if os.getenv("MEM0_TELEMETRY", "true").lower() == "true": return True return False + @pytest.fixture(autouse=True) def reset_env(): with patch.dict(os.environ, {}, clear=True): yield + def test_telemetry_enabled(): - with patch.dict(os.environ, {'MEM0_TELEMETRY': "true"}): + with patch.dict(os.environ, {"MEM0_TELEMETRY": "true"}): assert use_telemetry() is True + def test_telemetry_disabled(): - with patch.dict(os.environ, {'MEM0_TELEMETRY': "false"}): + with patch.dict(os.environ, {"MEM0_TELEMETRY": "false"}): assert use_telemetry() is False + def test_telemetry_default_enabled(): assert use_telemetry() is True diff --git a/tests/vector_stores/test_chroma.py b/tests/vector_stores/test_chroma.py new file mode 100644 index 0000000000..3d0c20b3dc --- /dev/null +++ b/tests/vector_stores/test_chroma.py @@ -0,0 +1,111 @@ +from unittest.mock import Mock, patch +import pytest +from mem0.vector_stores.chroma import ChromaDB, OutputData + + +@pytest.fixture +def mock_chromadb_client(): + with patch("chromadb.Client") as mock_client: + yield mock_client + + +@pytest.fixture +def chromadb_instance(mock_chromadb_client): + mock_collection = Mock() + mock_chromadb_client.return_value.get_or_create_collection.return_value = ( + mock_collection + ) + + return ChromaDB( + collection_name="test_collection", client=mock_chromadb_client.return_value + ) + + +def test_insert_vectors(chromadb_instance, mock_chromadb_client): + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"name": "vector1"}, {"name": "vector2"}] + ids = ["id1", "id2"] + + chromadb_instance.insert(vectors=vectors, payloads=payloads, ids=ids) + + chromadb_instance.collection.add.assert_called_once_with( + ids=ids, embeddings=vectors, metadatas=payloads + ) + + +def test_search_vectors(chromadb_instance, mock_chromadb_client): + mock_result = { + "ids": [["id1", "id2"]], + "distances": [[0.1, 0.2]], + "metadatas": [[{"name": "vector1"}, {"name": "vector2"}]], + } + chromadb_instance.collection.query.return_value = mock_result + + query = [[0.1, 0.2, 0.3]] + results = chromadb_instance.search(query=query, limit=2) + + chromadb_instance.collection.query.assert_called_once_with( + query_embeddings=query, where=None, n_results=2 + ) + + print(results, type(results)) + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].score == 0.1 + assert results[0].payload == {"name": "vector1"} + + +def test_delete_vector(chromadb_instance): + vector_id = "id1" + + chromadb_instance.delete(vector_id=vector_id) + + chromadb_instance.collection.delete.assert_called_once_with(ids=vector_id) + + +def test_update_vector(chromadb_instance): + vector_id = "id1" + new_vector = [0.7, 0.8, 0.9] + new_payload = {"name": "updated_vector"} + + chromadb_instance.update( + vector_id=vector_id, vector=new_vector, payload=new_payload + ) + + chromadb_instance.collection.update.assert_called_once_with( + ids=vector_id, embeddings=new_vector, metadatas=new_payload + ) + + +def test_get_vector(chromadb_instance): + mock_result = { + "ids": [["id1"]], + "distances": [[0.1]], + "metadatas": [[{"name": "vector1"}]], + } + chromadb_instance.collection.get.return_value = mock_result + + result = chromadb_instance.get(vector_id="id1") + + chromadb_instance.collection.get.assert_called_once_with(ids=["id1"]) + + assert result.id == "id1" + assert result.score == 0.1 + assert result.payload == {"name": "vector1"} + + +def test_list_vectors(chromadb_instance): + mock_result = { + "ids": [["id1", "id2"]], + "distances": [[0.1, 0.2]], + "metadatas": [[{"name": "vector1"}, {"name": "vector2"}]], + } + chromadb_instance.collection.get.return_value = mock_result + + results = chromadb_instance.list(limit=2) + + chromadb_instance.collection.get.assert_called_once_with(where=None, limit=2) + + assert len(results[0]) == 2 + assert results[0][0].id == "id1" + assert results[0][1].id == "id2" diff --git a/tests/vector_stores/test_qdrant.py b/tests/vector_stores/test_qdrant.py new file mode 100644 index 0000000000..b398335fad --- /dev/null +++ b/tests/vector_stores/test_qdrant.py @@ -0,0 +1,130 @@ +import unittest +from unittest.mock import MagicMock, patch +import uuid +from qdrant_client import QdrantClient +from qdrant_client.models import ( + Distance, + PointStruct, + VectorParams, + PointIdsList, +) +from mem0.vector_stores.qdrant import Qdrant + + +class TestQdrant(unittest.TestCase): + def setUp(self): + self.client_mock = MagicMock(spec=QdrantClient) + self.qdrant = Qdrant( + collection_name="test_collection", + embedding_model_dims=128, + client=self.client_mock, + path="test_path", + on_disk=True, + ) + + def test_create_col(self): + self.client_mock.get_collections.return_value = MagicMock(collections=[]) + + self.qdrant.create_col(vector_size=128, on_disk=True) + + expected_config = VectorParams(size=128, distance=Distance.COSINE, on_disk=True) + + self.client_mock.create_collection.assert_called_with( + collection_name="test_collection", vectors_config=expected_config + ) + + def test_insert(self): + vectors = [[0.1, 0.2], [0.3, 0.4]] + payloads = [{"key": "value1"}, {"key": "value2"}] + ids = [str(uuid.uuid4()), str(uuid.uuid4())] + + self.qdrant.insert(vectors=vectors, payloads=payloads, ids=ids) + + self.client_mock.upsert.assert_called_once() + points = self.client_mock.upsert.call_args[1]["points"] + + self.assertEqual(len(points), 2) + for point in points: + self.assertIsInstance(point, PointStruct) + + self.assertEqual(points[0].payload, payloads[0]) + + def test_search(self): + query_vector = [0.1, 0.2] + self.client_mock.search.return_value = [ + {"id": str(uuid.uuid4()), "score": 0.95, "payload": {"key": "value"}} + ] + + results = self.qdrant.search(query=query_vector, limit=1) + + self.client_mock.search.assert_called_once_with( + collection_name="test_collection", + query_vector=query_vector, + query_filter=None, + limit=1, + ) + + self.assertEqual(len(results), 1) + self.assertIn("id", results[0]) + self.assertIn("score", results[0]) + self.assertIn("payload", results[0]) + + def test_delete(self): + vector_id = str(uuid.uuid4()) + self.qdrant.delete(vector_id=vector_id) + + self.client_mock.delete.assert_called_once_with( + collection_name="test_collection", + points_selector=PointIdsList(points=[vector_id]), + ) + + def test_update(self): + vector_id = str(uuid.uuid4()) + updated_vector = [0.2, 0.3] + updated_payload = {"key": "updated_value"} + + self.qdrant.update( + vector_id=vector_id, vector=updated_vector, payload=updated_payload + ) + + self.client_mock.upsert.assert_called_once() + point = self.client_mock.upsert.call_args[1]["points"][0] + self.assertEqual(point.id, vector_id) + self.assertEqual(point.vector, updated_vector) + self.assertEqual(point.payload, updated_payload) + + def test_get(self): + vector_id = str(uuid.uuid4()) + self.client_mock.retrieve.return_value = [ + {"id": vector_id, "payload": {"key": "value"}} + ] + + result = self.qdrant.get(vector_id=vector_id) + + self.client_mock.retrieve.assert_called_once_with( + collection_name="test_collection", ids=[vector_id], with_payload=True + ) + self.assertEqual(result["id"], vector_id) + self.assertEqual(result["payload"], {"key": "value"}) + + def test_list_cols(self): + self.client_mock.get_collections.return_value = MagicMock( + collections=[{"name": "test_collection"}] + ) + result = self.qdrant.list_cols() + self.assertEqual(result.collections[0]["name"], "test_collection") + + def test_delete_col(self): + self.qdrant.delete_col() + self.client_mock.delete_collection.assert_called_once_with( + collection_name="test_collection" + ) + + def test_col_info(self): + self.qdrant.col_info() + self.client_mock.get_collection.assert_called_once_with( + collection_name="test_collection" + ) + + def tearDown(self): + del self.qdrant