diff --git a/bingchat.py b/bingchat.py deleted file mode 100644 index 3ba508d10..000000000 --- a/bingchat.py +++ /dev/null @@ -1,19 +0,0 @@ -from swarms.models.bing_chat import BingChat -from swarms.workers.worker import Worker -from swarms.tools.autogpt import EdgeGPTTool, tool - - -# Initialize the language model, -# This model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC -llm = BingChat(cookies_path="./cookies.json") - -# Initialize the Worker with the custom tool -worker = Worker( - llm=llm, - ai_name="EdgeGPT Worker", -) - -# Use the worker to process a task -task = "Hello, my name is ChatGPT" -response = worker.run(task) -print(response) diff --git a/cookies.json b/cookies.json deleted file mode 100644 index bd49de753..000000000 --- a/cookies.json +++ /dev/null @@ -1,6 +0,0 @@ -[ - { - "name": "cookie1", - "value": "1GJjj1-tM6Jlo4HFtnbocQ3r0QbQ9Aq_R65dqbcSWKzKxnN8oEMW1xa4RlsJ_nGyNjFlXQRzMWRR2GK11bve8-6n_bjF0zTczYcQQ8oDB8W66jgpIWSL7Hr4hneB0R9dIt-OQ4cVPs4eehL2lcRCObWQr0zkG14MHlH5EMwAKthv_NNIQSfThq4Ey2Hmzhq9sRuyS04JveHdLC9gfthJ8xk3J12yr7j4HsynpzmvFUcA" - } -] diff --git a/playground/agents/revgpt_agent.py b/playground/agents/revgpt_agent.py index e8667e90c..42d95359a 100644 --- a/playground/agents/revgpt_agent.py +++ b/playground/agents/revgpt_agent.py @@ -10,15 +10,12 @@ "plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")], "disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True", "PUID": os.getenv("REVGPT_PUID"), - "unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")] + "unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")], } llm = RevChatGPTModel(access_token=os.getenv("ACCESS_TOKEN"), **config) -worker = Worker( - ai_name="Optimus Prime", - llm=llm -) +worker = Worker(ai_name="Optimus Prime", llm=llm) task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times." response = worker.run(task) diff --git a/playground/models/bingchat.py b/playground/models/bingchat.py index bd2589b8f..bf06ecc66 100644 --- a/playground/models/bingchat.py +++ b/playground/models/bingchat.py @@ -2,18 +2,20 @@ from swarms.workers.worker import Worker from swarms.tools.autogpt import EdgeGPTTool, tool from swarms.models import OpenAIChat -import os +import os api_key = os.getenv("OPENAI_API_KEY") # Initialize the EdgeGPTModel edgegpt = BingChat(cookies_path="./cookies.txt") + @tool def edgegpt(task: str = None): """A tool to run infrence on the EdgeGPT Model""" return EdgeGPTTool.run(task) + # Initialize the language model, # This model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC llm = OpenAIChat( @@ -22,11 +24,7 @@ def edgegpt(task: str = None): ) # Initialize the Worker with the custom tool -worker = Worker( - llm=llm, - ai_name="EdgeGPT Worker", - external_tools=[edgegpt] -) +worker = Worker(llm=llm, ai_name="EdgeGPT Worker", external_tools=[edgegpt]) # Use the worker to process a task task = "Hello, my name is ChatGPT" diff --git a/revgpt.py b/playground/models/revgpt.py similarity index 98% rename from revgpt.py rename to playground/models/revgpt.py index 4bae77293..cd5bd2d67 100644 --- a/revgpt.py +++ b/playground/models/revgpt.py @@ -14,7 +14,7 @@ "plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")], "disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True", "PUID": os.getenv("REVGPT_PUID"), - "unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")] + "unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")], } # For v1 model diff --git a/swarms/agents/agent.py b/swarms/agents/agent.py new file mode 100644 index 000000000..109501f9a --- /dev/null +++ b/swarms/agents/agent.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +import json +import time +from typing import Any, Callable, List, Optional + +from langchain.chains.llm import LLMChain +from langchain.chat_models.base import BaseChatModel +from langchain.memory import ChatMessageHistory +from langchain.prompts.chat import ( + BaseChatPromptTemplate, +) +from langchain.schema import ( + BaseChatMessageHistory, + Document, +) +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, +) +from langchain.schema.vectorstore import VectorStoreRetriever +from langchain.tools.base import BaseTool +from langchain.tools.human.tool import HumanInputRun +from langchain_experimental.autonomous_agents.autogpt.output_parser import ( + AutoGPTOutputParser, + BaseAutoGPTOutputParser, +) +from langchain_experimental.autonomous_agents.autogpt.prompt import AutoGPTPrompt +from langchain_experimental.autonomous_agents.autogpt.prompt_generator import ( + FINISH_NAME, + get_prompt, +) +from langchain_experimental.pydantic_v1 import BaseModel, ValidationError + + +# PROMPT +FINISH_NAME = "finish" + + +# This class has a metaclass conflict: both `BaseChatPromptTemplate` and `BaseModel` +# define a metaclass to use, and the two metaclasses attempt to define +# the same functions but in mutually-incompatible ways. +# It isn't clear how to resolve this, and this code predates mypy +# beginning to perform that check. +# +# Mypy errors: +# ``` +# Definition of "__private_attributes__" in base class "BaseModel" is +# incompatible with definition in base class "BaseModel" [misc] +# Definition of "__repr_name__" in base class "Representation" is +# incompatible with definition in base class "BaseModel" [misc] +# Definition of "__pretty__" in base class "Representation" is +# incompatible with definition in base class "BaseModel" [misc] +# Definition of "__repr_str__" in base class "Representation" is +# incompatible with definition in base class "BaseModel" [misc] +# Definition of "__rich_repr__" in base class "Representation" is +# incompatible with definition in base class "BaseModel" [misc] +# Metaclass conflict: the metaclass of a derived class must be +# a (non-strict) subclass of the metaclasses of all its bases [misc] +# ``` +# +# TODO: look into refactoring this class in a way that avoids the mypy type errors +class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc] + """Prompt for AutoGPT.""" + + ai_name: str + ai_role: str + tools: List[BaseTool] + token_counter: Callable[[str], int] + send_token_limit: int = 4196 + + def construct_full_prompt(self, goals: List[str]) -> str: + prompt_start = ( + "Your decisions must always be made independently " + "without seeking user assistance.\n" + "Play to your strengths as an LLM and pursue simple " + "strategies with no legal complications.\n" + "If you have completed all your tasks, make sure to " + 'use the "finish" command.' + ) + # Construct full prompt + full_prompt = ( + f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n" + ) + for i, goal in enumerate(goals): + full_prompt += f"{i+1}. {goal}\n" + + full_prompt += f"\n\n{get_prompt(self.tools)}" + return full_prompt + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + base_prompt = SystemMessage(content=self.construct_full_prompt(kwargs["goals"])) + time_prompt = SystemMessage( + content=f"The current time and date is {time.strftime('%c')}" + ) + used_tokens = self.token_counter(base_prompt.content) + self.token_counter( + time_prompt.content + ) + memory: VectorStoreRetriever = kwargs["memory"] + previous_messages = kwargs["messages"] + relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:])) + relevant_memory = [d.page_content for d in relevant_docs] + relevant_memory_tokens = sum( + [self.token_counter(doc) for doc in relevant_memory] + ) + while used_tokens + relevant_memory_tokens > 2500: + relevant_memory = relevant_memory[:-1] + relevant_memory_tokens = sum( + [self.token_counter(doc) for doc in relevant_memory] + ) + content_format = ( + f"This reminds you of these events " + f"from your past:\n{relevant_memory}\n\n" + ) + memory_message = SystemMessage(content=content_format) + used_tokens += self.token_counter(memory_message.content) + historical_messages: List[BaseMessage] = [] + for message in previous_messages[-10:][::-1]: + message_tokens = self.token_counter(message.content) + if used_tokens + message_tokens > self.send_token_limit - 1000: + break + historical_messages = [message] + historical_messages + used_tokens += message_tokens + input_message = HumanMessage(content=kwargs["user_input"]) + messages: List[BaseMessage] = [base_prompt, time_prompt, memory_message] + messages += historical_messages + messages.append(input_message) + return messages + + +class PromptGenerator: + """A class for generating custom prompt strings. + + Does this based on constraints, commands, resources, and performance evaluations. + """ + + def __init__(self) -> None: + """Initialize the PromptGenerator object. + + Starts with empty lists of constraints, commands, resources, + and performance evaluations. + """ + self.constraints: List[str] = [] + self.commands: List[BaseTool] = [] + self.resources: List[str] = [] + self.performance_evaluation: List[str] = [] + self.response_format = { + "thoughts": { + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user", + }, + "command": {"name": "command name", "args": {"arg name": "value"}}, + } + + def add_constraint(self, constraint: str) -> None: + """ + Add a constraint to the constraints list. + + Args: + constraint (str): The constraint to be added. + """ + self.constraints.append(constraint) + + def add_tool(self, tool: BaseTool) -> None: + self.commands.append(tool) + + def _generate_command_string(self, tool: BaseTool) -> str: + output = f"{tool.name}: {tool.description}" + output += f", args json schema: {json.dumps(tool.args)}" + return output + + def add_resource(self, resource: str) -> None: + """ + Add a resource to the resources list. + + Args: + resource (str): The resource to be added. + """ + self.resources.append(resource) + + def add_performance_evaluation(self, evaluation: str) -> None: + """ + Add a performance evaluation item to the performance_evaluation list. + + Args: + evaluation (str): The evaluation item to be added. + """ + self.performance_evaluation.append(evaluation) + + def _generate_numbered_list(self, items: list, item_type: str = "list") -> str: + """ + Generate a numbered list from given items based on the item_type. + + Args: + items (list): A list of items to be numbered. + item_type (str, optional): The type of items in the list. + Defaults to 'list'. + + Returns: + str: The formatted numbered list. + """ + if item_type == "command": + command_strings = [ + f"{i + 1}. {self._generate_command_string(item)}" + for i, item in enumerate(items) + ] + finish_description = ( + "use this to signal that you have finished all your objectives" + ) + finish_args = ( + '"response": "final response to let ' + 'people know you have finished your objectives"' + ) + finish_string = ( + f"{len(items) + 1}. {FINISH_NAME}: " + f"{finish_description}, args: {finish_args}" + ) + return "\n".join(command_strings + [finish_string]) + else: + return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) + + def generate_prompt_string(self) -> str: + """Generate a prompt string. + + Returns: + str: The generated prompt string. + """ + formatted_response_format = json.dumps(self.response_format, indent=4) + prompt_string = ( + f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n" + f"Commands:\n" + f"{self._generate_numbered_list(self.commands, item_type='command')}\n\n" + f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n" + f"Performance Evaluation:\n" + f"{self._generate_numbered_list(self.performance_evaluation)}\n\n" + f"You should only respond in JSON format as described below " + f"\nResponse Format: \n{formatted_response_format} " + f"\nEnsure the response can be parsed by Python json.loads" + ) + + return prompt_string + + +def get_prompt(tools: List[BaseTool]) -> str: + """Generates a prompt string. + + It includes various constraints, commands, resources, and performance evaluations. + + Returns: + str: The generated prompt string. + """ + + # Initialize the PromptGenerator object + prompt_generator = PromptGenerator() + + # Add constraints to the PromptGenerator object + prompt_generator.add_constraint( + "~16000 word limit for short term memory. " + "Your short term memory is short, " + "so immediately save important information to files." + ) + prompt_generator.add_constraint( + "If you are unsure how you previously did something " + "or want to recall past events, " + "thinking about similar events will help you remember." + ) + prompt_generator.add_constraint("No user assistance") + prompt_generator.add_constraint( + 'Exclusively use the commands listed in double quotes e.g. "command name"' + ) + + # Add commands to the PromptGenerator object + for tool in tools: + prompt_generator.add_tool(tool) + + # Add resources to the PromptGenerator object + prompt_generator.add_resource( + "Internet access for searches and information gathering." + ) + prompt_generator.add_resource("Long Term memory management.") + prompt_generator.add_resource( + "GPT-3.5 powered Agents for delegation of simple tasks." + ) + prompt_generator.add_resource("File output.") + + # Add performance evaluations to the PromptGenerator object + prompt_generator.add_performance_evaluation( + "Continuously review and analyze your actions " + "to ensure you are performing to the best of your abilities." + ) + prompt_generator.add_performance_evaluation( + "Constructively self-criticize your big-picture behavior constantly." + ) + prompt_generator.add_performance_evaluation( + "Reflect on past decisions and strategies to refine your approach." + ) + prompt_generator.add_performance_evaluation( + "Every command has a cost, so be smart and efficient. " + "Aim to complete tasks in the least number of steps." + ) + + # Generate the prompt string + prompt_string = prompt_generator.generate_prompt_string() + + return prompt_string + + +class AutoGPT: + """ + AutoAgent: + + + Args: + + + + + """ + + def __init__( + self, + ai_name: str, + memory: VectorStoreRetriever, + chain: LLMChain, + output_parser: BaseAutoGPTOutputParser, + tools: List[BaseTool], + feedback_tool: Optional[HumanInputRun] = None, + chat_history_memory: Optional[BaseChatMessageHistory] = None, + ): + self.ai_name = ai_name + self.memory = memory + self.next_action_count = 0 + self.chain = chain + self.output_parser = output_parser + self.tools = tools + self.feedback_tool = feedback_tool + self.chat_history_memory = chat_history_memory or ChatMessageHistory() + + @classmethod + def from_llm_and_tools( + cls, + ai_name: str, + ai_role: str, + memory: VectorStoreRetriever, + tools: List[BaseTool], + llm: BaseChatModel, + human_in_the_loop: bool = False, + output_parser: Optional[BaseAutoGPTOutputParser] = None, + chat_history_memory: Optional[BaseChatMessageHistory] = None, + ) -> AutoGPT: + prompt = AutoGPTPrompt( + ai_name=ai_name, + ai_role=ai_role, + tools=tools, + input_variables=["memory", "messages", "goals", "user_input"], + token_counter=llm.get_num_tokens, + ) + human_feedback_tool = HumanInputRun() if human_in_the_loop else None + chain = LLMChain(llm=llm, prompt=prompt) + return cls( + ai_name, + memory, + chain, + output_parser or AutoGPTOutputParser(), + tools, + feedback_tool=human_feedback_tool, + chat_history_memory=chat_history_memory, + ) + + def run(self, goals: List[str]) -> str: + user_input = ( + "Determine which next command to use, " + "and respond using the format specified above:" + ) + # Interaction Loop + loop_count = 0 + while True: + # Discontinue if continuous limit is reached + loop_count += 1 + + # Send message to AI, get response + assistant_reply = self.chain.run( + goals=goals, + messages=self.chat_history_memory.messages, + memory=self.memory, + user_input=user_input, + ) + + # Print Assistant thoughts + print(assistant_reply) + self.chat_history_memory.add_message(HumanMessage(content=user_input)) + self.chat_history_memory.add_message(AIMessage(content=assistant_reply)) + + # Get command name and arguments + action = self.output_parser.parse(assistant_reply) + tools = {t.name: t for t in self.tools} + if action.name == FINISH_NAME: + return action.args["response"] + if action.name in tools: + tool = tools[action.name] + try: + observation = tool.run(action.args) + except ValidationError as e: + observation = ( + f"Validation Error in args: {str(e)}, args: {action.args}" + ) + except Exception as e: + observation = ( + f"Error: {str(e)}, {type(e).__name__}, args: {action.args}" + ) + result = f"Command {tool.name} returned: {observation}" + elif action.name == "ERROR": + result = f"Error: {action.args}. " + else: + result = ( + f"Unknown command '{action.name}'. " + f"Please refer to the 'COMMANDS' list for available " + f"commands and only respond in the specified JSON format." + ) + + memory_to_add = ( + f"Assistant Reply: {assistant_reply} " f"\nResult: {result} " + ) + if self.feedback_tool is not None: + feedback = f"\n{self.feedback_tool.run('Input: ')}" + if feedback in {"q", "stop"}: + print("EXITING") + return "EXITING" + memory_to_add += feedback + + self.memory.add_documents([Document(page_content=memory_to_add)]) + self.chat_history_memory.add_message(SystemMessage(content=result)) diff --git a/swarms/loaders/__init__.py b/swarms/loaders/__init__.py new file mode 100644 index 000000000..78bef3094 --- /dev/null +++ b/swarms/loaders/__init__.py @@ -0,0 +1,7 @@ +""" +Data Loaders for APPS + + +TODO: Clean up all the llama index stuff, remake the logic from scratch + +""" diff --git a/swarms/loaders/asana.py b/swarms/loaders/asana.py new file mode 100644 index 000000000..dd14cff41 --- /dev/null +++ b/swarms/loaders/asana.py @@ -0,0 +1,103 @@ +from typing import List, Optional + +from llama_index.readers.base import BaseReader +from llama_index.readers.schema.base import Document + + +class AsanaReader(BaseReader): + """Asana reader. Reads data from an Asana workspace. + + Args: + asana_token (str): Asana token. + + """ + + def __init__(self, asana_token: str) -> None: + """Initialize Asana reader.""" + import asana + + self.client = asana.Client.access_token(asana_token) + + def load_data( + self, workspace_id: Optional[str] = None, project_id: Optional[str] = None + ) -> List[Document]: + """Load data from the workspace. + + Args: + workspace_id (Optional[str], optional): Workspace ID. Defaults to None. + project_id (Optional[str], optional): Project ID. Defaults to None. + Returns: + List[Document]: List of documents. + """ + + if workspace_id is None and project_id is None: + raise ValueError("Either workspace_id or project_id must be provided") + + if workspace_id is not None and project_id is not None: + raise ValueError( + "Only one of workspace_id or project_id should be provided" + ) + + results = [] + + if workspace_id is not None: + workspace_name = self.client.workspaces.find_by_id(workspace_id)["name"] + projects = self.client.projects.find_all({"workspace": workspace_id}) + + # Case: Only project_id is provided + else: # since we've handled the other cases, this means project_id is not None + projects = [self.client.projects.find_by_id(project_id)] + workspace_name = projects[0]["workspace"]["name"] + + for project in projects: + tasks = self.client.tasks.find_all( + { + "project": project["gid"], + "opt_fields": "name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields", + } + ) + for task in tasks: + stories = self.client.tasks.stories(task["gid"], opt_fields="type,text") + comments = "\n".join( + [ + story["text"] + for story in stories + if story.get("type") == "comment" and "text" in story + ] + ) + + task_metadata = { + "task_id": task.get("gid", ""), + "name": task.get("name", ""), + "assignee": (task.get("assignee") or {}).get("name", ""), + "completed_on": task.get("completed_at", ""), + "completed_by": (task.get("completed_by") or {}).get("name", ""), + "project_name": project.get("name", ""), + "custom_fields": [ + i["display_value"] + for i in task.get("custom_fields") + if task.get("custom_fields") is not None + ], + "workspace_name": workspace_name, + "url": f"https://app.asana.com/0/{project['gid']}/{task['gid']}", + } + + if task.get("followers") is not None: + task_metadata["followers"] = [ + i.get("name") for i in task.get("followers") if "name" in i + ] + else: + task_metadata["followers"] = [] + + results.append( + Document( + text=task.get("name", "") + + " " + + task.get("notes", "") + + " " + + comments, + extra_info=task_metadata, + ) + ) + + return results diff --git a/swarms/loaders/base.py b/swarms/loaders/base.py new file mode 100644 index 000000000..940492b26 --- /dev/null +++ b/swarms/loaders/base.py @@ -0,0 +1,622 @@ +"""Base schema for data structures.""" +import json +import textwrap +import uuid +from abc import abstractmethod +from enum import Enum, auto +from hashlib import sha256 +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field, root_validator +from llama_index.utils import SAMPLE_TEXT, truncate_text +from typing_extensions import Self + +if TYPE_CHECKING: + from haystack.schema import Document as HaystackDocument + from semantic_kernel.memory.memory_record import MemoryRecord + + +#### +DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}" +DEFAULT_METADATA_TMPL = "{key}: {value}" +# NOTE: for pretty printing +TRUNCATE_LENGTH = 350 +WRAP_WIDTH = 70 + + +class BaseComponent(BaseModel): + """Base component object to capture class names.""" + + @classmethod + @abstractmethod + def class_name(cls) -> str: + """ + Get the class name, used as a unique ID in serialization. + + This provides a key that makes serialization robust against actual class + name changes. + """ + + def to_dict(self, **kwargs: Any) -> Dict[str, Any]: + data = self.dict(**kwargs) + data["class_name"] = self.class_name() + return data + + def to_json(self, **kwargs: Any) -> str: + data = self.to_dict(**kwargs) + return json.dumps(data) + + # TODO: return type here not supported by current mypy version + @classmethod + def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore + if isinstance(kwargs, dict): + data.update(kwargs) + + data.pop("class_name", None) + return cls(**data) + + @classmethod + def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore + data = json.loads(data_str) + return cls.from_dict(data, **kwargs) + + +class NodeRelationship(str, Enum): + """Node relationships used in `BaseNode` class. + + Attributes: + SOURCE: The node is the source document. + PREVIOUS: The node is the previous node in the document. + NEXT: The node is the next node in the document. + PARENT: The node is the parent node in the document. + CHILD: The node is a child node in the document. + + """ + + SOURCE = auto() + PREVIOUS = auto() + NEXT = auto() + PARENT = auto() + CHILD = auto() + + +class ObjectType(str, Enum): + TEXT = auto() + IMAGE = auto() + INDEX = auto() + DOCUMENT = auto() + + +class MetadataMode(str, Enum): + ALL = auto() + EMBED = auto() + LLM = auto() + NONE = auto() + + +class RelatedNodeInfo(BaseComponent): + node_id: str + node_type: Optional[ObjectType] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + hash: Optional[str] = None + + @classmethod + def class_name(cls) -> str: + return "RelatedNodeInfo" + + +RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]] + + +# Node classes for indexes +class BaseNode(BaseComponent): + """Base node Object. + + Generic abstract interface for retrievable nodes + + """ + + class Config: + allow_population_by_field_name = True + + id_: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node." + ) + embedding: Optional[List[float]] = Field( + default=None, description="Embedding of the node." + ) + + """" + metadata fields + - injected as part of the text shown to LLMs as context + - injected as part of the text for generating embeddings + - used by vector DBs for metadata filtering + + """ + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="A flat dictionary of metadata fields", + alias="extra_info", + ) + excluded_embed_metadata_keys: List[str] = Field( + default_factory=list, + description="Metadata keys that are excluded from text for the embed model.", + ) + excluded_llm_metadata_keys: List[str] = Field( + default_factory=list, + description="Metadata keys that are excluded from text for the LLM.", + ) + relationships: Dict[NodeRelationship, RelatedNodeType] = Field( + default_factory=dict, + description="A mapping of relationships to other node information.", + ) + hash: str = Field(default="", description="Hash of the node content.") + + @classmethod + @abstractmethod + def get_type(cls) -> str: + """Get Object type.""" + + @abstractmethod + def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: + """Get object content.""" + + @abstractmethod + def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: + """Metadata string.""" + + @abstractmethod + def set_content(self, value: Any) -> None: + """Set the content of the node.""" + + @property + def node_id(self) -> str: + return self.id_ + + @node_id.setter + def node_id(self, value: str) -> None: + self.id_ = value + + @property + def source_node(self) -> Optional[RelatedNodeInfo]: + """Source object node. + + Extracted from the relationships field. + + """ + if NodeRelationship.SOURCE not in self.relationships: + return None + + relation = self.relationships[NodeRelationship.SOURCE] + if isinstance(relation, list): + raise ValueError("Source object must be a single RelatedNodeInfo object") + return relation + + @property + def prev_node(self) -> Optional[RelatedNodeInfo]: + """Prev node.""" + if NodeRelationship.PREVIOUS not in self.relationships: + return None + + relation = self.relationships[NodeRelationship.PREVIOUS] + if not isinstance(relation, RelatedNodeInfo): + raise ValueError("Previous object must be a single RelatedNodeInfo object") + return relation + + @property + def next_node(self) -> Optional[RelatedNodeInfo]: + """Next node.""" + if NodeRelationship.NEXT not in self.relationships: + return None + + relation = self.relationships[NodeRelationship.NEXT] + if not isinstance(relation, RelatedNodeInfo): + raise ValueError("Next object must be a single RelatedNodeInfo object") + return relation + + @property + def parent_node(self) -> Optional[RelatedNodeInfo]: + """Parent node.""" + if NodeRelationship.PARENT not in self.relationships: + return None + + relation = self.relationships[NodeRelationship.PARENT] + if not isinstance(relation, RelatedNodeInfo): + raise ValueError("Parent object must be a single RelatedNodeInfo object") + return relation + + @property + def child_nodes(self) -> Optional[List[RelatedNodeInfo]]: + """Child nodes.""" + if NodeRelationship.CHILD not in self.relationships: + return None + + relation = self.relationships[NodeRelationship.CHILD] + if not isinstance(relation, list): + raise ValueError("Child objects must be a list of RelatedNodeInfo objects.") + return relation + + @property + def ref_doc_id(self) -> Optional[str]: + """Deprecated: Get ref doc id.""" + source_node = self.source_node + if source_node is None: + return None + return source_node.node_id + + @property + def extra_info(self) -> Dict[str, Any]: + """TODO: DEPRECATED: Extra info.""" + return self.metadata + + def __str__(self) -> str: + source_text_truncated = truncate_text( + self.get_content().strip(), TRUNCATE_LENGTH + ) + source_text_wrapped = textwrap.fill( + f"Text: {source_text_truncated}\n", width=WRAP_WIDTH + ) + return f"Node ID: {self.node_id}\n{source_text_wrapped}" + + def get_embedding(self) -> List[float]: + """Get embedding. + + Errors if embedding is None. + + """ + if self.embedding is None: + raise ValueError("embedding not set.") + return self.embedding + + def as_related_node_info(self) -> RelatedNodeInfo: + """Get node as RelatedNodeInfo.""" + return RelatedNodeInfo( + node_id=self.node_id, + node_type=self.get_type(), + metadata=self.metadata, + hash=self.hash, + ) + + +class TextNode(BaseNode): + text: str = Field(default="", description="Text content of the node.") + start_char_idx: Optional[int] = Field( + default=None, description="Start char index of the node." + ) + end_char_idx: Optional[int] = Field( + default=None, description="End char index of the node." + ) + text_template: str = Field( + default=DEFAULT_TEXT_NODE_TMPL, + description=( + "Template for how text is formatted, with {content} and " + "{metadata_str} placeholders." + ), + ) + metadata_template: str = Field( + default=DEFAULT_METADATA_TMPL, + description=( + "Template for how metadata is formatted, with {key} and " + "{value} placeholders." + ), + ) + metadata_seperator: str = Field( + default="\n", + description="Separator between metadata fields when converting to string.", + ) + + @classmethod + def class_name(cls) -> str: + return "TextNode" + + @root_validator + def _check_hash(cls, values: dict) -> dict: + """Generate a hash to represent the node.""" + text = values.get("text", "") + metadata = values.get("metadata", {}) + doc_identity = str(text) + str(metadata) + values["hash"] = str( + sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest() + ) + return values + + @classmethod + def get_type(cls) -> str: + """Get Object type.""" + return ObjectType.TEXT + + def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: + """Get object content.""" + metadata_str = self.get_metadata_str(mode=metadata_mode).strip() + if not metadata_str: + return self.text + + return self.text_template.format( + content=self.text, metadata_str=metadata_str + ).strip() + + def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: + """Metadata info string.""" + if mode == MetadataMode.NONE: + return "" + + usable_metadata_keys = set(self.metadata.keys()) + if mode == MetadataMode.LLM: + for key in self.excluded_llm_metadata_keys: + if key in usable_metadata_keys: + usable_metadata_keys.remove(key) + elif mode == MetadataMode.EMBED: + for key in self.excluded_embed_metadata_keys: + if key in usable_metadata_keys: + usable_metadata_keys.remove(key) + + return self.metadata_seperator.join( + [ + self.metadata_template.format(key=key, value=str(value)) + for key, value in self.metadata.items() + if key in usable_metadata_keys + ] + ) + + def set_content(self, value: str) -> None: + """Set the content of the node.""" + self.text = value + + def get_node_info(self) -> Dict[str, Any]: + """Get node info.""" + return {"start": self.start_char_idx, "end": self.end_char_idx} + + def get_text(self) -> str: + return self.get_content(metadata_mode=MetadataMode.NONE) + + @property + def node_info(self) -> Dict[str, Any]: + """Deprecated: Get node info.""" + return self.get_node_info() + + +# TODO: legacy backport of old Node class +Node = TextNode + + +class ImageNode(TextNode): + """Node with image.""" + + # TODO: store reference instead of actual image + # base64 encoded image str + image: Optional[str] = None + + @classmethod + def get_type(cls) -> str: + return ObjectType.IMAGE + + @classmethod + def class_name(cls) -> str: + return "ImageNode" + + +class IndexNode(TextNode): + """Node with reference to any object. + + This can include other indices, query engines, retrievers. + + This can also include other nodes (though this is overlapping with `relationships` + on the Node class). + + """ + + index_id: str + + @classmethod + def from_text_node( + cls, + node: TextNode, + index_id: str, + ) -> "IndexNode": + """Create index node from text node.""" + # copy all attributes from text node, add index id + return cls( + **node.dict(), + index_id=index_id, + ) + + @classmethod + def get_type(cls) -> str: + return ObjectType.INDEX + + @classmethod + def class_name(cls) -> str: + return "IndexNode" + + +class NodeWithScore(BaseComponent): + node: BaseNode + score: Optional[float] = None + + def __str__(self) -> str: + return f"{self.node}\nScore: {self.score: 0.3f}\n" + + def get_score(self, raise_error: bool = False) -> float: + """Get score.""" + if self.score is None: + if raise_error: + raise ValueError("Score not set.") + else: + return 0.0 + else: + return self.score + + @classmethod + def class_name(cls) -> str: + return "NodeWithScore" + + ##### pass through methods to BaseNode ##### + @property + def node_id(self) -> str: + return self.node.node_id + + @property + def id_(self) -> str: + return self.node.id_ + + @property + def text(self) -> str: + if isinstance(self.node, TextNode): + return self.node.text + else: + raise ValueError("Node must be a TextNode to get text.") + + @property + def metadata(self) -> Dict[str, Any]: + return self.node.metadata + + @property + def embedding(self) -> Optional[List[float]]: + return self.node.embedding + + def get_text(self) -> str: + if isinstance(self.node, TextNode): + return self.node.get_text() + else: + raise ValueError("Node must be a TextNode to get text.") + + def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: + return self.node.get_content(metadata_mode=metadata_mode) + + def get_embedding(self) -> List[float]: + return self.node.get_embedding() + + +# Document Classes for Readers + + +class Document(TextNode): + """Generic interface for a data document. + + This document connects to data sources. + + """ + + # TODO: A lot of backwards compatibility logic here, clean up + id_: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique ID of the node.", + alias="doc_id", + ) + + _compat_fields = {"doc_id": "id_", "extra_info": "metadata"} + + @classmethod + def get_type(cls) -> str: + """Get Document type.""" + return ObjectType.DOCUMENT + + @property + def doc_id(self) -> str: + """Get document ID.""" + return self.id_ + + def __str__(self) -> str: + source_text_truncated = truncate_text( + self.get_content().strip(), TRUNCATE_LENGTH + ) + source_text_wrapped = textwrap.fill( + f"Text: {source_text_truncated}\n", width=WRAP_WIDTH + ) + return f"Doc ID: {self.doc_id}\n{source_text_wrapped}" + + def get_doc_id(self) -> str: + """TODO: Deprecated: Get document ID.""" + return self.id_ + + def __setattr__(self, name: str, value: object) -> None: + if name in self._compat_fields: + name = self._compat_fields[name] + super().__setattr__(name, value) + + def to_langchain_format(self) -> "LCDocument": + """Convert struct to LangChain document format.""" + from llama_index.bridge.langchain import Document as LCDocument + + metadata = self.metadata or {} + return LCDocument(page_content=self.text, metadata=metadata) + + @classmethod + def from_langchain_format(cls, doc: "LCDocument") -> "Document": + """Convert struct from LangChain document format.""" + return cls(text=doc.page_content, metadata=doc.metadata) + + def to_haystack_format(self) -> "HaystackDocument": + """Convert struct to Haystack document format.""" + from haystack.schema import Document as HaystackDocument + + return HaystackDocument( + content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_ + ) + + @classmethod + def from_haystack_format(cls, doc: "HaystackDocument") -> "Document": + """Convert struct from Haystack document format.""" + return cls( + text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id + ) + + def to_embedchain_format(self) -> Dict[str, Any]: + """Convert struct to EmbedChain document format.""" + return { + "doc_id": self.id_, + "data": {"content": self.text, "meta_data": self.metadata}, + } + + @classmethod + def from_embedchain_format(cls, doc: Dict[str, Any]) -> "Document": + """Convert struct from EmbedChain document format.""" + return cls( + text=doc["data"]["content"], + metadata=doc["data"]["meta_data"], + id_=doc["doc_id"], + ) + + def to_semantic_kernel_format(self) -> "MemoryRecord": + """Convert struct to Semantic Kernel document format.""" + import numpy as np + from semantic_kernel.memory.memory_record import MemoryRecord + + return MemoryRecord( + id=self.id_, + text=self.text, + additional_metadata=self.get_metadata_str(), + embedding=np.array(self.embedding) if self.embedding else None, + ) + + @classmethod + def from_semantic_kernel_format(cls, doc: "MemoryRecord") -> "Document": + """Convert struct from Semantic Kernel document format.""" + return cls( + text=doc._text, + metadata={"additional_metadata": doc._additional_metadata}, + embedding=doc._embedding.tolist() if doc._embedding is not None else None, + id_=doc._id, + ) + + @classmethod + def example(cls) -> "Document": + return Document( + text=SAMPLE_TEXT, + metadata={"filename": "README.md", "category": "codebase"}, + ) + + @classmethod + def class_name(cls) -> str: + return "Document" + + +class ImageDocument(Document): + """Data document containing an image.""" + + # base64 encoded image str + image: Optional[str] = None + + @classmethod + def class_name(cls) -> str: + return "ImageDocument" diff --git a/swarms/models/revgptV1.py b/swarms/models/revgptV1.py index 4aaf2cf35..a7327d23c 100644 --- a/swarms/models/revgptV1.py +++ b/swarms/models/revgptV1.py @@ -46,6 +46,7 @@ bcolors = t.Colors() + def generate_random_hex(length: int = 17) -> str: """Generate a random hex string @@ -121,7 +122,6 @@ def wrapper(*args, **kwargs): BASE_URL = environ.get("CHATGPT_BASE_URL", "http://bypass.bzff.cn:9090/") - def captcha_solver(images: list[str], challenge_details: dict) -> int: # Create tempfile with tempfile.TemporaryDirectory() as tempdir: @@ -197,40 +197,40 @@ def get_arkose_token( raise Exception("Failed to verify captcha") return resp_json.get("token") # else: - # working_endpoints: list[str] = [] - # # Check uptime for different endpoints via gatus - # resp2: list[dict] = requests.get( - # "https://stats.churchless.tech/api/v1/endpoints/statuses?page=1" - # ).json() - # for endpoint in resp2: - # # print(endpoint.get("name")) - # if endpoint.get("group") != "Arkose Labs": - # continue - # # Check the last 5 results - # results: list[dict] = endpoint.get("results", [])[-5:-1] - # # print(results) - # if not results: - # print(f"Endpoint {endpoint.get('name')} has no results") - # continue - # # Check if all the results are up - # if all(result.get("success") == True for result in results): - # working_endpoints.append(endpoint.get("name")) - # if not working_endpoints: - # print("No working endpoints found. Please solve the captcha manually.\n找不到工作终结点。请手动解决captcha") - # return get_arkose_token(download_images=True, captcha_supported=False) - # # Choose a random endpoint - # endpoint = random.choice(working_endpoints) - # resp: requests.Response = requests.get(endpoint) - # if resp.status_code != 200: - # if resp.status_code != 511: - # raise Exception("Failed to get captcha token") - # else: - # print("需要验证码,请手动解决captcha.") - # return get_arkose_token(download_images=True, captcha_supported=True) - # try: - # return resp.json().get("token") - # except Exception: - # return resp.text + # working_endpoints: list[str] = [] + # # Check uptime for different endpoints via gatus + # resp2: list[dict] = requests.get( + # "https://stats.churchless.tech/api/v1/endpoints/statuses?page=1" + # ).json() + # for endpoint in resp2: + # # print(endpoint.get("name")) + # if endpoint.get("group") != "Arkose Labs": + # continue + # # Check the last 5 results + # results: list[dict] = endpoint.get("results", [])[-5:-1] + # # print(results) + # if not results: + # print(f"Endpoint {endpoint.get('name')} has no results") + # continue + # # Check if all the results are up + # if all(result.get("success") == True for result in results): + # working_endpoints.append(endpoint.get("name")) + # if not working_endpoints: + # print("No working endpoints found. Please solve the captcha manually.\n找不到工作终结点。请手动解决captcha") + # return get_arkose_token(download_images=True, captcha_supported=False) + # # Choose a random endpoint + # endpoint = random.choice(working_endpoints) + # resp: requests.Response = requests.get(endpoint) + # if resp.status_code != 200: + # if resp.status_code != 511: + # raise Exception("Failed to get captcha token") + # else: + # print("需要验证码,请手动解决captcha.") + # return get_arkose_token(download_images=True, captcha_supported=True) + # try: + # return resp.json().get("token") + # except Exception: + # return resp.text class Chatbot: @@ -1751,6 +1751,7 @@ def handle_commands(command: str) -> bool: ) main(configure()) + class RevChatGPTModelv1: def __init__(self, access_token=None, **kwargs): super().__init__() @@ -1764,7 +1765,7 @@ def run(self, task: str) -> str: self.start_time = time.time() prev_text = "" for data in self.chatbot.ask(task, fileinfo=None): - message = data["message"][len(prev_text):] + message = data["message"][len(prev_text) :] prev_text = data["message"] self.end_time = time.time() return prev_text @@ -1779,11 +1780,16 @@ def enable_plugin(self, plugin_id: str): def list_plugins(self): return self.chatbot.get_plugins() + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Manage RevChatGPT plugins.') - parser.add_argument('--enable', metavar='plugin_id', help='the plugin to enable') - parser.add_argument('--list', action='store_true', help='list all available plugins') - parser.add_argument('--access_token', required=True, help='access token for RevChatGPT') + parser = argparse.ArgumentParser(description="Manage RevChatGPT plugins.") + parser.add_argument("--enable", metavar="plugin_id", help="the plugin to enable") + parser.add_argument( + "--list", action="store_true", help="list all available plugins" + ) + parser.add_argument( + "--access_token", required=True, help="access token for RevChatGPT" + ) args = parser.parse_args() @@ -1795,4 +1801,3 @@ def list_plugins(self): plugins = model.list_plugins() for plugin in plugins: print(f"Plugin ID: {plugin['id']}, Name: {plugin['name']}") - diff --git a/swarms/models/revgptV4.py b/swarms/models/revgptV4.py index c3e2f78c7..c57182f13 100644 --- a/swarms/models/revgptV4.py +++ b/swarms/models/revgptV4.py @@ -1,4 +1,4 @@ -#4v image recognition +# 4v image recognition """ Standard ChatGPT """ @@ -26,6 +26,7 @@ import tempfile import random import os + # Import function type import httpx @@ -46,7 +47,7 @@ from prompt_toolkit.completion import WordCompleter from prompt_toolkit.history import InMemoryHistory from prompt_toolkit.key_binding import KeyBindings -from schemas.typings import Colors +from schemas.typings import Colors bindings = KeyBindings() @@ -56,6 +57,7 @@ bcolors = t.Colors() + def create_keybindings(key: str = "c-@") -> KeyBindings: """ Create keybindings for prompt_toolkit. Default key is ctrl+space. @@ -136,6 +138,7 @@ def get_filtered_keys_from_object(obj: object, *keys: str) -> any: # Only return specified keys that are in class_keys return {key for key in keys if key in class_keys} + def generate_random_hex(length: int = 17) -> str: """Generate a random hex string Args: @@ -202,8 +205,6 @@ def wrapper(*args, **kwargs): return decorator - - bcolors = Colors() @@ -284,7 +285,6 @@ def get_arkose_token( return resp_json.get("token") - class Chatbot: """ Chatbot class for ChatGPT @@ -636,7 +636,7 @@ def __send_request( yield { "author": author, "message": message, - "conversation_id": cid+'***************************', + "conversation_id": cid + "***************************", "parent_id": pid, "model": model, "finish_details": finish_details, @@ -711,7 +711,6 @@ def post_messages( if not conversation_id and not parent_id: parent_id = str(uuid.uuid4()) - if conversation_id and not parent_id: if conversation_id not in self.conversation_mapping: print(conversation_id) @@ -735,8 +734,8 @@ def post_messages( print( "Warning: Invalid conversation_id provided, treat as a new conversation", ) - #conversation_id = None - conversation_id =str(uuid.uuid4()) + # conversation_id = None + conversation_id = str(uuid.uuid4()) print(conversation_id) parent_id = str(uuid.uuid4()) model = model or self.config.get("model") or "text-davinci-002-render-sha" @@ -762,7 +761,7 @@ def post_messages( def ask( self, prompt: str, - fileinfo: dict , + fileinfo: dict, conversation_id: str | None = None, parent_id: str = "", model: str = "", @@ -795,7 +794,10 @@ def ask( "id": str(uuid.uuid4()), "role": "user", "author": {"role": "user"}, - "content": {"content_type": "multimodal_text", "parts": [prompt, fileinfo]}, + "content": { + "content_type": "multimodal_text", + "parts": [prompt, fileinfo], + }, }, ] @@ -871,7 +873,7 @@ def continue_write( parent_id = self.conversation_mapping[conversation_id] else: # invalid conversation_id provided, treat as a new conversation conversation_id = None - conversation_id=str(uuid.uuid4()) + conversation_id = str(uuid.uuid4()) parent_id = str(uuid.uuid4()) model = model or self.config.get("model") or "text-davinci-002-render-sha" data = { @@ -1304,7 +1306,7 @@ async def post_messages( print( "Warning: Invalid conversation_id provided, treat as a new conversation", ) - #conversation_id = None + # conversation_id = None conversation_id = str(uuid.uuid4()) print(conversation_id) parent_id = str(uuid.uuid4()) @@ -1363,12 +1365,18 @@ async def ask( { "id": str(uuid.uuid4()), "author": {"role": "user"}, - "content": {"content_type": "multimodal_text", "parts": [prompt, { - "asset_pointer": "file-service://file-V9IZRkWQnnk1HdHsBKAdoiGf", - "size_bytes": 239505, - "width": 1706, - "height": 1280 - }]}, + "content": { + "content_type": "multimodal_text", + "parts": [ + prompt, + { + "asset_pointer": "file-service://file-V9IZRkWQnnk1HdHsBKAdoiGf", + "size_bytes": 239505, + "width": 1706, + "height": 1280, + }, + ], + }, }, ] @@ -1763,6 +1771,7 @@ def handle_commands(command: str) -> bool: ) main(configure()) + class RevChatGPTModelv4: def __init__(self, access_token=None, **kwargs): super().__init__() @@ -1776,7 +1785,7 @@ def run(self, task: str) -> str: self.start_time = time.time() prev_text = "" for data in self.chatbot.ask(task, fileinfo=None): - message = data["message"][len(prev_text):] + message = data["message"][len(prev_text) :] prev_text = data["message"] self.end_time = time.time() return prev_text @@ -1791,11 +1800,16 @@ def enable_plugin(self, plugin_id: str): def list_plugins(self): return self.chatbot.get_plugins() + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Manage RevChatGPT plugins.') - parser.add_argument('--enable', metavar='plugin_id', help='the plugin to enable') - parser.add_argument('--list', action='store_true', help='list all available plugins') - parser.add_argument('--access_token', required=True, help='access token for RevChatGPT') + parser = argparse.ArgumentParser(description="Manage RevChatGPT plugins.") + parser.add_argument("--enable", metavar="plugin_id", help="the plugin to enable") + parser.add_argument( + "--list", action="store_true", help="list all available plugins" + ) + parser.add_argument( + "--access_token", required=True, help="access token for RevChatGPT" + ) args = parser.parse_args() diff --git a/swarms/schemas/__init__.py b/swarms/schemas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/schemas/typings.py b/swarms/schemas/typings.py similarity index 100% rename from schemas/typings.py rename to swarms/schemas/typings.py diff --git a/swarms/structs/__init__.py b/swarms/structs/__init__.py index 045f4c922..c74544745 100644 --- a/swarms/structs/__init__.py +++ b/swarms/structs/__init__.py @@ -1,2 +1,2 @@ -# from swarms.structs.workflow import Workflow -# from swarms.structs.task import Task +from swarms.structs.workflow import Workflow +from swarms.structs.task import Task diff --git a/swarms/structs/flow.py b/swarms/structs/flow.py new file mode 100644 index 000000000..e3c4b725f --- /dev/null +++ b/swarms/structs/flow.py @@ -0,0 +1,225 @@ +import time +from typing import Any, Dict, List, Optional, Union, Callable +from swarms.models import OpenAIChat +from typing import Any, Dict, List, Optional, Callable +import logging +import time + + +# Custome stopping condition +def stop_when_repeats(response: str) -> bool: + # Stop if the word stop appears in the response + return "Stop" in response.lower() + + +# class Flow: +# def __init__( +# self, +# llm: Any, +# template: str, +# max_loops: int = 1, +# stopping_condition: Optional[Callable[[str], bool]] = None, +# **kwargs: Any +# ): +# self.llm = llm +# self.template = template +# self.max_loops = max_loops +# self.stopping_condition = stopping_condition +# self.feedback = [] +# self.history = [] + +# def __call__( +# self, +# prompt, +# **kwargs +# ) -> str: +# """Invoke the flow by providing a template and it's variables""" +# response = self.llm(prompt, **kwargs) +# return response + +# def _check_stopping_condition(self, response: str) -> bool: +# """Check if the stopping condition is met""" +# if self.stopping_condition: +# return self.stopping_condition(response) +# return False + + +# def provide_feedback(self, feedback: str) -> None: +# """Allow users to to provide feedback on the responses""" +# feedback = self.feedback.append(feedback) +# return feedback + +# def format_prompt(self, **kwargs: Any) -> str: +# """Format the template with the provided kwargs using f string interpolation""" +# return self.template.format(**kwargs) + +# def _generate(self, formatted_prompts: str) -> str: +# """ +# Generate a result using the lm + +# """ +# return self.llm(formatted_prompts) + +# def run(self, **kwargs: Any) -> str: +# """Generate a result using the provided keyword args""" +# prompt = self.format_prompt(**kwargs) +# response = self._generate(prompt) +# return response + +# def bulk_run( +# self, +# inputs: List[Dict[str, Any]] +# ) -> List[str]: +# """Generate responses for multiple input sets""" +# return [self.run(**input_data) for input_data in inputs] + +# @staticmethod +# def from_llm_and_template(llm: Any, template: str) -> "Flow": +# """Create FlowStream from LLM and a string template""" +# return Flow(llm=llm, template=template) + +# @staticmethod +# def from_llm_and_template_file(llm: Any, template_file: str) -> "Flow": +# """Create FlowStream from LLM and a template file""" +# with open(template_file, "r") as f: +# template = f.read() + +# return Flow(llm=llm, template=template) + + + +class Flow: + def __init__( + self, + llm: Any, + template: str, + max_loops: int = 1, + stopping_condition: Optional[Callable[[str], bool]] = None, + loop_interval: int = 1, + retry_attempts: int = 3, + retry_interval: int = 1, + **kwargs: Any, + ): + self.llm = llm + self.template = template + self.max_loops = max_loops + self.stopping_condition = stopping_condition + self.loop_interval = loop_interval + self.retry_attempts = retry_attempts + self.retry_interval = retry_interval + self.feedback = [] + + def provide_feedback(self, feedback: str) -> None: + """Allow users to provide feedback on the responses.""" + self.feedback.append(feedback) + logging.info(f"Feedback received: {feedback}") + + def _check_stopping_condition(self, response: str) -> bool: + """Check if the stopping condition is met.""" + if self.stopping_condition: + return self.stopping_condition(response) + return False + + def __call__(self, prompt, **kwargs) -> str: + """Invoke the flow by providing a template and its variables.""" + response = self.llm(prompt, **kwargs) + return response + + def format_prompt(self, **kwargs: Any) -> str: + """Format the template with the provided kwargs using f-string interpolation.""" + return self.template.format(**kwargs) + + def _generate(self, task: str, formatted_prompts: str) -> str: + """ + Generate a result using the lm with optional query loops and stopping conditions. + """ + response = formatted_prompts + history = [task] + for _ in range(self.max_loops): + if self._check_stopping_condition(response): + break + attempt = 0 + while attempt < self.retry_attempts: + try: + response = self.llm(response) + break + except Exception as e: + logging.error(f"Error generating response: {e}") + attempt += 1 + time.sleep(self.retry_interval) + logging.info(f"Generated response: {response}") + history.append(response) + time.sleep(self.loop_interval) + return response, history + + def run(self, **kwargs: Any) -> str: + """Generate a result using the provided keyword args.""" + task = self.format_prompt(**kwargs) + response, history = self._generate(task, task) + logging.info(f"Message history: {history}") + return response + + def bulk_run(self, inputs: List[Dict[str, Any]]) -> List[str]: + """Generate responses for multiple input sets.""" + return [self.run(**input_data) for input_data in inputs] + + @staticmethod + def from_llm_and_template(llm: Any, template: str) -> "Flow": + """Create FlowStream from LLM and a string template.""" + return Flow(llm=llm, template=template) + + @staticmethod + def from_llm_and_template_file(llm: Any, template_file: str) -> "Flow": + """Create FlowStream from LLM and a template file.""" + with open(template_file, "r") as f: + template = f.read() + return Flow(llm=llm, template=template) + + +# # Configure logging +# logging.basicConfig(level=logging.INFO) + +# llm = OpenAIChat( +# api_key="YOUR_API_KEY", +# max_tokens=1000, +# temperature=0.9, +# ) + + +# def main(): +# # Initialize the Flow class with parameters +# flow = Flow( +# llm=llm, +# template="Translate this to backwards: {sentence}", +# max_loops=3, +# stopping_condition=stop_when_repeats, +# loop_interval=2, # Wait 2 seconds between loops +# retry_attempts=2, +# retry_interval=1, # Wait 1 second between retries +# ) + +# # Predict using the Flow +# response = flow.run(sentence="Hello, World!") +# print("Response:", response) +# time.sleep(1) # Pause for demonstration purposes + +# # Provide feedback on the result +# flow.provide_feedback("The translation was interesting!") +# time.sleep(1) # Pause for demonstration purposes + +# # Bulk run +# inputs = [ +# {"sentence": "This is a test."}, +# {"sentence": "OpenAI is great."}, +# {"sentence": "GPT models are powerful."}, +# {"sentence": "stop and check if our stopping condition works."}, +# ] + +# responses = flow.bulk_run(inputs=inputs) +# for idx, res in enumerate(responses): +# print(f"Input: {inputs[idx]['sentence']}, Response: {res}") +# time.sleep(1) # Pause for demonstration purposes + + +# if __name__ == "__main__": +# main() diff --git a/swarms/tools/autogpt.py b/swarms/tools/autogpt.py index a0e26491e..c2f56db63 100644 --- a/swarms/tools/autogpt.py +++ b/swarms/tools/autogpt.py @@ -142,8 +142,9 @@ def _run(self, url: str, question: str) -> str: async def _arun(self, url: str, question: str) -> str: raise NotImplementedError + class EdgeGPTTool: -# Initialize the custom tool + # Initialize the custom tool def __init__( self, model, @@ -152,10 +153,11 @@ def __init__( ): super().__init__(name=name, description=description) self.model = model - + def _run(self, prompt): return self.model.__call__(prompt) + @tool def VQAinference(self, inputs): """ diff --git a/swarms/utils/revutils.py b/swarms/utils/revutils.py index 8e7e0b755..2d1b431c7 100644 --- a/swarms/utils/revutils.py +++ b/swarms/utils/revutils.py @@ -11,7 +11,7 @@ from prompt_toolkit.completion import WordCompleter from prompt_toolkit.history import InMemoryHistory from prompt_toolkit.key_binding import KeyBindings -from schemas.typings import Colors +from schemas.typings import Colors bindings = KeyBindings() @@ -19,6 +19,7 @@ BASE_URL = os.environ.get("CHATGPT_BASE_URL", "https://ai.fakeopen.com/api/") # BASE_URL = environ.get("CHATGPT_BASE_URL", "https://bypass.churchless.tech/") + def create_keybindings(key: str = "c-@") -> KeyBindings: """ Create keybindings for prompt_toolkit. Default key is ctrl+space. @@ -99,6 +100,7 @@ def get_filtered_keys_from_object(obj: object, *keys: str) -> any: # Only return specified keys that are in class_keys return {key for key in keys if key in class_keys} + def generate_random_hex(length: int = 17) -> str: """Generate a random hex string Args: @@ -163,4 +165,3 @@ def wrapper(*args, **kwargs): return wrapper return decorator - diff --git a/tests/models/bingchat.py b/tests/models/bingchat.py index 4e6a82710..5ed2c6efb 100644 --- a/tests/models/bingchat.py +++ b/tests/models/bingchat.py @@ -5,12 +5,12 @@ # Assuming the BingChat class is in a file named "bing_chat.py" from bing_chat import BingChat, ConversationStyle -class TestBingChat(unittest.TestCase): +class TestBingChat(unittest.TestCase): def setUp(self): # Path to a mock cookies file for testing self.mock_cookies_path = "./mock_cookies.json" - with open(self.mock_cookies_path, 'w') as file: + with open(self.mock_cookies_path, "w") as file: json.dump({"mock_cookie": "mock_value"}, file) self.chat = BingChat(cookies_path=self.mock_cookies_path) @@ -33,10 +33,10 @@ def test_create_img(self): class MockImageGen: def __init__(self, *args, **kwargs): pass - + def get_images(self, *args, **kwargs): return [{"path": "mock_image.png"}] - + @staticmethod def save_images(*args, **kwargs): pass @@ -54,5 +54,6 @@ def test_set_cookie_dir_path(self): BingChat.set_cookie_dir_path(test_path) self.assertEqual(BingChat.Cookie.dir_path, test_path) + if __name__ == "__main__": unittest.main() diff --git a/tests/models/biogpt.py b/tests/models/biogpt.py index 29cbe86cd..f420292be 100644 --- a/tests/models/biogpt.py +++ b/tests/models/biogpt.py @@ -6,12 +6,11 @@ from transformers import BioGptForCausalLM, BioGptTokenizer - # Fixture for BioGPT instance @pytest.fixture def biogpt_instance(): from swarms.models import ( - BioGPT, + BioGPT, ) return BioGPT() diff --git a/tests/models/kosmos.py b/tests/models/kosmos.py index 975a80b78..cffa41e6e 100644 --- a/tests/models/kosmos.py +++ b/tests/models/kosmos.py @@ -9,6 +9,7 @@ # A placeholder image URL for testing TEST_IMAGE_URL = "https://images.unsplash.com/photo-1673267569891-ca4246caafd7?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDM1fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D" + # Mock the response for the test image @pytest.fixture def mock_image_request(): @@ -18,12 +19,14 @@ def mock_image_request(): with patch.object(requests, "get", return_value=mock_resp) as _fixture: yield _fixture + # Test utility function def test_is_overlapping(): - assert is_overlapping((1,1,3,3), (2,2,4,4)) == True - assert is_overlapping((1,1,2,2), (3,3,4,4)) == False - assert is_overlapping((0,0,1,1), (1,1,2,2)) == False - assert is_overlapping((0,0,2,2), (1,1,2,2)) == True + assert is_overlapping((1, 1, 3, 3), (2, 2, 4, 4)) == True + assert is_overlapping((1, 1, 2, 2), (3, 3, 4, 4)) == False + assert is_overlapping((0, 0, 1, 1), (1, 1, 2, 2)) == False + assert is_overlapping((0, 0, 2, 2), (1, 1, 2, 2)) == True + # Test model initialization def test_kosmos_init(): @@ -31,38 +34,49 @@ def test_kosmos_init(): assert kosmos.model is not None assert kosmos.processor is not None + # Test image fetching functionality def test_get_image(mock_image_request): kosmos = Kosmos() image = kosmos.get_image(TEST_IMAGE_URL) assert image is not None + # Test multimodal grounding def test_multimodal_grounding(mock_image_request): kosmos = Kosmos() kosmos.multimodal_grounding("Find the red apple in the image.", TEST_IMAGE_URL) # TODO: Validate the result if possible + # Test referring expression comprehension def test_referring_expression_comprehension(mock_image_request): kosmos = Kosmos() - kosmos.referring_expression_comprehension("Show me the green bottle.", TEST_IMAGE_URL) + kosmos.referring_expression_comprehension( + "Show me the green bottle.", TEST_IMAGE_URL + ) # TODO: Validate the result if possible + # ... (continue with other functions in the same manner) ... + # Test error scenarios - Example -@pytest.mark.parametrize("phrase, image_url", [ - (None, TEST_IMAGE_URL), - ("Find the red apple in the image.", None), - ("", TEST_IMAGE_URL), - ("Find the red apple in the image.", ""), -]) +@pytest.mark.parametrize( + "phrase, image_url", + [ + (None, TEST_IMAGE_URL), + ("Find the red apple in the image.", None), + ("", TEST_IMAGE_URL), + ("Find the red apple in the image.", ""), + ], +) def test_kosmos_error_scenarios(phrase, image_url): kosmos = Kosmos() with pytest.raises(Exception): kosmos.multimodal_grounding(phrase, image_url) + # ... (Add more tests for different edge cases and functionalities) ... # Sample test image URLs @@ -72,6 +86,7 @@ def test_kosmos_error_scenarios(phrase, image_url): IMG_URL4 = "https://images.unsplash.com/photo-1676156340083-fd49e4e53a21?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDc4fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D" IMG_URL5 = "https://images.unsplash.com/photo-1696862761045-0a65acbede8f?auto=format&fit=crop&q=80&w=1287&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + # Mock response for requests.get() class MockResponse: @staticmethod @@ -82,57 +97,69 @@ def json(): def raw(self): return open("tests/sample_image.jpg", "rb") + # Test the Kosmos class @pytest.fixture def kosmos(): return Kosmos() + # Mocking the requests.get() method @pytest.fixture def mock_request_get(monkeypatch): - monkeypatch.setattr(requests, 'get', lambda url, **kwargs: MockResponse()) + monkeypatch.setattr(requests, "get", lambda url, **kwargs: MockResponse()) + @pytest.mark.usefixtures("mock_request_get") def test_multimodal_grounding(kosmos): kosmos.multimodal_grounding("Find the red apple in the image.", IMG_URL1) + @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_comprehension(kosmos): kosmos.referring_expression_comprehension("Show me the green bottle.", IMG_URL2) + @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_generation(kosmos): kosmos.referring_expression_generation("It is on the table.", IMG_URL3) + @pytest.mark.usefixtures("mock_request_get") def test_grounded_vqa(kosmos): kosmos.grounded_vqa("What is the color of the car?", IMG_URL4) + @pytest.mark.usefixtures("mock_request_get") def test_grounded_image_captioning(kosmos): kosmos.grounded_image_captioning(IMG_URL5) + @pytest.mark.usefixtures("mock_request_get") def test_grounded_image_captioning_detailed(kosmos): kosmos.grounded_image_captioning_detailed(IMG_URL1) + @pytest.mark.usefixtures("mock_request_get") def test_multimodal_grounding_2(kosmos): kosmos.multimodal_grounding("Find the yellow fruit in the image.", IMG_URL2) + @pytest.mark.usefixtures("mock_request_get") def test_referring_expression_comprehension_2(kosmos): kosmos.referring_expression_comprehension("Where is the water bottle?", IMG_URL3) + @pytest.mark.usefixtures("mock_request_get") def test_grounded_vqa_2(kosmos): kosmos.grounded_vqa("How many cars are in the image?", IMG_URL4) + @pytest.mark.usefixtures("mock_request_get") def test_grounded_image_captioning_2(kosmos): kosmos.grounded_image_captioning(IMG_URL2) + @pytest.mark.usefixtures("mock_request_get") def test_grounded_image_captioning_detailed_2(kosmos): kosmos.grounded_image_captioning_detailed(IMG_URL3) - diff --git a/tests/models/revgptv1.py b/tests/models/revgptv1.py index 8f3722821..95dbb3c6a 100644 --- a/tests/models/revgptv1.py +++ b/tests/models/revgptv1.py @@ -2,12 +2,12 @@ from unittest.mock import patch from Sswarms.models.revgptv1 import RevChatGPTModelv1 -class TestRevChatGPT(unittest.TestCase): +class TestRevChatGPT(unittest.TestCase): def setUp(self): self.access_token = "" self.model = RevChatGPTModelv1(access_token=self.access_token) - + def test_run(self): prompt = "What is the capital of France?" response = self.model.run(prompt) @@ -21,7 +21,7 @@ def test_run_time(self): def test_generate_summary(self): text = "This is a sample text to summarize. It has multiple sentences and details. The summary should be concise." summary = self.model.generate_summary(text) - self.assertLess(len(summary), len(text)/2) + self.assertLess(len(summary), len(text) / 2) def test_enable_plugin(self): plugin_id = "some_plugin_id" @@ -39,9 +39,9 @@ def test_get_conversations(self): conversations = self.model.chatbot.get_conversations() self.assertIsInstance(conversations, list) - @patch("RevChatGPTModelv1.Chatbot.get_msg_history") + @patch("RevChatGPTModelv1.Chatbot.get_msg_history") def test_get_msg_history(self, mock_get_msg_history): - conversation_id = "convo_id" + conversation_id = "convo_id" self.model.chatbot.get_msg_history(conversation_id) mock_get_msg_history.assert_called_with(conversation_id) @@ -78,5 +78,6 @@ def test_rollback_conversation(self): self.model.chatbot.rollback_conversation(1) self.assertNotEqual(original_convo_id, self.model.chatbot.conversation_id) + if __name__ == "__main__": unittest.main() diff --git a/tests/models/revgptv4.py b/tests/models/revgptv4.py index fa7701c4f..7a40ab304 100644 --- a/tests/models/revgptv4.py +++ b/tests/models/revgptv4.py @@ -2,16 +2,16 @@ from unittest.mock import patch from RevChatGPTModelv4 import RevChatGPTModelv4 -class TestRevChatGPT(unittest.TestCase): +class TestRevChatGPT(unittest.TestCase): def setUp(self): self.access_token = "123" self.model = RevChatGPTModelv4(access_token=self.access_token) - + def test_run(self): prompt = "What is the capital of France?" self.model.start_time = 10 - self.model.end_time = 20 + self.model.end_time = 20 response = self.model.run(prompt) self.assertEqual(response, "The capital of France is Paris.") self.assertEqual(self.model.start_time, 10) @@ -44,7 +44,7 @@ def test_get_conversations(self, mock_get_conversations): @patch("RevChatGPTModelv4.Chatbot.get_msg_history") def test_get_msg_history(self, mock_get_msg_history): convo_id = "123" - self.model.chatbot.get_msg_history(convo_id) + self.model.chatbot.get_msg_history(convo_id) mock_get_msg_history.assert_called_with(convo_id) @patch("RevChatGPTModelv4.Chatbot.share_conversation") @@ -52,7 +52,7 @@ def test_share_conversation(self, mock_share_conversation): self.model.chatbot.share_conversation() mock_share_conversation.assert_called() - @patch("RevChatGPTModelv4.Chatbot.gen_title") + @patch("RevChatGPTModelv4.Chatbot.gen_title") def test_gen_title(self, mock_gen_title): convo_id = "123" message_id = "456" @@ -77,7 +77,7 @@ def test_clear_conversations(self, mock_clear_conversations): self.model.chatbot.clear_conversations() mock_clear_conversations.assert_called() - @patch("RevChatGPTModelv4.Chatbot.rollback_conversation") + @patch("RevChatGPTModelv4.Chatbot.rollback_conversation") def test_rollback_conversation(self, mock_rollback_conversation): num = 2 self.model.chatbot.rollback_conversation(num) @@ -88,5 +88,6 @@ def test_reset_chat(self, mock_reset_chat): self.model.chatbot.reset_chat() mock_reset_chat.assert_called() + if __name__ == "__main__": unittest.main()