Skip to content

Commit

Permalink
Merge branch 'fix/misc' into demo/agent-qna
Browse files Browse the repository at this point in the history
  • Loading branch information
htahir1 committed Jan 25, 2024
2 parents b67e1d5 + f0c7020 commit 4d01350
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 220 deletions.
2 changes: 2 additions & 0 deletions llm-agents/src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ gcsfs==2023.5.0
faiss-cpu==1.7.3
unstructured==0.5.7
tiktoken
bs4
typing_extensions
252 changes: 46 additions & 206 deletions llm-agents/src/run.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion llm-agents/src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def main():
version = "0.54.1"
version = "0.55.0"
docs_url = f"https://docs.zenml.io/v/{version}/"
website_url = "https://zenml.io"
repo_url = f"https://github.com/zenml-io/zenml/tree/{version}/examples"
Expand Down
16 changes: 9 additions & 7 deletions llm-agents/src/steps/agent_creator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Dict, cast
from typing import Dict, Optional
from typing_extensions import Annotated

from agent.agent_executor_materializer import AgentExecutorMaterializer
from agent.prompt import PREFIX, SUFFIX
Expand All @@ -8,7 +8,7 @@
from langchain.schema.vectorstore import VectorStore
from langchain.tools.vectorstore.tool import VectorStoreQATool
from langchain.agents import AgentExecutor
from zenml.steps import BaseParameters
from pydantic import BaseModel
from zenml import step, ArtifactConfig, log_artifact_metadata


Expand All @@ -17,7 +17,7 @@
CHARACTER = "technical assistant"


class AgentParameters(BaseParameters):
class AgentParameters(BaseModel):
"""Parameters for the agent."""

llm: Dict = {
Expand All @@ -33,8 +33,10 @@ class Config:

@step(output_materializers=AgentExecutorMaterializer, enable_cache=False)
def agent_creator(
vector_store: VectorStore, config: AgentParameters
) -> Annotated[AgentExecutor, ArtifactConfig(name="agent", is_model_artifact=True)]:
vector_store: VectorStore, config: AgentParameters = AgentParameters()
) -> Annotated[
AgentExecutor, ArtifactConfig(name="agent", is_model_artifact=True)
]:
"""Create an agent from a vector store.
Args:
Expand Down Expand Up @@ -81,7 +83,7 @@ def agent_creator(
"temperature": config.llm["temperature"],
"model_name": config.llm["model_name"],
},
}
},
)

return agent_executor
9 changes: 6 additions & 3 deletions llm-agents/src/steps/index_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Annotated, List
from typing_extensions import Annotated
from typing import List

from langchain.docstore.document import Document
from langchain.embeddings import OpenAIEmbeddings
Expand All @@ -25,7 +26,9 @@


@step(enable_cache=True)
def index_generator(documents: List[Document]) -> Annotated[VectorStore, "vector_store"]:
def index_generator(
documents: List[Document],
) -> Annotated[VectorStore, "vector_store"]:
embeddings = OpenAIEmbeddings()

text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
Expand All @@ -36,7 +39,7 @@ def index_generator(documents: List[Document]) -> Annotated[VectorStore, "vector
metadata={
"embedding_type": "OpenAIEmbeddings",
"vector_store_type": "FAISS",
}
},
)

return FAISS.from_documents(compiled_texts, embeddings)
6 changes: 3 additions & 3 deletions llm-agents/src/steps/url_scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Annotated, List

from typing import List
from typing_extensions import Annotated
from steps.url_scraping_utils import get_all_pages, get_nested_readme_urls
from zenml import step, log_artifact_metadata

Expand Down Expand Up @@ -45,6 +45,6 @@ def url_scraper(
artifact_name="urls",
metadata={
"count": len(all_urls),
}
},
)
return all_urls

0 comments on commit 4d01350

Please sign in to comment.