diff --git a/.env b/.env
index bbbf0ef..b043103 100644
--- a/.env
+++ b/.env
@@ -8,6 +8,7 @@ DEFAULT_MAX_COMPLETION_TOKENS=1000
DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE =1
DEFAULT_COMPLETION_TEMPERATURE=0.7
LLAMA_EMBEDDING_SERVER_LISTEN_PORT=8089
+UVICORN_NUMBER_OF_WORKERS=2
MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING=15
MAX_RETRIES=10
DB_WRITE_BATCH_SIZE=25
diff --git a/README.md b/README.md
index 0064201..5ea7912 100644
--- a/README.md
+++ b/README.md
@@ -109,6 +109,7 @@ fast_vector_similarity
faster-whisper
textract
pytz
+uvloop
```
## Running the Application
@@ -137,6 +138,7 @@ You can configure the service easily by editing the included `.env` file. Here's
- `DEFAULT_MODEL_NAME`: Default model name to use. (e.g., `yarn-llama-2-13b-128k`)
- `LLM_CONTEXT_SIZE_IN_TOKENS`: Context size in tokens for LLM. (e.g., `512`)
- `SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT`: Port number for the service. (e.g., `8089`)
+- `UVICORN_NUMBER_OF_WORKERS`: Number of workers for Uvicorn. (e.g., `2`)
- `MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING`: Minimum string length for document embedding. (e.g., `15`)
- `MAX_RETRIES`: Maximum retries for locked database. (e.g., `10`)
- `DB_WRITE_BATCH_SIZE`: Database write batch size. (e.g., `25`)
diff --git a/database_functions.py b/database_functions.py
new file mode 100644
index 0000000..2b5a52f
--- /dev/null
+++ b/database_functions.py
@@ -0,0 +1,162 @@
+from embeddings_data_models import Base, TextEmbedding, DocumentEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundle, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript
+from logger_config import setup_logger
+import traceback
+import asyncio
+import random
+from sqlalchemy import select
+from sqlalchemy import text as sql_text
+from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
+from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
+from sqlalchemy.orm import sessionmaker
+from datetime import datetime
+from decouple import config
+
+logger = setup_logger()
+db_writer = None
+DATABASE_URL = "sqlite+aiosqlite:///swiss_army_llama.sqlite"
+MAX_RETRIES = config("MAX_RETRIES", default=3, cast=int)
+DB_WRITE_BATCH_SIZE = config("DB_WRITE_BATCH_SIZE", default=25, cast=int)
+RETRY_DELAY_BASE_SECONDS = config("RETRY_DELAY_BASE_SECONDS", default=1, cast=int)
+JITTER_FACTOR = config("JITTER_FACTOR", default=0.1, cast=float)
+
+engine = create_async_engine(DATABASE_URL, echo=False, connect_args={"check_same_thread": False})
+AsyncSessionLocal = sessionmaker(
+ bind=engine,
+ class_=AsyncSession,
+ expire_on_commit=False,
+ autoflush=False
+)
+class DatabaseWriter:
+ def __init__(self, queue):
+ self.queue = queue
+ self.processing_hashes = set() # Set to store the hashes if everything that is currently being processed in the queue (to avoid duplicates of the same task being added to the queue)
+
+ def _get_hash_from_operation(self, operation):
+ attr_name = {
+ TextEmbedding: 'text_hash',
+ DocumentEmbedding: 'file_hash',
+ Document: 'document_hash',
+ TokenLevelEmbedding: 'token_hash',
+ TokenLevelEmbeddingBundle: 'input_text_hash',
+ TokenLevelEmbeddingBundleCombinedFeatureVector: 'combined_feature_vector_hash',
+ AudioTranscript: 'audio_file_hash'
+ }.get(type(operation))
+ hash_value = getattr(operation, attr_name, None)
+ llm_model_name = getattr(operation, 'llm_model_name', None)
+ return f"{hash_value}_{llm_model_name}" if hash_value and llm_model_name else None
+
+ async def initialize_processing_hashes(self, chunk_size=1000):
+ start_time = datetime.utcnow()
+ async with AsyncSessionLocal() as session:
+ queries = [
+ (select(TextEmbedding.text_hash, TextEmbedding.llm_model_name), True),
+ (select(DocumentEmbedding.file_hash, DocumentEmbedding.llm_model_name), True),
+ (select(Document.document_hash, Document.llm_model_name), True),
+ (select(TokenLevelEmbedding.token_hash, TokenLevelEmbedding.llm_model_name), True),
+ (select(TokenLevelEmbeddingBundle.input_text_hash, TokenLevelEmbeddingBundle.llm_model_name), True),
+ (select(TokenLevelEmbeddingBundleCombinedFeatureVector.combined_feature_vector_hash, TokenLevelEmbeddingBundleCombinedFeatureVector.llm_model_name), True),
+ (select(AudioTranscript.audio_file_hash), False)
+ ]
+ for query, has_llm in queries:
+ offset = 0
+ while True:
+ result = await session.execute(query.limit(chunk_size).offset(offset))
+ rows = result.fetchall()
+ if not rows:
+ break
+ for row in rows:
+ if has_llm:
+ hash_with_model = f"{row[0]}_{row[1]}"
+ else:
+ hash_with_model = row[0]
+ self.processing_hashes.add(hash_with_model)
+ offset += chunk_size
+ end_time = datetime.utcnow()
+ total_time = (end_time - start_time).total_seconds()
+ if len(self.processing_hashes) > 0:
+ logger.info(f"Finished initializing set of input hash/llm_model_name combinations that are either currently being processed or have already been processed. Set size: {len(self.processing_hashes)}; Took {total_time} seconds, for an average of {total_time / len(self.processing_hashes)} seconds per hash.")
+
+ async def _handle_integrity_error(self, e, write_operation, session):
+ unique_constraint_msg = {
+ TextEmbedding: "token_embeddings.token_hash, token_embeddings.llm_model_name",
+ DocumentEmbedding: "document_embeddings.file_hash, document_embeddings.llm_model_name",
+ Document: "documents.document_hash, documents.llm_model_name",
+ TokenLevelEmbedding: "token_level_embeddings.token_hash, token_level_embeddings.llm_model_name",
+ TokenLevelEmbeddingBundle: "token_level_embedding_bundles.input_text_hash, token_level_embedding_bundles.llm_model_name",
+ AudioTranscript: "audio_transcripts.audio_file_hash"
+ }.get(type(write_operation))
+ if unique_constraint_msg and unique_constraint_msg in str(e):
+ logger.warning(f"Embedding already exists in the database for given input and llm_model_name: {e}")
+ await session.rollback()
+ else:
+ raise
+
+ async def dedicated_db_writer(self):
+ while True:
+ write_operations_batch = await self.queue.get()
+ async with AsyncSessionLocal() as session:
+ try:
+ for write_operation in write_operations_batch:
+ session.add(write_operation)
+ await session.flush() # Flush to get the IDs
+ await session.commit()
+ for write_operation in write_operations_batch:
+ hash_to_remove = self._get_hash_from_operation(write_operation)
+ if hash_to_remove is not None and hash_to_remove in self.processing_hashes:
+ self.processing_hashes.remove(hash_to_remove)
+ except IntegrityError as e:
+ await self._handle_integrity_error(e, write_operation, session)
+ except SQLAlchemyError as e:
+ logger.error(f"Database error: {e}")
+ await session.rollback()
+ except Exception as e:
+ tb = traceback.format_exc()
+ logger.error(f"Unexpected error: {e}\n{tb}")
+ await session.rollback()
+ self.queue.task_done()
+
+ async def enqueue_write(self, write_operations):
+ write_operations = [op for op in write_operations if self._get_hash_from_operation(op) not in self.processing_hashes] # Filter out write operations for hashes that are already being processed
+ if not write_operations: # If there are no write operations left after filtering, return early
+ return
+ for op in write_operations: # Add the hashes of the write operations to the set
+ hash_value = self._get_hash_from_operation(op)
+ if hash_value:
+ self.processing_hashes.add(hash_value)
+ await self.queue.put(write_operations)
+
+
+async def execute_with_retry(func, *args, **kwargs):
+ retries = 0
+ while retries < MAX_RETRIES:
+ try:
+ return await func(*args, **kwargs)
+ except OperationalError as e:
+ if 'database is locked' in str(e):
+ retries += 1
+ sleep_time = RETRY_DELAY_BASE_SECONDS * (2 ** retries) + (random.random() * JITTER_FACTOR) # Implementing exponential backoff with jitter
+ logger.warning(f"Database is locked. Retrying ({retries}/{MAX_RETRIES})... Waiting for {sleep_time} seconds")
+ await asyncio.sleep(sleep_time)
+ else:
+ raise
+ raise OperationalError("Database is locked after multiple retries")
+
+async def initialize_db():
+ logger.info("Initializing database, creating tables, and setting SQLite PRAGMAs...")
+ list_of_sqlite_pragma_strings = ["PRAGMA journal_mode=WAL;", "PRAGMA synchronous = NORMAL;", "PRAGMA cache_size = -1048576;", "PRAGMA busy_timeout = 2000;", "PRAGMA wal_autocheckpoint = 100;"]
+ list_of_sqlite_pragma_justification_strings = ["Set SQLite to use Write-Ahead Logging (WAL) mode (from default DELETE mode) so that reads and writes can occur simultaneously",
+ "Set synchronous mode to NORMAL (from FULL) so that writes are not blocked by reads",
+ "Set cache size to 1GB (from default 2MB) so that more data can be cached in memory and not read from disk; to make this 256MB, set it to -262144 instead",
+ "Increase the busy timeout to 2 seconds so that the database waits",
+ "Set the WAL autocheckpoint to 100 (from default 1000) so that the WAL file is checkpointed more frequently"]
+ assert(len(list_of_sqlite_pragma_strings) == len(list_of_sqlite_pragma_justification_strings))
+ async with engine.begin() as conn:
+ for pragma_string in list_of_sqlite_pragma_strings:
+ await conn.execute(sql_text(pragma_string))
+ logger.info(f"Executed SQLite PRAGMA: {pragma_string}")
+ logger.info(f"Justification: {list_of_sqlite_pragma_justification_strings[list_of_sqlite_pragma_strings.index(pragma_string)]}")
+ await conn.run_sync(Base.metadata.create_all) # Create tables if they don't exist
+ logger.info("Database initialization completed.")
+
+def get_db_writer() -> DatabaseWriter:
+ return db_writer # Return the existing DatabaseWriter instance
diff --git a/llama_knife_sticker.webp b/image_files/llama_knife_sticker.webp
similarity index 100%
rename from llama_knife_sticker.webp
rename to image_files/llama_knife_sticker.webp
diff --git a/llama_knife_sticker2.jpg b/image_files/llama_knife_sticker2.jpg
similarity index 100%
rename from llama_knife_sticker2.jpg
rename to image_files/llama_knife_sticker2.jpg
diff --git a/swiss_army_llama__swagger_screenshot.png b/image_files/swiss_army_llama__swagger_screenshot.png
similarity index 100%
rename from swiss_army_llama__swagger_screenshot.png
rename to image_files/swiss_army_llama__swagger_screenshot.png
diff --git a/swiss_army_llama__swagger_screenshot_running.png b/image_files/swiss_army_llama__swagger_screenshot_running.png
similarity index 100%
rename from swiss_army_llama__swagger_screenshot_running.png
rename to image_files/swiss_army_llama__swagger_screenshot_running.png
diff --git a/swiss_army_llama_logo.webp b/image_files/swiss_army_llama_logo.webp
similarity index 100%
rename from swiss_army_llama_logo.webp
rename to image_files/swiss_army_llama_logo.webp
diff --git a/log_viewer_functions.py b/log_viewer_functions.py
index ac2c787..d271560 100644
--- a/log_viewer_functions.py
+++ b/log_viewer_functions.py
@@ -48,7 +48,6 @@ def highlight_rules_func(text):
text = text.replace('#COLOR13_OPEN#', '').replace('#COLOR13_CLOSE#', '')
return text
-
def show_logs_incremental_func(minutes: int, last_position: int):
new_logs = []
now = datetime.now(timezone('UTC')) # get current time, make it timezone-aware
@@ -75,21 +74,22 @@ def show_logs_incremental_func(minutes: int, last_position: int):
def show_logs_func(minutes: int = 5):
- # read the entire log file and generate HTML with logs up to `minutes` minutes from now
with open(log_file_path, "r") as f:
lines = f.readlines()
logs = []
- now = datetime.now(timezone('UTC')) # get current time, make it timezone-aware
+ now = datetime.now(timezone('UTC'))
for line in lines:
if line.strip() == "":
continue
- if line[0].isdigit():
- log_datetime_str = line.split(" - ")[0] # assuming the datetime is at the start of each line
- log_datetime = datetime.strptime(log_datetime_str, "%Y-%m-%d %H:%M:%S,%f") # parse the datetime string to a datetime object
- log_datetime = log_datetime.replace(tzinfo=timezone('UTC')) # set the datetime object timezone to UTC to match `now`
- if now - log_datetime <= timedelta(minutes=minutes): # if the log is within `minutes` minutes from now
- logs.append(highlight_rules_func(line.rstrip('\n'))) # add the highlighted log to the list and strip any newline at the end
- logs_as_string = "
".join(logs) # joining with
directly
+ try:
+ log_datetime_str = line.split(" - ")[0]
+ log_datetime = datetime.strptime(log_datetime_str, "%Y-%m-%d %H:%M:%S,%f")
+ log_datetime = log_datetime.replace(tzinfo=timezone('UTC'))
+ if now - log_datetime <= timedelta(minutes=minutes):
+ logs.append(highlight_rules_func(line.rstrip('\n')))
+ except (ValueError, IndexError):
+ logs.append(highlight_rules_func(line.rstrip('\n'))) # Line didn't meet datetime parsing criteria, continue with processing
+ logs_as_string = "
".join(logs)
logs_as_string_newlines_rendered = logs_as_string.replace("\n", "
")
logs_as_string_newlines_rendered_font_specified = """
diff --git a/logger_config.py b/logger_config.py
new file mode 100644
index 0000000..a21fe5a
--- /dev/null
+++ b/logger_config.py
@@ -0,0 +1,35 @@
+import logging
+import os
+import shutil
+import queue
+from logging.handlers import RotatingFileHandler, QueueHandler, QueueListener
+
+logger = logging.getLogger("swiss_army_llama")
+
+def setup_logger():
+ if logger.handlers:
+ return logger
+ old_logs_dir = 'old_logs'
+ if not os.path.exists(old_logs_dir):
+ os.makedirs(old_logs_dir)
+ logger.setLevel(logging.INFO)
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+ log_file_path = 'swiss_army_llama.log'
+ log_queue = queue.Queue(-1) # Create a queue for the handlers
+ fh = RotatingFileHandler(log_file_path, maxBytes=10*1024*1024, backupCount=5)
+ fh.setFormatter(formatter)
+ def namer(default_log_name): # Function to move rotated logs to the old_logs directory
+ return os.path.join(old_logs_dir, os.path.basename(default_log_name))
+ def rotator(source, dest):
+ shutil.move(source, dest)
+ fh.namer = namer
+ fh.rotator = rotator
+ sh = logging.StreamHandler() # Stream handler
+ sh.setFormatter(formatter)
+ queue_handler = QueueHandler(log_queue) # Create QueueHandler
+ queue_handler.setFormatter(formatter)
+ logger.addHandler(queue_handler)
+ listener = QueueListener(log_queue, fh, sh) # Create QueueListener with real handlers
+ listener.start()
+ logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # Configure SQLalchemy logging
+ return logger
diff --git a/misc_utility_functions.py b/misc_utility_functions.py
new file mode 100644
index 0000000..7c73ffd
--- /dev/null
+++ b/misc_utility_functions.py
@@ -0,0 +1,215 @@
+from logger_config import setup_logger
+from database_functions import AsyncSessionLocal
+import socket
+import os
+import re
+import json
+import io
+import numpy as np
+import faiss
+from typing import Any
+from collections import defaultdict
+from sqlalchemy import text as sql_text
+
+logger = setup_logger()
+
+def clean_filename_for_url_func(dirty_filename: str) -> str:
+ clean_filename = re.sub(r'[^\w\s]', '', dirty_filename) # Remove special characters and replace spaces with underscores
+ clean_filename = clean_filename.replace(' ', '_')
+ return clean_filename
+
+def is_redis_running(host='localhost', port=6379):
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ try:
+ s.connect((host, port))
+ return True
+ except ConnectionRefusedError:
+ return False
+ finally:
+ s.close()
+
+
+async def build_faiss_indexes():
+ global faiss_indexes, token_faiss_indexes, associated_texts_by_model
+ if os.environ.get("FAISS_SETUP_DONE") == "1":
+ logger.info("Faiss indexes already built by another worker. Skipping.")
+ return faiss_indexes, token_faiss_indexes, associated_texts_by_model
+ faiss_indexes = {}
+ token_faiss_indexes = {} # Separate FAISS indexes for token-level embeddings
+ associated_texts_by_model = defaultdict(list) # Create a dictionary to store associated texts by model name
+ async with AsyncSessionLocal() as session:
+ result = await session.execute(sql_text("SELECT llm_model_name, text, embedding_json FROM embeddings")) # Query regular embeddings
+ token_result = await session.execute(sql_text("SELECT llm_model_name, token, token_level_embedding_json FROM token_level_embeddings")) # Query token-level embeddings
+ embeddings_by_model = defaultdict(list)
+ token_embeddings_by_model = defaultdict(list)
+ for row in result.fetchall(): # Process regular embeddings
+ llm_model_name = row[0]
+ associated_texts_by_model[llm_model_name].append(row[1]) # Store the associated text by model name
+ embeddings_by_model[llm_model_name].append((row[1], json.loads(row[2])))
+ for row in token_result.fetchall(): # Process token-level embeddings
+ llm_model_name = row[0]
+ token_embeddings_by_model[llm_model_name].append(json.loads(row[2]))
+ for llm_model_name, embeddings in embeddings_by_model.items():
+ logger.info(f"Building Faiss index over embeddings for model {llm_model_name}...")
+ embeddings_array = np.array([e[1] for e in embeddings]).astype('float32')
+ if embeddings_array.size == 0:
+ logger.error(f"No embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!")
+ continue
+ logger.info(f"Loaded {len(embeddings_array)} embeddings for model {llm_model_name}.")
+ logger.info(f"Embedding dimension for model {llm_model_name}: {embeddings_array.shape[1]}")
+ logger.info(f"Normalizing {len(embeddings_array)} embeddings for model {llm_model_name}...")
+ faiss.normalize_L2(embeddings_array) # Normalize the vectors for cosine similarity
+ faiss_index = faiss.IndexFlatIP(embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity
+ faiss_index.add(embeddings_array)
+ logger.info(f"Faiss index built for model {llm_model_name}.")
+ faiss_indexes[llm_model_name] = faiss_index # Store the index by model name
+ for llm_model_name, token_embeddings in token_embeddings_by_model.items():
+ token_embeddings_array = np.array(token_embeddings).astype('float32')
+ if token_embeddings_array.size == 0:
+ logger.error(f"No token-level embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!")
+ continue
+ logger.info(f"Normalizing {len(token_embeddings_array)} token-level embeddings for model {llm_model_name}...")
+ faiss.normalize_L2(token_embeddings_array) # Normalize the vectors for cosine similarity
+ token_faiss_index = faiss.IndexFlatIP(token_embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity
+ token_faiss_index.add(token_embeddings_array)
+ logger.info(f"Token-level Faiss index built for model {llm_model_name}.")
+ token_faiss_indexes[llm_model_name] = token_faiss_index # Store the token-level index by model name
+ os.environ["FAISS_SETUP_DONE"] = "1"
+ logger.info("Faiss indexes built.")
+ return faiss_indexes, token_faiss_indexes, associated_texts_by_model
+
+class JSONAggregator:
+ def __init__(self):
+ self.completions = []
+ self.aggregate_result = None
+
+ @staticmethod
+ def weighted_vote(values, weights):
+ tally = defaultdict(float)
+ for v, w in zip(values, weights):
+ tally[v] += w
+ return max(tally, key=tally.get)
+
+ @staticmethod
+ def flatten_json(json_obj, parent_key='', sep='->'):
+ items = {}
+ for k, v in json_obj.items():
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
+ if isinstance(v, dict):
+ items.update(JSONAggregator.flatten_json(v, new_key, sep=sep))
+ else:
+ items[new_key] = v
+ return items
+
+ @staticmethod
+ def get_value_by_path(json_obj, path, sep='->'):
+ keys = path.split(sep)
+ item = json_obj
+ for k in keys:
+ item = item[k]
+ return item
+
+ @staticmethod
+ def set_value_by_path(json_obj, path, value, sep='->'):
+ keys = path.split(sep)
+ item = json_obj
+ for k in keys[:-1]:
+ item = item.setdefault(k, {})
+ item[keys[-1]] = value
+
+ def calculate_path_weights(self):
+ all_paths = []
+ for j in self.completions:
+ all_paths += list(self.flatten_json(j).keys())
+ path_weights = defaultdict(float)
+ for path in all_paths:
+ path_weights[path] += 1.0
+ return path_weights
+
+ def aggregate(self):
+ path_weights = self.calculate_path_weights()
+ aggregate = {}
+ for path, weight in path_weights.items():
+ values = [self.get_value_by_path(j, path) for j in self.completions if path in self.flatten_json(j)]
+ weights = [weight] * len(values)
+ aggregate_value = self.weighted_vote(values, weights)
+ self.set_value_by_path(aggregate, path, aggregate_value)
+ self.aggregate_result = aggregate
+
+class FakeUploadFile:
+ def __init__(self, filename: str, content: Any, content_type: str = 'text/plain'):
+ self.filename = filename
+ self.content_type = content_type
+ self.file = io.BytesIO(content)
+ def read(self, size: int = -1) -> bytes:
+ return self.file.read(size)
+ def seek(self, offset: int, whence: int = 0) -> int:
+ return self.file.seek(offset, whence)
+ def tell(self) -> int:
+ return self.file.tell()
+
+def normalize_logprobs(avg_logprob, min_logprob, max_logprob):
+ range_logprob = max_logprob - min_logprob
+ return (avg_logprob - min_logprob) / range_logprob if range_logprob != 0 else 0.5
+
+def remove_pagination_breaks(text: str) -> str:
+ text = re.sub(r'-(\n)(?=[a-z])', '', text) # Remove hyphens at the end of lines when the word continues on the next line
+ text = re.sub(r'(?<=\w)(? total_ram_gb:
+ raise ValueError(f"Cannot allocate {RAMDISK_SIZE_IN_GB}G for RAM Disk. Total system RAM is {total_ram_gb:.2f}G.")
+ logger.info("Setting up RAM Disk...")
+ os.makedirs(RAMDISK_PATH, exist_ok=True)
+ mount_command = ["sudo", "mount", "-t", "tmpfs", "-o", f"size={ramdisk_size_str}", "tmpfs", RAMDISK_PATH]
+ subprocess.run(mount_command, check=True)
+ logger.info(f"RAM Disk set up at {RAMDISK_PATH} with size {ramdisk_size_gb}G")
+ os.environ["RAMDISK_SETUP_DONE"] = "1"
+
+def copy_models_to_ramdisk(models_directory, ramdisk_directory):
+ total_size = sum(os.path.getsize(os.path.join(models_directory, model)) for model in os.listdir(models_directory))
+ free_ram = psutil.virtual_memory().free
+ if total_size > free_ram:
+ logger.warning(f"Not enough space on RAM Disk. Required: {total_size}, Available: {free_ram}. Rebuilding RAM Disk.")
+ clear_ramdisk()
+ free_ram = psutil.virtual_memory().free # Recompute the available RAM after clearing the RAM disk
+ if total_size > free_ram:
+ logger.error(f"Still not enough space on RAM Disk even after clearing. Required: {total_size}, Available: {free_ram}.")
+ raise ValueError("Not enough RAM space to copy models.")
+ os.makedirs(ramdisk_directory, exist_ok=True)
+ for model in os.listdir(models_directory):
+ src_path = os.path.join(models_directory, model)
+ dest_path = os.path.join(ramdisk_directory, model)
+ if os.path.exists(dest_path) and os.path.getsize(dest_path) == os.path.getsize(src_path): # Check if the file already exists in the RAM disk and has the same size
+ logger.info(f"Model {model} already exists in RAM Disk and is the same size. Skipping copy.")
+ continue
+ shutil.copyfile(src_path, dest_path)
+ logger.info(f"Copied model {model} to RAM Disk at {dest_path}")
+
+def clear_ramdisk():
+ while True:
+ cmd_check = f"sudo mount | grep {RAMDISK_PATH}"
+ result = subprocess.run(cmd_check, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
+ if RAMDISK_PATH not in result:
+ break # Exit the loop if the RAMDISK_PATH is not in the mount list
+ cmd_umount = f"sudo umount -l {RAMDISK_PATH}"
+ subprocess.run(cmd_umount, shell=True, check=True)
+ logger.info(f"Cleared RAM Disk at {RAMDISK_PATH}")
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 2f2f90a..0403a1a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -19,3 +19,6 @@ faster-whisper
textract
pytest
pytz
+uvloop
+aioredis
+aioredlock
\ No newline at end of file
diff --git a/service_functions.py b/service_functions.py
new file mode 100644
index 0000000..553c7b6
--- /dev/null
+++ b/service_functions.py
@@ -0,0 +1,614 @@
+from logger_config import setup_logger
+import shared_resources
+from shared_resources import load_model, token_level_embedding_model_cache, text_completion_model_cache
+from database_functions import AsyncSessionLocal, DatabaseWriter, execute_with_retry
+from misc_utility_functions import clean_filename_for_url_func, FakeUploadFile, sophisticated_sentence_splitter, merge_transcript_segments_into_combined_text
+from embeddings_data_models import TextEmbedding, DocumentEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript
+from embeddings_data_models import EmbeddingRequest, TextCompletionRequest
+from embeddings_data_models import TextCompletionResponse, AudioTranscriptResponse
+import os
+import shutil
+import psutil
+import glob
+import json
+import asyncio
+import zipfile
+import tempfile
+import time
+from datetime import datetime
+from hashlib import sha3_256
+from urllib.parse import quote
+import numpy as np
+import pandas as pd
+import textract
+from sqlalchemy import text as sql_text
+from sqlalchemy import select
+from fastapi import HTTPException, Request, UploadFile, File
+from fastapi.concurrency import run_in_threadpool
+from typing import List, Optional, Tuple
+from decouple import config
+from faster_whisper import WhisperModel
+from llama_cpp import Llama, LlamaGrammar
+
+logger = setup_logger()
+SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int)
+DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str)
+LLM_CONTEXT_SIZE_IN_TOKENS = config("LLM_CONTEXT_SIZE_IN_TOKENS", default=512, cast=int)
+TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS = config("TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS", default=4000, cast=int)
+DEFAULT_MAX_COMPLETION_TOKENS = config("DEFAULT_MAX_COMPLETION_TOKENS", default=100, cast=int)
+DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE = config("DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE", default=4, cast=int)
+DEFAULT_COMPLETION_TEMPERATURE = config("DEFAULT_COMPLETION_TEMPERATURE", default=0.7, cast=float)
+MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING = config("MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING", default=15, cast=int)
+USE_PARALLEL_INFERENCE_QUEUE = config("USE_PARALLEL_INFERENCE_QUEUE", default=False, cast=bool)
+MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS = config("MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS", default=10, cast=int)
+USE_RAMDISK = config("USE_RAMDISK", default=False, cast=bool)
+RAMDISK_PATH = config("RAMDISK_PATH", default="/mnt/ramdisk", cast=str)
+BASE_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
+
+async def get_transcript_from_db(audio_file_hash: str):
+ return await execute_with_retry(_get_transcript_from_db, audio_file_hash)
+
+async def _get_transcript_from_db(audio_file_hash: str) -> Optional[dict]:
+ async with AsyncSessionLocal() as session:
+ result = await session.execute(
+ sql_text("SELECT * FROM audio_transcripts WHERE audio_file_hash=:audio_file_hash"),
+ {"audio_file_hash": audio_file_hash},
+ )
+ row = result.fetchone()
+ if row:
+ try:
+ segments_json = json.loads(row.segments_json)
+ combined_transcript_text_list_of_metadata_dicts = json.loads(row.combined_transcript_text_list_of_metadata_dicts)
+ info_json = json.loads(row.info_json)
+ if hasattr(info_json, '__dict__'):
+ info_json = vars(info_json)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"JSON Decode Error: {e}")
+ if not isinstance(segments_json, list) or not isinstance(combined_transcript_text_list_of_metadata_dicts, list) or not isinstance(info_json, dict):
+ logger.error(f"Type of segments_json: {type(segments_json)}, Value: {segments_json}")
+ logger.error(f"Type of combined_transcript_text_list_of_metadata_dicts: {type(combined_transcript_text_list_of_metadata_dicts)}, Value: {combined_transcript_text_list_of_metadata_dicts}")
+ logger.error(f"Type of info_json: {type(info_json)}, Value: {info_json}")
+ raise ValueError("Deserialized JSON does not match the expected format.")
+ audio_transcript_response = {
+ "id": row.id,
+ "audio_file_name": row.audio_file_name,
+ "audio_file_size_mb": row.audio_file_size_mb,
+ "segments_json": segments_json,
+ "combined_transcript_text": row.combined_transcript_text,
+ "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts,
+ "info_json": info_json,
+ "ip_address": row.ip_address,
+ "request_time": row.request_time,
+ "response_time": row.response_time,
+ "total_time": row.total_time,
+ "url_to_download_zip_file_of_embeddings": ""
+ }
+ return AudioTranscriptResponse(**audio_transcript_response)
+ return None
+
+async def save_transcript_to_db(audio_file_hash, audio_file_name, audio_file_size_mb, transcript_segments, info, ip_address, request_time, response_time, total_time, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts):
+ existing_transcript = await get_transcript_from_db(audio_file_hash)
+ if existing_transcript:
+ return existing_transcript
+ audio_transcript = AudioTranscript(
+ audio_file_hash=audio_file_hash,
+ audio_file_name=audio_file_name,
+ audio_file_size_mb=audio_file_size_mb,
+ segments_json=json.dumps(transcript_segments),
+ combined_transcript_text=combined_transcript_text,
+ combined_transcript_text_list_of_metadata_dicts=json.dumps(combined_transcript_text_list_of_metadata_dicts),
+ info_json=json.dumps(info),
+ ip_address=ip_address,
+ request_time=request_time,
+ response_time=response_time,
+ total_time=total_time
+ )
+ await shared_resources.db_writer.enqueue_write([audio_transcript])
+
+
+async def compute_and_store_transcript_embeddings(audio_file_name, list_of_transcript_sentences, llm_model_name, ip_address, combined_transcript_text, req: Request):
+ logger.info(f"Now computing embeddings for entire transcript of {audio_file_name}...")
+ zip_dir = 'generated_transcript_embeddings_zip_files'
+ if not os.path.exists(zip_dir):
+ os.makedirs(zip_dir)
+ sanitized_file_name = clean_filename_for_url_func(audio_file_name)
+ document_name = f"automatic_whisper_transcript_of__{sanitized_file_name}"
+ file_hash = sha3_256(combined_transcript_text.encode('utf-8')).hexdigest()
+ computed_embeddings = await compute_embeddings_for_document(list_of_transcript_sentences, llm_model_name, ip_address, file_hash)
+ zip_file_path = f"{zip_dir}/{quote(document_name)}.zip"
+ with zipfile.ZipFile(zip_file_path, 'w') as zipf:
+ zipf.writestr("embeddings.txt", json.dumps(computed_embeddings))
+ download_url = f"download/{quote(document_name)}.zip"
+ full_download_url = f"{req.base_url}{download_url}"
+ logger.info(f"Generated download URL for transcript embeddings: {full_download_url}")
+ fake_upload_file = FakeUploadFile(filename=document_name, content=combined_transcript_text.encode(), content_type='text/plain')
+ logger.info(f"Storing transcript embeddings for {audio_file_name} in the database...")
+ await store_document_embeddings_in_db(fake_upload_file, file_hash, combined_transcript_text.encode(), json.dumps(computed_embeddings).encode(), computed_embeddings, llm_model_name, ip_address, datetime.utcnow())
+ return full_download_url
+
+async def compute_transcript_with_whisper_from_audio_func(audio_file_hash, audio_file_path, audio_file_name, audio_file_size_mb, ip_address, req: Request, compute_embeddings_for_resulting_transcript_document=True, llm_model_name=DEFAULT_MODEL_NAME):
+ model_size = "large-v2"
+ logger.info(f"Loading Whisper model {model_size}...")
+ num_workers = 1 if psutil.virtual_memory().total < 32 * (1024 ** 3) else min(4, max(1, int((psutil.virtual_memory().total - 32 * (1024 ** 3)) / (4 * (1024 ** 3))))) # Only use more than 1 worker if there is at least 32GB of RAM; then use 1 worker per additional 4GB of RAM up to 4 workers max
+ model = await run_in_threadpool(WhisperModel, model_size, device="cpu", compute_type="auto", cpu_threads=os.cpu_count(), num_workers=num_workers)
+ request_time = datetime.utcnow()
+ logger.info(f"Computing transcript for {audio_file_name} which has a {audio_file_size_mb :.2f}MB file size...")
+ segments, info = await run_in_threadpool(model.transcribe, audio_file_path, beam_size=20)
+ if not segments:
+ logger.warning(f"No segments were returned for file {audio_file_name}.")
+ return [], {}, "", [], request_time, datetime.utcnow(), 0, ""
+ segment_details = []
+ for idx, segment in enumerate(segments):
+ details = {
+ "start": round(segment.start, 2),
+ "end": round(segment.end, 2),
+ "text": segment.text,
+ "avg_logprob": round(segment.avg_logprob, 2)
+ }
+ logger.info(f"Details of transcript segment {idx} from file {audio_file_name}: {details}")
+ segment_details.append(details)
+ combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, list_of_transcript_sentences = merge_transcript_segments_into_combined_text(segment_details)
+ if compute_embeddings_for_resulting_transcript_document:
+ download_url = await compute_and_store_transcript_embeddings(audio_file_name, list_of_transcript_sentences, llm_model_name, ip_address, combined_transcript_text, req)
+ else:
+ download_url = ''
+ response_time = datetime.utcnow()
+ total_time = (response_time - request_time).total_seconds()
+ logger.info(f"Transcript computed in {total_time} seconds.")
+ await save_transcript_to_db(audio_file_hash, audio_file_name, audio_file_size_mb, segment_details, info, ip_address, request_time, response_time, total_time, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts)
+ info_dict = info._asdict()
+ return segment_details, info_dict, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, request_time, response_time, total_time, download_url
+
+async def get_or_compute_transcript(file: UploadFile, compute_embeddings_for_resulting_transcript_document: bool, llm_model_name: str, req: Request = None) -> dict:
+ request_time = datetime.utcnow()
+ ip_address = req.client.host if req else "127.0.0.1"
+ file_contents = await file.read()
+ audio_file_hash = sha3_256(file_contents).hexdigest()
+ file.file.seek(0) # Reset file pointer after read
+ unique_id = f"transcript_{audio_file_hash}_{llm_model_name}"
+ lock = await shared_resources.lock_manager.lock(unique_id)
+ if lock.valid:
+ try:
+ existing_audio_transcript = await get_transcript_from_db(audio_file_hash)
+ if existing_audio_transcript:
+ return existing_audio_transcript
+ current_position = file.file.tell()
+ file.file.seek(0, os.SEEK_END)
+ audio_file_size_mb = file.file.tell() / (1024 * 1024)
+ file.file.seek(current_position)
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
+ shutil.copyfileobj(file.file, tmp_file)
+ audio_file_name = tmp_file.name
+ segment_details, info, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, request_time, response_time, total_time, download_url = await compute_transcript_with_whisper_from_audio_func(audio_file_hash, audio_file_name, file.filename, audio_file_size_mb, ip_address, req, compute_embeddings_for_resulting_transcript_document, llm_model_name)
+ audio_transcript_response = {
+ "audio_file_hash": audio_file_hash,
+ "audio_file_name": file.filename,
+ "audio_file_size_mb": audio_file_size_mb,
+ "segments_json": segment_details,
+ "combined_transcript_text": combined_transcript_text,
+ "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts,
+ "info_json": info,
+ "ip_address": ip_address,
+ "request_time": request_time,
+ "response_time": response_time,
+ "total_time": total_time,
+ "url_to_download_zip_file_of_embeddings": download_url if compute_embeddings_for_resulting_transcript_document else ""
+ }
+ os.remove(audio_file_name)
+ return AudioTranscriptResponse(**audio_transcript_response)
+ finally:
+ await shared_resources.lock_manager.unlock(lock)
+ else:
+ return {"status": "already processing"}
+
+
+# Core embedding functions start here:
+
+
+def add_model_url(new_url: str) -> str:
+ corrected_url = new_url
+ if '/blob/main/' in new_url:
+ corrected_url = new_url.replace('/blob/main/', '/resolve/main/')
+ json_path = os.path.join(BASE_DIRECTORY, "model_urls.json")
+ with open(json_path, "r") as f:
+ existing_urls = json.load(f)
+ if corrected_url not in existing_urls:
+ logger.info(f"Model URL not found in database. Adding {new_url} now...")
+ existing_urls.append(corrected_url)
+ with open(json_path, "w") as f:
+ json.dump(existing_urls, f)
+ logger.info(f"Model URL added: {new_url}")
+ else:
+ logger.info("Model URL already exists.")
+ return corrected_url
+
+async def get_embedding_from_db(text: str, llm_model_name: str):
+ text_hash = sha3_256(text.encode('utf-8')).hexdigest() # Compute the hash
+ return await execute_with_retry(_get_embedding_from_db, text_hash, llm_model_name)
+
+async def _get_embedding_from_db(text_hash: str, llm_model_name: str) -> Optional[dict]:
+ async with AsyncSessionLocal() as session:
+ result = await session.execute(
+ sql_text("SELECT embedding_json FROM embeddings WHERE text_hash=:text_hash AND llm_model_name=:llm_model_name"),
+ {"text_hash": text_hash, "llm_model_name": llm_model_name},
+ )
+ row = result.fetchone()
+ if row:
+ embedding_json = row[0]
+ logger.info(f"Embedding found in database for text hash '{text_hash}' using model '{llm_model_name}'")
+ return json.loads(embedding_json)
+ return None
+
+async def get_or_compute_embedding(request: EmbeddingRequest, req: Request = None, client_ip: str = None, document_file_hash: str = None) -> dict:
+ request_time = datetime.utcnow() # Capture request time as datetime object
+ ip_address = client_ip or (req.client.host if req else "localhost") # If client_ip is provided, use it; otherwise, try to get from req; if not available, default to "localhost"
+ logger.info(f"Received request for embedding for '{request.text}' using model '{request.llm_model_name}' from IP address '{ip_address}'")
+ embedding_list = await get_embedding_from_db(request.text, request.llm_model_name) # Check if embedding exists in the database
+ if embedding_list is not None:
+ response_time = datetime.utcnow() # Capture response time as datetime object
+ total_time = (response_time - request_time).total_seconds() # Calculate time taken in seconds
+ logger.info(f"Embedding found in database for '{request.text}' using model '{request.llm_model_name}'; returning in {total_time:.4f} seconds")
+ return {"embedding": embedding_list}
+ model = load_model(request.llm_model_name)
+ embedding_list = calculate_sentence_embedding(model, request.text) # Compute the embedding if not in the database
+ if embedding_list is None:
+ logger.error(f"Could not calculate the embedding for the given text: '{request.text}' using model '{request.llm_model_name}!'")
+ raise HTTPException(status_code=400, detail="Could not calculate the embedding for the given text")
+ embedding_json = json.dumps(embedding_list) # Serialize the numpy array to JSON and save to the database
+ response_time = datetime.utcnow() # Capture response time as datetime object
+ total_time = (response_time - request_time).total_seconds() # Calculate total time using datetime objects
+ word_length_of_input_text = len(request.text.split())
+ if word_length_of_input_text > 0:
+ logger.info(f"Embedding calculated for '{request.text}' using model '{request.llm_model_name}' in {total_time} seconds, or an average of {total_time/word_length_of_input_text :.2f} seconds per word. Now saving to database...")
+ await save_embedding_to_db(request.text, request.llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash)
+ return {"embedding": embedding_list}
+
+async def save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None):
+ existing_embedding = await get_embedding_from_db(text, llm_model_name) # Check if the embedding already exists
+ if existing_embedding is not None:
+ return existing_embedding
+ return await execute_with_retry(_save_embedding_to_db, text, llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash)
+
+async def _save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None):
+ existing_embedding = await get_embedding_from_db(text, llm_model_name)
+ if existing_embedding:
+ return existing_embedding
+ embedding = TextEmbedding(
+ text=text,
+ llm_model_name=llm_model_name,
+ embedding_json=embedding_json,
+ ip_address=ip_address,
+ request_time=request_time,
+ response_time=response_time,
+ total_time=total_time,
+ document_file_hash=document_file_hash
+ )
+ await shared_resources.db_writer.enqueue_write([embedding]) # Enqueue the write operation using the db_writer instance
+
+
+def load_token_level_embedding_model(llm_model_name: str, raise_http_exception: bool = True):
+ try:
+ if llm_model_name in token_level_embedding_model_cache: # Check if the model is already loaded in the cache
+ return token_level_embedding_model_cache[llm_model_name]
+ models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path
+ matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files
+ if not matching_files:
+ logger.error(f"No model file found matching: {llm_model_name}")
+ raise FileNotFoundError
+ matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first)
+ model_file_path = matching_files[0]
+ model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model
+ token_level_embedding_model_cache[llm_model_name] = model_instance # Cache the loaded model
+ return model_instance
+ except TypeError as e:
+ logger.error(f"TypeError occurred while loading the model: {e}")
+ raise
+ except Exception as e:
+ logger.error(f"Exception occurred while loading the model: {e}")
+ if raise_http_exception:
+ raise HTTPException(status_code=404, detail="Model file not found")
+ else:
+ raise FileNotFoundError(f"No model file found matching: {llm_model_name}")
+
+async def compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings) -> List[float]:
+ start_time = datetime.utcnow()
+ logger.info("Extracting token-level embeddings from the bundle")
+ parsed_df = pd.read_json(token_level_embeddings) # Parse the json_content back to a DataFrame
+ token_level_embeddings = list(parsed_df['embedding'])
+ embeddings = np.array(token_level_embeddings) # Convert the list of embeddings to a NumPy array
+ logger.info(f"Computing column-wise means/mins/maxes/std_devs of the embeddings... (shape: {embeddings.shape})")
+ assert(len(embeddings) > 0)
+ means = np.mean(embeddings, axis=0)
+ mins = np.min(embeddings, axis=0)
+ maxes = np.max(embeddings, axis=0)
+ stds = np.std(embeddings, axis=0)
+ logger.info("Concatenating the computed statistics to form the combined feature vector")
+ combined_feature_vector = np.concatenate([means, mins, maxes, stds])
+ end_time = datetime.utcnow()
+ total_time = (end_time - start_time).total_seconds()
+ logger.info(f"Computed the token-level embedding bundle's combined feature vector computed in {total_time: .2f} seconds.")
+ return combined_feature_vector.tolist()
+
+
+async def get_or_compute_token_level_embedding_bundle_combined_feature_vector(token_level_embedding_bundle_id, token_level_embeddings, db_writer: DatabaseWriter) -> List[float]:
+ request_time = datetime.utcnow()
+ request_time = datetime.utcnow()
+ logger.info(f"Checking for existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}")
+ async with AsyncSessionLocal() as session:
+ result = await session.execute(
+ select(TokenLevelEmbeddingBundleCombinedFeatureVector)
+ .filter(TokenLevelEmbeddingBundleCombinedFeatureVector.token_level_embedding_bundle_id == token_level_embedding_bundle_id)
+ )
+ existing_combined_feature_vector = result.scalar_one_or_none()
+ if existing_combined_feature_vector:
+ response_time = datetime.utcnow()
+ total_time = (response_time - request_time).total_seconds()
+ logger.info(f"Found existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Returning cached result in {total_time:.2f} seconds.")
+ return json.loads(existing_combined_feature_vector.combined_feature_vector_json) # Parse the JSON string into a list
+ logger.info(f"No cached combined feature_vector found for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Computing now...")
+ combined_feature_vector = await compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings)
+ combined_feature_vector_db_object = TokenLevelEmbeddingBundleCombinedFeatureVector(
+ token_level_embedding_bundle_id=token_level_embedding_bundle_id,
+ combined_feature_vector_json=json.dumps(combined_feature_vector) # Convert the list to a JSON string
+ )
+ logger.info(f"Writing combined feature vector for database write for token-level embedding bundle ID: {token_level_embedding_bundle_id} to the database...")
+ await db_writer.enqueue_write([combined_feature_vector_db_object])
+ return combined_feature_vector
+
+
+async def calculate_token_level_embeddings(text: str, llm_model_name: str, client_ip: str, token_level_embedding_bundle_id: int) -> List[np.array]:
+ request_time = datetime.utcnow()
+ logger.info(f"Starting token-level embedding calculation for text: '{text}' using model: '{llm_model_name}'")
+ logger.info(f"Loading model: '{llm_model_name}'")
+ llm = load_token_level_embedding_model(llm_model_name) # Assuming this method returns an instance of the Llama class
+ token_embeddings = []
+ tokens = text.split() # Simple whitespace tokenizer; can be replaced with a more advanced one if needed
+ logger.info(f"Tokenized text into {len(tokens)} tokens")
+ for idx, token in enumerate(tokens, start=1):
+ try: # Check if the embedding is already available in the database
+ existing_embedding = await get_token_level_embedding_from_db(token, llm_model_name)
+ if existing_embedding is not None:
+ token_embeddings.append(np.array(existing_embedding))
+ logger.info(f"Embedding retrieved from database for token '{token}'")
+ continue
+ logger.info(f"Processing token {idx} of {len(tokens)}: '{token}'")
+ token_embedding = llm.embed(token)
+ token_embedding_array = np.array(token_embedding)
+ token_embeddings.append(token_embedding_array)
+ response_time = datetime.utcnow()
+ token_level_embedding_json = json.dumps(token_embedding_array.tolist())
+ await store_token_level_embeddings_in_db(token, llm_model_name, token_level_embedding_json, client_ip, request_time, response_time, token_level_embedding_bundle_id)
+ except RuntimeError as e:
+ logger.error(f"Failed to calculate embedding for token '{token}': {e}")
+ logger.info(f"Completed token embedding calculation for all tokens in text: '{text}'")
+ return token_embeddings
+
+async def get_token_level_embedding_from_db(token: str, llm_model_name: str) -> Optional[List[float]]:
+ token_hash = sha3_256(token.encode('utf-8')).hexdigest() # Compute the hash
+ async with AsyncSessionLocal() as session:
+ result = await session.execute(
+ sql_text("SELECT token_level_embedding_json FROM token_level_embeddings WHERE token_hash=:token_hash AND llm_model_name=:llm_model_name"),
+ {"token_hash": token_hash, "llm_model_name": llm_model_name},
+ )
+ row = result.fetchone()
+ if row:
+ embedding_json = row[0]
+ logger.info(f"Embedding found in database for token hash '{token_hash}' using model '{llm_model_name}'")
+ return json.loads(embedding_json)
+ return None
+
+async def store_token_level_embeddings_in_db(token: str, llm_model_name: str, token_level_embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, token_level_embedding_bundle_id: int):
+ total_time = (response_time - request_time).total_seconds()
+ embedding = TokenLevelEmbedding(
+ token=token,
+ llm_model_name=llm_model_name,
+ token_level_embedding_json=token_level_embedding_json,
+ ip_address=ip_address,
+ request_time=request_time,
+ response_time=response_time,
+ total_time=total_time,
+ token_level_embedding_bundle_id=token_level_embedding_bundle_id
+ )
+ await shared_resources.db_writer.enqueue_write([embedding]) # Enqueue the write operation for the token-level embedding
+
+def calculate_sentence_embedding(llama: Llama, text: str) -> np.array:
+ sentence_embedding = None
+ retry_count = 0
+ while sentence_embedding is None and retry_count < 3:
+ try:
+ if retry_count > 0:
+ logger.info(f"Attempting again calculate sentence embedding. Attempt number {retry_count + 1}")
+ sentence_embedding = llama.embed_query(text)
+ except TypeError as e:
+ logger.error(f"TypeError in calculate_sentence_embedding: {e}")
+ raise
+ except Exception as e:
+ logger.error(f"Exception in calculate_sentence_embedding: {e}")
+ text = text[:-int(len(text) * 0.1)]
+ retry_count += 1
+ logger.info(f"Trimming sentence due to too many tokens. New length: {len(text)}")
+ if sentence_embedding is None:
+ logger.error("Failed to calculate sentence embedding after multiple attempts")
+ return sentence_embedding
+
+async def compute_embeddings_for_document(strings: list, llm_model_name: str, client_ip: str, document_file_hash: str) -> List[Tuple[str, np.array]]:
+ from swiss_army_llama import get_embedding_vector_for_string
+ results = []
+ if USE_PARALLEL_INFERENCE_QUEUE:
+ logger.info(f"Using parallel inference queue to compute embeddings for {len(strings)} strings")
+ start_time = time.perf_counter() # Record the start time
+ semaphore = asyncio.Semaphore(MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS)
+ async def compute_embedding(text): # Define a function to compute the embedding for a given text
+ try:
+ async with semaphore: # Acquire a semaphore slot
+ request = EmbeddingRequest(text=text, llm_model_name=llm_model_name)
+ embedding = await get_embedding_vector_for_string(request, client_ip=client_ip, document_file_hash=document_file_hash)
+ return text, embedding["embedding"]
+ except Exception as e:
+ logger.error(f"Error computing embedding for text '{text}': {e}")
+ return text, None
+ results = await asyncio.gather(*[compute_embedding(s) for s in strings]) # Use asyncio.gather to run the tasks concurrently
+ end_time = time.perf_counter() # Record the end time
+ duration = end_time - start_time
+ if len(strings) > 0:
+ logger.info(f"Parallel inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string")
+ else: # Compute embeddings sequentially
+ logger.info(f"Using sequential inference to compute embeddings for {len(strings)} strings")
+ start_time = time.perf_counter() # Record the start time
+ for s in strings:
+ embedding_request = EmbeddingRequest(text=s, llm_model_name=llm_model_name)
+ embedding = await get_embedding_vector_for_string(embedding_request, client_ip=client_ip, document_file_hash=document_file_hash)
+ results.append((s, embedding["embedding"]))
+ end_time = time.perf_counter() # Record the end time
+ duration = end_time - start_time
+ if len(strings) > 0:
+ logger.info(f"Sequential inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string")
+ filtered_results = [(text, embedding) for text, embedding in results if embedding is not None] # Filter out results with None embeddings (applicable to parallel processing) and return
+ return filtered_results
+
+async def parse_submitted_document_file_into_sentence_strings_func(temp_file_path: str, mime_type: str):
+ strings = []
+ if mime_type.startswith('text/'):
+ with open(temp_file_path, 'r') as buffer:
+ content = buffer.read()
+ else:
+ try:
+ content = textract.process(temp_file_path).decode('utf-8')
+ except UnicodeDecodeError:
+ try:
+ content = textract.process(temp_file_path).decode('unicode_escape')
+ except Exception as e:
+ logger.error(f"Error while processing file: {e}, mime_type: {mime_type}")
+ raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}")
+ except Exception as e:
+ logger.error(f"Error while processing file: {e}, mime_type: {mime_type}")
+ raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}")
+ sentences = sophisticated_sentence_splitter(content)
+ if len(sentences) == 0 and temp_file_path.lower().endswith('.pdf'):
+ logger.info("No sentences found, attempting OCR using Tesseract.")
+ try:
+ content = textract.process(temp_file_path, method='tesseract').decode('utf-8')
+ sentences = sophisticated_sentence_splitter(content)
+ except Exception as e:
+ logger.error(f"Error while processing file with OCR: {e}")
+ raise HTTPException(status_code=400, detail=f"OCR failed: {e}")
+ if len(sentences) == 0:
+ logger.info("No sentences found in the document")
+ raise HTTPException(status_code=400, detail="No sentences found in the document")
+ logger.info(f"Extracted {len(sentences)} sentences from the document")
+ strings = [s.strip() for s in sentences if len(s.strip()) > MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING]
+ return strings
+
+async def _get_document_from_db(file_hash: str):
+ async with AsyncSessionLocal() as session:
+ result = await session.execute(select(Document).filter(Document.document_hash == file_hash))
+ return result.scalar_one_or_none()
+
+async def store_document_embeddings_in_db(file: File, file_hash: str, original_file_content: bytes, json_content: bytes, results: List[Tuple[str, np.array]], llm_model_name: str, client_ip: str, request_time: datetime):
+ document = await _get_document_from_db(file_hash) # First, check if a Document with the same hash already exists
+ if not document: # If not, create a new Document object
+ document = Document(document_hash=file_hash, llm_model_name=llm_model_name)
+ await shared_resources.db_writer.enqueue_write([document])
+ document_embedding = DocumentEmbedding(
+ filename=file.filename,
+ mimetype=file.content_type,
+ file_hash=file_hash,
+ llm_model_name=llm_model_name,
+ file_data=original_file_content,
+ document_embedding_results_json=json.loads(json_content.decode()),
+ ip_address=client_ip,
+ request_time=request_time,
+ response_time=datetime.utcnow(),
+ total_time=(datetime.utcnow() - request_time).total_seconds()
+ )
+ document.document_embeddings.append(document_embedding) # Associate it with the Document
+ document.update_hash() # This will trigger the SQLAlchemy event to update the document_hash
+ await shared_resources.db_writer.enqueue_write([document, document_embedding]) # Enqueue the write operation for the document embedding
+ write_operations = [] # Collect text embeddings to write
+ logger.info(f"Storing {len(results)} text embeddings in database")
+ for text, embedding in results:
+ embedding_entry = await _get_embedding_from_db(text, llm_model_name)
+ if not embedding_entry:
+ embedding_entry = TextEmbedding(
+ text=text,
+ llm_model_name=llm_model_name,
+ embedding_json=json.dumps(embedding),
+ ip_address=client_ip,
+ request_time=request_time,
+ response_time=datetime.utcnow(),
+ total_time=(datetime.utcnow() - request_time).total_seconds(),
+ document_file_hash=file_hash # Link it to the DocumentEmbedding via file_hash
+ )
+ else:
+ write_operations.append(embedding_entry)
+ await shared_resources.db_writer.enqueue_write(write_operations) # Enqueue the write operation for text embeddings
+
+def load_text_completion_model(llm_model_name: str, raise_http_exception: bool = True):
+ try:
+ if llm_model_name in text_completion_model_cache: # Check if the model is already loaded in the cache
+ return text_completion_model_cache[llm_model_name]
+ models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path
+ matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files
+ if not matching_files:
+ logger.error(f"No model file found matching: {llm_model_name}")
+ raise FileNotFoundError
+ matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first)
+ model_file_path = matching_files[0]
+ model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model
+ text_completion_model_cache[llm_model_name] = model_instance # Cache the loaded model
+ return model_instance
+ except TypeError as e:
+ logger.error(f"TypeError occurred while loading the model: {e}")
+ raise
+ except Exception as e:
+ logger.error(f"Exception occurred while loading the model: {e}")
+ if raise_http_exception:
+ raise HTTPException(status_code=404, detail="Model file not found")
+ else:
+ raise FileNotFoundError(f"No model file found matching: {llm_model_name}")
+
+async def generate_completion_from_llm(request: TextCompletionRequest, req: Request = None, client_ip: str = None) -> List[TextCompletionResponse]:
+ request_time = datetime.utcnow()
+ logger.info(f"Starting text completion calculation using model: '{request.llm_model_name}'for input prompt: '{request.input_prompt}'")
+ logger.info(f"Loading model: '{request.llm_model_name}'")
+ llm = load_text_completion_model(request.llm_model_name)
+ logger.info(f"Done loading model: '{request.llm_model_name}'")
+ list_of_llm_outputs = []
+ grammar_file_string_lower = request.grammar_file_string.lower() if request.grammar_file_string else ""
+ if grammar_file_string_lower:
+ list_of_grammar_files = glob.glob("./grammar_files/*.gbnf")
+ matching_grammar_files = [x for x in list_of_grammar_files if grammar_file_string_lower in os.path.splitext(os.path.basename(x).lower())[0]]
+ if len(matching_grammar_files) == 0:
+ logger.error(f"No grammar file found matching: {request.grammar_file_string}")
+ raise FileNotFoundError
+ matching_grammar_files.sort(key=os.path.getmtime, reverse=True)
+ grammar_file_path = matching_grammar_files[0]
+ logger.info(f"Loading selected grammar file: '{grammar_file_path}'")
+ llama_grammar = LlamaGrammar.from_file(grammar_file_path)
+ for ii in range(request.number_of_completions_to_generate):
+ logger.info(f"Generating completion {ii+1} of {request.number_of_completions_to_generate} with model {request.llm_model_name} for input prompt: '{request.input_prompt}'")
+ output = llm(prompt=request.input_prompt, grammar=llama_grammar, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature)
+ list_of_llm_outputs.append(output)
+ else:
+ for ii in range(request.number_of_completions_to_generate):
+ output = llm(prompt=request.input_prompt, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature)
+ list_of_llm_outputs.append(output)
+ response_time = datetime.utcnow()
+ total_time_per_completion = ((response_time - request_time).total_seconds()) / request.number_of_completions_to_generate
+ list_of_responses = []
+ for idx, current_completion_output in enumerate(list_of_llm_outputs):
+ generated_text = current_completion_output['choices'][0]['text']
+ if request.grammar_file_string == 'json':
+ generated_text = generated_text.encode('unicode_escape').decode()
+ llm_model_usage_json = json.dumps(current_completion_output['usage'])
+ logger.info(f"Completed text completion {idx} in an average of {total_time_per_completion:.2f} seconds for input prompt: '{request.input_prompt}'; Beginning of generated text: \n'{generated_text[:100]}'")
+ response = TextCompletionResponse(input_prompt = request.input_prompt,
+ llm_model_name = request.llm_model_name,
+ grammar_file_string = request.grammar_file_string,
+ number_of_tokens_to_generate = request.number_of_tokens_to_generate,
+ number_of_completions_to_generate = request.number_of_completions_to_generate,
+ time_taken_in_seconds = float(total_time_per_completion),
+ generated_text = generated_text,
+ llm_model_usage_json = llm_model_usage_json)
+ list_of_responses.append(response)
+ return list_of_responses
diff --git a/shared_resources.py b/shared_resources.py
new file mode 100644
index 0000000..eee389b
--- /dev/null
+++ b/shared_resources.py
@@ -0,0 +1,149 @@
+from misc_utility_functions import is_redis_running, build_faiss_indexes
+from database_functions import DatabaseWriter, initialize_db
+from ramdisk_functions import setup_ramdisk, copy_models_to_ramdisk, check_that_user_has_required_permissions_to_manage_ramdisks
+from logger_config import setup_logger
+from aioredlock import Aioredlock
+import aioredis
+import asyncio
+import subprocess
+import urllib.request
+import os
+import glob
+import json
+from typing import List, Tuple, Dict
+from langchain.embeddings import LlamaCppEmbeddings
+from decouple import config
+from fastapi import HTTPException
+
+logger = setup_logger()
+embedding_model_cache = {} # Model cache to store loaded models
+token_level_embedding_model_cache = {} # Model cache to store loaded token-level embedding models
+text_completion_model_cache = {} # Model cache to store loaded text completion models
+
+SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int)
+DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str)
+LLM_CONTEXT_SIZE_IN_TOKENS = config("LLM_CONTEXT_SIZE_IN_TOKENS", default=512, cast=int)
+TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS = config("TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS", default=4000, cast=int)
+DEFAULT_MAX_COMPLETION_TOKENS = config("DEFAULT_MAX_COMPLETION_TOKENS", default=100, cast=int)
+DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE = config("DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE", default=4, cast=int)
+DEFAULT_COMPLETION_TEMPERATURE = config("DEFAULT_COMPLETION_TEMPERATURE", default=0.7, cast=float)
+MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING = config("MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING", default=15, cast=int)
+USE_PARALLEL_INFERENCE_QUEUE = config("USE_PARALLEL_INFERENCE_QUEUE", default=False, cast=bool)
+MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS = config("MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS", default=10, cast=int)
+USE_RAMDISK = config("USE_RAMDISK", default=False, cast=bool)
+RAMDISK_PATH = config("RAMDISK_PATH", default="/mnt/ramdisk", cast=str)
+BASE_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
+
+async def initialize_globals():
+ global db_writer, faiss_indexes, token_faiss_indexes, associated_texts_by_model, redis, lock_manager
+ if not is_redis_running():
+ logger.info("Starting Redis server...")
+ subprocess.Popen(['redis-server'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ await asyncio.sleep(1) # Sleep for 1 second to give Redis time to start
+ redis = await aioredis.create_redis_pool('redis://localhost')
+ lock_manager = Aioredlock([redis])
+ await initialize_db()
+ queue = asyncio.Queue()
+ db_writer = DatabaseWriter(queue)
+ await db_writer.initialize_processing_hashes()
+ asyncio.create_task(db_writer.dedicated_db_writer())
+ global USE_RAMDISK
+ if USE_RAMDISK and not check_that_user_has_required_permissions_to_manage_ramdisks():
+ USE_RAMDISK = False
+ elif USE_RAMDISK:
+ setup_ramdisk()
+ list_of_downloaded_model_names, download_status = download_models()
+ for llm_model_name in list_of_downloaded_model_names:
+ try:
+ load_model(llm_model_name, raise_http_exception=False)
+ except FileNotFoundError as e:
+ logger.error(e)
+ faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes()
+
+# other shared variables and methods
+db_writer = None
+faiss_indexes = None
+token_faiss_indexes = None
+associated_texts_by_model = None
+redis = None
+lock_manager = None
+
+
+def download_models() -> Tuple[List[str], List[Dict[str, str]]]:
+ download_status = []
+ json_path = os.path.join(BASE_DIRECTORY, "model_urls.json")
+ if not os.path.exists(json_path):
+ initial_model_urls = [
+ 'https://huggingface.co/TheBloke/Yarn-Llama-2-7B-128K-GGUF/resolve/main/yarn-llama-2-7b-128k.Q4_K_M.gguf',
+ 'https://huggingface.co/TheBloke/openchat_v3.2_super-GGUF/resolve/main/openchat_v3.2_super.Q4_K_M.gguf',
+ 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q6_K.gguf'
+ ]
+ with open(json_path, "w") as f:
+ json.dump(initial_model_urls, f)
+ with open(json_path, "r") as f:
+ list_of_model_download_urls = json.load(f)
+ model_names = [os.path.basename(url) for url in list_of_model_download_urls]
+ current_file_path = os.path.abspath(__file__)
+ base_dir = os.path.dirname(current_file_path)
+ models_dir = os.path.join(base_dir, 'models')
+ logger.info("Checking models directory...")
+ if USE_RAMDISK:
+ ramdisk_models_dir = os.path.join(RAMDISK_PATH, 'models')
+ if not os.path.exists(RAMDISK_PATH):
+ setup_ramdisk()
+ if all(os.path.exists(os.path.join(ramdisk_models_dir, llm_model_name)) for llm_model_name in model_names):
+ logger.info("Models found in RAM Disk.")
+ for url in list_of_model_download_urls:
+ download_status.append({"url": url, "status": "success", "message": "Model found in RAM Disk."})
+ return model_names, download_status
+ if not os.path.exists(models_dir):
+ os.makedirs(models_dir)
+ logger.info(f"Created models directory: {models_dir}")
+ else:
+ logger.info(f"Models directory exists: {models_dir}")
+ for url, model_name_with_extension in zip(list_of_model_download_urls, model_names):
+ status = {"url": url, "status": "success", "message": "File already exists."}
+ filename = os.path.join(models_dir, model_name_with_extension)
+ if not os.path.exists(filename):
+ logger.info(f"Downloading model {model_name_with_extension} from {url}...")
+ urllib.request.urlretrieve(url, filename)
+ file_size = os.path.getsize(filename) / (1024 * 1024) # Convert bytes to MB
+ if file_size < 100:
+ os.remove(filename)
+ status["status"] = "failure"
+ status["message"] = "Downloaded file is too small, probably not a valid model file."
+ else:
+ logger.info(f"Downloaded: {filename}")
+ else:
+ logger.info(f"File already exists: {filename}")
+ download_status.append(status)
+ if USE_RAMDISK:
+ copy_models_to_ramdisk(models_dir, ramdisk_models_dir)
+ logger.info("Model downloads completed.")
+ return model_names, download_status
+
+
+def load_model(llm_model_name: str, raise_http_exception: bool = True):
+ try:
+ models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models')
+ if llm_model_name in embedding_model_cache:
+ return embedding_model_cache[llm_model_name]
+ matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*"))
+ if not matching_files:
+ logger.error(f"No model file found matching: {llm_model_name}")
+ raise FileNotFoundError
+ matching_files.sort(key=os.path.getmtime, reverse=True)
+ model_file_path = matching_files[0]
+ model_instance = LlamaCppEmbeddings(model_path=model_file_path, use_mlock=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS)
+ model_instance.client.verbose = False
+ embedding_model_cache[llm_model_name] = model_instance
+ return model_instance
+ except TypeError as e:
+ logger.error(f"TypeError occurred while loading the model: {e}")
+ raise
+ except Exception as e:
+ logger.error(f"Exception occurred while loading the model: {e}")
+ if raise_http_exception:
+ raise HTTPException(status_code=404, detail="Model file not found")
+ else:
+ raise FileNotFoundError(f"No model file found matching: {llm_model_name}")
\ No newline at end of file
diff --git a/swiss_army_llama.py b/swiss_army_llama.py
index 359d3fb..740df5e 100644
--- a/swiss_army_llama.py
+++ b/swiss_army_llama.py
@@ -1,80 +1,53 @@
-from embeddings_data_models import Base, TextEmbedding, DocumentEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundle, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript
+import shared_resources
+from shared_resources import initialize_globals, download_models
+from logger_config import setup_logger
+from database_functions import AsyncSessionLocal, DatabaseWriter, get_db_writer
+from ramdisk_functions import clear_ramdisk
+from misc_utility_functions import build_faiss_indexes
+from embeddings_data_models import DocumentEmbedding, TokenLevelEmbeddingBundle
from embeddings_data_models import EmbeddingRequest, SemanticSearchRequest, AdvancedSemanticSearchRequest, SimilarityRequest, TextCompletionRequest
-from embeddings_data_models import EmbeddingResponse, SemanticSearchResponse, AdvancedSemanticSearchResponse, SimilarityResponse, AllStringsResponse, AllDocumentsResponse, TextCompletionResponse, AudioTranscriptResponse
+from embeddings_data_models import EmbeddingResponse, SemanticSearchResponse, AdvancedSemanticSearchResponse, SimilarityResponse, AllStringsResponse, AllDocumentsResponse, TextCompletionResponse
from embeddings_data_models import ShowLogsIncrementalModel
+from service_functions import get_or_compute_embedding, get_or_compute_transcript, add_model_url, get_or_compute_token_level_embedding_bundle_combined_feature_vector, calculate_token_level_embeddings
+from service_functions import parse_submitted_document_file_into_sentence_strings_func, compute_embeddings_for_document, store_document_embeddings_in_db, generate_completion_from_llm
from log_viewer_functions import show_logs_incremental_func, show_logs_func
+from uvicorn_config import option
import asyncio
-import io
import glob
import json
-import logging
import os
-import random
import re
-import shutil
-import subprocess
import tempfile
-import time
import traceback
-import urllib.request
import zipfile
-from collections import defaultdict
from datetime import datetime
from hashlib import sha3_256
-from logging.handlers import RotatingFileHandler
-from typing import List, Optional, Tuple, Dict, Any
-from urllib.parse import quote, unquote
+from typing import List, Optional, Dict, Any
+from urllib.parse import unquote
import numpy as np
from decouple import config
import uvicorn
-import psutil
-import textract
import fastapi
from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Depends
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse, Response
-from fastapi.concurrency import run_in_threadpool
-from langchain.embeddings import LlamaCppEmbeddings
from sqlalchemy import select
from sqlalchemy import text as sql_text
-from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
-from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
-from sqlalchemy.orm import sessionmaker, joinedload
+from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm import joinedload
import faiss
import pandas as pd
from magic import Magic
-from llama_cpp import Llama, LlamaGrammar
import fast_vector_similarity as fvs
-from faster_whisper import WhisperModel
+import uvloop
+
+asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
+logger = setup_logger()
# Note: the Ramdisk setup and teardown requires sudo; to enable password-less sudo, edit your sudoers file with `sudo visudo`.
# Add the following lines, replacing username with your actual username
# username ALL=(ALL) NOPASSWD: /bin/mount -t tmpfs -o size=*G tmpfs /mnt/ramdisk
# username ALL=(ALL) NOPASSWD: /bin/umount /mnt/ramdisk
-# Setup logging
-old_logs_dir = 'old_logs' # Ensure the old_logs directory exists
-if not os.path.exists(old_logs_dir):
- os.makedirs(old_logs_dir)
-logger = logging.getLogger()
-logger.setLevel(logging.INFO)
-formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
-log_file_path = 'swiss_army_llama.log'
-fh = RotatingFileHandler(log_file_path, maxBytes=10*1024*1024, backupCount=5)
-fh.setFormatter(formatter)
-logger.addHandler(fh)
-def namer(default_log_name): # Move rotated logs to the old_logs directory
- return os.path.join(old_logs_dir, os.path.basename(default_log_name))
-def rotator(source, dest):
- shutil.move(source, dest)
-fh.namer = namer
-fh.rotator = rotator
-sh = logging.StreamHandler()
-sh.setFormatter(formatter)
-logger.addHandler(sh)
-logger = logging.getLogger(__name__)
-logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
-configured_logger = logger
-
# Global variables
use_hardcoded_security_token = 0
if use_hardcoded_security_token:
@@ -82,1056 +55,17 @@ def rotator(source, dest):
USE_SECURITY_TOKEN = config("USE_SECURITY_TOKEN", default=False, cast=bool)
else:
USE_SECURITY_TOKEN = False
-DATABASE_URL = "sqlite+aiosqlite:///swiss_army_llama.sqlite"
-SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int)
DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str)
-LLM_CONTEXT_SIZE_IN_TOKENS = config("LLM_CONTEXT_SIZE_IN_TOKENS", default=512, cast=int)
-TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS = config("TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS", default=4000, cast=int)
-DEFAULT_MAX_COMPLETION_TOKENS = config("DEFAULT_MAX_COMPLETION_TOKENS", default=100, cast=int)
-DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE = config("DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE", default=4, cast=int)
-DEFAULT_COMPLETION_TEMPERATURE = config("DEFAULT_COMPLETION_TEMPERATURE", default=0.7, cast=float)
-MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING = config("MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING", default=15, cast=int)
-USE_PARALLEL_INFERENCE_QUEUE = config("USE_PARALLEL_INFERENCE_QUEUE", default=False, cast=bool)
-MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS = config("MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS", default=10, cast=int)
USE_RAMDISK = config("USE_RAMDISK", default=False, cast=bool)
RAMDISK_PATH = config("RAMDISK_PATH", default="/mnt/ramdisk", cast=str)
-RAMDISK_SIZE_IN_GB = config("RAMDISK_SIZE_IN_GB", default=1, cast=int)
-MAX_RETRIES = config("MAX_RETRIES", default=3, cast=int)
-DB_WRITE_BATCH_SIZE = config("DB_WRITE_BATCH_SIZE", default=25, cast=int)
-RETRY_DELAY_BASE_SECONDS = config("RETRY_DELAY_BASE_SECONDS", default=1, cast=int)
-JITTER_FACTOR = config("JITTER_FACTOR", default=0.1, cast=float)
BASE_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
-embedding_model_cache = {} # Model cache to store loaded models
-token_level_embedding_model_cache = {} # Model cache to store loaded token-level embedding models
-text_completion_model_cache = {} # Model cache to store loaded text completion models
+
logger.info(f"USE_RAMDISK is set to: {USE_RAMDISK}")
-db_writer = None
description_string = """
🇨🇭🎖️🦙 Swiss Army Llama is your One-Stop-Shop to Quickly and Conveniently Integrate Powerful Local LLM Functionality into your Project via a REST API.
"""
app = FastAPI(title="Swiss Army Llama", description=description_string, docs_url="/") # Set the Swagger UI to root
-engine = create_async_engine(DATABASE_URL, echo=False, connect_args={"check_same_thread": False})
-AsyncSessionLocal = sessionmaker(
- bind=engine,
- class_=AsyncSession,
- expire_on_commit=False,
- autoflush=False
-)
-
-# Misc. utility functions and db writer class:
-def clean_filename_for_url_func(dirty_filename: str) -> str:
- clean_filename = re.sub(r'[^\w\s]', '', dirty_filename) # Remove special characters and replace spaces with underscores
- clean_filename = clean_filename.replace(' ', '_')
- return clean_filename
-
-class DatabaseWriter:
- def __init__(self, queue):
- self.queue = queue
- self.processing_hashes = set() # Set to store the hashes if everything that is currently being processed in the queue (to avoid duplicates of the same task being added to the queue)
-
- def _get_hash_from_operation(self, operation):
- attr_name = {
- TextEmbedding: 'text_hash',
- DocumentEmbedding: 'file_hash',
- Document: 'document_hash',
- TokenLevelEmbedding: 'token_hash',
- TokenLevelEmbeddingBundle: 'input_text_hash',
- TokenLevelEmbeddingBundleCombinedFeatureVector: 'combined_feature_vector_hash',
- AudioTranscript: 'audio_file_hash'
- }.get(type(operation))
- hash_value = getattr(operation, attr_name, None)
- llm_model_name = getattr(operation, 'llm_model_name', None)
- return f"{hash_value}_{llm_model_name}" if hash_value and llm_model_name else None
-
- async def initialize_processing_hashes(self, chunk_size=1000):
- start_time = datetime.utcnow()
- async with AsyncSessionLocal() as session:
- queries = [
- (select(TextEmbedding.text_hash, TextEmbedding.llm_model_name), True),
- (select(DocumentEmbedding.file_hash, DocumentEmbedding.llm_model_name), True),
- (select(Document.document_hash, Document.llm_model_name), True),
- (select(TokenLevelEmbedding.token_hash, TokenLevelEmbedding.llm_model_name), True),
- (select(TokenLevelEmbeddingBundle.input_text_hash, TokenLevelEmbeddingBundle.llm_model_name), True),
- (select(TokenLevelEmbeddingBundleCombinedFeatureVector.combined_feature_vector_hash, TokenLevelEmbeddingBundleCombinedFeatureVector.llm_model_name), True),
- (select(AudioTranscript.audio_file_hash), False)
- ]
- for query, has_llm in queries:
- offset = 0
- while True:
- result = await session.execute(query.limit(chunk_size).offset(offset))
- rows = result.fetchall()
- if not rows:
- break
- for row in rows:
- if has_llm:
- hash_with_model = f"{row[0]}_{row[1]}"
- else:
- hash_with_model = row[0]
- self.processing_hashes.add(hash_with_model)
- offset += chunk_size
- end_time = datetime.utcnow()
- total_time = (end_time - start_time).total_seconds()
- if len(self.processing_hashes) > 0:
- logger.info(f"Finished initializing set of input hash/llm_model_name combinations that are either currently being processed or have already been processed. Set size: {len(self.processing_hashes)}; Took {total_time} seconds, for an average of {total_time / len(self.processing_hashes)} seconds per hash.")
-
- async def _handle_integrity_error(self, e, write_operation, session):
- unique_constraint_msg = {
- TextEmbedding: "token_embeddings.token_hash, token_embeddings.llm_model_name",
- DocumentEmbedding: "document_embeddings.file_hash, document_embeddings.llm_model_name",
- Document: "documents.document_hash, documents.llm_model_name",
- TokenLevelEmbedding: "token_level_embeddings.token_hash, token_level_embeddings.llm_model_name",
- TokenLevelEmbeddingBundle: "token_level_embedding_bundles.input_text_hash, token_level_embedding_bundles.llm_model_name",
- AudioTranscript: "audio_transcripts.audio_file_hash"
- }.get(type(write_operation))
- if unique_constraint_msg and unique_constraint_msg in str(e):
- logger.warning(f"Embedding already exists in the database for given input and llm_model_name: {e}")
- await session.rollback()
- else:
- raise
-
- async def dedicated_db_writer(self):
- while True:
- write_operations_batch = await self.queue.get()
- async with AsyncSessionLocal() as session:
- try:
- for write_operation in write_operations_batch:
- session.add(write_operation)
- await session.flush() # Flush to get the IDs
- await session.commit()
- for write_operation in write_operations_batch:
- hash_to_remove = self._get_hash_from_operation(write_operation)
- if hash_to_remove is not None and hash_to_remove in self.processing_hashes:
- self.processing_hashes.remove(hash_to_remove)
- except IntegrityError as e:
- await self._handle_integrity_error(e, write_operation, session)
- except SQLAlchemyError as e:
- logger.error(f"Database error: {e}")
- await session.rollback()
- except Exception as e:
- tb = traceback.format_exc()
- logger.error(f"Unexpected error: {e}\n{tb}")
- await session.rollback()
- self.queue.task_done()
-
- async def enqueue_write(self, write_operations):
- write_operations = [op for op in write_operations if self._get_hash_from_operation(op) not in self.processing_hashes] # Filter out write operations for hashes that are already being processed
- if not write_operations: # If there are no write operations left after filtering, return early
- return
- for op in write_operations: # Add the hashes of the write operations to the set
- hash_value = self._get_hash_from_operation(op)
- if hash_value:
- self.processing_hashes.add(hash_value)
- await self.queue.put(write_operations)
-
-
-async def execute_with_retry(func, *args, **kwargs):
- retries = 0
- while retries < MAX_RETRIES:
- try:
- return await func(*args, **kwargs)
- except OperationalError as e:
- if 'database is locked' in str(e):
- retries += 1
- sleep_time = RETRY_DELAY_BASE_SECONDS * (2 ** retries) + (random.random() * JITTER_FACTOR) # Implementing exponential backoff with jitter
- logger.warning(f"Database is locked. Retrying ({retries}/{MAX_RETRIES})... Waiting for {sleep_time} seconds")
- await asyncio.sleep(sleep_time)
- else:
- raise
- raise OperationalError("Database is locked after multiple retries")
-
-async def initialize_db():
- logger.info("Initializing database, creating tables, and setting SQLite PRAGMAs...")
- list_of_sqlite_pragma_strings = ["PRAGMA journal_mode=WAL;", "PRAGMA synchronous = NORMAL;", "PRAGMA cache_size = -1048576;", "PRAGMA busy_timeout = 2000;", "PRAGMA wal_autocheckpoint = 100;"]
- list_of_sqlite_pragma_justification_strings = ["Set SQLite to use Write-Ahead Logging (WAL) mode (from default DELETE mode) so that reads and writes can occur simultaneously",
- "Set synchronous mode to NORMAL (from FULL) so that writes are not blocked by reads",
- "Set cache size to 1GB (from default 2MB) so that more data can be cached in memory and not read from disk; to make this 256MB, set it to -262144 instead",
- "Increase the busy timeout to 2 seconds so that the database waits",
- "Set the WAL autocheckpoint to 100 (from default 1000) so that the WAL file is checkpointed more frequently"]
- assert(len(list_of_sqlite_pragma_strings) == len(list_of_sqlite_pragma_justification_strings))
- async with engine.begin() as conn:
- for pragma_string in list_of_sqlite_pragma_strings:
- await conn.execute(sql_text(pragma_string))
- logger.info(f"Executed SQLite PRAGMA: {pragma_string}")
- logger.info(f"Justification: {list_of_sqlite_pragma_justification_strings[list_of_sqlite_pragma_strings.index(pragma_string)]}")
- await conn.run_sync(Base.metadata.create_all) # Create tables if they don't exist
- logger.info("Database initialization completed.")
-
-def get_db_writer() -> DatabaseWriter:
- return db_writer # Return the existing DatabaseWriter instance
-
-def check_that_user_has_required_permissions_to_manage_ramdisks():
- try: # Try to run a harmless command with sudo to test if the user has password-less sudo permissions
- result = subprocess.run(["sudo", "ls"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
- if "password" in result.stderr.lower():
- raise PermissionError("Password required for sudo")
- logger.info("User has sufficient permissions to manage RAM Disks.")
- return True
- except (PermissionError, subprocess.CalledProcessError) as e:
- logger.info("Sorry, current user does not have sufficient permissions to manage RAM Disks! Disabling RAM Disks for now...")
- logger.debug(f"Permission check error detail: {e}")
- return False
-
-def setup_ramdisk():
- cmd_check = f"sudo mount | grep {RAMDISK_PATH}" # Check if RAM disk already exists at the path
- result = subprocess.run(cmd_check, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
- if RAMDISK_PATH in result:
- logger.info(f"RAM Disk already set up at {RAMDISK_PATH}. Skipping setup.")
- return
- total_ram_gb = psutil.virtual_memory().total / (1024 ** 3)
- free_ram_gb = psutil.virtual_memory().free / (1024 ** 3)
- buffer_gb = 2 # buffer to ensure we don't use all the free RAM
- ramdisk_size_gb = max(min(RAMDISK_SIZE_IN_GB, free_ram_gb - buffer_gb), 0.1)
- ramdisk_size_mb = int(ramdisk_size_gb * 1024)
- ramdisk_size_str = f"{ramdisk_size_mb}M"
- logger.info(f"Total RAM: {total_ram_gb}G")
- logger.info(f"Free RAM: {free_ram_gb}G")
- logger.info(f"Calculated RAM Disk Size: {ramdisk_size_gb}G")
- if RAMDISK_SIZE_IN_GB > total_ram_gb:
- raise ValueError(f"Cannot allocate {RAMDISK_SIZE_IN_GB}G for RAM Disk. Total system RAM is {total_ram_gb:.2f}G.")
- logger.info("Setting up RAM Disk...")
- os.makedirs(RAMDISK_PATH, exist_ok=True)
- mount_command = ["sudo", "mount", "-t", "tmpfs", "-o", f"size={ramdisk_size_str}", "tmpfs", RAMDISK_PATH]
- subprocess.run(mount_command, check=True)
- logger.info(f"RAM Disk set up at {RAMDISK_PATH} with size {ramdisk_size_gb}G")
-
-def copy_models_to_ramdisk(models_directory, ramdisk_directory):
- total_size = sum(os.path.getsize(os.path.join(models_directory, model)) for model in os.listdir(models_directory))
- free_ram = psutil.virtual_memory().free
- if total_size > free_ram:
- logger.warning(f"Not enough space on RAM Disk. Required: {total_size}, Available: {free_ram}. Rebuilding RAM Disk.")
- clear_ramdisk()
- free_ram = psutil.virtual_memory().free # Recompute the available RAM after clearing the RAM disk
- if total_size > free_ram:
- logger.error(f"Still not enough space on RAM Disk even after clearing. Required: {total_size}, Available: {free_ram}.")
- raise ValueError("Not enough RAM space to copy models.")
- setup_ramdisk()
- os.makedirs(ramdisk_directory, exist_ok=True)
- for model in os.listdir(models_directory):
- shutil.copyfile(os.path.join(models_directory, model), os.path.join(ramdisk_directory, model))
- logger.info(f"Copied model {model} to RAM Disk at {os.path.join(ramdisk_directory, model)}")
-
-def clear_ramdisk():
- while True:
- cmd_check = f"sudo mount | grep {RAMDISK_PATH}"
- result = subprocess.run(cmd_check, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
- if RAMDISK_PATH not in result:
- break # Exit the loop if the RAMDISK_PATH is not in the mount list
- cmd_umount = f"sudo umount -l {RAMDISK_PATH}"
- subprocess.run(cmd_umount, shell=True, check=True)
- logger.info(f"Cleared RAM Disk at {RAMDISK_PATH}")
-
-async def build_faiss_indexes():
- global faiss_indexes, token_faiss_indexes, associated_texts_by_model
- faiss_indexes = {}
- token_faiss_indexes = {} # Separate FAISS indexes for token-level embeddings
- associated_texts_by_model = defaultdict(list) # Create a dictionary to store associated texts by model name
- async with AsyncSessionLocal() as session:
- result = await session.execute(sql_text("SELECT llm_model_name, text, embedding_json FROM embeddings")) # Query regular embeddings
- token_result = await session.execute(sql_text("SELECT llm_model_name, token, token_level_embedding_json FROM token_level_embeddings")) # Query token-level embeddings
- embeddings_by_model = defaultdict(list)
- token_embeddings_by_model = defaultdict(list)
- for row in result.fetchall(): # Process regular embeddings
- llm_model_name = row[0]
- associated_texts_by_model[llm_model_name].append(row[1]) # Store the associated text by model name
- embeddings_by_model[llm_model_name].append((row[1], json.loads(row[2])))
- for row in token_result.fetchall(): # Process token-level embeddings
- llm_model_name = row[0]
- token_embeddings_by_model[llm_model_name].append(json.loads(row[2]))
- for llm_model_name, embeddings in embeddings_by_model.items():
- logger.info(f"Building Faiss index over embeddings for model {llm_model_name}...")
- embeddings_array = np.array([e[1] for e in embeddings]).astype('float32')
- if embeddings_array.size == 0:
- logger.error(f"No embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!")
- continue
- logger.info(f"Loaded {len(embeddings_array)} embeddings for model {llm_model_name}.")
- logger.info(f"Embedding dimension for model {llm_model_name}: {embeddings_array.shape[1]}")
- logger.info(f"Normalizing {len(embeddings_array)} embeddings for model {llm_model_name}...")
- faiss.normalize_L2(embeddings_array) # Normalize the vectors for cosine similarity
- faiss_index = faiss.IndexFlatIP(embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity
- faiss_index.add(embeddings_array)
- logger.info(f"Faiss index built for model {llm_model_name}.")
- faiss_indexes[llm_model_name] = faiss_index # Store the index by model name
- for llm_model_name, token_embeddings in token_embeddings_by_model.items():
- token_embeddings_array = np.array(token_embeddings).astype('float32')
- if token_embeddings_array.size == 0:
- logger.error(f"No token-level embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!")
- continue
- logger.info(f"Normalizing {len(token_embeddings_array)} token-level embeddings for model {llm_model_name}...")
- faiss.normalize_L2(token_embeddings_array) # Normalize the vectors for cosine similarity
- token_faiss_index = faiss.IndexFlatIP(token_embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity
- token_faiss_index.add(token_embeddings_array)
- logger.info(f"Token-level Faiss index built for model {llm_model_name}.")
- token_faiss_indexes[llm_model_name] = token_faiss_index # Store the token-level index by model name
- return faiss_indexes, token_faiss_indexes, associated_texts_by_model
-
-class JSONAggregator:
- def __init__(self):
- self.completions = []
- self.aggregate_result = None
-
- @staticmethod
- def weighted_vote(values, weights):
- tally = defaultdict(float)
- for v, w in zip(values, weights):
- tally[v] += w
- return max(tally, key=tally.get)
-
- @staticmethod
- def flatten_json(json_obj, parent_key='', sep='->'):
- items = {}
- for k, v in json_obj.items():
- new_key = f"{parent_key}{sep}{k}" if parent_key else k
- if isinstance(v, dict):
- items.update(JSONAggregator.flatten_json(v, new_key, sep=sep))
- else:
- items[new_key] = v
- return items
-
- @staticmethod
- def get_value_by_path(json_obj, path, sep='->'):
- keys = path.split(sep)
- item = json_obj
- for k in keys:
- item = item[k]
- return item
-
- @staticmethod
- def set_value_by_path(json_obj, path, value, sep='->'):
- keys = path.split(sep)
- item = json_obj
- for k in keys[:-1]:
- item = item.setdefault(k, {})
- item[keys[-1]] = value
-
- def calculate_path_weights(self):
- all_paths = []
- for j in self.completions:
- all_paths += list(self.flatten_json(j).keys())
- path_weights = defaultdict(float)
- for path in all_paths:
- path_weights[path] += 1.0
- return path_weights
-
- def aggregate(self):
- path_weights = self.calculate_path_weights()
- aggregate = {}
- for path, weight in path_weights.items():
- values = [self.get_value_by_path(j, path) for j in self.completions if path in self.flatten_json(j)]
- weights = [weight] * len(values)
- aggregate_value = self.weighted_vote(values, weights)
- self.set_value_by_path(aggregate, path, aggregate_value)
- self.aggregate_result = aggregate
-
-class FakeUploadFile:
- def __init__(self, filename: str, content: Any, content_type: str = 'text/plain'):
- self.filename = filename
- self.content_type = content_type
- self.file = io.BytesIO(content)
- def read(self, size: int = -1) -> bytes:
- return self.file.read(size)
- def seek(self, offset: int, whence: int = 0) -> int:
- return self.file.seek(offset, whence)
- def tell(self) -> int:
- return self.file.tell()
-
-async def get_transcript_from_db(audio_file_hash: str):
- return await execute_with_retry(_get_transcript_from_db, audio_file_hash)
-
-async def _get_transcript_from_db(audio_file_hash: str) -> Optional[dict]:
- async with AsyncSessionLocal() as session:
- result = await session.execute(
- sql_text("SELECT * FROM audio_transcripts WHERE audio_file_hash=:audio_file_hash"),
- {"audio_file_hash": audio_file_hash},
- )
- row = result.fetchone()
- if row:
- try:
- segments_json = json.loads(row.segments_json)
- combined_transcript_text_list_of_metadata_dicts = json.loads(row.combined_transcript_text_list_of_metadata_dicts)
- info_json = json.loads(row.info_json)
- if hasattr(info_json, '__dict__'):
- info_json = vars(info_json)
- except json.JSONDecodeError as e:
- raise ValueError(f"JSON Decode Error: {e}")
- if not isinstance(segments_json, list) or not isinstance(combined_transcript_text_list_of_metadata_dicts, list) or not isinstance(info_json, dict):
- logger.error(f"Type of segments_json: {type(segments_json)}, Value: {segments_json}")
- logger.error(f"Type of combined_transcript_text_list_of_metadata_dicts: {type(combined_transcript_text_list_of_metadata_dicts)}, Value: {combined_transcript_text_list_of_metadata_dicts}")
- logger.error(f"Type of info_json: {type(info_json)}, Value: {info_json}")
- raise ValueError("Deserialized JSON does not match the expected format.")
- audio_transcript_response = {
- "id": row.id,
- "audio_file_name": row.audio_file_name,
- "audio_file_size_mb": row.audio_file_size_mb,
- "segments_json": segments_json,
- "combined_transcript_text": row.combined_transcript_text,
- "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts,
- "info_json": info_json,
- "ip_address": row.ip_address,
- "request_time": row.request_time,
- "response_time": row.response_time,
- "total_time": row.total_time,
- "url_to_download_zip_file_of_embeddings": ""
- }
- return AudioTranscriptResponse(**audio_transcript_response)
- return None
-
-async def save_transcript_to_db(audio_file_hash, audio_file_name, audio_file_size_mb, transcript_segments, info, ip_address, request_time, response_time, total_time, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts):
- existing_transcript = await get_transcript_from_db(audio_file_hash)
- if existing_transcript:
- return existing_transcript
- audio_transcript = AudioTranscript(
- audio_file_hash=audio_file_hash,
- audio_file_name=audio_file_name,
- audio_file_size_mb=audio_file_size_mb,
- segments_json=json.dumps(transcript_segments),
- combined_transcript_text=combined_transcript_text,
- combined_transcript_text_list_of_metadata_dicts=json.dumps(combined_transcript_text_list_of_metadata_dicts),
- info_json=json.dumps(info),
- ip_address=ip_address,
- request_time=request_time,
- response_time=response_time,
- total_time=total_time
- )
- await db_writer.enqueue_write([audio_transcript])
-
-def normalize_logprobs(avg_logprob, min_logprob, max_logprob):
- range_logprob = max_logprob - min_logprob
- return (avg_logprob - min_logprob) / range_logprob if range_logprob != 0 else 0.5
-
-def remove_pagination_breaks(text: str) -> str:
- text = re.sub(r'-(\n)(?=[a-z])', '', text) # Remove hyphens at the end of lines when the word continues on the next line
- text = re.sub(r'(?<=\w)(? dict:
- request_time = datetime.utcnow()
- ip_address = req.client.host if req else "127.0.0.1"
- file_contents = await file.read()
- audio_file_hash = sha3_256(file_contents).hexdigest()
- file.file.seek(0) # Reset file pointer after read
- existing_audio_transcript = await get_transcript_from_db(audio_file_hash)
- if existing_audio_transcript:
- return existing_audio_transcript
- current_position = file.file.tell()
- file.file.seek(0, os.SEEK_END)
- audio_file_size_mb = file.file.tell() / (1024 * 1024)
- file.file.seek(current_position)
- with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
- shutil.copyfileobj(file.file, tmp_file)
- audio_file_name = tmp_file.name
- segment_details, info, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, request_time, response_time, total_time, download_url = await compute_transcript_with_whisper_from_audio_func(audio_file_hash, audio_file_name, file.filename, audio_file_size_mb, ip_address, req, compute_embeddings_for_resulting_transcript_document, llm_model_name)
- audio_transcript_response = {
- "audio_file_hash": audio_file_hash,
- "audio_file_name": file.filename,
- "audio_file_size_mb": audio_file_size_mb,
- "segments_json": segment_details,
- "combined_transcript_text": combined_transcript_text,
- "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts,
- "info_json": info,
- "ip_address": ip_address,
- "request_time": request_time,
- "response_time": response_time,
- "total_time": total_time,
- "url_to_download_zip_file_of_embeddings": download_url if compute_embeddings_for_resulting_transcript_document else ""
- }
- os.remove(audio_file_name)
- return AudioTranscriptResponse(**audio_transcript_response)
-
-
-# Core embedding functions start here:
-
-def download_models() -> Tuple[List[str], List[Dict[str, str]]]:
- download_status = []
- json_path = os.path.join(BASE_DIRECTORY, "model_urls.json")
- if not os.path.exists(json_path):
- initial_model_urls = [
- 'https://huggingface.co/TheBloke/Yarn-Llama-2-13B-128K-GGUF/resolve/main/yarn-llama-2-13b-128k.Q4_K_M.gguf',
- 'https://huggingface.co/TheBloke/Yarn-Llama-2-7B-128K-GGUF/resolve/main/yarn-llama-2-7b-128k.Q4_K_M.gguf',
- 'https://huggingface.co/TheBloke/openchat_v3.2_super-GGUF/resolve/main/openchat_v3.2_super.Q4_K_M.gguf',
- 'https://huggingface.co/TheBloke/Phind-CodeLlama-34B-Python-v1-GGUF/resolve/main/phind-codellama-34b-python-v1.Q4_K_M.gguf',
- 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q6_K.gguf'
- ]
- with open(json_path, "w") as f:
- json.dump(initial_model_urls, f)
- with open(json_path, "r") as f:
- list_of_model_download_urls = json.load(f)
- model_names = [os.path.basename(url) for url in list_of_model_download_urls]
- current_file_path = os.path.abspath(__file__)
- base_dir = os.path.dirname(current_file_path)
- models_dir = os.path.join(base_dir, 'models')
- logger.info("Checking models directory...")
- if USE_RAMDISK:
- ramdisk_models_dir = os.path.join(RAMDISK_PATH, 'models')
- if not os.path.exists(RAMDISK_PATH):
- setup_ramdisk()
- if all(os.path.exists(os.path.join(ramdisk_models_dir, llm_model_name)) for llm_model_name in model_names):
- logger.info("Models found in RAM Disk.")
- for url in list_of_model_download_urls:
- download_status.append({"url": url, "status": "success", "message": "Model found in RAM Disk."})
- return model_names, download_status
- if not os.path.exists(models_dir):
- os.makedirs(models_dir)
- logger.info(f"Created models directory: {models_dir}")
- else:
- logger.info(f"Models directory exists: {models_dir}")
- for url, model_name_with_extension in zip(list_of_model_download_urls, model_names):
- status = {"url": url, "status": "success", "message": "File already exists."}
- filename = os.path.join(models_dir, model_name_with_extension)
- if not os.path.exists(filename):
- logger.info(f"Downloading model {model_name_with_extension} from {url}...")
- urllib.request.urlretrieve(url, filename)
- file_size = os.path.getsize(filename) / (1024 * 1024) # Convert bytes to MB
- if file_size < 100:
- os.remove(filename)
- status["status"] = "failure"
- status["message"] = "Downloaded file is too small, probably not a valid model file."
- else:
- logger.info(f"Downloaded: {filename}")
- else:
- logger.info(f"File already exists: {filename}")
- download_status.append(status)
- if USE_RAMDISK:
- copy_models_to_ramdisk(models_dir, ramdisk_models_dir)
- logger.info("Model downloads completed.")
- return model_names, download_status
-
-def add_model_url(new_url: str) -> str:
- corrected_url = new_url
- if '/blob/main/' in new_url:
- corrected_url = new_url.replace('/blob/main/', '/resolve/main/')
- json_path = os.path.join(BASE_DIRECTORY, "model_urls.json")
- with open(json_path, "r") as f:
- existing_urls = json.load(f)
- if corrected_url not in existing_urls:
- logger.info(f"Model URL not found in database. Adding {new_url} now...")
- existing_urls.append(corrected_url)
- with open(json_path, "w") as f:
- json.dump(existing_urls, f)
- logger.info(f"Model URL added: {new_url}")
- else:
- logger.info("Model URL already exists.")
- return corrected_url
-
-async def get_embedding_from_db(text: str, llm_model_name: str):
- text_hash = sha3_256(text.encode('utf-8')).hexdigest() # Compute the hash
- return await execute_with_retry(_get_embedding_from_db, text_hash, llm_model_name)
-
-async def _get_embedding_from_db(text_hash: str, llm_model_name: str) -> Optional[dict]:
- async with AsyncSessionLocal() as session:
- result = await session.execute(
- sql_text("SELECT embedding_json FROM embeddings WHERE text_hash=:text_hash AND llm_model_name=:llm_model_name"),
- {"text_hash": text_hash, "llm_model_name": llm_model_name},
- )
- row = result.fetchone()
- if row:
- embedding_json = row[0]
- logger.info(f"Embedding found in database for text hash '{text_hash}' using model '{llm_model_name}'")
- return json.loads(embedding_json)
- return None
-
-async def get_or_compute_embedding(request: EmbeddingRequest, req: Request = None, client_ip: str = None, document_file_hash: str = None) -> dict:
- request_time = datetime.utcnow() # Capture request time as datetime object
- ip_address = client_ip or (req.client.host if req else "localhost") # If client_ip is provided, use it; otherwise, try to get from req; if not available, default to "localhost"
- logger.info(f"Received request for embedding for '{request.text}' using model '{request.llm_model_name}' from IP address '{ip_address}'")
- embedding_list = await get_embedding_from_db(request.text, request.llm_model_name) # Check if embedding exists in the database
- if embedding_list is not None:
- response_time = datetime.utcnow() # Capture response time as datetime object
- total_time = (response_time - request_time).total_seconds() # Calculate time taken in seconds
- logger.info(f"Embedding found in database for '{request.text}' using model '{request.llm_model_name}'; returning in {total_time:.4f} seconds")
- return {"embedding": embedding_list}
- model = load_model(request.llm_model_name)
- embedding_list = calculate_sentence_embedding(model, request.text) # Compute the embedding if not in the database
- if embedding_list is None:
- logger.error(f"Could not calculate the embedding for the given text: '{request.text}' using model '{request.llm_model_name}!'")
- raise HTTPException(status_code=400, detail="Could not calculate the embedding for the given text")
- embedding_json = json.dumps(embedding_list) # Serialize the numpy array to JSON and save to the database
- response_time = datetime.utcnow() # Capture response time as datetime object
- total_time = (response_time - request_time).total_seconds() # Calculate total time using datetime objects
- word_length_of_input_text = len(request.text.split())
- if word_length_of_input_text > 0:
- logger.info(f"Embedding calculated for '{request.text}' using model '{request.llm_model_name}' in {total_time} seconds, or an average of {total_time/word_length_of_input_text :.2f} seconds per word. Now saving to database...")
- await save_embedding_to_db(request.text, request.llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash)
- return {"embedding": embedding_list}
-
-async def save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None):
- existing_embedding = await get_embedding_from_db(text, llm_model_name) # Check if the embedding already exists
- if existing_embedding is not None:
- return existing_embedding
- return await execute_with_retry(_save_embedding_to_db, text, llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash)
-
-async def _save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None):
- existing_embedding = await get_embedding_from_db(text, llm_model_name)
- if existing_embedding:
- return existing_embedding
- embedding = TextEmbedding(
- text=text,
- llm_model_name=llm_model_name,
- embedding_json=embedding_json,
- ip_address=ip_address,
- request_time=request_time,
- response_time=response_time,
- total_time=total_time,
- document_file_hash=document_file_hash
- )
- await db_writer.enqueue_write([embedding]) # Enqueue the write operation using the db_writer instance
-
-def load_model(llm_model_name: str, raise_http_exception: bool = True):
- try:
- models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models')
- if llm_model_name in embedding_model_cache:
- return embedding_model_cache[llm_model_name]
- matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*"))
- if not matching_files:
- logger.error(f"No model file found matching: {llm_model_name}")
- raise FileNotFoundError
- matching_files.sort(key=os.path.getmtime, reverse=True)
- model_file_path = matching_files[0]
- model_instance = LlamaCppEmbeddings(model_path=model_file_path, use_mlock=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS)
- model_instance.client.verbose = False
- embedding_model_cache[llm_model_name] = model_instance
- return model_instance
- except TypeError as e:
- logger.error(f"TypeError occurred while loading the model: {e}")
- raise
- except Exception as e:
- logger.error(f"Exception occurred while loading the model: {e}")
- if raise_http_exception:
- raise HTTPException(status_code=404, detail="Model file not found")
- else:
- raise FileNotFoundError(f"No model file found matching: {llm_model_name}")
-
-def load_token_level_embedding_model(llm_model_name: str, raise_http_exception: bool = True):
- try:
- if llm_model_name in token_level_embedding_model_cache: # Check if the model is already loaded in the cache
- return token_level_embedding_model_cache[llm_model_name]
- models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path
- matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files
- if not matching_files:
- logger.error(f"No model file found matching: {llm_model_name}")
- raise FileNotFoundError
- matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first)
- model_file_path = matching_files[0]
- model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model
- token_level_embedding_model_cache[llm_model_name] = model_instance # Cache the loaded model
- return model_instance
- except TypeError as e:
- logger.error(f"TypeError occurred while loading the model: {e}")
- raise
- except Exception as e:
- logger.error(f"Exception occurred while loading the model: {e}")
- if raise_http_exception:
- raise HTTPException(status_code=404, detail="Model file not found")
- else:
- raise FileNotFoundError(f"No model file found matching: {llm_model_name}")
-
-async def compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings) -> List[float]:
- start_time = datetime.utcnow()
- logger.info("Extracting token-level embeddings from the bundle")
- parsed_df = pd.read_json(token_level_embeddings) # Parse the json_content back to a DataFrame
- token_level_embeddings = list(parsed_df['embedding'])
- embeddings = np.array(token_level_embeddings) # Convert the list of embeddings to a NumPy array
- logger.info(f"Computing column-wise means/mins/maxes/std_devs of the embeddings... (shape: {embeddings.shape})")
- assert(len(embeddings) > 0)
- means = np.mean(embeddings, axis=0)
- mins = np.min(embeddings, axis=0)
- maxes = np.max(embeddings, axis=0)
- stds = np.std(embeddings, axis=0)
- logger.info("Concatenating the computed statistics to form the combined feature vector")
- combined_feature_vector = np.concatenate([means, mins, maxes, stds])
- end_time = datetime.utcnow()
- total_time = (end_time - start_time).total_seconds()
- logger.info(f"Computed the token-level embedding bundle's combined feature vector computed in {total_time: .2f} seconds.")
- return combined_feature_vector.tolist()
-
-async def get_or_compute_token_level_embedding_bundle_combined_feature_vector(token_level_embedding_bundle_id, token_level_embeddings, db_writer: DatabaseWriter) -> List[float]:
- request_time = datetime.utcnow()
- logger.info(f"Checking for existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}")
- async with AsyncSessionLocal() as session:
- result = await session.execute(
- select(TokenLevelEmbeddingBundleCombinedFeatureVector)
- .filter(TokenLevelEmbeddingBundleCombinedFeatureVector.token_level_embedding_bundle_id == token_level_embedding_bundle_id)
- )
- existing_combined_feature_vector = result.scalar_one_or_none()
- if existing_combined_feature_vector:
- response_time = datetime.utcnow()
- total_time = (response_time - request_time).total_seconds()
- logger.info(f"Found existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Returning cached result in {total_time:.2f} seconds.")
- return json.loads(existing_combined_feature_vector.combined_feature_vector_json) # Parse the JSON string into a list
- logger.info(f"No cached combined feature_vector found for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Computing now...")
- combined_feature_vector = await compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings)
- combined_feature_vector_db_object = TokenLevelEmbeddingBundleCombinedFeatureVector(
- token_level_embedding_bundle_id=token_level_embedding_bundle_id,
- combined_feature_vector_json=json.dumps(combined_feature_vector) # Convert the list to a JSON string
- )
- logger.info(f"Writing combined feature vector for database write for token-level embedding bundle ID: {token_level_embedding_bundle_id} to the database...")
- await db_writer.enqueue_write([combined_feature_vector_db_object])
- return combined_feature_vector
-
-async def calculate_token_level_embeddings(text: str, llm_model_name: str, client_ip: str, token_level_embedding_bundle_id: int) -> List[np.array]:
- request_time = datetime.utcnow()
- logger.info(f"Starting token-level embedding calculation for text: '{text}' using model: '{llm_model_name}'")
- logger.info(f"Loading model: '{llm_model_name}'")
- llm = load_token_level_embedding_model(llm_model_name) # Assuming this method returns an instance of the Llama class
- token_embeddings = []
- tokens = text.split() # Simple whitespace tokenizer; can be replaced with a more advanced one if needed
- logger.info(f"Tokenized text into {len(tokens)} tokens")
- for idx, token in enumerate(tokens, start=1):
- try: # Check if the embedding is already available in the database
- existing_embedding = await get_token_level_embedding_from_db(token, llm_model_name)
- if existing_embedding is not None:
- token_embeddings.append(np.array(existing_embedding))
- logger.info(f"Embedding retrieved from database for token '{token}'")
- continue
- logger.info(f"Processing token {idx} of {len(tokens)}: '{token}'")
- token_embedding = llm.embed(token)
- token_embedding_array = np.array(token_embedding)
- token_embeddings.append(token_embedding_array)
- response_time = datetime.utcnow()
- token_level_embedding_json = json.dumps(token_embedding_array.tolist())
- await store_token_level_embeddings_in_db(token, llm_model_name, token_level_embedding_json, client_ip, request_time, response_time, token_level_embedding_bundle_id)
- except RuntimeError as e:
- logger.error(f"Failed to calculate embedding for token '{token}': {e}")
- logger.info(f"Completed token embedding calculation for all tokens in text: '{text}'")
- return token_embeddings
-
-async def get_token_level_embedding_from_db(token: str, llm_model_name: str) -> Optional[List[float]]:
- token_hash = sha3_256(token.encode('utf-8')).hexdigest() # Compute the hash
- async with AsyncSessionLocal() as session:
- result = await session.execute(
- sql_text("SELECT token_level_embedding_json FROM token_level_embeddings WHERE token_hash=:token_hash AND llm_model_name=:llm_model_name"),
- {"token_hash": token_hash, "llm_model_name": llm_model_name},
- )
- row = result.fetchone()
- if row:
- embedding_json = row[0]
- logger.info(f"Embedding found in database for token hash '{token_hash}' using model '{llm_model_name}'")
- return json.loads(embedding_json)
- return None
-
-async def store_token_level_embeddings_in_db(token: str, llm_model_name: str, token_level_embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, token_level_embedding_bundle_id: int):
- total_time = (response_time - request_time).total_seconds()
- embedding = TokenLevelEmbedding(
- token=token,
- llm_model_name=llm_model_name,
- token_level_embedding_json=token_level_embedding_json,
- ip_address=ip_address,
- request_time=request_time,
- response_time=response_time,
- total_time=total_time,
- token_level_embedding_bundle_id=token_level_embedding_bundle_id
- )
- await db_writer.enqueue_write([embedding]) # Enqueue the write operation for the token-level embedding
-
-def calculate_sentence_embedding(llama: Llama, text: str) -> np.array:
- sentence_embedding = None
- retry_count = 0
- while sentence_embedding is None and retry_count < 3:
- try:
- if retry_count > 0:
- logger.info(f"Attempting again calculate sentence embedding. Attempt number {retry_count + 1}")
- sentence_embedding = llama.embed_query(text)
- except TypeError as e:
- logger.error(f"TypeError in calculate_sentence_embedding: {e}")
- raise
- except Exception as e:
- logger.error(f"Exception in calculate_sentence_embedding: {e}")
- text = text[:-int(len(text) * 0.1)]
- retry_count += 1
- logger.info(f"Trimming sentence due to too many tokens. New length: {len(text)}")
- if sentence_embedding is None:
- logger.error("Failed to calculate sentence embedding after multiple attempts")
- return sentence_embedding
-
-async def compute_embeddings_for_document(strings: list, llm_model_name: str, client_ip: str, document_file_hash: str) -> List[Tuple[str, np.array]]:
- results = []
- if USE_PARALLEL_INFERENCE_QUEUE:
- logger.info(f"Using parallel inference queue to compute embeddings for {len(strings)} strings")
- start_time = time.perf_counter() # Record the start time
- semaphore = asyncio.Semaphore(MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS)
- async def compute_embedding(text): # Define a function to compute the embedding for a given text
- try:
- async with semaphore: # Acquire a semaphore slot
- request = EmbeddingRequest(text=text, llm_model_name=llm_model_name)
- embedding = await get_embedding_vector_for_string(request, client_ip=client_ip, document_file_hash=document_file_hash)
- return text, embedding["embedding"]
- except Exception as e:
- logger.error(f"Error computing embedding for text '{text}': {e}")
- return text, None
- results = await asyncio.gather(*[compute_embedding(s) for s in strings]) # Use asyncio.gather to run the tasks concurrently
- end_time = time.perf_counter() # Record the end time
- duration = end_time - start_time
- if len(strings) > 0:
- logger.info(f"Parallel inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string")
- else: # Compute embeddings sequentially
- logger.info(f"Using sequential inference to compute embeddings for {len(strings)} strings")
- start_time = time.perf_counter() # Record the start time
- for s in strings:
- embedding_request = EmbeddingRequest(text=s, llm_model_name=llm_model_name)
- embedding = await get_embedding_vector_for_string(embedding_request, client_ip=client_ip, document_file_hash=document_file_hash)
- results.append((s, embedding["embedding"]))
- end_time = time.perf_counter() # Record the end time
- duration = end_time - start_time
- if len(strings) > 0:
- logger.info(f"Sequential inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string")
- filtered_results = [(text, embedding) for text, embedding in results if embedding is not None] # Filter out results with None embeddings (applicable to parallel processing) and return
- return filtered_results
-
-async def parse_submitted_document_file_into_sentence_strings_func(temp_file_path: str, mime_type: str):
- strings = []
- if mime_type.startswith('text/'):
- with open(temp_file_path, 'r') as buffer:
- content = buffer.read()
- else:
- try:
- content = textract.process(temp_file_path).decode('utf-8')
- except UnicodeDecodeError:
- try:
- content = textract.process(temp_file_path).decode('unicode_escape')
- except Exception as e:
- logger.error(f"Error while processing file: {e}, mime_type: {mime_type}")
- raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}")
- except Exception as e:
- logger.error(f"Error while processing file: {e}, mime_type: {mime_type}")
- raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}")
- sentences = sophisticated_sentence_splitter(content)
- if len(sentences) == 0 and temp_file_path.lower().endswith('.pdf'):
- logger.info("No sentences found, attempting OCR using Tesseract.")
- try:
- content = textract.process(temp_file_path, method='tesseract').decode('utf-8')
- sentences = sophisticated_sentence_splitter(content)
- except Exception as e:
- logger.error(f"Error while processing file with OCR: {e}")
- raise HTTPException(status_code=400, detail=f"OCR failed: {e}")
- if len(sentences) == 0:
- logger.info("No sentences found in the document")
- raise HTTPException(status_code=400, detail="No sentences found in the document")
- logger.info(f"Extracted {len(sentences)} sentences from the document")
- strings = [s.strip() for s in sentences if len(s.strip()) > MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING]
- return strings
-
-async def _get_document_from_db(file_hash: str):
- async with AsyncSessionLocal() as session:
- result = await session.execute(select(Document).filter(Document.document_hash == file_hash))
- return result.scalar_one_or_none()
-
-async def store_document_embeddings_in_db(file: File, file_hash: str, original_file_content: bytes, json_content: bytes, results: List[Tuple[str, np.array]], llm_model_name: str, client_ip: str, request_time: datetime):
- document = await _get_document_from_db(file_hash) # First, check if a Document with the same hash already exists
- if not document: # If not, create a new Document object
- document = Document(document_hash=file_hash, llm_model_name=llm_model_name)
- await db_writer.enqueue_write([document])
- document_embedding = DocumentEmbedding(
- filename=file.filename,
- mimetype=file.content_type,
- file_hash=file_hash,
- llm_model_name=llm_model_name,
- file_data=original_file_content,
- document_embedding_results_json=json.loads(json_content.decode()),
- ip_address=client_ip,
- request_time=request_time,
- response_time=datetime.utcnow(),
- total_time=(datetime.utcnow() - request_time).total_seconds()
- )
- document.document_embeddings.append(document_embedding) # Associate it with the Document
- document.update_hash() # This will trigger the SQLAlchemy event to update the document_hash
- await db_writer.enqueue_write([document, document_embedding]) # Enqueue the write operation for the document embedding
- write_operations = [] # Collect text embeddings to write
- logger.info(f"Storing {len(results)} text embeddings in database")
- for text, embedding in results:
- embedding_entry = await _get_embedding_from_db(text, llm_model_name)
- if not embedding_entry:
- embedding_entry = TextEmbedding(
- text=text,
- llm_model_name=llm_model_name,
- embedding_json=json.dumps(embedding),
- ip_address=client_ip,
- request_time=request_time,
- response_time=datetime.utcnow(),
- total_time=(datetime.utcnow() - request_time).total_seconds(),
- document_file_hash=file_hash # Link it to the DocumentEmbedding via file_hash
- )
- else:
- write_operations.append(embedding_entry)
- await db_writer.enqueue_write(write_operations) # Enqueue the write operation for text embeddings
-
-def load_text_completion_model(llm_model_name: str, raise_http_exception: bool = True):
- try:
- if llm_model_name in text_completion_model_cache: # Check if the model is already loaded in the cache
- return text_completion_model_cache[llm_model_name]
- models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path
- matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files
- if not matching_files:
- logger.error(f"No model file found matching: {llm_model_name}")
- raise FileNotFoundError
- matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first)
- model_file_path = matching_files[0]
- model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model
- text_completion_model_cache[llm_model_name] = model_instance # Cache the loaded model
- return model_instance
- except TypeError as e:
- logger.error(f"TypeError occurred while loading the model: {e}")
- raise
- except Exception as e:
- logger.error(f"Exception occurred while loading the model: {e}")
- if raise_http_exception:
- raise HTTPException(status_code=404, detail="Model file not found")
- else:
- raise FileNotFoundError(f"No model file found matching: {llm_model_name}")
-
-async def generate_completion_from_llm(request: TextCompletionRequest, req: Request = None, client_ip: str = None) -> List[TextCompletionResponse]:
- request_time = datetime.utcnow()
- logger.info(f"Starting text completion calculation using model: '{request.llm_model_name}'for input prompt: '{request.input_prompt}'")
- logger.info(f"Loading model: '{request.llm_model_name}'")
- llm = load_text_completion_model(request.llm_model_name)
- logger.info(f"Done loading model: '{request.llm_model_name}'")
- list_of_llm_outputs = []
- if request.grammar_file_string != "":
- list_of_grammar_files = glob.glob("./grammar_files/*.gbnf")
- matching_grammar_files = [x for x in list_of_grammar_files if request.grammar_file_string in x]
- if len(matching_grammar_files) == 0:
- logger.error(f"No grammar file found matching: {request.grammar_file_string}")
- raise FileNotFoundError
- matching_grammar_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first)
- grammar_file_path = matching_grammar_files[0]
- logger.info(f"Loading selected grammar file: '{grammar_file_path}'")
- llama_grammar = LlamaGrammar.from_file(grammar_file_path)
- for ii in range(request.number_of_completions_to_generate):
- logger.info(f"Generating completion {ii+1} of {request.number_of_completions_to_generate} with model {request.llm_model_name} for input prompt: '{request.input_prompt}'")
- output = llm(prompt=request.input_prompt, grammar=llama_grammar, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature)
- list_of_llm_outputs.append(output)
- else:
- for ii in range(request.number_of_completions_to_generate):
- output = llm(prompt=request.input_prompt, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature)
- list_of_llm_outputs.append(output)
- response_time = datetime.utcnow()
- total_time_per_completion = ((response_time - request_time).total_seconds()) / request.number_of_completions_to_generate
- list_of_responses = []
- for idx, current_completion_output in enumerate(list_of_llm_outputs):
- generated_text = current_completion_output['choices'][0]['text']
- if request.grammar_file_string == 'json':
- generated_text = generated_text.encode('unicode_escape').decode()
- llm_model_usage_json = json.dumps(current_completion_output['usage'])
- logger.info(f"Completed text completion {idx} in an average of {total_time_per_completion:.2f} seconds for input prompt: '{request.input_prompt}'; Beginning of generated text: \n'{generated_text[:100]}'")
- response = TextCompletionResponse(input_prompt = request.input_prompt,
- llm_model_name = request.llm_model_name,
- grammar_file_string = request.grammar_file_string,
- number_of_tokens_to_generate = request.number_of_tokens_to_generate,
- number_of_completions_to_generate = request.number_of_completions_to_generate,
- time_taken_in_seconds = float(total_time_per_completion),
- generated_text = generated_text,
- llm_model_usage_json = llm_model_usage_json)
- list_of_responses.append(response)
- return list_of_responses
@app.exception_handler(SQLAlchemyError)
@@ -1144,7 +78,6 @@ async def general_exception_handler(request: Request, exc: Exception) -> JSONRes
logger.exception(exc)
return JSONResponse(status_code=500, content={"message": "An unexpected error occurred"})
-#FastAPI Endpoints start here:
@app.get("/", include_in_schema=False)
async def custom_swagger_ui_html():
@@ -1164,7 +97,7 @@ async def custom_swagger_ui_html():
### Example Response:
```json
{
- "model_names": ["yarn-llama-2-7b-128k", "yarn-llama-2-13b-128k", "openchat_v3.2_super", "phind-codellama-34b-python-v1", "my_super_custom_model"]
+ "model_names": ["yarn-llama-2-7b-128k", "openchat_v3.2_super", "mistral-7b-instruct-v0.1", "my_super_custom_model"]
}
```""",
response_description="A JSON object containing the list of available model names.")
@@ -1283,15 +216,23 @@ async def get_all_stored_documents(req: Request, token: str = None) -> AllDocume
async def add_new_model(model_url: str, token: str = None) -> Dict[str, Any]:
if USE_SECURITY_TOKEN and (token is None or token != SECURITY_TOKEN):
raise HTTPException(status_code=403, detail="Unauthorized")
- decoded_model_url = unquote(model_url)
- if not decoded_model_url.endswith('.gguf'):
- return {"status": "error", "message": "Model URL must point to a .gguf file."}
- corrected_model_url = add_model_url(decoded_model_url)
- _, download_status = download_models()
- status_dict = {status["url"]: status for status in download_status}
- if corrected_model_url in status_dict:
- return {"status": status_dict[corrected_model_url]["status"], "message": status_dict[corrected_model_url]["message"]}
- return {"status": "unknown", "message": "Unexpected error."}
+ unique_id = f"add_model_{hash(model_url)}" # Generate a unique lock ID based on the model_url
+ lock = await shared_resources.lock_manager.lock(unique_id)
+ if lock.valid:
+ try:
+ decoded_model_url = unquote(model_url)
+ if not decoded_model_url.endswith('.gguf'):
+ return {"status": "error", "message": "Model URL must point to a .gguf file."}
+ corrected_model_url = add_model_url(decoded_model_url)
+ _, download_status = download_models()
+ status_dict = {status["url"]: status for status in download_status}
+ if corrected_model_url in status_dict:
+ return {"status": status_dict[corrected_model_url]["status"], "message": status_dict[corrected_model_url]["message"]}
+ return {"status": "unknown", "message": "Unexpected error."}
+ finally:
+ await shared_resources.lock_manager.unlock(lock)
+ else:
+ return {"status": "already processing", "message": "Another worker is already processing this model URL."}
@@ -1594,44 +535,53 @@ async def compute_similarity_between_strings(request: SimilarityRequest, req: Re
response_description="A JSON object containing the query text along with the most similar strings and similarity scores.")
async def search_stored_embeddings_with_query_string_for_semantic_similarity(request: SemanticSearchRequest, req: Request, token: str = None) -> SemanticSearchResponse:
global faiss_indexes, token_faiss_indexes, associated_texts_by_model
- faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes()
- request_time = datetime.utcnow()
- llm_model_name = request.llm_model_name
- num_results = request.number_of_most_similar_strings_to_return
- total_entries = len(associated_texts_by_model[llm_model_name]) # Get the total number of entries for the model
- num_results = min(num_results, total_entries) # Ensure num_results doesn't exceed the total number of entries
- logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}")
- if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN):
- raise HTTPException(status_code=403, detail="Unauthorized")
- try:
- logger.info(f"Computing embedding for input text: {request.query_text}")
- embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=request.llm_model_name)
- embedding_response = await get_embedding_vector_for_string(embedding_request, req)
- input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1)
- faiss.normalize_L2(input_embedding) # Normalize the input vector for cosine similarity
- logger.info(f"Computed embedding for input text: {request.query_text}")
- faiss_index = faiss_indexes.get(llm_model_name) # Retrieve the correct FAISS index for the llm_model_name
- if faiss_index is None:
- raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}")
- logger.info("Searching for the most similar string in the FAISS index")
- similarities, indices = faiss_index.search(input_embedding.reshape(1, -1), num_results) # Search for num_results similar strings
- results = [] # Create an empty list to store the results
- for ii in range(num_results):
- similarity = float(similarities[0][ii]) # Convert numpy.float32 to native float
- most_similar_text = associated_texts_by_model[llm_model_name][indices[0][ii]]
- if most_similar_text != request.query_text: # Don't return the query text as a result
- results.append({"search_result_text": most_similar_text, "similarity_to_query_text": similarity})
- response_time = datetime.utcnow()
- total_time = (response_time - request_time).total_seconds()
- logger.info(f"Finished searching for the most similar string in the FAISS index in {total_time} seconds. Found {len(results)} results, returning the top {num_results}.")
- logger.info(f"Found most similar strings for query string {request.query_text}: {results}")
- return {"query_text": request.query_text, "results": results} # Return the response matching the SemanticSearchResponse model
- except Exception as e:
- logger.error(f"An error occurred while processing the request: {e}")
- logger.error(traceback.format_exc()) # Print the traceback
- raise HTTPException(status_code=500, detail="Internal Server Error")
-
-
+ unique_id = f"semantic_search_{request.query_text}_{request.llm_model_name}_{request.number_of_most_similar_strings_to_return}" # Unique ID for this operation
+ lock = await shared_resources.lock_manager.lock(unique_id)
+ if lock.valid:
+ try:
+ faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes()
+ request_time = datetime.utcnow()
+ llm_model_name = request.llm_model_name
+ num_results = request.number_of_most_similar_strings_to_return
+ total_entries = len(associated_texts_by_model[llm_model_name]) # Get the total number of entries for the model
+ num_results = min(num_results, total_entries) # Ensure num_results doesn't exceed the total number of entries
+ logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}")
+ if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN):
+ raise HTTPException(status_code=403, detail="Unauthorized")
+ try:
+ logger.info(f"Computing embedding for input text: {request.query_text}")
+ embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=request.llm_model_name)
+ embedding_response = await get_embedding_vector_for_string(embedding_request, req)
+ input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1)
+ faiss.normalize_L2(input_embedding) # Normalize the input vector for cosine similarity
+ logger.info(f"Computed embedding for input text: {request.query_text}")
+ faiss_index = faiss_indexes.get(llm_model_name) # Retrieve the correct FAISS index for the llm_model_name
+ if faiss_index is None:
+ raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}")
+ logger.info("Searching for the most similar string in the FAISS index")
+ similarities, indices = faiss_index.search(input_embedding.reshape(1, -1), num_results) # Search for num_results similar strings
+ results = [] # Create an empty list to store the results
+ for ii in range(num_results):
+ similarity = float(similarities[0][ii]) # Convert numpy.float32 to native float
+ most_similar_text = associated_texts_by_model[llm_model_name][indices[0][ii]]
+ if most_similar_text != request.query_text: # Don't return the query text as a result
+ results.append({"search_result_text": most_similar_text, "similarity_to_query_text": similarity})
+ response_time = datetime.utcnow()
+ total_time = (response_time - request_time).total_seconds()
+ logger.info(f"Finished searching for the most similar string in the FAISS index in {total_time} seconds. Found {len(results)} results, returning the top {num_results}.")
+ logger.info(f"Found most similar strings for query string {request.query_text}: {results}")
+ return {"query_text": request.query_text, "results": results} # Return the response matching the SemanticSearchResponse model
+ except Exception as e:
+ logger.error(f"An error occurred while processing the request: {e}")
+ logger.error(traceback.format_exc()) # Print the traceback
+ raise HTTPException(status_code=500, detail="Internal Server Error")
+ finally:
+ await shared_resources.lock_manager.unlock(lock)
+ else:
+ return {"status": "already processing"}
+
+
+
@app.post("/advanced_search_stored_embeddings_with_query_string_for_semantic_similarity/",
response_model=AdvancedSemanticSearchResponse,
summary="Advanced Semantic Search with Two-Step Similarity Measures",
@@ -1676,52 +626,60 @@ async def search_stored_embeddings_with_query_string_for_semantic_similarity(req
response_description="A JSON object containing the query text and the most similar strings, along with their similarity scores for multiple measures.")
async def advanced_search_stored_embeddings_with_query_string_for_semantic_similarity(request: AdvancedSemanticSearchRequest, req: Request, token: str = None) -> AdvancedSemanticSearchResponse:
global faiss_indexes, token_faiss_indexes, associated_texts_by_model
- faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes()
- request_time = datetime.utcnow()
- llm_model_name = request.llm_model_name
- total_entries = len(associated_texts_by_model[llm_model_name])
- num_results = max([1, int((1 - request.similarity_filter_percentage) * total_entries)])
- logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}")
- if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN):
- raise HTTPException(status_code=403, detail="Unauthorized")
- try:
- logger.info(f"Computing embedding for input text: {request.query_text}")
- embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=llm_model_name)
- embedding_response = await get_embedding_vector_for_string(embedding_request, req)
- input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1)
- faiss.normalize_L2(input_embedding)
- logger.info(f"Computed embedding for input text: {request.query_text}")
- faiss_index = faiss_indexes.get(llm_model_name)
- if faiss_index is None:
- raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}")
- _, indices = faiss_index.search(input_embedding, num_results)
- filtered_indices = indices[0]
- similarity_results = []
- for idx in filtered_indices:
- associated_text = associated_texts_by_model[llm_model_name][idx]
- embedding_request = EmbeddingRequest(text=associated_text, llm_model_name=llm_model_name)
- embedding_response = await get_embedding_vector_for_string(embedding_request, req)
- filtered_embedding = np.array(embedding_response["embedding"])
- params = {
- "vector_1": input_embedding.tolist()[0],
- "vector_2": filtered_embedding.tolist(),
- "similarity_measure": "all"
- }
- similarity_stats_str = fvs.py_compute_vector_similarity_stats(json.dumps(params))
- similarity_stats_json = json.loads(similarity_stats_str)
- similarity_results.append({
- "search_result_text": associated_text,
- "similarity_to_query_text": similarity_stats_json
- })
- num_to_return = request.number_of_most_similar_strings_to_return if request.number_of_most_similar_strings_to_return is not None else len(similarity_results)
- results = sorted(similarity_results, key=lambda x: x["similarity_to_query_text"]["hoeffding_d"], reverse=True)[:num_to_return]
- response_time = datetime.utcnow()
- total_time = (response_time - request_time).total_seconds()
- logger.info(f"Finished advanced search in {total_time} seconds. Found {len(results)} results.")
- return {"query_text": request.query_text, "results": results}
- except Exception as e:
- logger.error(f"An error occurred while processing the request: {e}")
- raise HTTPException(status_code=500, detail="Internal Server Error")
+ unique_id = f"advanced_semantic_search_{request.query_text}_{request.llm_model_name}_{request.similarity_filter_percentage}_{request.number_of_most_similar_strings_to_return}"
+ lock = await shared_resources.lock_manager.lock(unique_id)
+ if lock.valid:
+ try:
+ faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes()
+ request_time = datetime.utcnow()
+ llm_model_name = request.llm_model_name
+ total_entries = len(associated_texts_by_model[llm_model_name])
+ num_results = max([1, int((1 - request.similarity_filter_percentage) * total_entries)])
+ logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}")
+ if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN):
+ raise HTTPException(status_code=403, detail="Unauthorized")
+ try:
+ logger.info(f"Computing embedding for input text: {request.query_text}")
+ embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=llm_model_name)
+ embedding_response = await get_embedding_vector_for_string(embedding_request, req)
+ input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1)
+ faiss.normalize_L2(input_embedding)
+ logger.info(f"Computed embedding for input text: {request.query_text}")
+ faiss_index = faiss_indexes.get(llm_model_name)
+ if faiss_index is None:
+ raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}")
+ _, indices = faiss_index.search(input_embedding, num_results)
+ filtered_indices = indices[0]
+ similarity_results = []
+ for idx in filtered_indices:
+ associated_text = associated_texts_by_model[llm_model_name][idx]
+ embedding_request = EmbeddingRequest(text=associated_text, llm_model_name=llm_model_name)
+ embedding_response = await get_embedding_vector_for_string(embedding_request, req)
+ filtered_embedding = np.array(embedding_response["embedding"])
+ params = {
+ "vector_1": input_embedding.tolist()[0],
+ "vector_2": filtered_embedding.tolist(),
+ "similarity_measure": "all"
+ }
+ similarity_stats_str = fvs.py_compute_vector_similarity_stats(json.dumps(params))
+ similarity_stats_json = json.loads(similarity_stats_str)
+ similarity_results.append({
+ "search_result_text": associated_text,
+ "similarity_to_query_text": similarity_stats_json
+ })
+ num_to_return = request.number_of_most_similar_strings_to_return if request.number_of_most_similar_strings_to_return is not None else len(similarity_results)
+ results = sorted(similarity_results, key=lambda x: x["similarity_to_query_text"]["hoeffding_d"], reverse=True)[:num_to_return]
+ response_time = datetime.utcnow()
+ total_time = (response_time - request_time).total_seconds()
+ logger.info(f"Finished advanced search in {total_time} seconds. Found {len(results)} results.")
+ return {"query_text": request.query_text, "results": results}
+ except Exception as e:
+ logger.error(f"An error occurred while processing the request: {e}")
+ raise HTTPException(status_code=500, detail="Internal Server Error")
+ finally:
+ await shared_resources.lock_manager.unlock(lock)
+ else:
+ return {"status": "already processing"}
@@ -1775,41 +733,50 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...),
hash_obj.update(chunk)
file_hash = hash_obj.hexdigest()
logger.info(f"SHA3-256 hash of submitted file: {file_hash}")
- async with AsyncSessionLocal() as session: # Check if the document has been processed before
- result = await session.execute(select(DocumentEmbedding).filter(DocumentEmbedding.file_hash == file_hash, DocumentEmbedding.llm_model_name == llm_model_name))
- existing_document_embedding = result.scalar_one_or_none()
- if existing_document_embedding: # If the document has been processed before, return the existing result
- logger.info(f"Document {file.filename} has been processed before, returning existing result")
- json_content = json.dumps(existing_document_embedding.document_embedding_results_json).encode()
- else: # If the document has not been processed, continue processing
- mime = Magic(mime=True)
- mime_type = mime.from_file(temp_file_path)
- logger.info(f"Received request to extract embeddings for document {file.filename} with MIME type: {mime_type} and size: {os.path.getsize(temp_file_path)} bytes from IP address: {client_ip}")
- strings = await parse_submitted_document_file_into_sentence_strings_func(temp_file_path, mime_type)
- results = await compute_embeddings_for_document(strings, llm_model_name, client_ip, file_hash) # Compute the embeddings and json_content for new documents
- df = pd.DataFrame(results, columns=['text', 'embedding'])
- json_content = df.to_json(orient=json_format or 'records').encode()
- with open(temp_file_path, 'rb') as file_buffer: # Store the results in the database
- original_file_content = file_buffer.read()
- await store_document_embeddings_in_db(file, file_hash, original_file_content, json_content, results, llm_model_name, client_ip, request_time)
- overall_total_time = (datetime.utcnow() - request_time).total_seconds()
- logger.info(f"Done getting all embeddings for document {file.filename} containing {len(strings)} with model {llm_model_name}")
- json_content_length = len(json_content)
- if len(json_content) > 0:
- logger.info(f"The response took {overall_total_time} seconds to generate, or {overall_total_time / (len(strings)/1000.0)} seconds per thousand input tokens and {overall_total_time / (float(json_content_length)/1000000.0)} seconds per million output characters.")
- if send_back_json_or_zip_file == 'json': # Assume 'json' response should be sent back
- logger.info(f"Returning JSON response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}")
- return JSONResponse(content=json.loads(json_content.decode())) # Decode the content and parse it as JSON
- else: # Assume 'zip' file should be sent back
- original_filename_without_extension, _ = os.path.splitext(file.filename)
- json_file_path = f"/tmp/{original_filename_without_extension}.json"
- with open(json_file_path, 'wb') as json_file: # Write the JSON content as bytes
- json_file.write(json_content)
- zip_file_path = f"/tmp/{original_filename_without_extension}.zip"
- with zipfile.ZipFile(zip_file_path, 'w') as zipf:
- zipf.write(json_file_path, os.path.basename(json_file_path))
- logger.info(f"Returning ZIP response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}")
- return FileResponse(zip_file_path, headers={"Content-Disposition": f"attachment; filename={original_filename_without_extension}.zip"})
+ unique_id = f"document_embedding_{file_hash}_{llm_model_name}"
+ lock = await shared_resources.lock_manager.lock(unique_id)
+ if lock.valid:
+ try:
+ async with AsyncSessionLocal() as session: # Check if the document has been processed before
+ result = await session.execute(select(DocumentEmbedding).filter(DocumentEmbedding.file_hash == file_hash, DocumentEmbedding.llm_model_name == llm_model_name))
+ existing_document_embedding = result.scalar_one_or_none()
+ if existing_document_embedding: # If the document has been processed before, return the existing result
+ logger.info(f"Document {file.filename} has been processed before, returning existing result")
+ json_content = json.dumps(existing_document_embedding.document_embedding_results_json).encode()
+ else: # If the document has not been processed, continue processing
+ mime = Magic(mime=True)
+ mime_type = mime.from_file(temp_file_path)
+ logger.info(f"Received request to extract embeddings for document {file.filename} with MIME type: {mime_type} and size: {os.path.getsize(temp_file_path)} bytes from IP address: {client_ip}")
+ strings = await parse_submitted_document_file_into_sentence_strings_func(temp_file_path, mime_type)
+ results = await compute_embeddings_for_document(strings, llm_model_name, client_ip, file_hash) # Compute the embeddings and json_content for new documents
+ df = pd.DataFrame(results, columns=['text', 'embedding'])
+ json_content = df.to_json(orient=json_format or 'records').encode()
+ with open(temp_file_path, 'rb') as file_buffer: # Store the results in the database
+ original_file_content = file_buffer.read()
+ await store_document_embeddings_in_db(file, file_hash, original_file_content, json_content, results, llm_model_name, client_ip, request_time)
+ overall_total_time = (datetime.utcnow() - request_time).total_seconds()
+ logger.info(f"Done getting all embeddings for document {file.filename} containing {len(strings)} with model {llm_model_name}")
+ json_content_length = len(json_content)
+ if len(json_content) > 0:
+ logger.info(f"The response took {overall_total_time} seconds to generate, or {overall_total_time / (len(strings)/1000.0)} seconds per thousand input tokens and {overall_total_time / (float(json_content_length)/1000000.0)} seconds per million output characters.")
+ if send_back_json_or_zip_file == 'json': # Assume 'json' response should be sent back
+ logger.info(f"Returning JSON response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}")
+ return JSONResponse(content=json.loads(json_content.decode())) # Decode the content and parse it as JSON
+ else: # Assume 'zip' file should be sent back
+ original_filename_without_extension, _ = os.path.splitext(file.filename)
+ json_file_path = f"/tmp/{original_filename_without_extension}.json"
+ with open(json_file_path, 'wb') as json_file: # Write the JSON content as bytes
+ json_file.write(json_content)
+ zip_file_path = f"/tmp/{original_filename_without_extension}.zip"
+ with zipfile.ZipFile(zip_file_path, 'w') as zipf:
+ zipf.write(json_file_path, os.path.basename(json_file_path))
+ logger.info(f"Returning ZIP response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}")
+ return FileResponse(zip_file_path, headers={"Content-Disposition": f"attachment; filename={original_filename_without_extension}.zip"})
+ finally:
+ await shared_resources.lock_manager.unlock(lock)
+ else:
+ return {"status": "already processing"}
+
@app.post("/get_text_completions_from_input_prompt/",
@@ -1833,7 +800,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...),
```json
{
"input_prompt": "The Kings of France in the 17th Century:",
- "llm_model_name": "phind-codellama-34b-python-v1",
+ "llm_model_name": "mistral-7b-instruct-v0.1",
"temperature": 0.95,
"grammar_file_string": "json",
"number_of_tokens_to_generate": 500,
@@ -1849,7 +816,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...),
[
{
"input_prompt": "The Kings of France in the 17th Century:",
- "llm_model_name": "phind-codellama-34b-python-v1",
+ "llm_model_name": "mistral-7b-instruct-v0.1",
"grammar_file_string": "json",
"number_of_tokens_to_generate": 500,
"number_of_completions_to_generate": 3,
@@ -1859,7 +826,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...),
},
{
"input_prompt": "The Kings of France in the 17th Century:",
- "llm_model_name": "phind-codellama-34b-python-v1",
+ "llm_model_name": "mistral-7b-instruct-v0.1",
"grammar_file_string": "json",
"number_of_tokens_to_generate": 500,
"number_of_completions_to_generate": 3,
@@ -1869,7 +836,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...),
},
{
"input_prompt": "The Kings of France in the 17th Century:",
- "llm_model_name": "phind-codellama-34b-python-v1",
+ "llm_model_name": "mistral-7b-instruct-v0.1",
"grammar_file_string": "json",
"number_of_tokens_to_generate": 500,
"number_of_completions_to_generate": 3,
@@ -1884,7 +851,15 @@ async def get_text_completions_from_input_prompt(request: TextCompletionRequest,
logger.warning(f"Unauthorized request from client IP {client_ip}")
raise HTTPException(status_code=403, detail="Unauthorized")
try:
- return await generate_completion_from_llm(request, req, client_ip)
+ unique_id = f"text_completion_{hash(request.input_prompt)}_{request.llm_model_name}"
+ lock = await shared_resources.lock_manager.lock(unique_id)
+ if lock.valid:
+ try:
+ return await generate_completion_from_llm(request, req, client_ip)
+ finally:
+ await shared_resources.lock_manager.unlock(lock)
+ else:
+ return {"status": "already processing"}
except Exception as e:
logger.error(f"An error occurred while processing the request: {e}")
logger.error(traceback.format_exc()) # Print the traceback
@@ -1936,24 +911,7 @@ async def clear_ramdisk_endpoint(token: str = None):
@app.on_event("startup")
async def startup_event():
- global db_writer, faiss_indexes, token_faiss_indexes, associated_texts_by_model
- await initialize_db()
- queue = asyncio.Queue()
- db_writer = DatabaseWriter(queue)
- await db_writer.initialize_processing_hashes()
- asyncio.create_task(db_writer.dedicated_db_writer())
- global USE_RAMDISK
- if USE_RAMDISK and not check_that_user_has_required_permissions_to_manage_ramdisks():
- USE_RAMDISK = False
- elif USE_RAMDISK:
- setup_ramdisk()
- list_of_downloaded_model_names, download_status = download_models()
- for llm_model_name in list_of_downloaded_model_names:
- try:
- load_model(llm_model_name, raise_http_exception=False)
- except FileNotFoundError as e:
- logger.error(e)
- faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes()
+ await initialize_globals()
@app.get("/download/{file_name}")
@@ -1985,4 +943,4 @@ def show_logs_default():
if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT)
+ uvicorn.run("swiss_army_llama:app", **option)
diff --git a/uvicorn_config.py b/uvicorn_config.py
new file mode 100644
index 0000000..ac4f89f
--- /dev/null
+++ b/uvicorn_config.py
@@ -0,0 +1,9 @@
+from decouple import config
+SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int)
+UVICORN_NUMBER_OF_WORKERS = config("UVICORN_NUMBER_OF_WORKERS", default=3, cast=int)
+
+option = {
+ "host": "0.0.0.0",
+ "port": SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT,
+ "workers": UVICORN_NUMBER_OF_WORKERS
+}