From deeaf8a0ee5cfc417a2808ae1470ae82d917caaa Mon Sep 17 00:00:00 2001 From: JE Date: Mon, 2 Oct 2023 18:23:51 -0500 Subject: [PATCH] Big refactoring, splitting up large `swiss_army_llama.py` file into several new code files: - `database_functions.py` - `log_viewer_functions.py` - `misc_utility_functions.py` - `ramdisk_functions.py` - `service_functions.py` - `shared_resources.py` - `uvicorn_config.py` Also, introduction of Redis based locks to enable multiple uvicorn workers to run without stepping on each other's toes. This allows multiple concurrent clients to be served, and also allows the log viewer to be used while the service is running without waiting. Fixes to log viewing code. Added uvloop to requirements.txt to speed up asyncio --- .env | 1 + README.md | 2 + database_functions.py | 162 ++ .../llama_knife_sticker.webp | Bin .../llama_knife_sticker2.jpg | Bin .../swiss_army_llama__swagger_screenshot.png | Bin ...army_llama__swagger_screenshot_running.png | Bin .../swiss_army_llama_logo.webp | Bin log_viewer_functions.py | 20 +- logger_config.py | 35 + misc_utility_functions.py | 215 +++ model_urls.json | 1 + ramdisk_functions.py | 79 + requirements.txt | 3 + service_functions.py | 614 +++++++ shared_resources.py | 149 ++ swiss_army_llama.py | 1438 +++-------------- uvicorn_config.py | 9 + 18 files changed, 1478 insertions(+), 1250 deletions(-) create mode 100644 database_functions.py rename llama_knife_sticker.webp => image_files/llama_knife_sticker.webp (100%) rename llama_knife_sticker2.jpg => image_files/llama_knife_sticker2.jpg (100%) rename swiss_army_llama__swagger_screenshot.png => image_files/swiss_army_llama__swagger_screenshot.png (100%) rename swiss_army_llama__swagger_screenshot_running.png => image_files/swiss_army_llama__swagger_screenshot_running.png (100%) rename swiss_army_llama_logo.webp => image_files/swiss_army_llama_logo.webp (100%) create mode 100644 logger_config.py create mode 100644 misc_utility_functions.py create mode 100644 model_urls.json create mode 100644 ramdisk_functions.py create mode 100644 service_functions.py create mode 100644 shared_resources.py create mode 100644 uvicorn_config.py diff --git a/.env b/.env index bbbf0ef..b043103 100644 --- a/.env +++ b/.env @@ -8,6 +8,7 @@ DEFAULT_MAX_COMPLETION_TOKENS=1000 DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE =1 DEFAULT_COMPLETION_TEMPERATURE=0.7 LLAMA_EMBEDDING_SERVER_LISTEN_PORT=8089 +UVICORN_NUMBER_OF_WORKERS=2 MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING=15 MAX_RETRIES=10 DB_WRITE_BATCH_SIZE=25 diff --git a/README.md b/README.md index 0064201..5ea7912 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,7 @@ fast_vector_similarity faster-whisper textract pytz +uvloop ``` ## Running the Application @@ -137,6 +138,7 @@ You can configure the service easily by editing the included `.env` file. Here's - `DEFAULT_MODEL_NAME`: Default model name to use. (e.g., `yarn-llama-2-13b-128k`) - `LLM_CONTEXT_SIZE_IN_TOKENS`: Context size in tokens for LLM. (e.g., `512`) - `SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT`: Port number for the service. (e.g., `8089`) +- `UVICORN_NUMBER_OF_WORKERS`: Number of workers for Uvicorn. (e.g., `2`) - `MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING`: Minimum string length for document embedding. (e.g., `15`) - `MAX_RETRIES`: Maximum retries for locked database. (e.g., `10`) - `DB_WRITE_BATCH_SIZE`: Database write batch size. (e.g., `25`) diff --git a/database_functions.py b/database_functions.py new file mode 100644 index 0000000..2b5a52f --- /dev/null +++ b/database_functions.py @@ -0,0 +1,162 @@ +from embeddings_data_models import Base, TextEmbedding, DocumentEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundle, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript +from logger_config import setup_logger +import traceback +import asyncio +import random +from sqlalchemy import select +from sqlalchemy import text as sql_text +from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from datetime import datetime +from decouple import config + +logger = setup_logger() +db_writer = None +DATABASE_URL = "sqlite+aiosqlite:///swiss_army_llama.sqlite" +MAX_RETRIES = config("MAX_RETRIES", default=3, cast=int) +DB_WRITE_BATCH_SIZE = config("DB_WRITE_BATCH_SIZE", default=25, cast=int) +RETRY_DELAY_BASE_SECONDS = config("RETRY_DELAY_BASE_SECONDS", default=1, cast=int) +JITTER_FACTOR = config("JITTER_FACTOR", default=0.1, cast=float) + +engine = create_async_engine(DATABASE_URL, echo=False, connect_args={"check_same_thread": False}) +AsyncSessionLocal = sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False +) +class DatabaseWriter: + def __init__(self, queue): + self.queue = queue + self.processing_hashes = set() # Set to store the hashes if everything that is currently being processed in the queue (to avoid duplicates of the same task being added to the queue) + + def _get_hash_from_operation(self, operation): + attr_name = { + TextEmbedding: 'text_hash', + DocumentEmbedding: 'file_hash', + Document: 'document_hash', + TokenLevelEmbedding: 'token_hash', + TokenLevelEmbeddingBundle: 'input_text_hash', + TokenLevelEmbeddingBundleCombinedFeatureVector: 'combined_feature_vector_hash', + AudioTranscript: 'audio_file_hash' + }.get(type(operation)) + hash_value = getattr(operation, attr_name, None) + llm_model_name = getattr(operation, 'llm_model_name', None) + return f"{hash_value}_{llm_model_name}" if hash_value and llm_model_name else None + + async def initialize_processing_hashes(self, chunk_size=1000): + start_time = datetime.utcnow() + async with AsyncSessionLocal() as session: + queries = [ + (select(TextEmbedding.text_hash, TextEmbedding.llm_model_name), True), + (select(DocumentEmbedding.file_hash, DocumentEmbedding.llm_model_name), True), + (select(Document.document_hash, Document.llm_model_name), True), + (select(TokenLevelEmbedding.token_hash, TokenLevelEmbedding.llm_model_name), True), + (select(TokenLevelEmbeddingBundle.input_text_hash, TokenLevelEmbeddingBundle.llm_model_name), True), + (select(TokenLevelEmbeddingBundleCombinedFeatureVector.combined_feature_vector_hash, TokenLevelEmbeddingBundleCombinedFeatureVector.llm_model_name), True), + (select(AudioTranscript.audio_file_hash), False) + ] + for query, has_llm in queries: + offset = 0 + while True: + result = await session.execute(query.limit(chunk_size).offset(offset)) + rows = result.fetchall() + if not rows: + break + for row in rows: + if has_llm: + hash_with_model = f"{row[0]}_{row[1]}" + else: + hash_with_model = row[0] + self.processing_hashes.add(hash_with_model) + offset += chunk_size + end_time = datetime.utcnow() + total_time = (end_time - start_time).total_seconds() + if len(self.processing_hashes) > 0: + logger.info(f"Finished initializing set of input hash/llm_model_name combinations that are either currently being processed or have already been processed. Set size: {len(self.processing_hashes)}; Took {total_time} seconds, for an average of {total_time / len(self.processing_hashes)} seconds per hash.") + + async def _handle_integrity_error(self, e, write_operation, session): + unique_constraint_msg = { + TextEmbedding: "token_embeddings.token_hash, token_embeddings.llm_model_name", + DocumentEmbedding: "document_embeddings.file_hash, document_embeddings.llm_model_name", + Document: "documents.document_hash, documents.llm_model_name", + TokenLevelEmbedding: "token_level_embeddings.token_hash, token_level_embeddings.llm_model_name", + TokenLevelEmbeddingBundle: "token_level_embedding_bundles.input_text_hash, token_level_embedding_bundles.llm_model_name", + AudioTranscript: "audio_transcripts.audio_file_hash" + }.get(type(write_operation)) + if unique_constraint_msg and unique_constraint_msg in str(e): + logger.warning(f"Embedding already exists in the database for given input and llm_model_name: {e}") + await session.rollback() + else: + raise + + async def dedicated_db_writer(self): + while True: + write_operations_batch = await self.queue.get() + async with AsyncSessionLocal() as session: + try: + for write_operation in write_operations_batch: + session.add(write_operation) + await session.flush() # Flush to get the IDs + await session.commit() + for write_operation in write_operations_batch: + hash_to_remove = self._get_hash_from_operation(write_operation) + if hash_to_remove is not None and hash_to_remove in self.processing_hashes: + self.processing_hashes.remove(hash_to_remove) + except IntegrityError as e: + await self._handle_integrity_error(e, write_operation, session) + except SQLAlchemyError as e: + logger.error(f"Database error: {e}") + await session.rollback() + except Exception as e: + tb = traceback.format_exc() + logger.error(f"Unexpected error: {e}\n{tb}") + await session.rollback() + self.queue.task_done() + + async def enqueue_write(self, write_operations): + write_operations = [op for op in write_operations if self._get_hash_from_operation(op) not in self.processing_hashes] # Filter out write operations for hashes that are already being processed + if not write_operations: # If there are no write operations left after filtering, return early + return + for op in write_operations: # Add the hashes of the write operations to the set + hash_value = self._get_hash_from_operation(op) + if hash_value: + self.processing_hashes.add(hash_value) + await self.queue.put(write_operations) + + +async def execute_with_retry(func, *args, **kwargs): + retries = 0 + while retries < MAX_RETRIES: + try: + return await func(*args, **kwargs) + except OperationalError as e: + if 'database is locked' in str(e): + retries += 1 + sleep_time = RETRY_DELAY_BASE_SECONDS * (2 ** retries) + (random.random() * JITTER_FACTOR) # Implementing exponential backoff with jitter + logger.warning(f"Database is locked. Retrying ({retries}/{MAX_RETRIES})... Waiting for {sleep_time} seconds") + await asyncio.sleep(sleep_time) + else: + raise + raise OperationalError("Database is locked after multiple retries") + +async def initialize_db(): + logger.info("Initializing database, creating tables, and setting SQLite PRAGMAs...") + list_of_sqlite_pragma_strings = ["PRAGMA journal_mode=WAL;", "PRAGMA synchronous = NORMAL;", "PRAGMA cache_size = -1048576;", "PRAGMA busy_timeout = 2000;", "PRAGMA wal_autocheckpoint = 100;"] + list_of_sqlite_pragma_justification_strings = ["Set SQLite to use Write-Ahead Logging (WAL) mode (from default DELETE mode) so that reads and writes can occur simultaneously", + "Set synchronous mode to NORMAL (from FULL) so that writes are not blocked by reads", + "Set cache size to 1GB (from default 2MB) so that more data can be cached in memory and not read from disk; to make this 256MB, set it to -262144 instead", + "Increase the busy timeout to 2 seconds so that the database waits", + "Set the WAL autocheckpoint to 100 (from default 1000) so that the WAL file is checkpointed more frequently"] + assert(len(list_of_sqlite_pragma_strings) == len(list_of_sqlite_pragma_justification_strings)) + async with engine.begin() as conn: + for pragma_string in list_of_sqlite_pragma_strings: + await conn.execute(sql_text(pragma_string)) + logger.info(f"Executed SQLite PRAGMA: {pragma_string}") + logger.info(f"Justification: {list_of_sqlite_pragma_justification_strings[list_of_sqlite_pragma_strings.index(pragma_string)]}") + await conn.run_sync(Base.metadata.create_all) # Create tables if they don't exist + logger.info("Database initialization completed.") + +def get_db_writer() -> DatabaseWriter: + return db_writer # Return the existing DatabaseWriter instance diff --git a/llama_knife_sticker.webp b/image_files/llama_knife_sticker.webp similarity index 100% rename from llama_knife_sticker.webp rename to image_files/llama_knife_sticker.webp diff --git a/llama_knife_sticker2.jpg b/image_files/llama_knife_sticker2.jpg similarity index 100% rename from llama_knife_sticker2.jpg rename to image_files/llama_knife_sticker2.jpg diff --git a/swiss_army_llama__swagger_screenshot.png b/image_files/swiss_army_llama__swagger_screenshot.png similarity index 100% rename from swiss_army_llama__swagger_screenshot.png rename to image_files/swiss_army_llama__swagger_screenshot.png diff --git a/swiss_army_llama__swagger_screenshot_running.png b/image_files/swiss_army_llama__swagger_screenshot_running.png similarity index 100% rename from swiss_army_llama__swagger_screenshot_running.png rename to image_files/swiss_army_llama__swagger_screenshot_running.png diff --git a/swiss_army_llama_logo.webp b/image_files/swiss_army_llama_logo.webp similarity index 100% rename from swiss_army_llama_logo.webp rename to image_files/swiss_army_llama_logo.webp diff --git a/log_viewer_functions.py b/log_viewer_functions.py index ac2c787..d271560 100644 --- a/log_viewer_functions.py +++ b/log_viewer_functions.py @@ -48,7 +48,6 @@ def highlight_rules_func(text): text = text.replace('#COLOR13_OPEN#', '').replace('#COLOR13_CLOSE#', '') return text - def show_logs_incremental_func(minutes: int, last_position: int): new_logs = [] now = datetime.now(timezone('UTC')) # get current time, make it timezone-aware @@ -75,21 +74,22 @@ def show_logs_incremental_func(minutes: int, last_position: int): def show_logs_func(minutes: int = 5): - # read the entire log file and generate HTML with logs up to `minutes` minutes from now with open(log_file_path, "r") as f: lines = f.readlines() logs = [] - now = datetime.now(timezone('UTC')) # get current time, make it timezone-aware + now = datetime.now(timezone('UTC')) for line in lines: if line.strip() == "": continue - if line[0].isdigit(): - log_datetime_str = line.split(" - ")[0] # assuming the datetime is at the start of each line - log_datetime = datetime.strptime(log_datetime_str, "%Y-%m-%d %H:%M:%S,%f") # parse the datetime string to a datetime object - log_datetime = log_datetime.replace(tzinfo=timezone('UTC')) # set the datetime object timezone to UTC to match `now` - if now - log_datetime <= timedelta(minutes=minutes): # if the log is within `minutes` minutes from now - logs.append(highlight_rules_func(line.rstrip('\n'))) # add the highlighted log to the list and strip any newline at the end - logs_as_string = "
".join(logs) # joining with
directly + try: + log_datetime_str = line.split(" - ")[0] + log_datetime = datetime.strptime(log_datetime_str, "%Y-%m-%d %H:%M:%S,%f") + log_datetime = log_datetime.replace(tzinfo=timezone('UTC')) + if now - log_datetime <= timedelta(minutes=minutes): + logs.append(highlight_rules_func(line.rstrip('\n'))) + except (ValueError, IndexError): + logs.append(highlight_rules_func(line.rstrip('\n'))) # Line didn't meet datetime parsing criteria, continue with processing + logs_as_string = "
".join(logs) logs_as_string_newlines_rendered = logs_as_string.replace("\n", "
") logs_as_string_newlines_rendered_font_specified = """ diff --git a/logger_config.py b/logger_config.py new file mode 100644 index 0000000..a21fe5a --- /dev/null +++ b/logger_config.py @@ -0,0 +1,35 @@ +import logging +import os +import shutil +import queue +from logging.handlers import RotatingFileHandler, QueueHandler, QueueListener + +logger = logging.getLogger("swiss_army_llama") + +def setup_logger(): + if logger.handlers: + return logger + old_logs_dir = 'old_logs' + if not os.path.exists(old_logs_dir): + os.makedirs(old_logs_dir) + logger.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + log_file_path = 'swiss_army_llama.log' + log_queue = queue.Queue(-1) # Create a queue for the handlers + fh = RotatingFileHandler(log_file_path, maxBytes=10*1024*1024, backupCount=5) + fh.setFormatter(formatter) + def namer(default_log_name): # Function to move rotated logs to the old_logs directory + return os.path.join(old_logs_dir, os.path.basename(default_log_name)) + def rotator(source, dest): + shutil.move(source, dest) + fh.namer = namer + fh.rotator = rotator + sh = logging.StreamHandler() # Stream handler + sh.setFormatter(formatter) + queue_handler = QueueHandler(log_queue) # Create QueueHandler + queue_handler.setFormatter(formatter) + logger.addHandler(queue_handler) + listener = QueueListener(log_queue, fh, sh) # Create QueueListener with real handlers + listener.start() + logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # Configure SQLalchemy logging + return logger diff --git a/misc_utility_functions.py b/misc_utility_functions.py new file mode 100644 index 0000000..7c73ffd --- /dev/null +++ b/misc_utility_functions.py @@ -0,0 +1,215 @@ +from logger_config import setup_logger +from database_functions import AsyncSessionLocal +import socket +import os +import re +import json +import io +import numpy as np +import faiss +from typing import Any +from collections import defaultdict +from sqlalchemy import text as sql_text + +logger = setup_logger() + +def clean_filename_for_url_func(dirty_filename: str) -> str: + clean_filename = re.sub(r'[^\w\s]', '', dirty_filename) # Remove special characters and replace spaces with underscores + clean_filename = clean_filename.replace(' ', '_') + return clean_filename + +def is_redis_running(host='localhost', port=6379): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.connect((host, port)) + return True + except ConnectionRefusedError: + return False + finally: + s.close() + + +async def build_faiss_indexes(): + global faiss_indexes, token_faiss_indexes, associated_texts_by_model + if os.environ.get("FAISS_SETUP_DONE") == "1": + logger.info("Faiss indexes already built by another worker. Skipping.") + return faiss_indexes, token_faiss_indexes, associated_texts_by_model + faiss_indexes = {} + token_faiss_indexes = {} # Separate FAISS indexes for token-level embeddings + associated_texts_by_model = defaultdict(list) # Create a dictionary to store associated texts by model name + async with AsyncSessionLocal() as session: + result = await session.execute(sql_text("SELECT llm_model_name, text, embedding_json FROM embeddings")) # Query regular embeddings + token_result = await session.execute(sql_text("SELECT llm_model_name, token, token_level_embedding_json FROM token_level_embeddings")) # Query token-level embeddings + embeddings_by_model = defaultdict(list) + token_embeddings_by_model = defaultdict(list) + for row in result.fetchall(): # Process regular embeddings + llm_model_name = row[0] + associated_texts_by_model[llm_model_name].append(row[1]) # Store the associated text by model name + embeddings_by_model[llm_model_name].append((row[1], json.loads(row[2]))) + for row in token_result.fetchall(): # Process token-level embeddings + llm_model_name = row[0] + token_embeddings_by_model[llm_model_name].append(json.loads(row[2])) + for llm_model_name, embeddings in embeddings_by_model.items(): + logger.info(f"Building Faiss index over embeddings for model {llm_model_name}...") + embeddings_array = np.array([e[1] for e in embeddings]).astype('float32') + if embeddings_array.size == 0: + logger.error(f"No embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!") + continue + logger.info(f"Loaded {len(embeddings_array)} embeddings for model {llm_model_name}.") + logger.info(f"Embedding dimension for model {llm_model_name}: {embeddings_array.shape[1]}") + logger.info(f"Normalizing {len(embeddings_array)} embeddings for model {llm_model_name}...") + faiss.normalize_L2(embeddings_array) # Normalize the vectors for cosine similarity + faiss_index = faiss.IndexFlatIP(embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity + faiss_index.add(embeddings_array) + logger.info(f"Faiss index built for model {llm_model_name}.") + faiss_indexes[llm_model_name] = faiss_index # Store the index by model name + for llm_model_name, token_embeddings in token_embeddings_by_model.items(): + token_embeddings_array = np.array(token_embeddings).astype('float32') + if token_embeddings_array.size == 0: + logger.error(f"No token-level embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!") + continue + logger.info(f"Normalizing {len(token_embeddings_array)} token-level embeddings for model {llm_model_name}...") + faiss.normalize_L2(token_embeddings_array) # Normalize the vectors for cosine similarity + token_faiss_index = faiss.IndexFlatIP(token_embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity + token_faiss_index.add(token_embeddings_array) + logger.info(f"Token-level Faiss index built for model {llm_model_name}.") + token_faiss_indexes[llm_model_name] = token_faiss_index # Store the token-level index by model name + os.environ["FAISS_SETUP_DONE"] = "1" + logger.info("Faiss indexes built.") + return faiss_indexes, token_faiss_indexes, associated_texts_by_model + +class JSONAggregator: + def __init__(self): + self.completions = [] + self.aggregate_result = None + + @staticmethod + def weighted_vote(values, weights): + tally = defaultdict(float) + for v, w in zip(values, weights): + tally[v] += w + return max(tally, key=tally.get) + + @staticmethod + def flatten_json(json_obj, parent_key='', sep='->'): + items = {} + for k, v in json_obj.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.update(JSONAggregator.flatten_json(v, new_key, sep=sep)) + else: + items[new_key] = v + return items + + @staticmethod + def get_value_by_path(json_obj, path, sep='->'): + keys = path.split(sep) + item = json_obj + for k in keys: + item = item[k] + return item + + @staticmethod + def set_value_by_path(json_obj, path, value, sep='->'): + keys = path.split(sep) + item = json_obj + for k in keys[:-1]: + item = item.setdefault(k, {}) + item[keys[-1]] = value + + def calculate_path_weights(self): + all_paths = [] + for j in self.completions: + all_paths += list(self.flatten_json(j).keys()) + path_weights = defaultdict(float) + for path in all_paths: + path_weights[path] += 1.0 + return path_weights + + def aggregate(self): + path_weights = self.calculate_path_weights() + aggregate = {} + for path, weight in path_weights.items(): + values = [self.get_value_by_path(j, path) for j in self.completions if path in self.flatten_json(j)] + weights = [weight] * len(values) + aggregate_value = self.weighted_vote(values, weights) + self.set_value_by_path(aggregate, path, aggregate_value) + self.aggregate_result = aggregate + +class FakeUploadFile: + def __init__(self, filename: str, content: Any, content_type: str = 'text/plain'): + self.filename = filename + self.content_type = content_type + self.file = io.BytesIO(content) + def read(self, size: int = -1) -> bytes: + return self.file.read(size) + def seek(self, offset: int, whence: int = 0) -> int: + return self.file.seek(offset, whence) + def tell(self) -> int: + return self.file.tell() + +def normalize_logprobs(avg_logprob, min_logprob, max_logprob): + range_logprob = max_logprob - min_logprob + return (avg_logprob - min_logprob) / range_logprob if range_logprob != 0 else 0.5 + +def remove_pagination_breaks(text: str) -> str: + text = re.sub(r'-(\n)(?=[a-z])', '', text) # Remove hyphens at the end of lines when the word continues on the next line + text = re.sub(r'(?<=\w)(? total_ram_gb: + raise ValueError(f"Cannot allocate {RAMDISK_SIZE_IN_GB}G for RAM Disk. Total system RAM is {total_ram_gb:.2f}G.") + logger.info("Setting up RAM Disk...") + os.makedirs(RAMDISK_PATH, exist_ok=True) + mount_command = ["sudo", "mount", "-t", "tmpfs", "-o", f"size={ramdisk_size_str}", "tmpfs", RAMDISK_PATH] + subprocess.run(mount_command, check=True) + logger.info(f"RAM Disk set up at {RAMDISK_PATH} with size {ramdisk_size_gb}G") + os.environ["RAMDISK_SETUP_DONE"] = "1" + +def copy_models_to_ramdisk(models_directory, ramdisk_directory): + total_size = sum(os.path.getsize(os.path.join(models_directory, model)) for model in os.listdir(models_directory)) + free_ram = psutil.virtual_memory().free + if total_size > free_ram: + logger.warning(f"Not enough space on RAM Disk. Required: {total_size}, Available: {free_ram}. Rebuilding RAM Disk.") + clear_ramdisk() + free_ram = psutil.virtual_memory().free # Recompute the available RAM after clearing the RAM disk + if total_size > free_ram: + logger.error(f"Still not enough space on RAM Disk even after clearing. Required: {total_size}, Available: {free_ram}.") + raise ValueError("Not enough RAM space to copy models.") + os.makedirs(ramdisk_directory, exist_ok=True) + for model in os.listdir(models_directory): + src_path = os.path.join(models_directory, model) + dest_path = os.path.join(ramdisk_directory, model) + if os.path.exists(dest_path) and os.path.getsize(dest_path) == os.path.getsize(src_path): # Check if the file already exists in the RAM disk and has the same size + logger.info(f"Model {model} already exists in RAM Disk and is the same size. Skipping copy.") + continue + shutil.copyfile(src_path, dest_path) + logger.info(f"Copied model {model} to RAM Disk at {dest_path}") + +def clear_ramdisk(): + while True: + cmd_check = f"sudo mount | grep {RAMDISK_PATH}" + result = subprocess.run(cmd_check, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') + if RAMDISK_PATH not in result: + break # Exit the loop if the RAMDISK_PATH is not in the mount list + cmd_umount = f"sudo umount -l {RAMDISK_PATH}" + subprocess.run(cmd_umount, shell=True, check=True) + logger.info(f"Cleared RAM Disk at {RAMDISK_PATH}") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2f2f90a..0403a1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,6 @@ faster-whisper textract pytest pytz +uvloop +aioredis +aioredlock \ No newline at end of file diff --git a/service_functions.py b/service_functions.py new file mode 100644 index 0000000..553c7b6 --- /dev/null +++ b/service_functions.py @@ -0,0 +1,614 @@ +from logger_config import setup_logger +import shared_resources +from shared_resources import load_model, token_level_embedding_model_cache, text_completion_model_cache +from database_functions import AsyncSessionLocal, DatabaseWriter, execute_with_retry +from misc_utility_functions import clean_filename_for_url_func, FakeUploadFile, sophisticated_sentence_splitter, merge_transcript_segments_into_combined_text +from embeddings_data_models import TextEmbedding, DocumentEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript +from embeddings_data_models import EmbeddingRequest, TextCompletionRequest +from embeddings_data_models import TextCompletionResponse, AudioTranscriptResponse +import os +import shutil +import psutil +import glob +import json +import asyncio +import zipfile +import tempfile +import time +from datetime import datetime +from hashlib import sha3_256 +from urllib.parse import quote +import numpy as np +import pandas as pd +import textract +from sqlalchemy import text as sql_text +from sqlalchemy import select +from fastapi import HTTPException, Request, UploadFile, File +from fastapi.concurrency import run_in_threadpool +from typing import List, Optional, Tuple +from decouple import config +from faster_whisper import WhisperModel +from llama_cpp import Llama, LlamaGrammar + +logger = setup_logger() +SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int) +DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str) +LLM_CONTEXT_SIZE_IN_TOKENS = config("LLM_CONTEXT_SIZE_IN_TOKENS", default=512, cast=int) +TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS = config("TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS", default=4000, cast=int) +DEFAULT_MAX_COMPLETION_TOKENS = config("DEFAULT_MAX_COMPLETION_TOKENS", default=100, cast=int) +DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE = config("DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE", default=4, cast=int) +DEFAULT_COMPLETION_TEMPERATURE = config("DEFAULT_COMPLETION_TEMPERATURE", default=0.7, cast=float) +MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING = config("MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING", default=15, cast=int) +USE_PARALLEL_INFERENCE_QUEUE = config("USE_PARALLEL_INFERENCE_QUEUE", default=False, cast=bool) +MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS = config("MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS", default=10, cast=int) +USE_RAMDISK = config("USE_RAMDISK", default=False, cast=bool) +RAMDISK_PATH = config("RAMDISK_PATH", default="/mnt/ramdisk", cast=str) +BASE_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) + +async def get_transcript_from_db(audio_file_hash: str): + return await execute_with_retry(_get_transcript_from_db, audio_file_hash) + +async def _get_transcript_from_db(audio_file_hash: str) -> Optional[dict]: + async with AsyncSessionLocal() as session: + result = await session.execute( + sql_text("SELECT * FROM audio_transcripts WHERE audio_file_hash=:audio_file_hash"), + {"audio_file_hash": audio_file_hash}, + ) + row = result.fetchone() + if row: + try: + segments_json = json.loads(row.segments_json) + combined_transcript_text_list_of_metadata_dicts = json.loads(row.combined_transcript_text_list_of_metadata_dicts) + info_json = json.loads(row.info_json) + if hasattr(info_json, '__dict__'): + info_json = vars(info_json) + except json.JSONDecodeError as e: + raise ValueError(f"JSON Decode Error: {e}") + if not isinstance(segments_json, list) or not isinstance(combined_transcript_text_list_of_metadata_dicts, list) or not isinstance(info_json, dict): + logger.error(f"Type of segments_json: {type(segments_json)}, Value: {segments_json}") + logger.error(f"Type of combined_transcript_text_list_of_metadata_dicts: {type(combined_transcript_text_list_of_metadata_dicts)}, Value: {combined_transcript_text_list_of_metadata_dicts}") + logger.error(f"Type of info_json: {type(info_json)}, Value: {info_json}") + raise ValueError("Deserialized JSON does not match the expected format.") + audio_transcript_response = { + "id": row.id, + "audio_file_name": row.audio_file_name, + "audio_file_size_mb": row.audio_file_size_mb, + "segments_json": segments_json, + "combined_transcript_text": row.combined_transcript_text, + "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts, + "info_json": info_json, + "ip_address": row.ip_address, + "request_time": row.request_time, + "response_time": row.response_time, + "total_time": row.total_time, + "url_to_download_zip_file_of_embeddings": "" + } + return AudioTranscriptResponse(**audio_transcript_response) + return None + +async def save_transcript_to_db(audio_file_hash, audio_file_name, audio_file_size_mb, transcript_segments, info, ip_address, request_time, response_time, total_time, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts): + existing_transcript = await get_transcript_from_db(audio_file_hash) + if existing_transcript: + return existing_transcript + audio_transcript = AudioTranscript( + audio_file_hash=audio_file_hash, + audio_file_name=audio_file_name, + audio_file_size_mb=audio_file_size_mb, + segments_json=json.dumps(transcript_segments), + combined_transcript_text=combined_transcript_text, + combined_transcript_text_list_of_metadata_dicts=json.dumps(combined_transcript_text_list_of_metadata_dicts), + info_json=json.dumps(info), + ip_address=ip_address, + request_time=request_time, + response_time=response_time, + total_time=total_time + ) + await shared_resources.db_writer.enqueue_write([audio_transcript]) + + +async def compute_and_store_transcript_embeddings(audio_file_name, list_of_transcript_sentences, llm_model_name, ip_address, combined_transcript_text, req: Request): + logger.info(f"Now computing embeddings for entire transcript of {audio_file_name}...") + zip_dir = 'generated_transcript_embeddings_zip_files' + if not os.path.exists(zip_dir): + os.makedirs(zip_dir) + sanitized_file_name = clean_filename_for_url_func(audio_file_name) + document_name = f"automatic_whisper_transcript_of__{sanitized_file_name}" + file_hash = sha3_256(combined_transcript_text.encode('utf-8')).hexdigest() + computed_embeddings = await compute_embeddings_for_document(list_of_transcript_sentences, llm_model_name, ip_address, file_hash) + zip_file_path = f"{zip_dir}/{quote(document_name)}.zip" + with zipfile.ZipFile(zip_file_path, 'w') as zipf: + zipf.writestr("embeddings.txt", json.dumps(computed_embeddings)) + download_url = f"download/{quote(document_name)}.zip" + full_download_url = f"{req.base_url}{download_url}" + logger.info(f"Generated download URL for transcript embeddings: {full_download_url}") + fake_upload_file = FakeUploadFile(filename=document_name, content=combined_transcript_text.encode(), content_type='text/plain') + logger.info(f"Storing transcript embeddings for {audio_file_name} in the database...") + await store_document_embeddings_in_db(fake_upload_file, file_hash, combined_transcript_text.encode(), json.dumps(computed_embeddings).encode(), computed_embeddings, llm_model_name, ip_address, datetime.utcnow()) + return full_download_url + +async def compute_transcript_with_whisper_from_audio_func(audio_file_hash, audio_file_path, audio_file_name, audio_file_size_mb, ip_address, req: Request, compute_embeddings_for_resulting_transcript_document=True, llm_model_name=DEFAULT_MODEL_NAME): + model_size = "large-v2" + logger.info(f"Loading Whisper model {model_size}...") + num_workers = 1 if psutil.virtual_memory().total < 32 * (1024 ** 3) else min(4, max(1, int((psutil.virtual_memory().total - 32 * (1024 ** 3)) / (4 * (1024 ** 3))))) # Only use more than 1 worker if there is at least 32GB of RAM; then use 1 worker per additional 4GB of RAM up to 4 workers max + model = await run_in_threadpool(WhisperModel, model_size, device="cpu", compute_type="auto", cpu_threads=os.cpu_count(), num_workers=num_workers) + request_time = datetime.utcnow() + logger.info(f"Computing transcript for {audio_file_name} which has a {audio_file_size_mb :.2f}MB file size...") + segments, info = await run_in_threadpool(model.transcribe, audio_file_path, beam_size=20) + if not segments: + logger.warning(f"No segments were returned for file {audio_file_name}.") + return [], {}, "", [], request_time, datetime.utcnow(), 0, "" + segment_details = [] + for idx, segment in enumerate(segments): + details = { + "start": round(segment.start, 2), + "end": round(segment.end, 2), + "text": segment.text, + "avg_logprob": round(segment.avg_logprob, 2) + } + logger.info(f"Details of transcript segment {idx} from file {audio_file_name}: {details}") + segment_details.append(details) + combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, list_of_transcript_sentences = merge_transcript_segments_into_combined_text(segment_details) + if compute_embeddings_for_resulting_transcript_document: + download_url = await compute_and_store_transcript_embeddings(audio_file_name, list_of_transcript_sentences, llm_model_name, ip_address, combined_transcript_text, req) + else: + download_url = '' + response_time = datetime.utcnow() + total_time = (response_time - request_time).total_seconds() + logger.info(f"Transcript computed in {total_time} seconds.") + await save_transcript_to_db(audio_file_hash, audio_file_name, audio_file_size_mb, segment_details, info, ip_address, request_time, response_time, total_time, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts) + info_dict = info._asdict() + return segment_details, info_dict, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, request_time, response_time, total_time, download_url + +async def get_or_compute_transcript(file: UploadFile, compute_embeddings_for_resulting_transcript_document: bool, llm_model_name: str, req: Request = None) -> dict: + request_time = datetime.utcnow() + ip_address = req.client.host if req else "127.0.0.1" + file_contents = await file.read() + audio_file_hash = sha3_256(file_contents).hexdigest() + file.file.seek(0) # Reset file pointer after read + unique_id = f"transcript_{audio_file_hash}_{llm_model_name}" + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + existing_audio_transcript = await get_transcript_from_db(audio_file_hash) + if existing_audio_transcript: + return existing_audio_transcript + current_position = file.file.tell() + file.file.seek(0, os.SEEK_END) + audio_file_size_mb = file.file.tell() / (1024 * 1024) + file.file.seek(current_position) + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + shutil.copyfileobj(file.file, tmp_file) + audio_file_name = tmp_file.name + segment_details, info, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, request_time, response_time, total_time, download_url = await compute_transcript_with_whisper_from_audio_func(audio_file_hash, audio_file_name, file.filename, audio_file_size_mb, ip_address, req, compute_embeddings_for_resulting_transcript_document, llm_model_name) + audio_transcript_response = { + "audio_file_hash": audio_file_hash, + "audio_file_name": file.filename, + "audio_file_size_mb": audio_file_size_mb, + "segments_json": segment_details, + "combined_transcript_text": combined_transcript_text, + "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts, + "info_json": info, + "ip_address": ip_address, + "request_time": request_time, + "response_time": response_time, + "total_time": total_time, + "url_to_download_zip_file_of_embeddings": download_url if compute_embeddings_for_resulting_transcript_document else "" + } + os.remove(audio_file_name) + return AudioTranscriptResponse(**audio_transcript_response) + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing"} + + +# Core embedding functions start here: + + +def add_model_url(new_url: str) -> str: + corrected_url = new_url + if '/blob/main/' in new_url: + corrected_url = new_url.replace('/blob/main/', '/resolve/main/') + json_path = os.path.join(BASE_DIRECTORY, "model_urls.json") + with open(json_path, "r") as f: + existing_urls = json.load(f) + if corrected_url not in existing_urls: + logger.info(f"Model URL not found in database. Adding {new_url} now...") + existing_urls.append(corrected_url) + with open(json_path, "w") as f: + json.dump(existing_urls, f) + logger.info(f"Model URL added: {new_url}") + else: + logger.info("Model URL already exists.") + return corrected_url + +async def get_embedding_from_db(text: str, llm_model_name: str): + text_hash = sha3_256(text.encode('utf-8')).hexdigest() # Compute the hash + return await execute_with_retry(_get_embedding_from_db, text_hash, llm_model_name) + +async def _get_embedding_from_db(text_hash: str, llm_model_name: str) -> Optional[dict]: + async with AsyncSessionLocal() as session: + result = await session.execute( + sql_text("SELECT embedding_json FROM embeddings WHERE text_hash=:text_hash AND llm_model_name=:llm_model_name"), + {"text_hash": text_hash, "llm_model_name": llm_model_name}, + ) + row = result.fetchone() + if row: + embedding_json = row[0] + logger.info(f"Embedding found in database for text hash '{text_hash}' using model '{llm_model_name}'") + return json.loads(embedding_json) + return None + +async def get_or_compute_embedding(request: EmbeddingRequest, req: Request = None, client_ip: str = None, document_file_hash: str = None) -> dict: + request_time = datetime.utcnow() # Capture request time as datetime object + ip_address = client_ip or (req.client.host if req else "localhost") # If client_ip is provided, use it; otherwise, try to get from req; if not available, default to "localhost" + logger.info(f"Received request for embedding for '{request.text}' using model '{request.llm_model_name}' from IP address '{ip_address}'") + embedding_list = await get_embedding_from_db(request.text, request.llm_model_name) # Check if embedding exists in the database + if embedding_list is not None: + response_time = datetime.utcnow() # Capture response time as datetime object + total_time = (response_time - request_time).total_seconds() # Calculate time taken in seconds + logger.info(f"Embedding found in database for '{request.text}' using model '{request.llm_model_name}'; returning in {total_time:.4f} seconds") + return {"embedding": embedding_list} + model = load_model(request.llm_model_name) + embedding_list = calculate_sentence_embedding(model, request.text) # Compute the embedding if not in the database + if embedding_list is None: + logger.error(f"Could not calculate the embedding for the given text: '{request.text}' using model '{request.llm_model_name}!'") + raise HTTPException(status_code=400, detail="Could not calculate the embedding for the given text") + embedding_json = json.dumps(embedding_list) # Serialize the numpy array to JSON and save to the database + response_time = datetime.utcnow() # Capture response time as datetime object + total_time = (response_time - request_time).total_seconds() # Calculate total time using datetime objects + word_length_of_input_text = len(request.text.split()) + if word_length_of_input_text > 0: + logger.info(f"Embedding calculated for '{request.text}' using model '{request.llm_model_name}' in {total_time} seconds, or an average of {total_time/word_length_of_input_text :.2f} seconds per word. Now saving to database...") + await save_embedding_to_db(request.text, request.llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash) + return {"embedding": embedding_list} + +async def save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None): + existing_embedding = await get_embedding_from_db(text, llm_model_name) # Check if the embedding already exists + if existing_embedding is not None: + return existing_embedding + return await execute_with_retry(_save_embedding_to_db, text, llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash) + +async def _save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None): + existing_embedding = await get_embedding_from_db(text, llm_model_name) + if existing_embedding: + return existing_embedding + embedding = TextEmbedding( + text=text, + llm_model_name=llm_model_name, + embedding_json=embedding_json, + ip_address=ip_address, + request_time=request_time, + response_time=response_time, + total_time=total_time, + document_file_hash=document_file_hash + ) + await shared_resources.db_writer.enqueue_write([embedding]) # Enqueue the write operation using the db_writer instance + + +def load_token_level_embedding_model(llm_model_name: str, raise_http_exception: bool = True): + try: + if llm_model_name in token_level_embedding_model_cache: # Check if the model is already loaded in the cache + return token_level_embedding_model_cache[llm_model_name] + models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path + matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files + if not matching_files: + logger.error(f"No model file found matching: {llm_model_name}") + raise FileNotFoundError + matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first) + model_file_path = matching_files[0] + model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model + token_level_embedding_model_cache[llm_model_name] = model_instance # Cache the loaded model + return model_instance + except TypeError as e: + logger.error(f"TypeError occurred while loading the model: {e}") + raise + except Exception as e: + logger.error(f"Exception occurred while loading the model: {e}") + if raise_http_exception: + raise HTTPException(status_code=404, detail="Model file not found") + else: + raise FileNotFoundError(f"No model file found matching: {llm_model_name}") + +async def compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings) -> List[float]: + start_time = datetime.utcnow() + logger.info("Extracting token-level embeddings from the bundle") + parsed_df = pd.read_json(token_level_embeddings) # Parse the json_content back to a DataFrame + token_level_embeddings = list(parsed_df['embedding']) + embeddings = np.array(token_level_embeddings) # Convert the list of embeddings to a NumPy array + logger.info(f"Computing column-wise means/mins/maxes/std_devs of the embeddings... (shape: {embeddings.shape})") + assert(len(embeddings) > 0) + means = np.mean(embeddings, axis=0) + mins = np.min(embeddings, axis=0) + maxes = np.max(embeddings, axis=0) + stds = np.std(embeddings, axis=0) + logger.info("Concatenating the computed statistics to form the combined feature vector") + combined_feature_vector = np.concatenate([means, mins, maxes, stds]) + end_time = datetime.utcnow() + total_time = (end_time - start_time).total_seconds() + logger.info(f"Computed the token-level embedding bundle's combined feature vector computed in {total_time: .2f} seconds.") + return combined_feature_vector.tolist() + + +async def get_or_compute_token_level_embedding_bundle_combined_feature_vector(token_level_embedding_bundle_id, token_level_embeddings, db_writer: DatabaseWriter) -> List[float]: + request_time = datetime.utcnow() + request_time = datetime.utcnow() + logger.info(f"Checking for existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}") + async with AsyncSessionLocal() as session: + result = await session.execute( + select(TokenLevelEmbeddingBundleCombinedFeatureVector) + .filter(TokenLevelEmbeddingBundleCombinedFeatureVector.token_level_embedding_bundle_id == token_level_embedding_bundle_id) + ) + existing_combined_feature_vector = result.scalar_one_or_none() + if existing_combined_feature_vector: + response_time = datetime.utcnow() + total_time = (response_time - request_time).total_seconds() + logger.info(f"Found existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Returning cached result in {total_time:.2f} seconds.") + return json.loads(existing_combined_feature_vector.combined_feature_vector_json) # Parse the JSON string into a list + logger.info(f"No cached combined feature_vector found for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Computing now...") + combined_feature_vector = await compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings) + combined_feature_vector_db_object = TokenLevelEmbeddingBundleCombinedFeatureVector( + token_level_embedding_bundle_id=token_level_embedding_bundle_id, + combined_feature_vector_json=json.dumps(combined_feature_vector) # Convert the list to a JSON string + ) + logger.info(f"Writing combined feature vector for database write for token-level embedding bundle ID: {token_level_embedding_bundle_id} to the database...") + await db_writer.enqueue_write([combined_feature_vector_db_object]) + return combined_feature_vector + + +async def calculate_token_level_embeddings(text: str, llm_model_name: str, client_ip: str, token_level_embedding_bundle_id: int) -> List[np.array]: + request_time = datetime.utcnow() + logger.info(f"Starting token-level embedding calculation for text: '{text}' using model: '{llm_model_name}'") + logger.info(f"Loading model: '{llm_model_name}'") + llm = load_token_level_embedding_model(llm_model_name) # Assuming this method returns an instance of the Llama class + token_embeddings = [] + tokens = text.split() # Simple whitespace tokenizer; can be replaced with a more advanced one if needed + logger.info(f"Tokenized text into {len(tokens)} tokens") + for idx, token in enumerate(tokens, start=1): + try: # Check if the embedding is already available in the database + existing_embedding = await get_token_level_embedding_from_db(token, llm_model_name) + if existing_embedding is not None: + token_embeddings.append(np.array(existing_embedding)) + logger.info(f"Embedding retrieved from database for token '{token}'") + continue + logger.info(f"Processing token {idx} of {len(tokens)}: '{token}'") + token_embedding = llm.embed(token) + token_embedding_array = np.array(token_embedding) + token_embeddings.append(token_embedding_array) + response_time = datetime.utcnow() + token_level_embedding_json = json.dumps(token_embedding_array.tolist()) + await store_token_level_embeddings_in_db(token, llm_model_name, token_level_embedding_json, client_ip, request_time, response_time, token_level_embedding_bundle_id) + except RuntimeError as e: + logger.error(f"Failed to calculate embedding for token '{token}': {e}") + logger.info(f"Completed token embedding calculation for all tokens in text: '{text}'") + return token_embeddings + +async def get_token_level_embedding_from_db(token: str, llm_model_name: str) -> Optional[List[float]]: + token_hash = sha3_256(token.encode('utf-8')).hexdigest() # Compute the hash + async with AsyncSessionLocal() as session: + result = await session.execute( + sql_text("SELECT token_level_embedding_json FROM token_level_embeddings WHERE token_hash=:token_hash AND llm_model_name=:llm_model_name"), + {"token_hash": token_hash, "llm_model_name": llm_model_name}, + ) + row = result.fetchone() + if row: + embedding_json = row[0] + logger.info(f"Embedding found in database for token hash '{token_hash}' using model '{llm_model_name}'") + return json.loads(embedding_json) + return None + +async def store_token_level_embeddings_in_db(token: str, llm_model_name: str, token_level_embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, token_level_embedding_bundle_id: int): + total_time = (response_time - request_time).total_seconds() + embedding = TokenLevelEmbedding( + token=token, + llm_model_name=llm_model_name, + token_level_embedding_json=token_level_embedding_json, + ip_address=ip_address, + request_time=request_time, + response_time=response_time, + total_time=total_time, + token_level_embedding_bundle_id=token_level_embedding_bundle_id + ) + await shared_resources.db_writer.enqueue_write([embedding]) # Enqueue the write operation for the token-level embedding + +def calculate_sentence_embedding(llama: Llama, text: str) -> np.array: + sentence_embedding = None + retry_count = 0 + while sentence_embedding is None and retry_count < 3: + try: + if retry_count > 0: + logger.info(f"Attempting again calculate sentence embedding. Attempt number {retry_count + 1}") + sentence_embedding = llama.embed_query(text) + except TypeError as e: + logger.error(f"TypeError in calculate_sentence_embedding: {e}") + raise + except Exception as e: + logger.error(f"Exception in calculate_sentence_embedding: {e}") + text = text[:-int(len(text) * 0.1)] + retry_count += 1 + logger.info(f"Trimming sentence due to too many tokens. New length: {len(text)}") + if sentence_embedding is None: + logger.error("Failed to calculate sentence embedding after multiple attempts") + return sentence_embedding + +async def compute_embeddings_for_document(strings: list, llm_model_name: str, client_ip: str, document_file_hash: str) -> List[Tuple[str, np.array]]: + from swiss_army_llama import get_embedding_vector_for_string + results = [] + if USE_PARALLEL_INFERENCE_QUEUE: + logger.info(f"Using parallel inference queue to compute embeddings for {len(strings)} strings") + start_time = time.perf_counter() # Record the start time + semaphore = asyncio.Semaphore(MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS) + async def compute_embedding(text): # Define a function to compute the embedding for a given text + try: + async with semaphore: # Acquire a semaphore slot + request = EmbeddingRequest(text=text, llm_model_name=llm_model_name) + embedding = await get_embedding_vector_for_string(request, client_ip=client_ip, document_file_hash=document_file_hash) + return text, embedding["embedding"] + except Exception as e: + logger.error(f"Error computing embedding for text '{text}': {e}") + return text, None + results = await asyncio.gather(*[compute_embedding(s) for s in strings]) # Use asyncio.gather to run the tasks concurrently + end_time = time.perf_counter() # Record the end time + duration = end_time - start_time + if len(strings) > 0: + logger.info(f"Parallel inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string") + else: # Compute embeddings sequentially + logger.info(f"Using sequential inference to compute embeddings for {len(strings)} strings") + start_time = time.perf_counter() # Record the start time + for s in strings: + embedding_request = EmbeddingRequest(text=s, llm_model_name=llm_model_name) + embedding = await get_embedding_vector_for_string(embedding_request, client_ip=client_ip, document_file_hash=document_file_hash) + results.append((s, embedding["embedding"])) + end_time = time.perf_counter() # Record the end time + duration = end_time - start_time + if len(strings) > 0: + logger.info(f"Sequential inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string") + filtered_results = [(text, embedding) for text, embedding in results if embedding is not None] # Filter out results with None embeddings (applicable to parallel processing) and return + return filtered_results + +async def parse_submitted_document_file_into_sentence_strings_func(temp_file_path: str, mime_type: str): + strings = [] + if mime_type.startswith('text/'): + with open(temp_file_path, 'r') as buffer: + content = buffer.read() + else: + try: + content = textract.process(temp_file_path).decode('utf-8') + except UnicodeDecodeError: + try: + content = textract.process(temp_file_path).decode('unicode_escape') + except Exception as e: + logger.error(f"Error while processing file: {e}, mime_type: {mime_type}") + raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}") + except Exception as e: + logger.error(f"Error while processing file: {e}, mime_type: {mime_type}") + raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}") + sentences = sophisticated_sentence_splitter(content) + if len(sentences) == 0 and temp_file_path.lower().endswith('.pdf'): + logger.info("No sentences found, attempting OCR using Tesseract.") + try: + content = textract.process(temp_file_path, method='tesseract').decode('utf-8') + sentences = sophisticated_sentence_splitter(content) + except Exception as e: + logger.error(f"Error while processing file with OCR: {e}") + raise HTTPException(status_code=400, detail=f"OCR failed: {e}") + if len(sentences) == 0: + logger.info("No sentences found in the document") + raise HTTPException(status_code=400, detail="No sentences found in the document") + logger.info(f"Extracted {len(sentences)} sentences from the document") + strings = [s.strip() for s in sentences if len(s.strip()) > MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING] + return strings + +async def _get_document_from_db(file_hash: str): + async with AsyncSessionLocal() as session: + result = await session.execute(select(Document).filter(Document.document_hash == file_hash)) + return result.scalar_one_or_none() + +async def store_document_embeddings_in_db(file: File, file_hash: str, original_file_content: bytes, json_content: bytes, results: List[Tuple[str, np.array]], llm_model_name: str, client_ip: str, request_time: datetime): + document = await _get_document_from_db(file_hash) # First, check if a Document with the same hash already exists + if not document: # If not, create a new Document object + document = Document(document_hash=file_hash, llm_model_name=llm_model_name) + await shared_resources.db_writer.enqueue_write([document]) + document_embedding = DocumentEmbedding( + filename=file.filename, + mimetype=file.content_type, + file_hash=file_hash, + llm_model_name=llm_model_name, + file_data=original_file_content, + document_embedding_results_json=json.loads(json_content.decode()), + ip_address=client_ip, + request_time=request_time, + response_time=datetime.utcnow(), + total_time=(datetime.utcnow() - request_time).total_seconds() + ) + document.document_embeddings.append(document_embedding) # Associate it with the Document + document.update_hash() # This will trigger the SQLAlchemy event to update the document_hash + await shared_resources.db_writer.enqueue_write([document, document_embedding]) # Enqueue the write operation for the document embedding + write_operations = [] # Collect text embeddings to write + logger.info(f"Storing {len(results)} text embeddings in database") + for text, embedding in results: + embedding_entry = await _get_embedding_from_db(text, llm_model_name) + if not embedding_entry: + embedding_entry = TextEmbedding( + text=text, + llm_model_name=llm_model_name, + embedding_json=json.dumps(embedding), + ip_address=client_ip, + request_time=request_time, + response_time=datetime.utcnow(), + total_time=(datetime.utcnow() - request_time).total_seconds(), + document_file_hash=file_hash # Link it to the DocumentEmbedding via file_hash + ) + else: + write_operations.append(embedding_entry) + await shared_resources.db_writer.enqueue_write(write_operations) # Enqueue the write operation for text embeddings + +def load_text_completion_model(llm_model_name: str, raise_http_exception: bool = True): + try: + if llm_model_name in text_completion_model_cache: # Check if the model is already loaded in the cache + return text_completion_model_cache[llm_model_name] + models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path + matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files + if not matching_files: + logger.error(f"No model file found matching: {llm_model_name}") + raise FileNotFoundError + matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first) + model_file_path = matching_files[0] + model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model + text_completion_model_cache[llm_model_name] = model_instance # Cache the loaded model + return model_instance + except TypeError as e: + logger.error(f"TypeError occurred while loading the model: {e}") + raise + except Exception as e: + logger.error(f"Exception occurred while loading the model: {e}") + if raise_http_exception: + raise HTTPException(status_code=404, detail="Model file not found") + else: + raise FileNotFoundError(f"No model file found matching: {llm_model_name}") + +async def generate_completion_from_llm(request: TextCompletionRequest, req: Request = None, client_ip: str = None) -> List[TextCompletionResponse]: + request_time = datetime.utcnow() + logger.info(f"Starting text completion calculation using model: '{request.llm_model_name}'for input prompt: '{request.input_prompt}'") + logger.info(f"Loading model: '{request.llm_model_name}'") + llm = load_text_completion_model(request.llm_model_name) + logger.info(f"Done loading model: '{request.llm_model_name}'") + list_of_llm_outputs = [] + grammar_file_string_lower = request.grammar_file_string.lower() if request.grammar_file_string else "" + if grammar_file_string_lower: + list_of_grammar_files = glob.glob("./grammar_files/*.gbnf") + matching_grammar_files = [x for x in list_of_grammar_files if grammar_file_string_lower in os.path.splitext(os.path.basename(x).lower())[0]] + if len(matching_grammar_files) == 0: + logger.error(f"No grammar file found matching: {request.grammar_file_string}") + raise FileNotFoundError + matching_grammar_files.sort(key=os.path.getmtime, reverse=True) + grammar_file_path = matching_grammar_files[0] + logger.info(f"Loading selected grammar file: '{grammar_file_path}'") + llama_grammar = LlamaGrammar.from_file(grammar_file_path) + for ii in range(request.number_of_completions_to_generate): + logger.info(f"Generating completion {ii+1} of {request.number_of_completions_to_generate} with model {request.llm_model_name} for input prompt: '{request.input_prompt}'") + output = llm(prompt=request.input_prompt, grammar=llama_grammar, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature) + list_of_llm_outputs.append(output) + else: + for ii in range(request.number_of_completions_to_generate): + output = llm(prompt=request.input_prompt, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature) + list_of_llm_outputs.append(output) + response_time = datetime.utcnow() + total_time_per_completion = ((response_time - request_time).total_seconds()) / request.number_of_completions_to_generate + list_of_responses = [] + for idx, current_completion_output in enumerate(list_of_llm_outputs): + generated_text = current_completion_output['choices'][0]['text'] + if request.grammar_file_string == 'json': + generated_text = generated_text.encode('unicode_escape').decode() + llm_model_usage_json = json.dumps(current_completion_output['usage']) + logger.info(f"Completed text completion {idx} in an average of {total_time_per_completion:.2f} seconds for input prompt: '{request.input_prompt}'; Beginning of generated text: \n'{generated_text[:100]}'") + response = TextCompletionResponse(input_prompt = request.input_prompt, + llm_model_name = request.llm_model_name, + grammar_file_string = request.grammar_file_string, + number_of_tokens_to_generate = request.number_of_tokens_to_generate, + number_of_completions_to_generate = request.number_of_completions_to_generate, + time_taken_in_seconds = float(total_time_per_completion), + generated_text = generated_text, + llm_model_usage_json = llm_model_usage_json) + list_of_responses.append(response) + return list_of_responses diff --git a/shared_resources.py b/shared_resources.py new file mode 100644 index 0000000..eee389b --- /dev/null +++ b/shared_resources.py @@ -0,0 +1,149 @@ +from misc_utility_functions import is_redis_running, build_faiss_indexes +from database_functions import DatabaseWriter, initialize_db +from ramdisk_functions import setup_ramdisk, copy_models_to_ramdisk, check_that_user_has_required_permissions_to_manage_ramdisks +from logger_config import setup_logger +from aioredlock import Aioredlock +import aioredis +import asyncio +import subprocess +import urllib.request +import os +import glob +import json +from typing import List, Tuple, Dict +from langchain.embeddings import LlamaCppEmbeddings +from decouple import config +from fastapi import HTTPException + +logger = setup_logger() +embedding_model_cache = {} # Model cache to store loaded models +token_level_embedding_model_cache = {} # Model cache to store loaded token-level embedding models +text_completion_model_cache = {} # Model cache to store loaded text completion models + +SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int) +DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str) +LLM_CONTEXT_SIZE_IN_TOKENS = config("LLM_CONTEXT_SIZE_IN_TOKENS", default=512, cast=int) +TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS = config("TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS", default=4000, cast=int) +DEFAULT_MAX_COMPLETION_TOKENS = config("DEFAULT_MAX_COMPLETION_TOKENS", default=100, cast=int) +DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE = config("DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE", default=4, cast=int) +DEFAULT_COMPLETION_TEMPERATURE = config("DEFAULT_COMPLETION_TEMPERATURE", default=0.7, cast=float) +MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING = config("MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING", default=15, cast=int) +USE_PARALLEL_INFERENCE_QUEUE = config("USE_PARALLEL_INFERENCE_QUEUE", default=False, cast=bool) +MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS = config("MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS", default=10, cast=int) +USE_RAMDISK = config("USE_RAMDISK", default=False, cast=bool) +RAMDISK_PATH = config("RAMDISK_PATH", default="/mnt/ramdisk", cast=str) +BASE_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) + +async def initialize_globals(): + global db_writer, faiss_indexes, token_faiss_indexes, associated_texts_by_model, redis, lock_manager + if not is_redis_running(): + logger.info("Starting Redis server...") + subprocess.Popen(['redis-server'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + await asyncio.sleep(1) # Sleep for 1 second to give Redis time to start + redis = await aioredis.create_redis_pool('redis://localhost') + lock_manager = Aioredlock([redis]) + await initialize_db() + queue = asyncio.Queue() + db_writer = DatabaseWriter(queue) + await db_writer.initialize_processing_hashes() + asyncio.create_task(db_writer.dedicated_db_writer()) + global USE_RAMDISK + if USE_RAMDISK and not check_that_user_has_required_permissions_to_manage_ramdisks(): + USE_RAMDISK = False + elif USE_RAMDISK: + setup_ramdisk() + list_of_downloaded_model_names, download_status = download_models() + for llm_model_name in list_of_downloaded_model_names: + try: + load_model(llm_model_name, raise_http_exception=False) + except FileNotFoundError as e: + logger.error(e) + faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() + +# other shared variables and methods +db_writer = None +faiss_indexes = None +token_faiss_indexes = None +associated_texts_by_model = None +redis = None +lock_manager = None + + +def download_models() -> Tuple[List[str], List[Dict[str, str]]]: + download_status = [] + json_path = os.path.join(BASE_DIRECTORY, "model_urls.json") + if not os.path.exists(json_path): + initial_model_urls = [ + 'https://huggingface.co/TheBloke/Yarn-Llama-2-7B-128K-GGUF/resolve/main/yarn-llama-2-7b-128k.Q4_K_M.gguf', + 'https://huggingface.co/TheBloke/openchat_v3.2_super-GGUF/resolve/main/openchat_v3.2_super.Q4_K_M.gguf', + 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q6_K.gguf' + ] + with open(json_path, "w") as f: + json.dump(initial_model_urls, f) + with open(json_path, "r") as f: + list_of_model_download_urls = json.load(f) + model_names = [os.path.basename(url) for url in list_of_model_download_urls] + current_file_path = os.path.abspath(__file__) + base_dir = os.path.dirname(current_file_path) + models_dir = os.path.join(base_dir, 'models') + logger.info("Checking models directory...") + if USE_RAMDISK: + ramdisk_models_dir = os.path.join(RAMDISK_PATH, 'models') + if not os.path.exists(RAMDISK_PATH): + setup_ramdisk() + if all(os.path.exists(os.path.join(ramdisk_models_dir, llm_model_name)) for llm_model_name in model_names): + logger.info("Models found in RAM Disk.") + for url in list_of_model_download_urls: + download_status.append({"url": url, "status": "success", "message": "Model found in RAM Disk."}) + return model_names, download_status + if not os.path.exists(models_dir): + os.makedirs(models_dir) + logger.info(f"Created models directory: {models_dir}") + else: + logger.info(f"Models directory exists: {models_dir}") + for url, model_name_with_extension in zip(list_of_model_download_urls, model_names): + status = {"url": url, "status": "success", "message": "File already exists."} + filename = os.path.join(models_dir, model_name_with_extension) + if not os.path.exists(filename): + logger.info(f"Downloading model {model_name_with_extension} from {url}...") + urllib.request.urlretrieve(url, filename) + file_size = os.path.getsize(filename) / (1024 * 1024) # Convert bytes to MB + if file_size < 100: + os.remove(filename) + status["status"] = "failure" + status["message"] = "Downloaded file is too small, probably not a valid model file." + else: + logger.info(f"Downloaded: {filename}") + else: + logger.info(f"File already exists: {filename}") + download_status.append(status) + if USE_RAMDISK: + copy_models_to_ramdisk(models_dir, ramdisk_models_dir) + logger.info("Model downloads completed.") + return model_names, download_status + + +def load_model(llm_model_name: str, raise_http_exception: bool = True): + try: + models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') + if llm_model_name in embedding_model_cache: + return embedding_model_cache[llm_model_name] + matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) + if not matching_files: + logger.error(f"No model file found matching: {llm_model_name}") + raise FileNotFoundError + matching_files.sort(key=os.path.getmtime, reverse=True) + model_file_path = matching_files[0] + model_instance = LlamaCppEmbeddings(model_path=model_file_path, use_mlock=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS) + model_instance.client.verbose = False + embedding_model_cache[llm_model_name] = model_instance + return model_instance + except TypeError as e: + logger.error(f"TypeError occurred while loading the model: {e}") + raise + except Exception as e: + logger.error(f"Exception occurred while loading the model: {e}") + if raise_http_exception: + raise HTTPException(status_code=404, detail="Model file not found") + else: + raise FileNotFoundError(f"No model file found matching: {llm_model_name}") \ No newline at end of file diff --git a/swiss_army_llama.py b/swiss_army_llama.py index 359d3fb..740df5e 100644 --- a/swiss_army_llama.py +++ b/swiss_army_llama.py @@ -1,80 +1,53 @@ -from embeddings_data_models import Base, TextEmbedding, DocumentEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundle, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript +import shared_resources +from shared_resources import initialize_globals, download_models +from logger_config import setup_logger +from database_functions import AsyncSessionLocal, DatabaseWriter, get_db_writer +from ramdisk_functions import clear_ramdisk +from misc_utility_functions import build_faiss_indexes +from embeddings_data_models import DocumentEmbedding, TokenLevelEmbeddingBundle from embeddings_data_models import EmbeddingRequest, SemanticSearchRequest, AdvancedSemanticSearchRequest, SimilarityRequest, TextCompletionRequest -from embeddings_data_models import EmbeddingResponse, SemanticSearchResponse, AdvancedSemanticSearchResponse, SimilarityResponse, AllStringsResponse, AllDocumentsResponse, TextCompletionResponse, AudioTranscriptResponse +from embeddings_data_models import EmbeddingResponse, SemanticSearchResponse, AdvancedSemanticSearchResponse, SimilarityResponse, AllStringsResponse, AllDocumentsResponse, TextCompletionResponse from embeddings_data_models import ShowLogsIncrementalModel +from service_functions import get_or_compute_embedding, get_or_compute_transcript, add_model_url, get_or_compute_token_level_embedding_bundle_combined_feature_vector, calculate_token_level_embeddings +from service_functions import parse_submitted_document_file_into_sentence_strings_func, compute_embeddings_for_document, store_document_embeddings_in_db, generate_completion_from_llm from log_viewer_functions import show_logs_incremental_func, show_logs_func +from uvicorn_config import option import asyncio -import io import glob import json -import logging import os -import random import re -import shutil -import subprocess import tempfile -import time import traceback -import urllib.request import zipfile -from collections import defaultdict from datetime import datetime from hashlib import sha3_256 -from logging.handlers import RotatingFileHandler -from typing import List, Optional, Tuple, Dict, Any -from urllib.parse import quote, unquote +from typing import List, Optional, Dict, Any +from urllib.parse import unquote import numpy as np from decouple import config import uvicorn -import psutil -import textract import fastapi from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Depends from fastapi.responses import JSONResponse, FileResponse, HTMLResponse, Response -from fastapi.concurrency import run_in_threadpool -from langchain.embeddings import LlamaCppEmbeddings from sqlalchemy import select from sqlalchemy import text as sql_text -from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession -from sqlalchemy.orm import sessionmaker, joinedload +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import joinedload import faiss import pandas as pd from magic import Magic -from llama_cpp import Llama, LlamaGrammar import fast_vector_similarity as fvs -from faster_whisper import WhisperModel +import uvloop + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +logger = setup_logger() # Note: the Ramdisk setup and teardown requires sudo; to enable password-less sudo, edit your sudoers file with `sudo visudo`. # Add the following lines, replacing username with your actual username # username ALL=(ALL) NOPASSWD: /bin/mount -t tmpfs -o size=*G tmpfs /mnt/ramdisk # username ALL=(ALL) NOPASSWD: /bin/umount /mnt/ramdisk -# Setup logging -old_logs_dir = 'old_logs' # Ensure the old_logs directory exists -if not os.path.exists(old_logs_dir): - os.makedirs(old_logs_dir) -logger = logging.getLogger() -logger.setLevel(logging.INFO) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -log_file_path = 'swiss_army_llama.log' -fh = RotatingFileHandler(log_file_path, maxBytes=10*1024*1024, backupCount=5) -fh.setFormatter(formatter) -logger.addHandler(fh) -def namer(default_log_name): # Move rotated logs to the old_logs directory - return os.path.join(old_logs_dir, os.path.basename(default_log_name)) -def rotator(source, dest): - shutil.move(source, dest) -fh.namer = namer -fh.rotator = rotator -sh = logging.StreamHandler() -sh.setFormatter(formatter) -logger.addHandler(sh) -logger = logging.getLogger(__name__) -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) -configured_logger = logger - # Global variables use_hardcoded_security_token = 0 if use_hardcoded_security_token: @@ -82,1056 +55,17 @@ def rotator(source, dest): USE_SECURITY_TOKEN = config("USE_SECURITY_TOKEN", default=False, cast=bool) else: USE_SECURITY_TOKEN = False -DATABASE_URL = "sqlite+aiosqlite:///swiss_army_llama.sqlite" -SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int) DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str) -LLM_CONTEXT_SIZE_IN_TOKENS = config("LLM_CONTEXT_SIZE_IN_TOKENS", default=512, cast=int) -TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS = config("TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS", default=4000, cast=int) -DEFAULT_MAX_COMPLETION_TOKENS = config("DEFAULT_MAX_COMPLETION_TOKENS", default=100, cast=int) -DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE = config("DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE", default=4, cast=int) -DEFAULT_COMPLETION_TEMPERATURE = config("DEFAULT_COMPLETION_TEMPERATURE", default=0.7, cast=float) -MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING = config("MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING", default=15, cast=int) -USE_PARALLEL_INFERENCE_QUEUE = config("USE_PARALLEL_INFERENCE_QUEUE", default=False, cast=bool) -MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS = config("MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS", default=10, cast=int) USE_RAMDISK = config("USE_RAMDISK", default=False, cast=bool) RAMDISK_PATH = config("RAMDISK_PATH", default="/mnt/ramdisk", cast=str) -RAMDISK_SIZE_IN_GB = config("RAMDISK_SIZE_IN_GB", default=1, cast=int) -MAX_RETRIES = config("MAX_RETRIES", default=3, cast=int) -DB_WRITE_BATCH_SIZE = config("DB_WRITE_BATCH_SIZE", default=25, cast=int) -RETRY_DELAY_BASE_SECONDS = config("RETRY_DELAY_BASE_SECONDS", default=1, cast=int) -JITTER_FACTOR = config("JITTER_FACTOR", default=0.1, cast=float) BASE_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) -embedding_model_cache = {} # Model cache to store loaded models -token_level_embedding_model_cache = {} # Model cache to store loaded token-level embedding models -text_completion_model_cache = {} # Model cache to store loaded text completion models + logger.info(f"USE_RAMDISK is set to: {USE_RAMDISK}") -db_writer = None description_string = """ 🇨🇭🎖️🦙 Swiss Army Llama is your One-Stop-Shop to Quickly and Conveniently Integrate Powerful Local LLM Functionality into your Project via a REST API. """ app = FastAPI(title="Swiss Army Llama", description=description_string, docs_url="/") # Set the Swagger UI to root -engine = create_async_engine(DATABASE_URL, echo=False, connect_args={"check_same_thread": False}) -AsyncSessionLocal = sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - autoflush=False -) - -# Misc. utility functions and db writer class: -def clean_filename_for_url_func(dirty_filename: str) -> str: - clean_filename = re.sub(r'[^\w\s]', '', dirty_filename) # Remove special characters and replace spaces with underscores - clean_filename = clean_filename.replace(' ', '_') - return clean_filename - -class DatabaseWriter: - def __init__(self, queue): - self.queue = queue - self.processing_hashes = set() # Set to store the hashes if everything that is currently being processed in the queue (to avoid duplicates of the same task being added to the queue) - - def _get_hash_from_operation(self, operation): - attr_name = { - TextEmbedding: 'text_hash', - DocumentEmbedding: 'file_hash', - Document: 'document_hash', - TokenLevelEmbedding: 'token_hash', - TokenLevelEmbeddingBundle: 'input_text_hash', - TokenLevelEmbeddingBundleCombinedFeatureVector: 'combined_feature_vector_hash', - AudioTranscript: 'audio_file_hash' - }.get(type(operation)) - hash_value = getattr(operation, attr_name, None) - llm_model_name = getattr(operation, 'llm_model_name', None) - return f"{hash_value}_{llm_model_name}" if hash_value and llm_model_name else None - - async def initialize_processing_hashes(self, chunk_size=1000): - start_time = datetime.utcnow() - async with AsyncSessionLocal() as session: - queries = [ - (select(TextEmbedding.text_hash, TextEmbedding.llm_model_name), True), - (select(DocumentEmbedding.file_hash, DocumentEmbedding.llm_model_name), True), - (select(Document.document_hash, Document.llm_model_name), True), - (select(TokenLevelEmbedding.token_hash, TokenLevelEmbedding.llm_model_name), True), - (select(TokenLevelEmbeddingBundle.input_text_hash, TokenLevelEmbeddingBundle.llm_model_name), True), - (select(TokenLevelEmbeddingBundleCombinedFeatureVector.combined_feature_vector_hash, TokenLevelEmbeddingBundleCombinedFeatureVector.llm_model_name), True), - (select(AudioTranscript.audio_file_hash), False) - ] - for query, has_llm in queries: - offset = 0 - while True: - result = await session.execute(query.limit(chunk_size).offset(offset)) - rows = result.fetchall() - if not rows: - break - for row in rows: - if has_llm: - hash_with_model = f"{row[0]}_{row[1]}" - else: - hash_with_model = row[0] - self.processing_hashes.add(hash_with_model) - offset += chunk_size - end_time = datetime.utcnow() - total_time = (end_time - start_time).total_seconds() - if len(self.processing_hashes) > 0: - logger.info(f"Finished initializing set of input hash/llm_model_name combinations that are either currently being processed or have already been processed. Set size: {len(self.processing_hashes)}; Took {total_time} seconds, for an average of {total_time / len(self.processing_hashes)} seconds per hash.") - - async def _handle_integrity_error(self, e, write_operation, session): - unique_constraint_msg = { - TextEmbedding: "token_embeddings.token_hash, token_embeddings.llm_model_name", - DocumentEmbedding: "document_embeddings.file_hash, document_embeddings.llm_model_name", - Document: "documents.document_hash, documents.llm_model_name", - TokenLevelEmbedding: "token_level_embeddings.token_hash, token_level_embeddings.llm_model_name", - TokenLevelEmbeddingBundle: "token_level_embedding_bundles.input_text_hash, token_level_embedding_bundles.llm_model_name", - AudioTranscript: "audio_transcripts.audio_file_hash" - }.get(type(write_operation)) - if unique_constraint_msg and unique_constraint_msg in str(e): - logger.warning(f"Embedding already exists in the database for given input and llm_model_name: {e}") - await session.rollback() - else: - raise - - async def dedicated_db_writer(self): - while True: - write_operations_batch = await self.queue.get() - async with AsyncSessionLocal() as session: - try: - for write_operation in write_operations_batch: - session.add(write_operation) - await session.flush() # Flush to get the IDs - await session.commit() - for write_operation in write_operations_batch: - hash_to_remove = self._get_hash_from_operation(write_operation) - if hash_to_remove is not None and hash_to_remove in self.processing_hashes: - self.processing_hashes.remove(hash_to_remove) - except IntegrityError as e: - await self._handle_integrity_error(e, write_operation, session) - except SQLAlchemyError as e: - logger.error(f"Database error: {e}") - await session.rollback() - except Exception as e: - tb = traceback.format_exc() - logger.error(f"Unexpected error: {e}\n{tb}") - await session.rollback() - self.queue.task_done() - - async def enqueue_write(self, write_operations): - write_operations = [op for op in write_operations if self._get_hash_from_operation(op) not in self.processing_hashes] # Filter out write operations for hashes that are already being processed - if not write_operations: # If there are no write operations left after filtering, return early - return - for op in write_operations: # Add the hashes of the write operations to the set - hash_value = self._get_hash_from_operation(op) - if hash_value: - self.processing_hashes.add(hash_value) - await self.queue.put(write_operations) - - -async def execute_with_retry(func, *args, **kwargs): - retries = 0 - while retries < MAX_RETRIES: - try: - return await func(*args, **kwargs) - except OperationalError as e: - if 'database is locked' in str(e): - retries += 1 - sleep_time = RETRY_DELAY_BASE_SECONDS * (2 ** retries) + (random.random() * JITTER_FACTOR) # Implementing exponential backoff with jitter - logger.warning(f"Database is locked. Retrying ({retries}/{MAX_RETRIES})... Waiting for {sleep_time} seconds") - await asyncio.sleep(sleep_time) - else: - raise - raise OperationalError("Database is locked after multiple retries") - -async def initialize_db(): - logger.info("Initializing database, creating tables, and setting SQLite PRAGMAs...") - list_of_sqlite_pragma_strings = ["PRAGMA journal_mode=WAL;", "PRAGMA synchronous = NORMAL;", "PRAGMA cache_size = -1048576;", "PRAGMA busy_timeout = 2000;", "PRAGMA wal_autocheckpoint = 100;"] - list_of_sqlite_pragma_justification_strings = ["Set SQLite to use Write-Ahead Logging (WAL) mode (from default DELETE mode) so that reads and writes can occur simultaneously", - "Set synchronous mode to NORMAL (from FULL) so that writes are not blocked by reads", - "Set cache size to 1GB (from default 2MB) so that more data can be cached in memory and not read from disk; to make this 256MB, set it to -262144 instead", - "Increase the busy timeout to 2 seconds so that the database waits", - "Set the WAL autocheckpoint to 100 (from default 1000) so that the WAL file is checkpointed more frequently"] - assert(len(list_of_sqlite_pragma_strings) == len(list_of_sqlite_pragma_justification_strings)) - async with engine.begin() as conn: - for pragma_string in list_of_sqlite_pragma_strings: - await conn.execute(sql_text(pragma_string)) - logger.info(f"Executed SQLite PRAGMA: {pragma_string}") - logger.info(f"Justification: {list_of_sqlite_pragma_justification_strings[list_of_sqlite_pragma_strings.index(pragma_string)]}") - await conn.run_sync(Base.metadata.create_all) # Create tables if they don't exist - logger.info("Database initialization completed.") - -def get_db_writer() -> DatabaseWriter: - return db_writer # Return the existing DatabaseWriter instance - -def check_that_user_has_required_permissions_to_manage_ramdisks(): - try: # Try to run a harmless command with sudo to test if the user has password-less sudo permissions - result = subprocess.run(["sudo", "ls"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - if "password" in result.stderr.lower(): - raise PermissionError("Password required for sudo") - logger.info("User has sufficient permissions to manage RAM Disks.") - return True - except (PermissionError, subprocess.CalledProcessError) as e: - logger.info("Sorry, current user does not have sufficient permissions to manage RAM Disks! Disabling RAM Disks for now...") - logger.debug(f"Permission check error detail: {e}") - return False - -def setup_ramdisk(): - cmd_check = f"sudo mount | grep {RAMDISK_PATH}" # Check if RAM disk already exists at the path - result = subprocess.run(cmd_check, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') - if RAMDISK_PATH in result: - logger.info(f"RAM Disk already set up at {RAMDISK_PATH}. Skipping setup.") - return - total_ram_gb = psutil.virtual_memory().total / (1024 ** 3) - free_ram_gb = psutil.virtual_memory().free / (1024 ** 3) - buffer_gb = 2 # buffer to ensure we don't use all the free RAM - ramdisk_size_gb = max(min(RAMDISK_SIZE_IN_GB, free_ram_gb - buffer_gb), 0.1) - ramdisk_size_mb = int(ramdisk_size_gb * 1024) - ramdisk_size_str = f"{ramdisk_size_mb}M" - logger.info(f"Total RAM: {total_ram_gb}G") - logger.info(f"Free RAM: {free_ram_gb}G") - logger.info(f"Calculated RAM Disk Size: {ramdisk_size_gb}G") - if RAMDISK_SIZE_IN_GB > total_ram_gb: - raise ValueError(f"Cannot allocate {RAMDISK_SIZE_IN_GB}G for RAM Disk. Total system RAM is {total_ram_gb:.2f}G.") - logger.info("Setting up RAM Disk...") - os.makedirs(RAMDISK_PATH, exist_ok=True) - mount_command = ["sudo", "mount", "-t", "tmpfs", "-o", f"size={ramdisk_size_str}", "tmpfs", RAMDISK_PATH] - subprocess.run(mount_command, check=True) - logger.info(f"RAM Disk set up at {RAMDISK_PATH} with size {ramdisk_size_gb}G") - -def copy_models_to_ramdisk(models_directory, ramdisk_directory): - total_size = sum(os.path.getsize(os.path.join(models_directory, model)) for model in os.listdir(models_directory)) - free_ram = psutil.virtual_memory().free - if total_size > free_ram: - logger.warning(f"Not enough space on RAM Disk. Required: {total_size}, Available: {free_ram}. Rebuilding RAM Disk.") - clear_ramdisk() - free_ram = psutil.virtual_memory().free # Recompute the available RAM after clearing the RAM disk - if total_size > free_ram: - logger.error(f"Still not enough space on RAM Disk even after clearing. Required: {total_size}, Available: {free_ram}.") - raise ValueError("Not enough RAM space to copy models.") - setup_ramdisk() - os.makedirs(ramdisk_directory, exist_ok=True) - for model in os.listdir(models_directory): - shutil.copyfile(os.path.join(models_directory, model), os.path.join(ramdisk_directory, model)) - logger.info(f"Copied model {model} to RAM Disk at {os.path.join(ramdisk_directory, model)}") - -def clear_ramdisk(): - while True: - cmd_check = f"sudo mount | grep {RAMDISK_PATH}" - result = subprocess.run(cmd_check, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') - if RAMDISK_PATH not in result: - break # Exit the loop if the RAMDISK_PATH is not in the mount list - cmd_umount = f"sudo umount -l {RAMDISK_PATH}" - subprocess.run(cmd_umount, shell=True, check=True) - logger.info(f"Cleared RAM Disk at {RAMDISK_PATH}") - -async def build_faiss_indexes(): - global faiss_indexes, token_faiss_indexes, associated_texts_by_model - faiss_indexes = {} - token_faiss_indexes = {} # Separate FAISS indexes for token-level embeddings - associated_texts_by_model = defaultdict(list) # Create a dictionary to store associated texts by model name - async with AsyncSessionLocal() as session: - result = await session.execute(sql_text("SELECT llm_model_name, text, embedding_json FROM embeddings")) # Query regular embeddings - token_result = await session.execute(sql_text("SELECT llm_model_name, token, token_level_embedding_json FROM token_level_embeddings")) # Query token-level embeddings - embeddings_by_model = defaultdict(list) - token_embeddings_by_model = defaultdict(list) - for row in result.fetchall(): # Process regular embeddings - llm_model_name = row[0] - associated_texts_by_model[llm_model_name].append(row[1]) # Store the associated text by model name - embeddings_by_model[llm_model_name].append((row[1], json.loads(row[2]))) - for row in token_result.fetchall(): # Process token-level embeddings - llm_model_name = row[0] - token_embeddings_by_model[llm_model_name].append(json.loads(row[2])) - for llm_model_name, embeddings in embeddings_by_model.items(): - logger.info(f"Building Faiss index over embeddings for model {llm_model_name}...") - embeddings_array = np.array([e[1] for e in embeddings]).astype('float32') - if embeddings_array.size == 0: - logger.error(f"No embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!") - continue - logger.info(f"Loaded {len(embeddings_array)} embeddings for model {llm_model_name}.") - logger.info(f"Embedding dimension for model {llm_model_name}: {embeddings_array.shape[1]}") - logger.info(f"Normalizing {len(embeddings_array)} embeddings for model {llm_model_name}...") - faiss.normalize_L2(embeddings_array) # Normalize the vectors for cosine similarity - faiss_index = faiss.IndexFlatIP(embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity - faiss_index.add(embeddings_array) - logger.info(f"Faiss index built for model {llm_model_name}.") - faiss_indexes[llm_model_name] = faiss_index # Store the index by model name - for llm_model_name, token_embeddings in token_embeddings_by_model.items(): - token_embeddings_array = np.array(token_embeddings).astype('float32') - if token_embeddings_array.size == 0: - logger.error(f"No token-level embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!") - continue - logger.info(f"Normalizing {len(token_embeddings_array)} token-level embeddings for model {llm_model_name}...") - faiss.normalize_L2(token_embeddings_array) # Normalize the vectors for cosine similarity - token_faiss_index = faiss.IndexFlatIP(token_embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity - token_faiss_index.add(token_embeddings_array) - logger.info(f"Token-level Faiss index built for model {llm_model_name}.") - token_faiss_indexes[llm_model_name] = token_faiss_index # Store the token-level index by model name - return faiss_indexes, token_faiss_indexes, associated_texts_by_model - -class JSONAggregator: - def __init__(self): - self.completions = [] - self.aggregate_result = None - - @staticmethod - def weighted_vote(values, weights): - tally = defaultdict(float) - for v, w in zip(values, weights): - tally[v] += w - return max(tally, key=tally.get) - - @staticmethod - def flatten_json(json_obj, parent_key='', sep='->'): - items = {} - for k, v in json_obj.items(): - new_key = f"{parent_key}{sep}{k}" if parent_key else k - if isinstance(v, dict): - items.update(JSONAggregator.flatten_json(v, new_key, sep=sep)) - else: - items[new_key] = v - return items - - @staticmethod - def get_value_by_path(json_obj, path, sep='->'): - keys = path.split(sep) - item = json_obj - for k in keys: - item = item[k] - return item - - @staticmethod - def set_value_by_path(json_obj, path, value, sep='->'): - keys = path.split(sep) - item = json_obj - for k in keys[:-1]: - item = item.setdefault(k, {}) - item[keys[-1]] = value - - def calculate_path_weights(self): - all_paths = [] - for j in self.completions: - all_paths += list(self.flatten_json(j).keys()) - path_weights = defaultdict(float) - for path in all_paths: - path_weights[path] += 1.0 - return path_weights - - def aggregate(self): - path_weights = self.calculate_path_weights() - aggregate = {} - for path, weight in path_weights.items(): - values = [self.get_value_by_path(j, path) for j in self.completions if path in self.flatten_json(j)] - weights = [weight] * len(values) - aggregate_value = self.weighted_vote(values, weights) - self.set_value_by_path(aggregate, path, aggregate_value) - self.aggregate_result = aggregate - -class FakeUploadFile: - def __init__(self, filename: str, content: Any, content_type: str = 'text/plain'): - self.filename = filename - self.content_type = content_type - self.file = io.BytesIO(content) - def read(self, size: int = -1) -> bytes: - return self.file.read(size) - def seek(self, offset: int, whence: int = 0) -> int: - return self.file.seek(offset, whence) - def tell(self) -> int: - return self.file.tell() - -async def get_transcript_from_db(audio_file_hash: str): - return await execute_with_retry(_get_transcript_from_db, audio_file_hash) - -async def _get_transcript_from_db(audio_file_hash: str) -> Optional[dict]: - async with AsyncSessionLocal() as session: - result = await session.execute( - sql_text("SELECT * FROM audio_transcripts WHERE audio_file_hash=:audio_file_hash"), - {"audio_file_hash": audio_file_hash}, - ) - row = result.fetchone() - if row: - try: - segments_json = json.loads(row.segments_json) - combined_transcript_text_list_of_metadata_dicts = json.loads(row.combined_transcript_text_list_of_metadata_dicts) - info_json = json.loads(row.info_json) - if hasattr(info_json, '__dict__'): - info_json = vars(info_json) - except json.JSONDecodeError as e: - raise ValueError(f"JSON Decode Error: {e}") - if not isinstance(segments_json, list) or not isinstance(combined_transcript_text_list_of_metadata_dicts, list) or not isinstance(info_json, dict): - logger.error(f"Type of segments_json: {type(segments_json)}, Value: {segments_json}") - logger.error(f"Type of combined_transcript_text_list_of_metadata_dicts: {type(combined_transcript_text_list_of_metadata_dicts)}, Value: {combined_transcript_text_list_of_metadata_dicts}") - logger.error(f"Type of info_json: {type(info_json)}, Value: {info_json}") - raise ValueError("Deserialized JSON does not match the expected format.") - audio_transcript_response = { - "id": row.id, - "audio_file_name": row.audio_file_name, - "audio_file_size_mb": row.audio_file_size_mb, - "segments_json": segments_json, - "combined_transcript_text": row.combined_transcript_text, - "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts, - "info_json": info_json, - "ip_address": row.ip_address, - "request_time": row.request_time, - "response_time": row.response_time, - "total_time": row.total_time, - "url_to_download_zip_file_of_embeddings": "" - } - return AudioTranscriptResponse(**audio_transcript_response) - return None - -async def save_transcript_to_db(audio_file_hash, audio_file_name, audio_file_size_mb, transcript_segments, info, ip_address, request_time, response_time, total_time, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts): - existing_transcript = await get_transcript_from_db(audio_file_hash) - if existing_transcript: - return existing_transcript - audio_transcript = AudioTranscript( - audio_file_hash=audio_file_hash, - audio_file_name=audio_file_name, - audio_file_size_mb=audio_file_size_mb, - segments_json=json.dumps(transcript_segments), - combined_transcript_text=combined_transcript_text, - combined_transcript_text_list_of_metadata_dicts=json.dumps(combined_transcript_text_list_of_metadata_dicts), - info_json=json.dumps(info), - ip_address=ip_address, - request_time=request_time, - response_time=response_time, - total_time=total_time - ) - await db_writer.enqueue_write([audio_transcript]) - -def normalize_logprobs(avg_logprob, min_logprob, max_logprob): - range_logprob = max_logprob - min_logprob - return (avg_logprob - min_logprob) / range_logprob if range_logprob != 0 else 0.5 - -def remove_pagination_breaks(text: str) -> str: - text = re.sub(r'-(\n)(?=[a-z])', '', text) # Remove hyphens at the end of lines when the word continues on the next line - text = re.sub(r'(?<=\w)(? dict: - request_time = datetime.utcnow() - ip_address = req.client.host if req else "127.0.0.1" - file_contents = await file.read() - audio_file_hash = sha3_256(file_contents).hexdigest() - file.file.seek(0) # Reset file pointer after read - existing_audio_transcript = await get_transcript_from_db(audio_file_hash) - if existing_audio_transcript: - return existing_audio_transcript - current_position = file.file.tell() - file.file.seek(0, os.SEEK_END) - audio_file_size_mb = file.file.tell() / (1024 * 1024) - file.file.seek(current_position) - with tempfile.NamedTemporaryFile(delete=False) as tmp_file: - shutil.copyfileobj(file.file, tmp_file) - audio_file_name = tmp_file.name - segment_details, info, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, request_time, response_time, total_time, download_url = await compute_transcript_with_whisper_from_audio_func(audio_file_hash, audio_file_name, file.filename, audio_file_size_mb, ip_address, req, compute_embeddings_for_resulting_transcript_document, llm_model_name) - audio_transcript_response = { - "audio_file_hash": audio_file_hash, - "audio_file_name": file.filename, - "audio_file_size_mb": audio_file_size_mb, - "segments_json": segment_details, - "combined_transcript_text": combined_transcript_text, - "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts, - "info_json": info, - "ip_address": ip_address, - "request_time": request_time, - "response_time": response_time, - "total_time": total_time, - "url_to_download_zip_file_of_embeddings": download_url if compute_embeddings_for_resulting_transcript_document else "" - } - os.remove(audio_file_name) - return AudioTranscriptResponse(**audio_transcript_response) - - -# Core embedding functions start here: - -def download_models() -> Tuple[List[str], List[Dict[str, str]]]: - download_status = [] - json_path = os.path.join(BASE_DIRECTORY, "model_urls.json") - if not os.path.exists(json_path): - initial_model_urls = [ - 'https://huggingface.co/TheBloke/Yarn-Llama-2-13B-128K-GGUF/resolve/main/yarn-llama-2-13b-128k.Q4_K_M.gguf', - 'https://huggingface.co/TheBloke/Yarn-Llama-2-7B-128K-GGUF/resolve/main/yarn-llama-2-7b-128k.Q4_K_M.gguf', - 'https://huggingface.co/TheBloke/openchat_v3.2_super-GGUF/resolve/main/openchat_v3.2_super.Q4_K_M.gguf', - 'https://huggingface.co/TheBloke/Phind-CodeLlama-34B-Python-v1-GGUF/resolve/main/phind-codellama-34b-python-v1.Q4_K_M.gguf', - 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q6_K.gguf' - ] - with open(json_path, "w") as f: - json.dump(initial_model_urls, f) - with open(json_path, "r") as f: - list_of_model_download_urls = json.load(f) - model_names = [os.path.basename(url) for url in list_of_model_download_urls] - current_file_path = os.path.abspath(__file__) - base_dir = os.path.dirname(current_file_path) - models_dir = os.path.join(base_dir, 'models') - logger.info("Checking models directory...") - if USE_RAMDISK: - ramdisk_models_dir = os.path.join(RAMDISK_PATH, 'models') - if not os.path.exists(RAMDISK_PATH): - setup_ramdisk() - if all(os.path.exists(os.path.join(ramdisk_models_dir, llm_model_name)) for llm_model_name in model_names): - logger.info("Models found in RAM Disk.") - for url in list_of_model_download_urls: - download_status.append({"url": url, "status": "success", "message": "Model found in RAM Disk."}) - return model_names, download_status - if not os.path.exists(models_dir): - os.makedirs(models_dir) - logger.info(f"Created models directory: {models_dir}") - else: - logger.info(f"Models directory exists: {models_dir}") - for url, model_name_with_extension in zip(list_of_model_download_urls, model_names): - status = {"url": url, "status": "success", "message": "File already exists."} - filename = os.path.join(models_dir, model_name_with_extension) - if not os.path.exists(filename): - logger.info(f"Downloading model {model_name_with_extension} from {url}...") - urllib.request.urlretrieve(url, filename) - file_size = os.path.getsize(filename) / (1024 * 1024) # Convert bytes to MB - if file_size < 100: - os.remove(filename) - status["status"] = "failure" - status["message"] = "Downloaded file is too small, probably not a valid model file." - else: - logger.info(f"Downloaded: {filename}") - else: - logger.info(f"File already exists: {filename}") - download_status.append(status) - if USE_RAMDISK: - copy_models_to_ramdisk(models_dir, ramdisk_models_dir) - logger.info("Model downloads completed.") - return model_names, download_status - -def add_model_url(new_url: str) -> str: - corrected_url = new_url - if '/blob/main/' in new_url: - corrected_url = new_url.replace('/blob/main/', '/resolve/main/') - json_path = os.path.join(BASE_DIRECTORY, "model_urls.json") - with open(json_path, "r") as f: - existing_urls = json.load(f) - if corrected_url not in existing_urls: - logger.info(f"Model URL not found in database. Adding {new_url} now...") - existing_urls.append(corrected_url) - with open(json_path, "w") as f: - json.dump(existing_urls, f) - logger.info(f"Model URL added: {new_url}") - else: - logger.info("Model URL already exists.") - return corrected_url - -async def get_embedding_from_db(text: str, llm_model_name: str): - text_hash = sha3_256(text.encode('utf-8')).hexdigest() # Compute the hash - return await execute_with_retry(_get_embedding_from_db, text_hash, llm_model_name) - -async def _get_embedding_from_db(text_hash: str, llm_model_name: str) -> Optional[dict]: - async with AsyncSessionLocal() as session: - result = await session.execute( - sql_text("SELECT embedding_json FROM embeddings WHERE text_hash=:text_hash AND llm_model_name=:llm_model_name"), - {"text_hash": text_hash, "llm_model_name": llm_model_name}, - ) - row = result.fetchone() - if row: - embedding_json = row[0] - logger.info(f"Embedding found in database for text hash '{text_hash}' using model '{llm_model_name}'") - return json.loads(embedding_json) - return None - -async def get_or_compute_embedding(request: EmbeddingRequest, req: Request = None, client_ip: str = None, document_file_hash: str = None) -> dict: - request_time = datetime.utcnow() # Capture request time as datetime object - ip_address = client_ip or (req.client.host if req else "localhost") # If client_ip is provided, use it; otherwise, try to get from req; if not available, default to "localhost" - logger.info(f"Received request for embedding for '{request.text}' using model '{request.llm_model_name}' from IP address '{ip_address}'") - embedding_list = await get_embedding_from_db(request.text, request.llm_model_name) # Check if embedding exists in the database - if embedding_list is not None: - response_time = datetime.utcnow() # Capture response time as datetime object - total_time = (response_time - request_time).total_seconds() # Calculate time taken in seconds - logger.info(f"Embedding found in database for '{request.text}' using model '{request.llm_model_name}'; returning in {total_time:.4f} seconds") - return {"embedding": embedding_list} - model = load_model(request.llm_model_name) - embedding_list = calculate_sentence_embedding(model, request.text) # Compute the embedding if not in the database - if embedding_list is None: - logger.error(f"Could not calculate the embedding for the given text: '{request.text}' using model '{request.llm_model_name}!'") - raise HTTPException(status_code=400, detail="Could not calculate the embedding for the given text") - embedding_json = json.dumps(embedding_list) # Serialize the numpy array to JSON and save to the database - response_time = datetime.utcnow() # Capture response time as datetime object - total_time = (response_time - request_time).total_seconds() # Calculate total time using datetime objects - word_length_of_input_text = len(request.text.split()) - if word_length_of_input_text > 0: - logger.info(f"Embedding calculated for '{request.text}' using model '{request.llm_model_name}' in {total_time} seconds, or an average of {total_time/word_length_of_input_text :.2f} seconds per word. Now saving to database...") - await save_embedding_to_db(request.text, request.llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash) - return {"embedding": embedding_list} - -async def save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None): - existing_embedding = await get_embedding_from_db(text, llm_model_name) # Check if the embedding already exists - if existing_embedding is not None: - return existing_embedding - return await execute_with_retry(_save_embedding_to_db, text, llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash) - -async def _save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None): - existing_embedding = await get_embedding_from_db(text, llm_model_name) - if existing_embedding: - return existing_embedding - embedding = TextEmbedding( - text=text, - llm_model_name=llm_model_name, - embedding_json=embedding_json, - ip_address=ip_address, - request_time=request_time, - response_time=response_time, - total_time=total_time, - document_file_hash=document_file_hash - ) - await db_writer.enqueue_write([embedding]) # Enqueue the write operation using the db_writer instance - -def load_model(llm_model_name: str, raise_http_exception: bool = True): - try: - models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') - if llm_model_name in embedding_model_cache: - return embedding_model_cache[llm_model_name] - matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) - if not matching_files: - logger.error(f"No model file found matching: {llm_model_name}") - raise FileNotFoundError - matching_files.sort(key=os.path.getmtime, reverse=True) - model_file_path = matching_files[0] - model_instance = LlamaCppEmbeddings(model_path=model_file_path, use_mlock=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS) - model_instance.client.verbose = False - embedding_model_cache[llm_model_name] = model_instance - return model_instance - except TypeError as e: - logger.error(f"TypeError occurred while loading the model: {e}") - raise - except Exception as e: - logger.error(f"Exception occurred while loading the model: {e}") - if raise_http_exception: - raise HTTPException(status_code=404, detail="Model file not found") - else: - raise FileNotFoundError(f"No model file found matching: {llm_model_name}") - -def load_token_level_embedding_model(llm_model_name: str, raise_http_exception: bool = True): - try: - if llm_model_name in token_level_embedding_model_cache: # Check if the model is already loaded in the cache - return token_level_embedding_model_cache[llm_model_name] - models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path - matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files - if not matching_files: - logger.error(f"No model file found matching: {llm_model_name}") - raise FileNotFoundError - matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first) - model_file_path = matching_files[0] - model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model - token_level_embedding_model_cache[llm_model_name] = model_instance # Cache the loaded model - return model_instance - except TypeError as e: - logger.error(f"TypeError occurred while loading the model: {e}") - raise - except Exception as e: - logger.error(f"Exception occurred while loading the model: {e}") - if raise_http_exception: - raise HTTPException(status_code=404, detail="Model file not found") - else: - raise FileNotFoundError(f"No model file found matching: {llm_model_name}") - -async def compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings) -> List[float]: - start_time = datetime.utcnow() - logger.info("Extracting token-level embeddings from the bundle") - parsed_df = pd.read_json(token_level_embeddings) # Parse the json_content back to a DataFrame - token_level_embeddings = list(parsed_df['embedding']) - embeddings = np.array(token_level_embeddings) # Convert the list of embeddings to a NumPy array - logger.info(f"Computing column-wise means/mins/maxes/std_devs of the embeddings... (shape: {embeddings.shape})") - assert(len(embeddings) > 0) - means = np.mean(embeddings, axis=0) - mins = np.min(embeddings, axis=0) - maxes = np.max(embeddings, axis=0) - stds = np.std(embeddings, axis=0) - logger.info("Concatenating the computed statistics to form the combined feature vector") - combined_feature_vector = np.concatenate([means, mins, maxes, stds]) - end_time = datetime.utcnow() - total_time = (end_time - start_time).total_seconds() - logger.info(f"Computed the token-level embedding bundle's combined feature vector computed in {total_time: .2f} seconds.") - return combined_feature_vector.tolist() - -async def get_or_compute_token_level_embedding_bundle_combined_feature_vector(token_level_embedding_bundle_id, token_level_embeddings, db_writer: DatabaseWriter) -> List[float]: - request_time = datetime.utcnow() - logger.info(f"Checking for existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}") - async with AsyncSessionLocal() as session: - result = await session.execute( - select(TokenLevelEmbeddingBundleCombinedFeatureVector) - .filter(TokenLevelEmbeddingBundleCombinedFeatureVector.token_level_embedding_bundle_id == token_level_embedding_bundle_id) - ) - existing_combined_feature_vector = result.scalar_one_or_none() - if existing_combined_feature_vector: - response_time = datetime.utcnow() - total_time = (response_time - request_time).total_seconds() - logger.info(f"Found existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Returning cached result in {total_time:.2f} seconds.") - return json.loads(existing_combined_feature_vector.combined_feature_vector_json) # Parse the JSON string into a list - logger.info(f"No cached combined feature_vector found for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Computing now...") - combined_feature_vector = await compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings) - combined_feature_vector_db_object = TokenLevelEmbeddingBundleCombinedFeatureVector( - token_level_embedding_bundle_id=token_level_embedding_bundle_id, - combined_feature_vector_json=json.dumps(combined_feature_vector) # Convert the list to a JSON string - ) - logger.info(f"Writing combined feature vector for database write for token-level embedding bundle ID: {token_level_embedding_bundle_id} to the database...") - await db_writer.enqueue_write([combined_feature_vector_db_object]) - return combined_feature_vector - -async def calculate_token_level_embeddings(text: str, llm_model_name: str, client_ip: str, token_level_embedding_bundle_id: int) -> List[np.array]: - request_time = datetime.utcnow() - logger.info(f"Starting token-level embedding calculation for text: '{text}' using model: '{llm_model_name}'") - logger.info(f"Loading model: '{llm_model_name}'") - llm = load_token_level_embedding_model(llm_model_name) # Assuming this method returns an instance of the Llama class - token_embeddings = [] - tokens = text.split() # Simple whitespace tokenizer; can be replaced with a more advanced one if needed - logger.info(f"Tokenized text into {len(tokens)} tokens") - for idx, token in enumerate(tokens, start=1): - try: # Check if the embedding is already available in the database - existing_embedding = await get_token_level_embedding_from_db(token, llm_model_name) - if existing_embedding is not None: - token_embeddings.append(np.array(existing_embedding)) - logger.info(f"Embedding retrieved from database for token '{token}'") - continue - logger.info(f"Processing token {idx} of {len(tokens)}: '{token}'") - token_embedding = llm.embed(token) - token_embedding_array = np.array(token_embedding) - token_embeddings.append(token_embedding_array) - response_time = datetime.utcnow() - token_level_embedding_json = json.dumps(token_embedding_array.tolist()) - await store_token_level_embeddings_in_db(token, llm_model_name, token_level_embedding_json, client_ip, request_time, response_time, token_level_embedding_bundle_id) - except RuntimeError as e: - logger.error(f"Failed to calculate embedding for token '{token}': {e}") - logger.info(f"Completed token embedding calculation for all tokens in text: '{text}'") - return token_embeddings - -async def get_token_level_embedding_from_db(token: str, llm_model_name: str) -> Optional[List[float]]: - token_hash = sha3_256(token.encode('utf-8')).hexdigest() # Compute the hash - async with AsyncSessionLocal() as session: - result = await session.execute( - sql_text("SELECT token_level_embedding_json FROM token_level_embeddings WHERE token_hash=:token_hash AND llm_model_name=:llm_model_name"), - {"token_hash": token_hash, "llm_model_name": llm_model_name}, - ) - row = result.fetchone() - if row: - embedding_json = row[0] - logger.info(f"Embedding found in database for token hash '{token_hash}' using model '{llm_model_name}'") - return json.loads(embedding_json) - return None - -async def store_token_level_embeddings_in_db(token: str, llm_model_name: str, token_level_embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, token_level_embedding_bundle_id: int): - total_time = (response_time - request_time).total_seconds() - embedding = TokenLevelEmbedding( - token=token, - llm_model_name=llm_model_name, - token_level_embedding_json=token_level_embedding_json, - ip_address=ip_address, - request_time=request_time, - response_time=response_time, - total_time=total_time, - token_level_embedding_bundle_id=token_level_embedding_bundle_id - ) - await db_writer.enqueue_write([embedding]) # Enqueue the write operation for the token-level embedding - -def calculate_sentence_embedding(llama: Llama, text: str) -> np.array: - sentence_embedding = None - retry_count = 0 - while sentence_embedding is None and retry_count < 3: - try: - if retry_count > 0: - logger.info(f"Attempting again calculate sentence embedding. Attempt number {retry_count + 1}") - sentence_embedding = llama.embed_query(text) - except TypeError as e: - logger.error(f"TypeError in calculate_sentence_embedding: {e}") - raise - except Exception as e: - logger.error(f"Exception in calculate_sentence_embedding: {e}") - text = text[:-int(len(text) * 0.1)] - retry_count += 1 - logger.info(f"Trimming sentence due to too many tokens. New length: {len(text)}") - if sentence_embedding is None: - logger.error("Failed to calculate sentence embedding after multiple attempts") - return sentence_embedding - -async def compute_embeddings_for_document(strings: list, llm_model_name: str, client_ip: str, document_file_hash: str) -> List[Tuple[str, np.array]]: - results = [] - if USE_PARALLEL_INFERENCE_QUEUE: - logger.info(f"Using parallel inference queue to compute embeddings for {len(strings)} strings") - start_time = time.perf_counter() # Record the start time - semaphore = asyncio.Semaphore(MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS) - async def compute_embedding(text): # Define a function to compute the embedding for a given text - try: - async with semaphore: # Acquire a semaphore slot - request = EmbeddingRequest(text=text, llm_model_name=llm_model_name) - embedding = await get_embedding_vector_for_string(request, client_ip=client_ip, document_file_hash=document_file_hash) - return text, embedding["embedding"] - except Exception as e: - logger.error(f"Error computing embedding for text '{text}': {e}") - return text, None - results = await asyncio.gather(*[compute_embedding(s) for s in strings]) # Use asyncio.gather to run the tasks concurrently - end_time = time.perf_counter() # Record the end time - duration = end_time - start_time - if len(strings) > 0: - logger.info(f"Parallel inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string") - else: # Compute embeddings sequentially - logger.info(f"Using sequential inference to compute embeddings for {len(strings)} strings") - start_time = time.perf_counter() # Record the start time - for s in strings: - embedding_request = EmbeddingRequest(text=s, llm_model_name=llm_model_name) - embedding = await get_embedding_vector_for_string(embedding_request, client_ip=client_ip, document_file_hash=document_file_hash) - results.append((s, embedding["embedding"])) - end_time = time.perf_counter() # Record the end time - duration = end_time - start_time - if len(strings) > 0: - logger.info(f"Sequential inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string") - filtered_results = [(text, embedding) for text, embedding in results if embedding is not None] # Filter out results with None embeddings (applicable to parallel processing) and return - return filtered_results - -async def parse_submitted_document_file_into_sentence_strings_func(temp_file_path: str, mime_type: str): - strings = [] - if mime_type.startswith('text/'): - with open(temp_file_path, 'r') as buffer: - content = buffer.read() - else: - try: - content = textract.process(temp_file_path).decode('utf-8') - except UnicodeDecodeError: - try: - content = textract.process(temp_file_path).decode('unicode_escape') - except Exception as e: - logger.error(f"Error while processing file: {e}, mime_type: {mime_type}") - raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}") - except Exception as e: - logger.error(f"Error while processing file: {e}, mime_type: {mime_type}") - raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}") - sentences = sophisticated_sentence_splitter(content) - if len(sentences) == 0 and temp_file_path.lower().endswith('.pdf'): - logger.info("No sentences found, attempting OCR using Tesseract.") - try: - content = textract.process(temp_file_path, method='tesseract').decode('utf-8') - sentences = sophisticated_sentence_splitter(content) - except Exception as e: - logger.error(f"Error while processing file with OCR: {e}") - raise HTTPException(status_code=400, detail=f"OCR failed: {e}") - if len(sentences) == 0: - logger.info("No sentences found in the document") - raise HTTPException(status_code=400, detail="No sentences found in the document") - logger.info(f"Extracted {len(sentences)} sentences from the document") - strings = [s.strip() for s in sentences if len(s.strip()) > MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING] - return strings - -async def _get_document_from_db(file_hash: str): - async with AsyncSessionLocal() as session: - result = await session.execute(select(Document).filter(Document.document_hash == file_hash)) - return result.scalar_one_or_none() - -async def store_document_embeddings_in_db(file: File, file_hash: str, original_file_content: bytes, json_content: bytes, results: List[Tuple[str, np.array]], llm_model_name: str, client_ip: str, request_time: datetime): - document = await _get_document_from_db(file_hash) # First, check if a Document with the same hash already exists - if not document: # If not, create a new Document object - document = Document(document_hash=file_hash, llm_model_name=llm_model_name) - await db_writer.enqueue_write([document]) - document_embedding = DocumentEmbedding( - filename=file.filename, - mimetype=file.content_type, - file_hash=file_hash, - llm_model_name=llm_model_name, - file_data=original_file_content, - document_embedding_results_json=json.loads(json_content.decode()), - ip_address=client_ip, - request_time=request_time, - response_time=datetime.utcnow(), - total_time=(datetime.utcnow() - request_time).total_seconds() - ) - document.document_embeddings.append(document_embedding) # Associate it with the Document - document.update_hash() # This will trigger the SQLAlchemy event to update the document_hash - await db_writer.enqueue_write([document, document_embedding]) # Enqueue the write operation for the document embedding - write_operations = [] # Collect text embeddings to write - logger.info(f"Storing {len(results)} text embeddings in database") - for text, embedding in results: - embedding_entry = await _get_embedding_from_db(text, llm_model_name) - if not embedding_entry: - embedding_entry = TextEmbedding( - text=text, - llm_model_name=llm_model_name, - embedding_json=json.dumps(embedding), - ip_address=client_ip, - request_time=request_time, - response_time=datetime.utcnow(), - total_time=(datetime.utcnow() - request_time).total_seconds(), - document_file_hash=file_hash # Link it to the DocumentEmbedding via file_hash - ) - else: - write_operations.append(embedding_entry) - await db_writer.enqueue_write(write_operations) # Enqueue the write operation for text embeddings - -def load_text_completion_model(llm_model_name: str, raise_http_exception: bool = True): - try: - if llm_model_name in text_completion_model_cache: # Check if the model is already loaded in the cache - return text_completion_model_cache[llm_model_name] - models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path - matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files - if not matching_files: - logger.error(f"No model file found matching: {llm_model_name}") - raise FileNotFoundError - matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first) - model_file_path = matching_files[0] - model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model - text_completion_model_cache[llm_model_name] = model_instance # Cache the loaded model - return model_instance - except TypeError as e: - logger.error(f"TypeError occurred while loading the model: {e}") - raise - except Exception as e: - logger.error(f"Exception occurred while loading the model: {e}") - if raise_http_exception: - raise HTTPException(status_code=404, detail="Model file not found") - else: - raise FileNotFoundError(f"No model file found matching: {llm_model_name}") - -async def generate_completion_from_llm(request: TextCompletionRequest, req: Request = None, client_ip: str = None) -> List[TextCompletionResponse]: - request_time = datetime.utcnow() - logger.info(f"Starting text completion calculation using model: '{request.llm_model_name}'for input prompt: '{request.input_prompt}'") - logger.info(f"Loading model: '{request.llm_model_name}'") - llm = load_text_completion_model(request.llm_model_name) - logger.info(f"Done loading model: '{request.llm_model_name}'") - list_of_llm_outputs = [] - if request.grammar_file_string != "": - list_of_grammar_files = glob.glob("./grammar_files/*.gbnf") - matching_grammar_files = [x for x in list_of_grammar_files if request.grammar_file_string in x] - if len(matching_grammar_files) == 0: - logger.error(f"No grammar file found matching: {request.grammar_file_string}") - raise FileNotFoundError - matching_grammar_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first) - grammar_file_path = matching_grammar_files[0] - logger.info(f"Loading selected grammar file: '{grammar_file_path}'") - llama_grammar = LlamaGrammar.from_file(grammar_file_path) - for ii in range(request.number_of_completions_to_generate): - logger.info(f"Generating completion {ii+1} of {request.number_of_completions_to_generate} with model {request.llm_model_name} for input prompt: '{request.input_prompt}'") - output = llm(prompt=request.input_prompt, grammar=llama_grammar, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature) - list_of_llm_outputs.append(output) - else: - for ii in range(request.number_of_completions_to_generate): - output = llm(prompt=request.input_prompt, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature) - list_of_llm_outputs.append(output) - response_time = datetime.utcnow() - total_time_per_completion = ((response_time - request_time).total_seconds()) / request.number_of_completions_to_generate - list_of_responses = [] - for idx, current_completion_output in enumerate(list_of_llm_outputs): - generated_text = current_completion_output['choices'][0]['text'] - if request.grammar_file_string == 'json': - generated_text = generated_text.encode('unicode_escape').decode() - llm_model_usage_json = json.dumps(current_completion_output['usage']) - logger.info(f"Completed text completion {idx} in an average of {total_time_per_completion:.2f} seconds for input prompt: '{request.input_prompt}'; Beginning of generated text: \n'{generated_text[:100]}'") - response = TextCompletionResponse(input_prompt = request.input_prompt, - llm_model_name = request.llm_model_name, - grammar_file_string = request.grammar_file_string, - number_of_tokens_to_generate = request.number_of_tokens_to_generate, - number_of_completions_to_generate = request.number_of_completions_to_generate, - time_taken_in_seconds = float(total_time_per_completion), - generated_text = generated_text, - llm_model_usage_json = llm_model_usage_json) - list_of_responses.append(response) - return list_of_responses @app.exception_handler(SQLAlchemyError) @@ -1144,7 +78,6 @@ async def general_exception_handler(request: Request, exc: Exception) -> JSONRes logger.exception(exc) return JSONResponse(status_code=500, content={"message": "An unexpected error occurred"}) -#FastAPI Endpoints start here: @app.get("/", include_in_schema=False) async def custom_swagger_ui_html(): @@ -1164,7 +97,7 @@ async def custom_swagger_ui_html(): ### Example Response: ```json { - "model_names": ["yarn-llama-2-7b-128k", "yarn-llama-2-13b-128k", "openchat_v3.2_super", "phind-codellama-34b-python-v1", "my_super_custom_model"] + "model_names": ["yarn-llama-2-7b-128k", "openchat_v3.2_super", "mistral-7b-instruct-v0.1", "my_super_custom_model"] } ```""", response_description="A JSON object containing the list of available model names.") @@ -1283,15 +216,23 @@ async def get_all_stored_documents(req: Request, token: str = None) -> AllDocume async def add_new_model(model_url: str, token: str = None) -> Dict[str, Any]: if USE_SECURITY_TOKEN and (token is None or token != SECURITY_TOKEN): raise HTTPException(status_code=403, detail="Unauthorized") - decoded_model_url = unquote(model_url) - if not decoded_model_url.endswith('.gguf'): - return {"status": "error", "message": "Model URL must point to a .gguf file."} - corrected_model_url = add_model_url(decoded_model_url) - _, download_status = download_models() - status_dict = {status["url"]: status for status in download_status} - if corrected_model_url in status_dict: - return {"status": status_dict[corrected_model_url]["status"], "message": status_dict[corrected_model_url]["message"]} - return {"status": "unknown", "message": "Unexpected error."} + unique_id = f"add_model_{hash(model_url)}" # Generate a unique lock ID based on the model_url + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + decoded_model_url = unquote(model_url) + if not decoded_model_url.endswith('.gguf'): + return {"status": "error", "message": "Model URL must point to a .gguf file."} + corrected_model_url = add_model_url(decoded_model_url) + _, download_status = download_models() + status_dict = {status["url"]: status for status in download_status} + if corrected_model_url in status_dict: + return {"status": status_dict[corrected_model_url]["status"], "message": status_dict[corrected_model_url]["message"]} + return {"status": "unknown", "message": "Unexpected error."} + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing", "message": "Another worker is already processing this model URL."} @@ -1594,44 +535,53 @@ async def compute_similarity_between_strings(request: SimilarityRequest, req: Re response_description="A JSON object containing the query text along with the most similar strings and similarity scores.") async def search_stored_embeddings_with_query_string_for_semantic_similarity(request: SemanticSearchRequest, req: Request, token: str = None) -> SemanticSearchResponse: global faiss_indexes, token_faiss_indexes, associated_texts_by_model - faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() - request_time = datetime.utcnow() - llm_model_name = request.llm_model_name - num_results = request.number_of_most_similar_strings_to_return - total_entries = len(associated_texts_by_model[llm_model_name]) # Get the total number of entries for the model - num_results = min(num_results, total_entries) # Ensure num_results doesn't exceed the total number of entries - logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}") - if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN): - raise HTTPException(status_code=403, detail="Unauthorized") - try: - logger.info(f"Computing embedding for input text: {request.query_text}") - embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=request.llm_model_name) - embedding_response = await get_embedding_vector_for_string(embedding_request, req) - input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1) - faiss.normalize_L2(input_embedding) # Normalize the input vector for cosine similarity - logger.info(f"Computed embedding for input text: {request.query_text}") - faiss_index = faiss_indexes.get(llm_model_name) # Retrieve the correct FAISS index for the llm_model_name - if faiss_index is None: - raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}") - logger.info("Searching for the most similar string in the FAISS index") - similarities, indices = faiss_index.search(input_embedding.reshape(1, -1), num_results) # Search for num_results similar strings - results = [] # Create an empty list to store the results - for ii in range(num_results): - similarity = float(similarities[0][ii]) # Convert numpy.float32 to native float - most_similar_text = associated_texts_by_model[llm_model_name][indices[0][ii]] - if most_similar_text != request.query_text: # Don't return the query text as a result - results.append({"search_result_text": most_similar_text, "similarity_to_query_text": similarity}) - response_time = datetime.utcnow() - total_time = (response_time - request_time).total_seconds() - logger.info(f"Finished searching for the most similar string in the FAISS index in {total_time} seconds. Found {len(results)} results, returning the top {num_results}.") - logger.info(f"Found most similar strings for query string {request.query_text}: {results}") - return {"query_text": request.query_text, "results": results} # Return the response matching the SemanticSearchResponse model - except Exception as e: - logger.error(f"An error occurred while processing the request: {e}") - logger.error(traceback.format_exc()) # Print the traceback - raise HTTPException(status_code=500, detail="Internal Server Error") - - + unique_id = f"semantic_search_{request.query_text}_{request.llm_model_name}_{request.number_of_most_similar_strings_to_return}" # Unique ID for this operation + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() + request_time = datetime.utcnow() + llm_model_name = request.llm_model_name + num_results = request.number_of_most_similar_strings_to_return + total_entries = len(associated_texts_by_model[llm_model_name]) # Get the total number of entries for the model + num_results = min(num_results, total_entries) # Ensure num_results doesn't exceed the total number of entries + logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}") + if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN): + raise HTTPException(status_code=403, detail="Unauthorized") + try: + logger.info(f"Computing embedding for input text: {request.query_text}") + embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=request.llm_model_name) + embedding_response = await get_embedding_vector_for_string(embedding_request, req) + input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1) + faiss.normalize_L2(input_embedding) # Normalize the input vector for cosine similarity + logger.info(f"Computed embedding for input text: {request.query_text}") + faiss_index = faiss_indexes.get(llm_model_name) # Retrieve the correct FAISS index for the llm_model_name + if faiss_index is None: + raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}") + logger.info("Searching for the most similar string in the FAISS index") + similarities, indices = faiss_index.search(input_embedding.reshape(1, -1), num_results) # Search for num_results similar strings + results = [] # Create an empty list to store the results + for ii in range(num_results): + similarity = float(similarities[0][ii]) # Convert numpy.float32 to native float + most_similar_text = associated_texts_by_model[llm_model_name][indices[0][ii]] + if most_similar_text != request.query_text: # Don't return the query text as a result + results.append({"search_result_text": most_similar_text, "similarity_to_query_text": similarity}) + response_time = datetime.utcnow() + total_time = (response_time - request_time).total_seconds() + logger.info(f"Finished searching for the most similar string in the FAISS index in {total_time} seconds. Found {len(results)} results, returning the top {num_results}.") + logger.info(f"Found most similar strings for query string {request.query_text}: {results}") + return {"query_text": request.query_text, "results": results} # Return the response matching the SemanticSearchResponse model + except Exception as e: + logger.error(f"An error occurred while processing the request: {e}") + logger.error(traceback.format_exc()) # Print the traceback + raise HTTPException(status_code=500, detail="Internal Server Error") + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing"} + + + @app.post("/advanced_search_stored_embeddings_with_query_string_for_semantic_similarity/", response_model=AdvancedSemanticSearchResponse, summary="Advanced Semantic Search with Two-Step Similarity Measures", @@ -1676,52 +626,60 @@ async def search_stored_embeddings_with_query_string_for_semantic_similarity(req response_description="A JSON object containing the query text and the most similar strings, along with their similarity scores for multiple measures.") async def advanced_search_stored_embeddings_with_query_string_for_semantic_similarity(request: AdvancedSemanticSearchRequest, req: Request, token: str = None) -> AdvancedSemanticSearchResponse: global faiss_indexes, token_faiss_indexes, associated_texts_by_model - faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() - request_time = datetime.utcnow() - llm_model_name = request.llm_model_name - total_entries = len(associated_texts_by_model[llm_model_name]) - num_results = max([1, int((1 - request.similarity_filter_percentage) * total_entries)]) - logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}") - if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN): - raise HTTPException(status_code=403, detail="Unauthorized") - try: - logger.info(f"Computing embedding for input text: {request.query_text}") - embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=llm_model_name) - embedding_response = await get_embedding_vector_for_string(embedding_request, req) - input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1) - faiss.normalize_L2(input_embedding) - logger.info(f"Computed embedding for input text: {request.query_text}") - faiss_index = faiss_indexes.get(llm_model_name) - if faiss_index is None: - raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}") - _, indices = faiss_index.search(input_embedding, num_results) - filtered_indices = indices[0] - similarity_results = [] - for idx in filtered_indices: - associated_text = associated_texts_by_model[llm_model_name][idx] - embedding_request = EmbeddingRequest(text=associated_text, llm_model_name=llm_model_name) - embedding_response = await get_embedding_vector_for_string(embedding_request, req) - filtered_embedding = np.array(embedding_response["embedding"]) - params = { - "vector_1": input_embedding.tolist()[0], - "vector_2": filtered_embedding.tolist(), - "similarity_measure": "all" - } - similarity_stats_str = fvs.py_compute_vector_similarity_stats(json.dumps(params)) - similarity_stats_json = json.loads(similarity_stats_str) - similarity_results.append({ - "search_result_text": associated_text, - "similarity_to_query_text": similarity_stats_json - }) - num_to_return = request.number_of_most_similar_strings_to_return if request.number_of_most_similar_strings_to_return is not None else len(similarity_results) - results = sorted(similarity_results, key=lambda x: x["similarity_to_query_text"]["hoeffding_d"], reverse=True)[:num_to_return] - response_time = datetime.utcnow() - total_time = (response_time - request_time).total_seconds() - logger.info(f"Finished advanced search in {total_time} seconds. Found {len(results)} results.") - return {"query_text": request.query_text, "results": results} - except Exception as e: - logger.error(f"An error occurred while processing the request: {e}") - raise HTTPException(status_code=500, detail="Internal Server Error") + unique_id = f"advanced_semantic_search_{request.query_text}_{request.llm_model_name}_{request.similarity_filter_percentage}_{request.number_of_most_similar_strings_to_return}" + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() + request_time = datetime.utcnow() + llm_model_name = request.llm_model_name + total_entries = len(associated_texts_by_model[llm_model_name]) + num_results = max([1, int((1 - request.similarity_filter_percentage) * total_entries)]) + logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}") + if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN): + raise HTTPException(status_code=403, detail="Unauthorized") + try: + logger.info(f"Computing embedding for input text: {request.query_text}") + embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=llm_model_name) + embedding_response = await get_embedding_vector_for_string(embedding_request, req) + input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1) + faiss.normalize_L2(input_embedding) + logger.info(f"Computed embedding for input text: {request.query_text}") + faiss_index = faiss_indexes.get(llm_model_name) + if faiss_index is None: + raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}") + _, indices = faiss_index.search(input_embedding, num_results) + filtered_indices = indices[0] + similarity_results = [] + for idx in filtered_indices: + associated_text = associated_texts_by_model[llm_model_name][idx] + embedding_request = EmbeddingRequest(text=associated_text, llm_model_name=llm_model_name) + embedding_response = await get_embedding_vector_for_string(embedding_request, req) + filtered_embedding = np.array(embedding_response["embedding"]) + params = { + "vector_1": input_embedding.tolist()[0], + "vector_2": filtered_embedding.tolist(), + "similarity_measure": "all" + } + similarity_stats_str = fvs.py_compute_vector_similarity_stats(json.dumps(params)) + similarity_stats_json = json.loads(similarity_stats_str) + similarity_results.append({ + "search_result_text": associated_text, + "similarity_to_query_text": similarity_stats_json + }) + num_to_return = request.number_of_most_similar_strings_to_return if request.number_of_most_similar_strings_to_return is not None else len(similarity_results) + results = sorted(similarity_results, key=lambda x: x["similarity_to_query_text"]["hoeffding_d"], reverse=True)[:num_to_return] + response_time = datetime.utcnow() + total_time = (response_time - request_time).total_seconds() + logger.info(f"Finished advanced search in {total_time} seconds. Found {len(results)} results.") + return {"query_text": request.query_text, "results": results} + except Exception as e: + logger.error(f"An error occurred while processing the request: {e}") + raise HTTPException(status_code=500, detail="Internal Server Error") + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing"} @@ -1775,41 +733,50 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...), hash_obj.update(chunk) file_hash = hash_obj.hexdigest() logger.info(f"SHA3-256 hash of submitted file: {file_hash}") - async with AsyncSessionLocal() as session: # Check if the document has been processed before - result = await session.execute(select(DocumentEmbedding).filter(DocumentEmbedding.file_hash == file_hash, DocumentEmbedding.llm_model_name == llm_model_name)) - existing_document_embedding = result.scalar_one_or_none() - if existing_document_embedding: # If the document has been processed before, return the existing result - logger.info(f"Document {file.filename} has been processed before, returning existing result") - json_content = json.dumps(existing_document_embedding.document_embedding_results_json).encode() - else: # If the document has not been processed, continue processing - mime = Magic(mime=True) - mime_type = mime.from_file(temp_file_path) - logger.info(f"Received request to extract embeddings for document {file.filename} with MIME type: {mime_type} and size: {os.path.getsize(temp_file_path)} bytes from IP address: {client_ip}") - strings = await parse_submitted_document_file_into_sentence_strings_func(temp_file_path, mime_type) - results = await compute_embeddings_for_document(strings, llm_model_name, client_ip, file_hash) # Compute the embeddings and json_content for new documents - df = pd.DataFrame(results, columns=['text', 'embedding']) - json_content = df.to_json(orient=json_format or 'records').encode() - with open(temp_file_path, 'rb') as file_buffer: # Store the results in the database - original_file_content = file_buffer.read() - await store_document_embeddings_in_db(file, file_hash, original_file_content, json_content, results, llm_model_name, client_ip, request_time) - overall_total_time = (datetime.utcnow() - request_time).total_seconds() - logger.info(f"Done getting all embeddings for document {file.filename} containing {len(strings)} with model {llm_model_name}") - json_content_length = len(json_content) - if len(json_content) > 0: - logger.info(f"The response took {overall_total_time} seconds to generate, or {overall_total_time / (len(strings)/1000.0)} seconds per thousand input tokens and {overall_total_time / (float(json_content_length)/1000000.0)} seconds per million output characters.") - if send_back_json_or_zip_file == 'json': # Assume 'json' response should be sent back - logger.info(f"Returning JSON response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}") - return JSONResponse(content=json.loads(json_content.decode())) # Decode the content and parse it as JSON - else: # Assume 'zip' file should be sent back - original_filename_without_extension, _ = os.path.splitext(file.filename) - json_file_path = f"/tmp/{original_filename_without_extension}.json" - with open(json_file_path, 'wb') as json_file: # Write the JSON content as bytes - json_file.write(json_content) - zip_file_path = f"/tmp/{original_filename_without_extension}.zip" - with zipfile.ZipFile(zip_file_path, 'w') as zipf: - zipf.write(json_file_path, os.path.basename(json_file_path)) - logger.info(f"Returning ZIP response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}") - return FileResponse(zip_file_path, headers={"Content-Disposition": f"attachment; filename={original_filename_without_extension}.zip"}) + unique_id = f"document_embedding_{file_hash}_{llm_model_name}" + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + async with AsyncSessionLocal() as session: # Check if the document has been processed before + result = await session.execute(select(DocumentEmbedding).filter(DocumentEmbedding.file_hash == file_hash, DocumentEmbedding.llm_model_name == llm_model_name)) + existing_document_embedding = result.scalar_one_or_none() + if existing_document_embedding: # If the document has been processed before, return the existing result + logger.info(f"Document {file.filename} has been processed before, returning existing result") + json_content = json.dumps(existing_document_embedding.document_embedding_results_json).encode() + else: # If the document has not been processed, continue processing + mime = Magic(mime=True) + mime_type = mime.from_file(temp_file_path) + logger.info(f"Received request to extract embeddings for document {file.filename} with MIME type: {mime_type} and size: {os.path.getsize(temp_file_path)} bytes from IP address: {client_ip}") + strings = await parse_submitted_document_file_into_sentence_strings_func(temp_file_path, mime_type) + results = await compute_embeddings_for_document(strings, llm_model_name, client_ip, file_hash) # Compute the embeddings and json_content for new documents + df = pd.DataFrame(results, columns=['text', 'embedding']) + json_content = df.to_json(orient=json_format or 'records').encode() + with open(temp_file_path, 'rb') as file_buffer: # Store the results in the database + original_file_content = file_buffer.read() + await store_document_embeddings_in_db(file, file_hash, original_file_content, json_content, results, llm_model_name, client_ip, request_time) + overall_total_time = (datetime.utcnow() - request_time).total_seconds() + logger.info(f"Done getting all embeddings for document {file.filename} containing {len(strings)} with model {llm_model_name}") + json_content_length = len(json_content) + if len(json_content) > 0: + logger.info(f"The response took {overall_total_time} seconds to generate, or {overall_total_time / (len(strings)/1000.0)} seconds per thousand input tokens and {overall_total_time / (float(json_content_length)/1000000.0)} seconds per million output characters.") + if send_back_json_or_zip_file == 'json': # Assume 'json' response should be sent back + logger.info(f"Returning JSON response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}") + return JSONResponse(content=json.loads(json_content.decode())) # Decode the content and parse it as JSON + else: # Assume 'zip' file should be sent back + original_filename_without_extension, _ = os.path.splitext(file.filename) + json_file_path = f"/tmp/{original_filename_without_extension}.json" + with open(json_file_path, 'wb') as json_file: # Write the JSON content as bytes + json_file.write(json_content) + zip_file_path = f"/tmp/{original_filename_without_extension}.zip" + with zipfile.ZipFile(zip_file_path, 'w') as zipf: + zipf.write(json_file_path, os.path.basename(json_file_path)) + logger.info(f"Returning ZIP response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}") + return FileResponse(zip_file_path, headers={"Content-Disposition": f"attachment; filename={original_filename_without_extension}.zip"}) + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing"} + @app.post("/get_text_completions_from_input_prompt/", @@ -1833,7 +800,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...), ```json { "input_prompt": "The Kings of France in the 17th Century:", - "llm_model_name": "phind-codellama-34b-python-v1", + "llm_model_name": "mistral-7b-instruct-v0.1", "temperature": 0.95, "grammar_file_string": "json", "number_of_tokens_to_generate": 500, @@ -1849,7 +816,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...), [ { "input_prompt": "The Kings of France in the 17th Century:", - "llm_model_name": "phind-codellama-34b-python-v1", + "llm_model_name": "mistral-7b-instruct-v0.1", "grammar_file_string": "json", "number_of_tokens_to_generate": 500, "number_of_completions_to_generate": 3, @@ -1859,7 +826,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...), }, { "input_prompt": "The Kings of France in the 17th Century:", - "llm_model_name": "phind-codellama-34b-python-v1", + "llm_model_name": "mistral-7b-instruct-v0.1", "grammar_file_string": "json", "number_of_tokens_to_generate": 500, "number_of_completions_to_generate": 3, @@ -1869,7 +836,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...), }, { "input_prompt": "The Kings of France in the 17th Century:", - "llm_model_name": "phind-codellama-34b-python-v1", + "llm_model_name": "mistral-7b-instruct-v0.1", "grammar_file_string": "json", "number_of_tokens_to_generate": 500, "number_of_completions_to_generate": 3, @@ -1884,7 +851,15 @@ async def get_text_completions_from_input_prompt(request: TextCompletionRequest, logger.warning(f"Unauthorized request from client IP {client_ip}") raise HTTPException(status_code=403, detail="Unauthorized") try: - return await generate_completion_from_llm(request, req, client_ip) + unique_id = f"text_completion_{hash(request.input_prompt)}_{request.llm_model_name}" + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + return await generate_completion_from_llm(request, req, client_ip) + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing"} except Exception as e: logger.error(f"An error occurred while processing the request: {e}") logger.error(traceback.format_exc()) # Print the traceback @@ -1936,24 +911,7 @@ async def clear_ramdisk_endpoint(token: str = None): @app.on_event("startup") async def startup_event(): - global db_writer, faiss_indexes, token_faiss_indexes, associated_texts_by_model - await initialize_db() - queue = asyncio.Queue() - db_writer = DatabaseWriter(queue) - await db_writer.initialize_processing_hashes() - asyncio.create_task(db_writer.dedicated_db_writer()) - global USE_RAMDISK - if USE_RAMDISK and not check_that_user_has_required_permissions_to_manage_ramdisks(): - USE_RAMDISK = False - elif USE_RAMDISK: - setup_ramdisk() - list_of_downloaded_model_names, download_status = download_models() - for llm_model_name in list_of_downloaded_model_names: - try: - load_model(llm_model_name, raise_http_exception=False) - except FileNotFoundError as e: - logger.error(e) - faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() + await initialize_globals() @app.get("/download/{file_name}") @@ -1985,4 +943,4 @@ def show_logs_default(): if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT) + uvicorn.run("swiss_army_llama:app", **option) diff --git a/uvicorn_config.py b/uvicorn_config.py new file mode 100644 index 0000000..ac4f89f --- /dev/null +++ b/uvicorn_config.py @@ -0,0 +1,9 @@ +from decouple import config +SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int) +UVICORN_NUMBER_OF_WORKERS = config("UVICORN_NUMBER_OF_WORKERS", default=3, cast=int) + +option = { + "host": "0.0.0.0", + "port": SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT, + "workers": UVICORN_NUMBER_OF_WORKERS +}