Skip to content

Commit

Permalink
Merge pull request #24 from crypdick/feature/pre-commit-hooks
Browse files Browse the repository at this point in the history
Add precommit hooks and lint
  • Loading branch information
eugeneyan authored Mar 1, 2024
2 parents 9a6d256 + 76bf501 commit 851bcdb
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 190 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/python_app.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Python application

on: [push, pull_request]

permissions:
contents: read

jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
37 changes: 37 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# note: to install, run `pre-commit install`
# note: if you update this file, run `pre-commit autoupdate`
default_language_version:
python: python3.9
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
hooks:
# Run the linter.
- id: ruff
# run on ipynb, too
types_or: [ python, pyi, jupyter ]
args: ["--fix", "--show-source"]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/myint/docformatter
rev: v1.7.5
hooks:
- id: docformatter
args: ['--in-place', '--pre-summary-newline', '--make-summary-multi-line', '--wrap-descriptions', '120', '--wrap-summaries', '120']
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
exclude: ".ipynb|.pub"
- id: detect-private-key
- id: check-added-large-files
- id: check-yaml
- id: check-toml
- id: check-ast
language: python
types: [python]
- id: check-merge-conflict
exclude: \.rst$|\.pot?$
- id: requirements-txt-fixer
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Build the OpenSearch and semantic indices
# Build the docker image
make build
# Start the opensearch container and wait for it to start.
# Start the opensearch container and wait for it to start.
# You should see something like this: [c6587bf83572] Node 'c6587bf83572' initialized
make opensearch
Expand All @@ -48,7 +48,7 @@ make build-artifacts

Running the retrieval app
```
# First, stop the opensearch container (CTRL + C). Then, start the retrieval app.
# First, stop the opensearch container (CTRL + C). Then, start the retrieval app.
# You should see this: Uvicorn running on http://0.0.0.0:8000
make run
```
Expand All @@ -71,6 +71,10 @@ At a high level, when you type a section header, it'll:
- The retrieved context is then used to generate paragraphs for the section
- It is also displayed in a new tab for info

## Contributing

To install the pre-commit hooks, run `pip install pre-commit && pre-commit install` in the root of the repository.

## TODOs

- [ ] Add support for using anthrophic claude (100k context)
Expand Down
91 changes: 54 additions & 37 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Simple FastAPI app that queries opensearch and a semantic index for retrieval-augmented generation.
"""

import os
import pickle
from typing import Dict, List
Expand All @@ -13,29 +14,30 @@
from transformers import AutoModel, AutoTokenizer

from src.logger import logger
from src.prep.build_opensearch_index import (INDEX_NAME, get_opensearch,
query_opensearch)
from src.prep.build_opensearch_index import INDEX_NAME, get_opensearch, query_opensearch
from src.prep.build_semantic_index import query_semantic

# Load vault dictionary
vault = pickle.load(open('data/vault_dict.pickle', 'rb'))
logger.info(f'Vault loaded with {len(vault)} documents')
vault = pickle.load(open("data/vault_dict.pickle", "rb"))
logger.info(f"Vault loaded with {len(vault)} documents")

# Create opensearch client
try:
os_client = get_opensearch('opensearch')
os_client = get_opensearch("opensearch")
except ConnectionRefusedError:
os_client = get_opensearch('localhost') # Change to 'localhost' if running locally
logger.info(f'OS client initialized: {os_client.info()}')
os_client = get_opensearch("localhost") # Change to 'localhost' if running locally
logger.info(f"OS client initialized: {os_client.info()}")

# Load semantic index
doc_embeddings_array = np.load('data/doc_embeddings_array.npy')
with open('data/embedding_index.pickle', 'rb') as f:
doc_embeddings_array = np.load("data/doc_embeddings_array.npy")
with open("data/embedding_index.pickle", "rb") as f:
embedding_index = pickle.load(f)
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-small-v2') # Max token length is 512
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
model = AutoModel.from_pretrained('intfloat/e5-small-v2')
logger.info(f'Semantic index loaded with {len(embedding_index)} documents')
tokenizer = AutoTokenizer.from_pretrained(
"intfloat/e5-small-v2"
) # Max token length is 512
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model = AutoModel.from_pretrained("intfloat/e5-small-v2")
logger.info(f"Semantic index loaded with {len(embedding_index)} documents")


# Create app
Expand All @@ -57,7 +59,8 @@


def parse_os_response(response: dict) -> List[dict]:
"""Parse response from opensearch index.
"""
Parse response from opensearch index.
Args:
response: Response from opensearch query.
Expand All @@ -67,14 +70,17 @@ def parse_os_response(response: dict) -> List[dict]:
"""
hits = []

for rank, hit in enumerate(response['hits']['hits']):
hits.append({'id': hit['_id'], 'rank': rank})
for rank, hit in enumerate(response["hits"]["hits"]):
hits.append({"id": hit["_id"], "rank": rank})

return hits


def parse_semantic_response(indices: np.ndarray, embedding_index: Dict[int, str]) -> List[dict]:
"""Parse response from semantic index.
def parse_semantic_response(
indices: np.ndarray, embedding_index: Dict[int, str]
) -> List[dict]:
"""
Parse response from semantic index.
Args:
indices: Response from semantic query, an array of ints.
Expand All @@ -85,13 +91,14 @@ def parse_semantic_response(indices: np.ndarray, embedding_index: Dict[int, str]
hits = []

for rank, idx in enumerate(indices):
hits.append({'id': embedding_index[idx], 'rank': rank})
hits.append({"id": embedding_index[idx], "rank": rank})

return hits


def num_tokens_from_string(string: str, model_name: str) -> int:
"""Returns the number of tokens in a string based on tiktoken encoding.
"""
Returns the number of tokens in a string based on tiktoken encoding.
Args:
string: String to count tokens for
Expand All @@ -105,8 +112,11 @@ def num_tokens_from_string(string: str, model_name: str) -> int:
return num_tokens


def get_chunks_from_hits(hits: List[dict], model_name: str = 'gpt-3.5-turbo', max_tokens: int = 3200) -> List[dict]:
"""Deduplicates and scores a list of chunks. (There may be duplicate chunks as we query multiple indices.)
def get_chunks_from_hits(
hits: List[dict], model_name: str = "gpt-3.5-turbo", max_tokens: int = 3200
) -> List[dict]:
"""
Deduplicates and scores a list of chunks. (There may be duplicate chunks as we query multiple indices.)
Args:
hits: List of hits from opensearch, semantic index, etc.
Expand All @@ -119,29 +129,34 @@ def get_chunks_from_hits(hits: List[dict], model_name: str = 'gpt-3.5-turbo', ma
"""
# Combine os and semantic hits and rank them
df = pd.DataFrame(hits)
df['score'] = df['rank'].apply(lambda x: 10 - x)
df["score"] = df["rank"].apply(lambda x: 10 - x)
# deduplicate chunks by ID, summing their OS and semantic scores
ranked = df.groupby('id').agg({'score': 'sum'}).sort_values('score', ascending=False).reset_index()
ranked = (
df.groupby("id")
.agg({"score": "sum"})
.sort_values("score", ascending=False)
.reset_index()
)

# Get context based on ranked IDs
chunks = []
token_count = 0

for id in ranked['id'].tolist():
chunk = vault[id]['chunk']
title = vault[id]['title']
for id in ranked["id"].tolist():
chunk = vault[id]["chunk"]
title = vault[id]["title"]

# Check if token count exceeds max_tokens
token_count += num_tokens_from_string(chunk, model_name)
if token_count > max_tokens:
break

chunks.append({'title': title, 'chunk': chunk})
chunks.append({"title": title, "chunk": chunk})

return chunks


@app.get('/get_chunks')
@app.get("/get_chunks")
def get_chunks(query: str):
if not query:
raise ValueError(
Expand All @@ -150,28 +165,30 @@ def get_chunks(query: str):
# Get hits from opensearch
os_response = query_opensearch(query, os_client, INDEX_NAME)
os_hits = parse_os_response(os_response)
logger.debug(f'OS hits: {os_hits}')
logger.debug(f"OS hits: {os_hits}")

# Get hits from semantic index
semantic_response = query_semantic(query, tokenizer, model, doc_embeddings_array)
semantic_hits = parse_semantic_response(semantic_response, embedding_index)
logger.debug(f'Semantic hits: {semantic_hits}')
logger.debug(f"Semantic hits: {semantic_hits}")

# Get context
context = get_chunks_from_hits(os_hits + semantic_hits)
return context


if __name__ == '__main__':
if __name__ == "__main__":
logger.info(f'Environment variables loaded: {os.getenv("OPENAI_API_KEY")}')
test_query = 'Examples of bandits in industry'
test_query = "Examples of bandits in industry"
os_response = query_opensearch(test_query, os_client, INDEX_NAME)
os_hits = parse_os_response(os_response)
logger.debug(f'OS hits: {os_hits}')
semantic_response = query_semantic(f'query: {test_query}', tokenizer, model, doc_embeddings_array)
logger.debug(f"OS hits: {os_hits}")
semantic_response = query_semantic(
f"query: {test_query}", tokenizer, model, doc_embeddings_array
)
semantic_hits = parse_semantic_response(semantic_response, embedding_index)
logger.debug(f'Semantic hits: {semantic_hits}')
logger.debug(f"Semantic hits: {semantic_hits}")

# Combine os and semantic hits and rank them
context = get_chunks_from_hits(os_hits + semantic_hits)
logger.info(f'Context: {context}')
logger.info(f"Context: {context}")
2 changes: 1 addition & 1 deletion src/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Add the log message handler to the logger
handler = logging.StreamHandler() # Writes logging output to streams like sys.stdout, sys.stderr or any file-like object

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)

logger.addHandler(handler)
Loading

0 comments on commit 851bcdb

Please sign in to comment.