diff --git a/.env b/.env index bbbf0ef..b043103 100644 --- a/.env +++ b/.env @@ -8,6 +8,7 @@ DEFAULT_MAX_COMPLETION_TOKENS=1000 DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE =1 DEFAULT_COMPLETION_TEMPERATURE=0.7 LLAMA_EMBEDDING_SERVER_LISTEN_PORT=8089 +UVICORN_NUMBER_OF_WORKERS=2 MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING=15 MAX_RETRIES=10 DB_WRITE_BATCH_SIZE=25 diff --git a/README.md b/README.md index 0064201..5ea7912 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,7 @@ fast_vector_similarity faster-whisper textract pytz +uvloop ``` ## Running the Application @@ -137,6 +138,7 @@ You can configure the service easily by editing the included `.env` file. Here's - `DEFAULT_MODEL_NAME`: Default model name to use. (e.g., `yarn-llama-2-13b-128k`) - `LLM_CONTEXT_SIZE_IN_TOKENS`: Context size in tokens for LLM. (e.g., `512`) - `SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT`: Port number for the service. (e.g., `8089`) +- `UVICORN_NUMBER_OF_WORKERS`: Number of workers for Uvicorn. (e.g., `2`) - `MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING`: Minimum string length for document embedding. (e.g., `15`) - `MAX_RETRIES`: Maximum retries for locked database. (e.g., `10`) - `DB_WRITE_BATCH_SIZE`: Database write batch size. (e.g., `25`) diff --git a/database_functions.py b/database_functions.py new file mode 100644 index 0000000..2b5a52f --- /dev/null +++ b/database_functions.py @@ -0,0 +1,162 @@ +from embeddings_data_models import Base, TextEmbedding, DocumentEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundle, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript +from logger_config import setup_logger +import traceback +import asyncio +import random +from sqlalchemy import select +from sqlalchemy import text as sql_text +from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from datetime import datetime +from decouple import config + +logger = setup_logger() +db_writer = None +DATABASE_URL = "sqlite+aiosqlite:///swiss_army_llama.sqlite" +MAX_RETRIES = config("MAX_RETRIES", default=3, cast=int) +DB_WRITE_BATCH_SIZE = config("DB_WRITE_BATCH_SIZE", default=25, cast=int) +RETRY_DELAY_BASE_SECONDS = config("RETRY_DELAY_BASE_SECONDS", default=1, cast=int) +JITTER_FACTOR = config("JITTER_FACTOR", default=0.1, cast=float) + +engine = create_async_engine(DATABASE_URL, echo=False, connect_args={"check_same_thread": False}) +AsyncSessionLocal = sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False +) +class DatabaseWriter: + def __init__(self, queue): + self.queue = queue + self.processing_hashes = set() # Set to store the hashes if everything that is currently being processed in the queue (to avoid duplicates of the same task being added to the queue) + + def _get_hash_from_operation(self, operation): + attr_name = { + TextEmbedding: 'text_hash', + DocumentEmbedding: 'file_hash', + Document: 'document_hash', + TokenLevelEmbedding: 'token_hash', + TokenLevelEmbeddingBundle: 'input_text_hash', + TokenLevelEmbeddingBundleCombinedFeatureVector: 'combined_feature_vector_hash', + AudioTranscript: 'audio_file_hash' + }.get(type(operation)) + hash_value = getattr(operation, attr_name, None) + llm_model_name = getattr(operation, 'llm_model_name', None) + return f"{hash_value}_{llm_model_name}" if hash_value and llm_model_name else None + + async def initialize_processing_hashes(self, chunk_size=1000): + start_time = datetime.utcnow() + async with AsyncSessionLocal() as session: + queries = [ + (select(TextEmbedding.text_hash, TextEmbedding.llm_model_name), True), + (select(DocumentEmbedding.file_hash, DocumentEmbedding.llm_model_name), True), + (select(Document.document_hash, Document.llm_model_name), True), + (select(TokenLevelEmbedding.token_hash, TokenLevelEmbedding.llm_model_name), True), + (select(TokenLevelEmbeddingBundle.input_text_hash, TokenLevelEmbeddingBundle.llm_model_name), True), + (select(TokenLevelEmbeddingBundleCombinedFeatureVector.combined_feature_vector_hash, TokenLevelEmbeddingBundleCombinedFeatureVector.llm_model_name), True), + (select(AudioTranscript.audio_file_hash), False) + ] + for query, has_llm in queries: + offset = 0 + while True: + result = await session.execute(query.limit(chunk_size).offset(offset)) + rows = result.fetchall() + if not rows: + break + for row in rows: + if has_llm: + hash_with_model = f"{row[0]}_{row[1]}" + else: + hash_with_model = row[0] + self.processing_hashes.add(hash_with_model) + offset += chunk_size + end_time = datetime.utcnow() + total_time = (end_time - start_time).total_seconds() + if len(self.processing_hashes) > 0: + logger.info(f"Finished initializing set of input hash/llm_model_name combinations that are either currently being processed or have already been processed. Set size: {len(self.processing_hashes)}; Took {total_time} seconds, for an average of {total_time / len(self.processing_hashes)} seconds per hash.") + + async def _handle_integrity_error(self, e, write_operation, session): + unique_constraint_msg = { + TextEmbedding: "token_embeddings.token_hash, token_embeddings.llm_model_name", + DocumentEmbedding: "document_embeddings.file_hash, document_embeddings.llm_model_name", + Document: "documents.document_hash, documents.llm_model_name", + TokenLevelEmbedding: "token_level_embeddings.token_hash, token_level_embeddings.llm_model_name", + TokenLevelEmbeddingBundle: "token_level_embedding_bundles.input_text_hash, token_level_embedding_bundles.llm_model_name", + AudioTranscript: "audio_transcripts.audio_file_hash" + }.get(type(write_operation)) + if unique_constraint_msg and unique_constraint_msg in str(e): + logger.warning(f"Embedding already exists in the database for given input and llm_model_name: {e}") + await session.rollback() + else: + raise + + async def dedicated_db_writer(self): + while True: + write_operations_batch = await self.queue.get() + async with AsyncSessionLocal() as session: + try: + for write_operation in write_operations_batch: + session.add(write_operation) + await session.flush() # Flush to get the IDs + await session.commit() + for write_operation in write_operations_batch: + hash_to_remove = self._get_hash_from_operation(write_operation) + if hash_to_remove is not None and hash_to_remove in self.processing_hashes: + self.processing_hashes.remove(hash_to_remove) + except IntegrityError as e: + await self._handle_integrity_error(e, write_operation, session) + except SQLAlchemyError as e: + logger.error(f"Database error: {e}") + await session.rollback() + except Exception as e: + tb = traceback.format_exc() + logger.error(f"Unexpected error: {e}\n{tb}") + await session.rollback() + self.queue.task_done() + + async def enqueue_write(self, write_operations): + write_operations = [op for op in write_operations if self._get_hash_from_operation(op) not in self.processing_hashes] # Filter out write operations for hashes that are already being processed + if not write_operations: # If there are no write operations left after filtering, return early + return + for op in write_operations: # Add the hashes of the write operations to the set + hash_value = self._get_hash_from_operation(op) + if hash_value: + self.processing_hashes.add(hash_value) + await self.queue.put(write_operations) + + +async def execute_with_retry(func, *args, **kwargs): + retries = 0 + while retries < MAX_RETRIES: + try: + return await func(*args, **kwargs) + except OperationalError as e: + if 'database is locked' in str(e): + retries += 1 + sleep_time = RETRY_DELAY_BASE_SECONDS * (2 ** retries) + (random.random() * JITTER_FACTOR) # Implementing exponential backoff with jitter + logger.warning(f"Database is locked. Retrying ({retries}/{MAX_RETRIES})... Waiting for {sleep_time} seconds") + await asyncio.sleep(sleep_time) + else: + raise + raise OperationalError("Database is locked after multiple retries") + +async def initialize_db(): + logger.info("Initializing database, creating tables, and setting SQLite PRAGMAs...") + list_of_sqlite_pragma_strings = ["PRAGMA journal_mode=WAL;", "PRAGMA synchronous = NORMAL;", "PRAGMA cache_size = -1048576;", "PRAGMA busy_timeout = 2000;", "PRAGMA wal_autocheckpoint = 100;"] + list_of_sqlite_pragma_justification_strings = ["Set SQLite to use Write-Ahead Logging (WAL) mode (from default DELETE mode) so that reads and writes can occur simultaneously", + "Set synchronous mode to NORMAL (from FULL) so that writes are not blocked by reads", + "Set cache size to 1GB (from default 2MB) so that more data can be cached in memory and not read from disk; to make this 256MB, set it to -262144 instead", + "Increase the busy timeout to 2 seconds so that the database waits", + "Set the WAL autocheckpoint to 100 (from default 1000) so that the WAL file is checkpointed more frequently"] + assert(len(list_of_sqlite_pragma_strings) == len(list_of_sqlite_pragma_justification_strings)) + async with engine.begin() as conn: + for pragma_string in list_of_sqlite_pragma_strings: + await conn.execute(sql_text(pragma_string)) + logger.info(f"Executed SQLite PRAGMA: {pragma_string}") + logger.info(f"Justification: {list_of_sqlite_pragma_justification_strings[list_of_sqlite_pragma_strings.index(pragma_string)]}") + await conn.run_sync(Base.metadata.create_all) # Create tables if they don't exist + logger.info("Database initialization completed.") + +def get_db_writer() -> DatabaseWriter: + return db_writer # Return the existing DatabaseWriter instance diff --git a/llama_knife_sticker.webp b/image_files/llama_knife_sticker.webp similarity index 100% rename from llama_knife_sticker.webp rename to image_files/llama_knife_sticker.webp diff --git a/llama_knife_sticker2.jpg b/image_files/llama_knife_sticker2.jpg similarity index 100% rename from llama_knife_sticker2.jpg rename to image_files/llama_knife_sticker2.jpg diff --git a/swiss_army_llama__swagger_screenshot.png b/image_files/swiss_army_llama__swagger_screenshot.png similarity index 100% rename from swiss_army_llama__swagger_screenshot.png rename to image_files/swiss_army_llama__swagger_screenshot.png diff --git a/swiss_army_llama__swagger_screenshot_running.png b/image_files/swiss_army_llama__swagger_screenshot_running.png similarity index 100% rename from swiss_army_llama__swagger_screenshot_running.png rename to image_files/swiss_army_llama__swagger_screenshot_running.png diff --git a/swiss_army_llama_logo.webp b/image_files/swiss_army_llama_logo.webp similarity index 100% rename from swiss_army_llama_logo.webp rename to image_files/swiss_army_llama_logo.webp diff --git a/log_viewer_functions.py b/log_viewer_functions.py index ac2c787..d271560 100644 --- a/log_viewer_functions.py +++ b/log_viewer_functions.py @@ -48,7 +48,6 @@ def highlight_rules_func(text): text = text.replace('#COLOR13_OPEN#', '').replace('#COLOR13_CLOSE#', '') return text - def show_logs_incremental_func(minutes: int, last_position: int): new_logs = [] now = datetime.now(timezone('UTC')) # get current time, make it timezone-aware @@ -75,21 +74,22 @@ def show_logs_incremental_func(minutes: int, last_position: int): def show_logs_func(minutes: int = 5): - # read the entire log file and generate HTML with logs up to `minutes` minutes from now with open(log_file_path, "r") as f: lines = f.readlines() logs = [] - now = datetime.now(timezone('UTC')) # get current time, make it timezone-aware + now = datetime.now(timezone('UTC')) for line in lines: if line.strip() == "": continue - if line[0].isdigit(): - log_datetime_str = line.split(" - ")[0] # assuming the datetime is at the start of each line - log_datetime = datetime.strptime(log_datetime_str, "%Y-%m-%d %H:%M:%S,%f") # parse the datetime string to a datetime object - log_datetime = log_datetime.replace(tzinfo=timezone('UTC')) # set the datetime object timezone to UTC to match `now` - if now - log_datetime <= timedelta(minutes=minutes): # if the log is within `minutes` minutes from now - logs.append(highlight_rules_func(line.rstrip('\n'))) # add the highlighted log to the list and strip any newline at the end - logs_as_string = "
".join(logs) # joining with
directly + try: + log_datetime_str = line.split(" - ")[0] + log_datetime = datetime.strptime(log_datetime_str, "%Y-%m-%d %H:%M:%S,%f") + log_datetime = log_datetime.replace(tzinfo=timezone('UTC')) + if now - log_datetime <= timedelta(minutes=minutes): + logs.append(highlight_rules_func(line.rstrip('\n'))) + except (ValueError, IndexError): + logs.append(highlight_rules_func(line.rstrip('\n'))) # Line didn't meet datetime parsing criteria, continue with processing + logs_as_string = "
".join(logs) logs_as_string_newlines_rendered = logs_as_string.replace("\n", "
") logs_as_string_newlines_rendered_font_specified = """ diff --git a/logger_config.py b/logger_config.py new file mode 100644 index 0000000..a21fe5a --- /dev/null +++ b/logger_config.py @@ -0,0 +1,35 @@ +import logging +import os +import shutil +import queue +from logging.handlers import RotatingFileHandler, QueueHandler, QueueListener + +logger = logging.getLogger("swiss_army_llama") + +def setup_logger(): + if logger.handlers: + return logger + old_logs_dir = 'old_logs' + if not os.path.exists(old_logs_dir): + os.makedirs(old_logs_dir) + logger.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + log_file_path = 'swiss_army_llama.log' + log_queue = queue.Queue(-1) # Create a queue for the handlers + fh = RotatingFileHandler(log_file_path, maxBytes=10*1024*1024, backupCount=5) + fh.setFormatter(formatter) + def namer(default_log_name): # Function to move rotated logs to the old_logs directory + return os.path.join(old_logs_dir, os.path.basename(default_log_name)) + def rotator(source, dest): + shutil.move(source, dest) + fh.namer = namer + fh.rotator = rotator + sh = logging.StreamHandler() # Stream handler + sh.setFormatter(formatter) + queue_handler = QueueHandler(log_queue) # Create QueueHandler + queue_handler.setFormatter(formatter) + logger.addHandler(queue_handler) + listener = QueueListener(log_queue, fh, sh) # Create QueueListener with real handlers + listener.start() + logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # Configure SQLalchemy logging + return logger diff --git a/misc_utility_functions.py b/misc_utility_functions.py new file mode 100644 index 0000000..7c73ffd --- /dev/null +++ b/misc_utility_functions.py @@ -0,0 +1,215 @@ +from logger_config import setup_logger +from database_functions import AsyncSessionLocal +import socket +import os +import re +import json +import io +import numpy as np +import faiss +from typing import Any +from collections import defaultdict +from sqlalchemy import text as sql_text + +logger = setup_logger() + +def clean_filename_for_url_func(dirty_filename: str) -> str: + clean_filename = re.sub(r'[^\w\s]', '', dirty_filename) # Remove special characters and replace spaces with underscores + clean_filename = clean_filename.replace(' ', '_') + return clean_filename + +def is_redis_running(host='localhost', port=6379): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.connect((host, port)) + return True + except ConnectionRefusedError: + return False + finally: + s.close() + + +async def build_faiss_indexes(): + global faiss_indexes, token_faiss_indexes, associated_texts_by_model + if os.environ.get("FAISS_SETUP_DONE") == "1": + logger.info("Faiss indexes already built by another worker. Skipping.") + return faiss_indexes, token_faiss_indexes, associated_texts_by_model + faiss_indexes = {} + token_faiss_indexes = {} # Separate FAISS indexes for token-level embeddings + associated_texts_by_model = defaultdict(list) # Create a dictionary to store associated texts by model name + async with AsyncSessionLocal() as session: + result = await session.execute(sql_text("SELECT llm_model_name, text, embedding_json FROM embeddings")) # Query regular embeddings + token_result = await session.execute(sql_text("SELECT llm_model_name, token, token_level_embedding_json FROM token_level_embeddings")) # Query token-level embeddings + embeddings_by_model = defaultdict(list) + token_embeddings_by_model = defaultdict(list) + for row in result.fetchall(): # Process regular embeddings + llm_model_name = row[0] + associated_texts_by_model[llm_model_name].append(row[1]) # Store the associated text by model name + embeddings_by_model[llm_model_name].append((row[1], json.loads(row[2]))) + for row in token_result.fetchall(): # Process token-level embeddings + llm_model_name = row[0] + token_embeddings_by_model[llm_model_name].append(json.loads(row[2])) + for llm_model_name, embeddings in embeddings_by_model.items(): + logger.info(f"Building Faiss index over embeddings for model {llm_model_name}...") + embeddings_array = np.array([e[1] for e in embeddings]).astype('float32') + if embeddings_array.size == 0: + logger.error(f"No embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!") + continue + logger.info(f"Loaded {len(embeddings_array)} embeddings for model {llm_model_name}.") + logger.info(f"Embedding dimension for model {llm_model_name}: {embeddings_array.shape[1]}") + logger.info(f"Normalizing {len(embeddings_array)} embeddings for model {llm_model_name}...") + faiss.normalize_L2(embeddings_array) # Normalize the vectors for cosine similarity + faiss_index = faiss.IndexFlatIP(embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity + faiss_index.add(embeddings_array) + logger.info(f"Faiss index built for model {llm_model_name}.") + faiss_indexes[llm_model_name] = faiss_index # Store the index by model name + for llm_model_name, token_embeddings in token_embeddings_by_model.items(): + token_embeddings_array = np.array(token_embeddings).astype('float32') + if token_embeddings_array.size == 0: + logger.error(f"No token-level embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!") + continue + logger.info(f"Normalizing {len(token_embeddings_array)} token-level embeddings for model {llm_model_name}...") + faiss.normalize_L2(token_embeddings_array) # Normalize the vectors for cosine similarity + token_faiss_index = faiss.IndexFlatIP(token_embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity + token_faiss_index.add(token_embeddings_array) + logger.info(f"Token-level Faiss index built for model {llm_model_name}.") + token_faiss_indexes[llm_model_name] = token_faiss_index # Store the token-level index by model name + os.environ["FAISS_SETUP_DONE"] = "1" + logger.info("Faiss indexes built.") + return faiss_indexes, token_faiss_indexes, associated_texts_by_model + +class JSONAggregator: + def __init__(self): + self.completions = [] + self.aggregate_result = None + + @staticmethod + def weighted_vote(values, weights): + tally = defaultdict(float) + for v, w in zip(values, weights): + tally[v] += w + return max(tally, key=tally.get) + + @staticmethod + def flatten_json(json_obj, parent_key='', sep='->'): + items = {} + for k, v in json_obj.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.update(JSONAggregator.flatten_json(v, new_key, sep=sep)) + else: + items[new_key] = v + return items + + @staticmethod + def get_value_by_path(json_obj, path, sep='->'): + keys = path.split(sep) + item = json_obj + for k in keys: + item = item[k] + return item + + @staticmethod + def set_value_by_path(json_obj, path, value, sep='->'): + keys = path.split(sep) + item = json_obj + for k in keys[:-1]: + item = item.setdefault(k, {}) + item[keys[-1]] = value + + def calculate_path_weights(self): + all_paths = [] + for j in self.completions: + all_paths += list(self.flatten_json(j).keys()) + path_weights = defaultdict(float) + for path in all_paths: + path_weights[path] += 1.0 + return path_weights + + def aggregate(self): + path_weights = self.calculate_path_weights() + aggregate = {} + for path, weight in path_weights.items(): + values = [self.get_value_by_path(j, path) for j in self.completions if path in self.flatten_json(j)] + weights = [weight] * len(values) + aggregate_value = self.weighted_vote(values, weights) + self.set_value_by_path(aggregate, path, aggregate_value) + self.aggregate_result = aggregate + +class FakeUploadFile: + def __init__(self, filename: str, content: Any, content_type: str = 'text/plain'): + self.filename = filename + self.content_type = content_type + self.file = io.BytesIO(content) + def read(self, size: int = -1) -> bytes: + return self.file.read(size) + def seek(self, offset: int, whence: int = 0) -> int: + return self.file.seek(offset, whence) + def tell(self) -> int: + return self.file.tell() + +def normalize_logprobs(avg_logprob, min_logprob, max_logprob): + range_logprob = max_logprob - min_logprob + return (avg_logprob - min_logprob) / range_logprob if range_logprob != 0 else 0.5 + +def remove_pagination_breaks(text: str) -> str: + text = re.sub(r'-(\n)(?=[a-z])', '', text) # Remove hyphens at the end of lines when the word continues on the next line + text = re.sub(r'(?<=\w)(? total_ram_gb: + raise ValueError(f"Cannot allocate {RAMDISK_SIZE_IN_GB}G for RAM Disk. Total system RAM is {total_ram_gb:.2f}G.") + logger.info("Setting up RAM Disk...") + os.makedirs(RAMDISK_PATH, exist_ok=True) + mount_command = ["sudo", "mount", "-t", "tmpfs", "-o", f"size={ramdisk_size_str}", "tmpfs", RAMDISK_PATH] + subprocess.run(mount_command, check=True) + logger.info(f"RAM Disk set up at {RAMDISK_PATH} with size {ramdisk_size_gb}G") + os.environ["RAMDISK_SETUP_DONE"] = "1" + +def copy_models_to_ramdisk(models_directory, ramdisk_directory): + total_size = sum(os.path.getsize(os.path.join(models_directory, model)) for model in os.listdir(models_directory)) + free_ram = psutil.virtual_memory().free + if total_size > free_ram: + logger.warning(f"Not enough space on RAM Disk. Required: {total_size}, Available: {free_ram}. Rebuilding RAM Disk.") + clear_ramdisk() + free_ram = psutil.virtual_memory().free # Recompute the available RAM after clearing the RAM disk + if total_size > free_ram: + logger.error(f"Still not enough space on RAM Disk even after clearing. Required: {total_size}, Available: {free_ram}.") + raise ValueError("Not enough RAM space to copy models.") + os.makedirs(ramdisk_directory, exist_ok=True) + for model in os.listdir(models_directory): + src_path = os.path.join(models_directory, model) + dest_path = os.path.join(ramdisk_directory, model) + if os.path.exists(dest_path) and os.path.getsize(dest_path) == os.path.getsize(src_path): # Check if the file already exists in the RAM disk and has the same size + logger.info(f"Model {model} already exists in RAM Disk and is the same size. Skipping copy.") + continue + shutil.copyfile(src_path, dest_path) + logger.info(f"Copied model {model} to RAM Disk at {dest_path}") + +def clear_ramdisk(): + while True: + cmd_check = f"sudo mount | grep {RAMDISK_PATH}" + result = subprocess.run(cmd_check, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') + if RAMDISK_PATH not in result: + break # Exit the loop if the RAMDISK_PATH is not in the mount list + cmd_umount = f"sudo umount -l {RAMDISK_PATH}" + subprocess.run(cmd_umount, shell=True, check=True) + logger.info(f"Cleared RAM Disk at {RAMDISK_PATH}") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2f2f90a..0403a1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,6 @@ faster-whisper textract pytest pytz +uvloop +aioredis +aioredlock \ No newline at end of file diff --git a/service_functions.py b/service_functions.py new file mode 100644 index 0000000..553c7b6 --- /dev/null +++ b/service_functions.py @@ -0,0 +1,614 @@ +from logger_config import setup_logger +import shared_resources +from shared_resources import load_model, token_level_embedding_model_cache, text_completion_model_cache +from database_functions import AsyncSessionLocal, DatabaseWriter, execute_with_retry +from misc_utility_functions import clean_filename_for_url_func, FakeUploadFile, sophisticated_sentence_splitter, merge_transcript_segments_into_combined_text +from embeddings_data_models import TextEmbedding, DocumentEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript +from embeddings_data_models import EmbeddingRequest, TextCompletionRequest +from embeddings_data_models import TextCompletionResponse, AudioTranscriptResponse +import os +import shutil +import psutil +import glob +import json +import asyncio +import zipfile +import tempfile +import time +from datetime import datetime +from hashlib import sha3_256 +from urllib.parse import quote +import numpy as np +import pandas as pd +import textract +from sqlalchemy import text as sql_text +from sqlalchemy import select +from fastapi import HTTPException, Request, UploadFile, File +from fastapi.concurrency import run_in_threadpool +from typing import List, Optional, Tuple +from decouple import config +from faster_whisper import WhisperModel +from llama_cpp import Llama, LlamaGrammar + +logger = setup_logger() +SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int) +DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str) +LLM_CONTEXT_SIZE_IN_TOKENS = config("LLM_CONTEXT_SIZE_IN_TOKENS", default=512, cast=int) +TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS = config("TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS", default=4000, cast=int) +DEFAULT_MAX_COMPLETION_TOKENS = config("DEFAULT_MAX_COMPLETION_TOKENS", default=100, cast=int) +DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE = config("DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE", default=4, cast=int) +DEFAULT_COMPLETION_TEMPERATURE = config("DEFAULT_COMPLETION_TEMPERATURE", default=0.7, cast=float) +MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING = config("MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING", default=15, cast=int) +USE_PARALLEL_INFERENCE_QUEUE = config("USE_PARALLEL_INFERENCE_QUEUE", default=False, cast=bool) +MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS = config("MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS", default=10, cast=int) +USE_RAMDISK = config("USE_RAMDISK", default=False, cast=bool) +RAMDISK_PATH = config("RAMDISK_PATH", default="/mnt/ramdisk", cast=str) +BASE_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) + +async def get_transcript_from_db(audio_file_hash: str): + return await execute_with_retry(_get_transcript_from_db, audio_file_hash) + +async def _get_transcript_from_db(audio_file_hash: str) -> Optional[dict]: + async with AsyncSessionLocal() as session: + result = await session.execute( + sql_text("SELECT * FROM audio_transcripts WHERE audio_file_hash=:audio_file_hash"), + {"audio_file_hash": audio_file_hash}, + ) + row = result.fetchone() + if row: + try: + segments_json = json.loads(row.segments_json) + combined_transcript_text_list_of_metadata_dicts = json.loads(row.combined_transcript_text_list_of_metadata_dicts) + info_json = json.loads(row.info_json) + if hasattr(info_json, '__dict__'): + info_json = vars(info_json) + except json.JSONDecodeError as e: + raise ValueError(f"JSON Decode Error: {e}") + if not isinstance(segments_json, list) or not isinstance(combined_transcript_text_list_of_metadata_dicts, list) or not isinstance(info_json, dict): + logger.error(f"Type of segments_json: {type(segments_json)}, Value: {segments_json}") + logger.error(f"Type of combined_transcript_text_list_of_metadata_dicts: {type(combined_transcript_text_list_of_metadata_dicts)}, Value: {combined_transcript_text_list_of_metadata_dicts}") + logger.error(f"Type of info_json: {type(info_json)}, Value: {info_json}") + raise ValueError("Deserialized JSON does not match the expected format.") + audio_transcript_response = { + "id": row.id, + "audio_file_name": row.audio_file_name, + "audio_file_size_mb": row.audio_file_size_mb, + "segments_json": segments_json, + "combined_transcript_text": row.combined_transcript_text, + "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts, + "info_json": info_json, + "ip_address": row.ip_address, + "request_time": row.request_time, + "response_time": row.response_time, + "total_time": row.total_time, + "url_to_download_zip_file_of_embeddings": "" + } + return AudioTranscriptResponse(**audio_transcript_response) + return None + +async def save_transcript_to_db(audio_file_hash, audio_file_name, audio_file_size_mb, transcript_segments, info, ip_address, request_time, response_time, total_time, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts): + existing_transcript = await get_transcript_from_db(audio_file_hash) + if existing_transcript: + return existing_transcript + audio_transcript = AudioTranscript( + audio_file_hash=audio_file_hash, + audio_file_name=audio_file_name, + audio_file_size_mb=audio_file_size_mb, + segments_json=json.dumps(transcript_segments), + combined_transcript_text=combined_transcript_text, + combined_transcript_text_list_of_metadata_dicts=json.dumps(combined_transcript_text_list_of_metadata_dicts), + info_json=json.dumps(info), + ip_address=ip_address, + request_time=request_time, + response_time=response_time, + total_time=total_time + ) + await shared_resources.db_writer.enqueue_write([audio_transcript]) + + +async def compute_and_store_transcript_embeddings(audio_file_name, list_of_transcript_sentences, llm_model_name, ip_address, combined_transcript_text, req: Request): + logger.info(f"Now computing embeddings for entire transcript of {audio_file_name}...") + zip_dir = 'generated_transcript_embeddings_zip_files' + if not os.path.exists(zip_dir): + os.makedirs(zip_dir) + sanitized_file_name = clean_filename_for_url_func(audio_file_name) + document_name = f"automatic_whisper_transcript_of__{sanitized_file_name}" + file_hash = sha3_256(combined_transcript_text.encode('utf-8')).hexdigest() + computed_embeddings = await compute_embeddings_for_document(list_of_transcript_sentences, llm_model_name, ip_address, file_hash) + zip_file_path = f"{zip_dir}/{quote(document_name)}.zip" + with zipfile.ZipFile(zip_file_path, 'w') as zipf: + zipf.writestr("embeddings.txt", json.dumps(computed_embeddings)) + download_url = f"download/{quote(document_name)}.zip" + full_download_url = f"{req.base_url}{download_url}" + logger.info(f"Generated download URL for transcript embeddings: {full_download_url}") + fake_upload_file = FakeUploadFile(filename=document_name, content=combined_transcript_text.encode(), content_type='text/plain') + logger.info(f"Storing transcript embeddings for {audio_file_name} in the database...") + await store_document_embeddings_in_db(fake_upload_file, file_hash, combined_transcript_text.encode(), json.dumps(computed_embeddings).encode(), computed_embeddings, llm_model_name, ip_address, datetime.utcnow()) + return full_download_url + +async def compute_transcript_with_whisper_from_audio_func(audio_file_hash, audio_file_path, audio_file_name, audio_file_size_mb, ip_address, req: Request, compute_embeddings_for_resulting_transcript_document=True, llm_model_name=DEFAULT_MODEL_NAME): + model_size = "large-v2" + logger.info(f"Loading Whisper model {model_size}...") + num_workers = 1 if psutil.virtual_memory().total < 32 * (1024 ** 3) else min(4, max(1, int((psutil.virtual_memory().total - 32 * (1024 ** 3)) / (4 * (1024 ** 3))))) # Only use more than 1 worker if there is at least 32GB of RAM; then use 1 worker per additional 4GB of RAM up to 4 workers max + model = await run_in_threadpool(WhisperModel, model_size, device="cpu", compute_type="auto", cpu_threads=os.cpu_count(), num_workers=num_workers) + request_time = datetime.utcnow() + logger.info(f"Computing transcript for {audio_file_name} which has a {audio_file_size_mb :.2f}MB file size...") + segments, info = await run_in_threadpool(model.transcribe, audio_file_path, beam_size=20) + if not segments: + logger.warning(f"No segments were returned for file {audio_file_name}.") + return [], {}, "", [], request_time, datetime.utcnow(), 0, "" + segment_details = [] + for idx, segment in enumerate(segments): + details = { + "start": round(segment.start, 2), + "end": round(segment.end, 2), + "text": segment.text, + "avg_logprob": round(segment.avg_logprob, 2) + } + logger.info(f"Details of transcript segment {idx} from file {audio_file_name}: {details}") + segment_details.append(details) + combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, list_of_transcript_sentences = merge_transcript_segments_into_combined_text(segment_details) + if compute_embeddings_for_resulting_transcript_document: + download_url = await compute_and_store_transcript_embeddings(audio_file_name, list_of_transcript_sentences, llm_model_name, ip_address, combined_transcript_text, req) + else: + download_url = '' + response_time = datetime.utcnow() + total_time = (response_time - request_time).total_seconds() + logger.info(f"Transcript computed in {total_time} seconds.") + await save_transcript_to_db(audio_file_hash, audio_file_name, audio_file_size_mb, segment_details, info, ip_address, request_time, response_time, total_time, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts) + info_dict = info._asdict() + return segment_details, info_dict, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, request_time, response_time, total_time, download_url + +async def get_or_compute_transcript(file: UploadFile, compute_embeddings_for_resulting_transcript_document: bool, llm_model_name: str, req: Request = None) -> dict: + request_time = datetime.utcnow() + ip_address = req.client.host if req else "127.0.0.1" + file_contents = await file.read() + audio_file_hash = sha3_256(file_contents).hexdigest() + file.file.seek(0) # Reset file pointer after read + unique_id = f"transcript_{audio_file_hash}_{llm_model_name}" + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + existing_audio_transcript = await get_transcript_from_db(audio_file_hash) + if existing_audio_transcript: + return existing_audio_transcript + current_position = file.file.tell() + file.file.seek(0, os.SEEK_END) + audio_file_size_mb = file.file.tell() / (1024 * 1024) + file.file.seek(current_position) + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + shutil.copyfileobj(file.file, tmp_file) + audio_file_name = tmp_file.name + segment_details, info, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, request_time, response_time, total_time, download_url = await compute_transcript_with_whisper_from_audio_func(audio_file_hash, audio_file_name, file.filename, audio_file_size_mb, ip_address, req, compute_embeddings_for_resulting_transcript_document, llm_model_name) + audio_transcript_response = { + "audio_file_hash": audio_file_hash, + "audio_file_name": file.filename, + "audio_file_size_mb": audio_file_size_mb, + "segments_json": segment_details, + "combined_transcript_text": combined_transcript_text, + "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts, + "info_json": info, + "ip_address": ip_address, + "request_time": request_time, + "response_time": response_time, + "total_time": total_time, + "url_to_download_zip_file_of_embeddings": download_url if compute_embeddings_for_resulting_transcript_document else "" + } + os.remove(audio_file_name) + return AudioTranscriptResponse(**audio_transcript_response) + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing"} + + +# Core embedding functions start here: + + +def add_model_url(new_url: str) -> str: + corrected_url = new_url + if '/blob/main/' in new_url: + corrected_url = new_url.replace('/blob/main/', '/resolve/main/') + json_path = os.path.join(BASE_DIRECTORY, "model_urls.json") + with open(json_path, "r") as f: + existing_urls = json.load(f) + if corrected_url not in existing_urls: + logger.info(f"Model URL not found in database. Adding {new_url} now...") + existing_urls.append(corrected_url) + with open(json_path, "w") as f: + json.dump(existing_urls, f) + logger.info(f"Model URL added: {new_url}") + else: + logger.info("Model URL already exists.") + return corrected_url + +async def get_embedding_from_db(text: str, llm_model_name: str): + text_hash = sha3_256(text.encode('utf-8')).hexdigest() # Compute the hash + return await execute_with_retry(_get_embedding_from_db, text_hash, llm_model_name) + +async def _get_embedding_from_db(text_hash: str, llm_model_name: str) -> Optional[dict]: + async with AsyncSessionLocal() as session: + result = await session.execute( + sql_text("SELECT embedding_json FROM embeddings WHERE text_hash=:text_hash AND llm_model_name=:llm_model_name"), + {"text_hash": text_hash, "llm_model_name": llm_model_name}, + ) + row = result.fetchone() + if row: + embedding_json = row[0] + logger.info(f"Embedding found in database for text hash '{text_hash}' using model '{llm_model_name}'") + return json.loads(embedding_json) + return None + +async def get_or_compute_embedding(request: EmbeddingRequest, req: Request = None, client_ip: str = None, document_file_hash: str = None) -> dict: + request_time = datetime.utcnow() # Capture request time as datetime object + ip_address = client_ip or (req.client.host if req else "localhost") # If client_ip is provided, use it; otherwise, try to get from req; if not available, default to "localhost" + logger.info(f"Received request for embedding for '{request.text}' using model '{request.llm_model_name}' from IP address '{ip_address}'") + embedding_list = await get_embedding_from_db(request.text, request.llm_model_name) # Check if embedding exists in the database + if embedding_list is not None: + response_time = datetime.utcnow() # Capture response time as datetime object + total_time = (response_time - request_time).total_seconds() # Calculate time taken in seconds + logger.info(f"Embedding found in database for '{request.text}' using model '{request.llm_model_name}'; returning in {total_time:.4f} seconds") + return {"embedding": embedding_list} + model = load_model(request.llm_model_name) + embedding_list = calculate_sentence_embedding(model, request.text) # Compute the embedding if not in the database + if embedding_list is None: + logger.error(f"Could not calculate the embedding for the given text: '{request.text}' using model '{request.llm_model_name}!'") + raise HTTPException(status_code=400, detail="Could not calculate the embedding for the given text") + embedding_json = json.dumps(embedding_list) # Serialize the numpy array to JSON and save to the database + response_time = datetime.utcnow() # Capture response time as datetime object + total_time = (response_time - request_time).total_seconds() # Calculate total time using datetime objects + word_length_of_input_text = len(request.text.split()) + if word_length_of_input_text > 0: + logger.info(f"Embedding calculated for '{request.text}' using model '{request.llm_model_name}' in {total_time} seconds, or an average of {total_time/word_length_of_input_text :.2f} seconds per word. Now saving to database...") + await save_embedding_to_db(request.text, request.llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash) + return {"embedding": embedding_list} + +async def save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None): + existing_embedding = await get_embedding_from_db(text, llm_model_name) # Check if the embedding already exists + if existing_embedding is not None: + return existing_embedding + return await execute_with_retry(_save_embedding_to_db, text, llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash) + +async def _save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None): + existing_embedding = await get_embedding_from_db(text, llm_model_name) + if existing_embedding: + return existing_embedding + embedding = TextEmbedding( + text=text, + llm_model_name=llm_model_name, + embedding_json=embedding_json, + ip_address=ip_address, + request_time=request_time, + response_time=response_time, + total_time=total_time, + document_file_hash=document_file_hash + ) + await shared_resources.db_writer.enqueue_write([embedding]) # Enqueue the write operation using the db_writer instance + + +def load_token_level_embedding_model(llm_model_name: str, raise_http_exception: bool = True): + try: + if llm_model_name in token_level_embedding_model_cache: # Check if the model is already loaded in the cache + return token_level_embedding_model_cache[llm_model_name] + models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path + matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files + if not matching_files: + logger.error(f"No model file found matching: {llm_model_name}") + raise FileNotFoundError + matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first) + model_file_path = matching_files[0] + model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model + token_level_embedding_model_cache[llm_model_name] = model_instance # Cache the loaded model + return model_instance + except TypeError as e: + logger.error(f"TypeError occurred while loading the model: {e}") + raise + except Exception as e: + logger.error(f"Exception occurred while loading the model: {e}") + if raise_http_exception: + raise HTTPException(status_code=404, detail="Model file not found") + else: + raise FileNotFoundError(f"No model file found matching: {llm_model_name}") + +async def compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings) -> List[float]: + start_time = datetime.utcnow() + logger.info("Extracting token-level embeddings from the bundle") + parsed_df = pd.read_json(token_level_embeddings) # Parse the json_content back to a DataFrame + token_level_embeddings = list(parsed_df['embedding']) + embeddings = np.array(token_level_embeddings) # Convert the list of embeddings to a NumPy array + logger.info(f"Computing column-wise means/mins/maxes/std_devs of the embeddings... (shape: {embeddings.shape})") + assert(len(embeddings) > 0) + means = np.mean(embeddings, axis=0) + mins = np.min(embeddings, axis=0) + maxes = np.max(embeddings, axis=0) + stds = np.std(embeddings, axis=0) + logger.info("Concatenating the computed statistics to form the combined feature vector") + combined_feature_vector = np.concatenate([means, mins, maxes, stds]) + end_time = datetime.utcnow() + total_time = (end_time - start_time).total_seconds() + logger.info(f"Computed the token-level embedding bundle's combined feature vector computed in {total_time: .2f} seconds.") + return combined_feature_vector.tolist() + + +async def get_or_compute_token_level_embedding_bundle_combined_feature_vector(token_level_embedding_bundle_id, token_level_embeddings, db_writer: DatabaseWriter) -> List[float]: + request_time = datetime.utcnow() + request_time = datetime.utcnow() + logger.info(f"Checking for existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}") + async with AsyncSessionLocal() as session: + result = await session.execute( + select(TokenLevelEmbeddingBundleCombinedFeatureVector) + .filter(TokenLevelEmbeddingBundleCombinedFeatureVector.token_level_embedding_bundle_id == token_level_embedding_bundle_id) + ) + existing_combined_feature_vector = result.scalar_one_or_none() + if existing_combined_feature_vector: + response_time = datetime.utcnow() + total_time = (response_time - request_time).total_seconds() + logger.info(f"Found existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Returning cached result in {total_time:.2f} seconds.") + return json.loads(existing_combined_feature_vector.combined_feature_vector_json) # Parse the JSON string into a list + logger.info(f"No cached combined feature_vector found for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Computing now...") + combined_feature_vector = await compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings) + combined_feature_vector_db_object = TokenLevelEmbeddingBundleCombinedFeatureVector( + token_level_embedding_bundle_id=token_level_embedding_bundle_id, + combined_feature_vector_json=json.dumps(combined_feature_vector) # Convert the list to a JSON string + ) + logger.info(f"Writing combined feature vector for database write for token-level embedding bundle ID: {token_level_embedding_bundle_id} to the database...") + await db_writer.enqueue_write([combined_feature_vector_db_object]) + return combined_feature_vector + + +async def calculate_token_level_embeddings(text: str, llm_model_name: str, client_ip: str, token_level_embedding_bundle_id: int) -> List[np.array]: + request_time = datetime.utcnow() + logger.info(f"Starting token-level embedding calculation for text: '{text}' using model: '{llm_model_name}'") + logger.info(f"Loading model: '{llm_model_name}'") + llm = load_token_level_embedding_model(llm_model_name) # Assuming this method returns an instance of the Llama class + token_embeddings = [] + tokens = text.split() # Simple whitespace tokenizer; can be replaced with a more advanced one if needed + logger.info(f"Tokenized text into {len(tokens)} tokens") + for idx, token in enumerate(tokens, start=1): + try: # Check if the embedding is already available in the database + existing_embedding = await get_token_level_embedding_from_db(token, llm_model_name) + if existing_embedding is not None: + token_embeddings.append(np.array(existing_embedding)) + logger.info(f"Embedding retrieved from database for token '{token}'") + continue + logger.info(f"Processing token {idx} of {len(tokens)}: '{token}'") + token_embedding = llm.embed(token) + token_embedding_array = np.array(token_embedding) + token_embeddings.append(token_embedding_array) + response_time = datetime.utcnow() + token_level_embedding_json = json.dumps(token_embedding_array.tolist()) + await store_token_level_embeddings_in_db(token, llm_model_name, token_level_embedding_json, client_ip, request_time, response_time, token_level_embedding_bundle_id) + except RuntimeError as e: + logger.error(f"Failed to calculate embedding for token '{token}': {e}") + logger.info(f"Completed token embedding calculation for all tokens in text: '{text}'") + return token_embeddings + +async def get_token_level_embedding_from_db(token: str, llm_model_name: str) -> Optional[List[float]]: + token_hash = sha3_256(token.encode('utf-8')).hexdigest() # Compute the hash + async with AsyncSessionLocal() as session: + result = await session.execute( + sql_text("SELECT token_level_embedding_json FROM token_level_embeddings WHERE token_hash=:token_hash AND llm_model_name=:llm_model_name"), + {"token_hash": token_hash, "llm_model_name": llm_model_name}, + ) + row = result.fetchone() + if row: + embedding_json = row[0] + logger.info(f"Embedding found in database for token hash '{token_hash}' using model '{llm_model_name}'") + return json.loads(embedding_json) + return None + +async def store_token_level_embeddings_in_db(token: str, llm_model_name: str, token_level_embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, token_level_embedding_bundle_id: int): + total_time = (response_time - request_time).total_seconds() + embedding = TokenLevelEmbedding( + token=token, + llm_model_name=llm_model_name, + token_level_embedding_json=token_level_embedding_json, + ip_address=ip_address, + request_time=request_time, + response_time=response_time, + total_time=total_time, + token_level_embedding_bundle_id=token_level_embedding_bundle_id + ) + await shared_resources.db_writer.enqueue_write([embedding]) # Enqueue the write operation for the token-level embedding + +def calculate_sentence_embedding(llama: Llama, text: str) -> np.array: + sentence_embedding = None + retry_count = 0 + while sentence_embedding is None and retry_count < 3: + try: + if retry_count > 0: + logger.info(f"Attempting again calculate sentence embedding. Attempt number {retry_count + 1}") + sentence_embedding = llama.embed_query(text) + except TypeError as e: + logger.error(f"TypeError in calculate_sentence_embedding: {e}") + raise + except Exception as e: + logger.error(f"Exception in calculate_sentence_embedding: {e}") + text = text[:-int(len(text) * 0.1)] + retry_count += 1 + logger.info(f"Trimming sentence due to too many tokens. New length: {len(text)}") + if sentence_embedding is None: + logger.error("Failed to calculate sentence embedding after multiple attempts") + return sentence_embedding + +async def compute_embeddings_for_document(strings: list, llm_model_name: str, client_ip: str, document_file_hash: str) -> List[Tuple[str, np.array]]: + from swiss_army_llama import get_embedding_vector_for_string + results = [] + if USE_PARALLEL_INFERENCE_QUEUE: + logger.info(f"Using parallel inference queue to compute embeddings for {len(strings)} strings") + start_time = time.perf_counter() # Record the start time + semaphore = asyncio.Semaphore(MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS) + async def compute_embedding(text): # Define a function to compute the embedding for a given text + try: + async with semaphore: # Acquire a semaphore slot + request = EmbeddingRequest(text=text, llm_model_name=llm_model_name) + embedding = await get_embedding_vector_for_string(request, client_ip=client_ip, document_file_hash=document_file_hash) + return text, embedding["embedding"] + except Exception as e: + logger.error(f"Error computing embedding for text '{text}': {e}") + return text, None + results = await asyncio.gather(*[compute_embedding(s) for s in strings]) # Use asyncio.gather to run the tasks concurrently + end_time = time.perf_counter() # Record the end time + duration = end_time - start_time + if len(strings) > 0: + logger.info(f"Parallel inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string") + else: # Compute embeddings sequentially + logger.info(f"Using sequential inference to compute embeddings for {len(strings)} strings") + start_time = time.perf_counter() # Record the start time + for s in strings: + embedding_request = EmbeddingRequest(text=s, llm_model_name=llm_model_name) + embedding = await get_embedding_vector_for_string(embedding_request, client_ip=client_ip, document_file_hash=document_file_hash) + results.append((s, embedding["embedding"])) + end_time = time.perf_counter() # Record the end time + duration = end_time - start_time + if len(strings) > 0: + logger.info(f"Sequential inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string") + filtered_results = [(text, embedding) for text, embedding in results if embedding is not None] # Filter out results with None embeddings (applicable to parallel processing) and return + return filtered_results + +async def parse_submitted_document_file_into_sentence_strings_func(temp_file_path: str, mime_type: str): + strings = [] + if mime_type.startswith('text/'): + with open(temp_file_path, 'r') as buffer: + content = buffer.read() + else: + try: + content = textract.process(temp_file_path).decode('utf-8') + except UnicodeDecodeError: + try: + content = textract.process(temp_file_path).decode('unicode_escape') + except Exception as e: + logger.error(f"Error while processing file: {e}, mime_type: {mime_type}") + raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}") + except Exception as e: + logger.error(f"Error while processing file: {e}, mime_type: {mime_type}") + raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}") + sentences = sophisticated_sentence_splitter(content) + if len(sentences) == 0 and temp_file_path.lower().endswith('.pdf'): + logger.info("No sentences found, attempting OCR using Tesseract.") + try: + content = textract.process(temp_file_path, method='tesseract').decode('utf-8') + sentences = sophisticated_sentence_splitter(content) + except Exception as e: + logger.error(f"Error while processing file with OCR: {e}") + raise HTTPException(status_code=400, detail=f"OCR failed: {e}") + if len(sentences) == 0: + logger.info("No sentences found in the document") + raise HTTPException(status_code=400, detail="No sentences found in the document") + logger.info(f"Extracted {len(sentences)} sentences from the document") + strings = [s.strip() for s in sentences if len(s.strip()) > MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING] + return strings + +async def _get_document_from_db(file_hash: str): + async with AsyncSessionLocal() as session: + result = await session.execute(select(Document).filter(Document.document_hash == file_hash)) + return result.scalar_one_or_none() + +async def store_document_embeddings_in_db(file: File, file_hash: str, original_file_content: bytes, json_content: bytes, results: List[Tuple[str, np.array]], llm_model_name: str, client_ip: str, request_time: datetime): + document = await _get_document_from_db(file_hash) # First, check if a Document with the same hash already exists + if not document: # If not, create a new Document object + document = Document(document_hash=file_hash, llm_model_name=llm_model_name) + await shared_resources.db_writer.enqueue_write([document]) + document_embedding = DocumentEmbedding( + filename=file.filename, + mimetype=file.content_type, + file_hash=file_hash, + llm_model_name=llm_model_name, + file_data=original_file_content, + document_embedding_results_json=json.loads(json_content.decode()), + ip_address=client_ip, + request_time=request_time, + response_time=datetime.utcnow(), + total_time=(datetime.utcnow() - request_time).total_seconds() + ) + document.document_embeddings.append(document_embedding) # Associate it with the Document + document.update_hash() # This will trigger the SQLAlchemy event to update the document_hash + await shared_resources.db_writer.enqueue_write([document, document_embedding]) # Enqueue the write operation for the document embedding + write_operations = [] # Collect text embeddings to write + logger.info(f"Storing {len(results)} text embeddings in database") + for text, embedding in results: + embedding_entry = await _get_embedding_from_db(text, llm_model_name) + if not embedding_entry: + embedding_entry = TextEmbedding( + text=text, + llm_model_name=llm_model_name, + embedding_json=json.dumps(embedding), + ip_address=client_ip, + request_time=request_time, + response_time=datetime.utcnow(), + total_time=(datetime.utcnow() - request_time).total_seconds(), + document_file_hash=file_hash # Link it to the DocumentEmbedding via file_hash + ) + else: + write_operations.append(embedding_entry) + await shared_resources.db_writer.enqueue_write(write_operations) # Enqueue the write operation for text embeddings + +def load_text_completion_model(llm_model_name: str, raise_http_exception: bool = True): + try: + if llm_model_name in text_completion_model_cache: # Check if the model is already loaded in the cache + return text_completion_model_cache[llm_model_name] + models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path + matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files + if not matching_files: + logger.error(f"No model file found matching: {llm_model_name}") + raise FileNotFoundError + matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first) + model_file_path = matching_files[0] + model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model + text_completion_model_cache[llm_model_name] = model_instance # Cache the loaded model + return model_instance + except TypeError as e: + logger.error(f"TypeError occurred while loading the model: {e}") + raise + except Exception as e: + logger.error(f"Exception occurred while loading the model: {e}") + if raise_http_exception: + raise HTTPException(status_code=404, detail="Model file not found") + else: + raise FileNotFoundError(f"No model file found matching: {llm_model_name}") + +async def generate_completion_from_llm(request: TextCompletionRequest, req: Request = None, client_ip: str = None) -> List[TextCompletionResponse]: + request_time = datetime.utcnow() + logger.info(f"Starting text completion calculation using model: '{request.llm_model_name}'for input prompt: '{request.input_prompt}'") + logger.info(f"Loading model: '{request.llm_model_name}'") + llm = load_text_completion_model(request.llm_model_name) + logger.info(f"Done loading model: '{request.llm_model_name}'") + list_of_llm_outputs = [] + grammar_file_string_lower = request.grammar_file_string.lower() if request.grammar_file_string else "" + if grammar_file_string_lower: + list_of_grammar_files = glob.glob("./grammar_files/*.gbnf") + matching_grammar_files = [x for x in list_of_grammar_files if grammar_file_string_lower in os.path.splitext(os.path.basename(x).lower())[0]] + if len(matching_grammar_files) == 0: + logger.error(f"No grammar file found matching: {request.grammar_file_string}") + raise FileNotFoundError + matching_grammar_files.sort(key=os.path.getmtime, reverse=True) + grammar_file_path = matching_grammar_files[0] + logger.info(f"Loading selected grammar file: '{grammar_file_path}'") + llama_grammar = LlamaGrammar.from_file(grammar_file_path) + for ii in range(request.number_of_completions_to_generate): + logger.info(f"Generating completion {ii+1} of {request.number_of_completions_to_generate} with model {request.llm_model_name} for input prompt: '{request.input_prompt}'") + output = llm(prompt=request.input_prompt, grammar=llama_grammar, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature) + list_of_llm_outputs.append(output) + else: + for ii in range(request.number_of_completions_to_generate): + output = llm(prompt=request.input_prompt, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature) + list_of_llm_outputs.append(output) + response_time = datetime.utcnow() + total_time_per_completion = ((response_time - request_time).total_seconds()) / request.number_of_completions_to_generate + list_of_responses = [] + for idx, current_completion_output in enumerate(list_of_llm_outputs): + generated_text = current_completion_output['choices'][0]['text'] + if request.grammar_file_string == 'json': + generated_text = generated_text.encode('unicode_escape').decode() + llm_model_usage_json = json.dumps(current_completion_output['usage']) + logger.info(f"Completed text completion {idx} in an average of {total_time_per_completion:.2f} seconds for input prompt: '{request.input_prompt}'; Beginning of generated text: \n'{generated_text[:100]}'") + response = TextCompletionResponse(input_prompt = request.input_prompt, + llm_model_name = request.llm_model_name, + grammar_file_string = request.grammar_file_string, + number_of_tokens_to_generate = request.number_of_tokens_to_generate, + number_of_completions_to_generate = request.number_of_completions_to_generate, + time_taken_in_seconds = float(total_time_per_completion), + generated_text = generated_text, + llm_model_usage_json = llm_model_usage_json) + list_of_responses.append(response) + return list_of_responses diff --git a/shared_resources.py b/shared_resources.py new file mode 100644 index 0000000..eee389b --- /dev/null +++ b/shared_resources.py @@ -0,0 +1,149 @@ +from misc_utility_functions import is_redis_running, build_faiss_indexes +from database_functions import DatabaseWriter, initialize_db +from ramdisk_functions import setup_ramdisk, copy_models_to_ramdisk, check_that_user_has_required_permissions_to_manage_ramdisks +from logger_config import setup_logger +from aioredlock import Aioredlock +import aioredis +import asyncio +import subprocess +import urllib.request +import os +import glob +import json +from typing import List, Tuple, Dict +from langchain.embeddings import LlamaCppEmbeddings +from decouple import config +from fastapi import HTTPException + +logger = setup_logger() +embedding_model_cache = {} # Model cache to store loaded models +token_level_embedding_model_cache = {} # Model cache to store loaded token-level embedding models +text_completion_model_cache = {} # Model cache to store loaded text completion models + +SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int) +DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str) +LLM_CONTEXT_SIZE_IN_TOKENS = config("LLM_CONTEXT_SIZE_IN_TOKENS", default=512, cast=int) +TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS = config("TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS", default=4000, cast=int) +DEFAULT_MAX_COMPLETION_TOKENS = config("DEFAULT_MAX_COMPLETION_TOKENS", default=100, cast=int) +DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE = config("DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE", default=4, cast=int) +DEFAULT_COMPLETION_TEMPERATURE = config("DEFAULT_COMPLETION_TEMPERATURE", default=0.7, cast=float) +MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING = config("MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING", default=15, cast=int) +USE_PARALLEL_INFERENCE_QUEUE = config("USE_PARALLEL_INFERENCE_QUEUE", default=False, cast=bool) +MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS = config("MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS", default=10, cast=int) +USE_RAMDISK = config("USE_RAMDISK", default=False, cast=bool) +RAMDISK_PATH = config("RAMDISK_PATH", default="/mnt/ramdisk", cast=str) +BASE_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) + +async def initialize_globals(): + global db_writer, faiss_indexes, token_faiss_indexes, associated_texts_by_model, redis, lock_manager + if not is_redis_running(): + logger.info("Starting Redis server...") + subprocess.Popen(['redis-server'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + await asyncio.sleep(1) # Sleep for 1 second to give Redis time to start + redis = await aioredis.create_redis_pool('redis://localhost') + lock_manager = Aioredlock([redis]) + await initialize_db() + queue = asyncio.Queue() + db_writer = DatabaseWriter(queue) + await db_writer.initialize_processing_hashes() + asyncio.create_task(db_writer.dedicated_db_writer()) + global USE_RAMDISK + if USE_RAMDISK and not check_that_user_has_required_permissions_to_manage_ramdisks(): + USE_RAMDISK = False + elif USE_RAMDISK: + setup_ramdisk() + list_of_downloaded_model_names, download_status = download_models() + for llm_model_name in list_of_downloaded_model_names: + try: + load_model(llm_model_name, raise_http_exception=False) + except FileNotFoundError as e: + logger.error(e) + faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() + +# other shared variables and methods +db_writer = None +faiss_indexes = None +token_faiss_indexes = None +associated_texts_by_model = None +redis = None +lock_manager = None + + +def download_models() -> Tuple[List[str], List[Dict[str, str]]]: + download_status = [] + json_path = os.path.join(BASE_DIRECTORY, "model_urls.json") + if not os.path.exists(json_path): + initial_model_urls = [ + 'https://huggingface.co/TheBloke/Yarn-Llama-2-7B-128K-GGUF/resolve/main/yarn-llama-2-7b-128k.Q4_K_M.gguf', + 'https://huggingface.co/TheBloke/openchat_v3.2_super-GGUF/resolve/main/openchat_v3.2_super.Q4_K_M.gguf', + 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q6_K.gguf' + ] + with open(json_path, "w") as f: + json.dump(initial_model_urls, f) + with open(json_path, "r") as f: + list_of_model_download_urls = json.load(f) + model_names = [os.path.basename(url) for url in list_of_model_download_urls] + current_file_path = os.path.abspath(__file__) + base_dir = os.path.dirname(current_file_path) + models_dir = os.path.join(base_dir, 'models') + logger.info("Checking models directory...") + if USE_RAMDISK: + ramdisk_models_dir = os.path.join(RAMDISK_PATH, 'models') + if not os.path.exists(RAMDISK_PATH): + setup_ramdisk() + if all(os.path.exists(os.path.join(ramdisk_models_dir, llm_model_name)) for llm_model_name in model_names): + logger.info("Models found in RAM Disk.") + for url in list_of_model_download_urls: + download_status.append({"url": url, "status": "success", "message": "Model found in RAM Disk."}) + return model_names, download_status + if not os.path.exists(models_dir): + os.makedirs(models_dir) + logger.info(f"Created models directory: {models_dir}") + else: + logger.info(f"Models directory exists: {models_dir}") + for url, model_name_with_extension in zip(list_of_model_download_urls, model_names): + status = {"url": url, "status": "success", "message": "File already exists."} + filename = os.path.join(models_dir, model_name_with_extension) + if not os.path.exists(filename): + logger.info(f"Downloading model {model_name_with_extension} from {url}...") + urllib.request.urlretrieve(url, filename) + file_size = os.path.getsize(filename) / (1024 * 1024) # Convert bytes to MB + if file_size < 100: + os.remove(filename) + status["status"] = "failure" + status["message"] = "Downloaded file is too small, probably not a valid model file." + else: + logger.info(f"Downloaded: {filename}") + else: + logger.info(f"File already exists: {filename}") + download_status.append(status) + if USE_RAMDISK: + copy_models_to_ramdisk(models_dir, ramdisk_models_dir) + logger.info("Model downloads completed.") + return model_names, download_status + + +def load_model(llm_model_name: str, raise_http_exception: bool = True): + try: + models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') + if llm_model_name in embedding_model_cache: + return embedding_model_cache[llm_model_name] + matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) + if not matching_files: + logger.error(f"No model file found matching: {llm_model_name}") + raise FileNotFoundError + matching_files.sort(key=os.path.getmtime, reverse=True) + model_file_path = matching_files[0] + model_instance = LlamaCppEmbeddings(model_path=model_file_path, use_mlock=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS) + model_instance.client.verbose = False + embedding_model_cache[llm_model_name] = model_instance + return model_instance + except TypeError as e: + logger.error(f"TypeError occurred while loading the model: {e}") + raise + except Exception as e: + logger.error(f"Exception occurred while loading the model: {e}") + if raise_http_exception: + raise HTTPException(status_code=404, detail="Model file not found") + else: + raise FileNotFoundError(f"No model file found matching: {llm_model_name}") \ No newline at end of file diff --git a/swiss_army_llama.py b/swiss_army_llama.py index 359d3fb..740df5e 100644 --- a/swiss_army_llama.py +++ b/swiss_army_llama.py @@ -1,80 +1,53 @@ -from embeddings_data_models import Base, TextEmbedding, DocumentEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundle, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript +import shared_resources +from shared_resources import initialize_globals, download_models +from logger_config import setup_logger +from database_functions import AsyncSessionLocal, DatabaseWriter, get_db_writer +from ramdisk_functions import clear_ramdisk +from misc_utility_functions import build_faiss_indexes +from embeddings_data_models import DocumentEmbedding, TokenLevelEmbeddingBundle from embeddings_data_models import EmbeddingRequest, SemanticSearchRequest, AdvancedSemanticSearchRequest, SimilarityRequest, TextCompletionRequest -from embeddings_data_models import EmbeddingResponse, SemanticSearchResponse, AdvancedSemanticSearchResponse, SimilarityResponse, AllStringsResponse, AllDocumentsResponse, TextCompletionResponse, AudioTranscriptResponse +from embeddings_data_models import EmbeddingResponse, SemanticSearchResponse, AdvancedSemanticSearchResponse, SimilarityResponse, AllStringsResponse, AllDocumentsResponse, TextCompletionResponse from embeddings_data_models import ShowLogsIncrementalModel +from service_functions import get_or_compute_embedding, get_or_compute_transcript, add_model_url, get_or_compute_token_level_embedding_bundle_combined_feature_vector, calculate_token_level_embeddings +from service_functions import parse_submitted_document_file_into_sentence_strings_func, compute_embeddings_for_document, store_document_embeddings_in_db, generate_completion_from_llm from log_viewer_functions import show_logs_incremental_func, show_logs_func +from uvicorn_config import option import asyncio -import io import glob import json -import logging import os -import random import re -import shutil -import subprocess import tempfile -import time import traceback -import urllib.request import zipfile -from collections import defaultdict from datetime import datetime from hashlib import sha3_256 -from logging.handlers import RotatingFileHandler -from typing import List, Optional, Tuple, Dict, Any -from urllib.parse import quote, unquote +from typing import List, Optional, Dict, Any +from urllib.parse import unquote import numpy as np from decouple import config import uvicorn -import psutil -import textract import fastapi from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Depends from fastapi.responses import JSONResponse, FileResponse, HTMLResponse, Response -from fastapi.concurrency import run_in_threadpool -from langchain.embeddings import LlamaCppEmbeddings from sqlalchemy import select from sqlalchemy import text as sql_text -from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession -from sqlalchemy.orm import sessionmaker, joinedload +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import joinedload import faiss import pandas as pd from magic import Magic -from llama_cpp import Llama, LlamaGrammar import fast_vector_similarity as fvs -from faster_whisper import WhisperModel +import uvloop + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +logger = setup_logger() # Note: the Ramdisk setup and teardown requires sudo; to enable password-less sudo, edit your sudoers file with `sudo visudo`. # Add the following lines, replacing username with your actual username # username ALL=(ALL) NOPASSWD: /bin/mount -t tmpfs -o size=*G tmpfs /mnt/ramdisk # username ALL=(ALL) NOPASSWD: /bin/umount /mnt/ramdisk -# Setup logging -old_logs_dir = 'old_logs' # Ensure the old_logs directory exists -if not os.path.exists(old_logs_dir): - os.makedirs(old_logs_dir) -logger = logging.getLogger() -logger.setLevel(logging.INFO) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -log_file_path = 'swiss_army_llama.log' -fh = RotatingFileHandler(log_file_path, maxBytes=10*1024*1024, backupCount=5) -fh.setFormatter(formatter) -logger.addHandler(fh) -def namer(default_log_name): # Move rotated logs to the old_logs directory - return os.path.join(old_logs_dir, os.path.basename(default_log_name)) -def rotator(source, dest): - shutil.move(source, dest) -fh.namer = namer -fh.rotator = rotator -sh = logging.StreamHandler() -sh.setFormatter(formatter) -logger.addHandler(sh) -logger = logging.getLogger(__name__) -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) -configured_logger = logger - # Global variables use_hardcoded_security_token = 0 if use_hardcoded_security_token: @@ -82,1056 +55,17 @@ def rotator(source, dest): USE_SECURITY_TOKEN = config("USE_SECURITY_TOKEN", default=False, cast=bool) else: USE_SECURITY_TOKEN = False -DATABASE_URL = "sqlite+aiosqlite:///swiss_army_llama.sqlite" -SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int) DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str) -LLM_CONTEXT_SIZE_IN_TOKENS = config("LLM_CONTEXT_SIZE_IN_TOKENS", default=512, cast=int) -TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS = config("TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS", default=4000, cast=int) -DEFAULT_MAX_COMPLETION_TOKENS = config("DEFAULT_MAX_COMPLETION_TOKENS", default=100, cast=int) -DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE = config("DEFAULT_NUMBER_OF_COMPLETIONS_TO_GENERATE", default=4, cast=int) -DEFAULT_COMPLETION_TEMPERATURE = config("DEFAULT_COMPLETION_TEMPERATURE", default=0.7, cast=float) -MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING = config("MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING", default=15, cast=int) -USE_PARALLEL_INFERENCE_QUEUE = config("USE_PARALLEL_INFERENCE_QUEUE", default=False, cast=bool) -MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS = config("MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS", default=10, cast=int) USE_RAMDISK = config("USE_RAMDISK", default=False, cast=bool) RAMDISK_PATH = config("RAMDISK_PATH", default="/mnt/ramdisk", cast=str) -RAMDISK_SIZE_IN_GB = config("RAMDISK_SIZE_IN_GB", default=1, cast=int) -MAX_RETRIES = config("MAX_RETRIES", default=3, cast=int) -DB_WRITE_BATCH_SIZE = config("DB_WRITE_BATCH_SIZE", default=25, cast=int) -RETRY_DELAY_BASE_SECONDS = config("RETRY_DELAY_BASE_SECONDS", default=1, cast=int) -JITTER_FACTOR = config("JITTER_FACTOR", default=0.1, cast=float) BASE_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) -embedding_model_cache = {} # Model cache to store loaded models -token_level_embedding_model_cache = {} # Model cache to store loaded token-level embedding models -text_completion_model_cache = {} # Model cache to store loaded text completion models + logger.info(f"USE_RAMDISK is set to: {USE_RAMDISK}") -db_writer = None description_string = """ 🇨🇭🎖️🦙 Swiss Army Llama is your One-Stop-Shop to Quickly and Conveniently Integrate Powerful Local LLM Functionality into your Project via a REST API. """ app = FastAPI(title="Swiss Army Llama", description=description_string, docs_url="/") # Set the Swagger UI to root -engine = create_async_engine(DATABASE_URL, echo=False, connect_args={"check_same_thread": False}) -AsyncSessionLocal = sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - autoflush=False -) - -# Misc. utility functions and db writer class: -def clean_filename_for_url_func(dirty_filename: str) -> str: - clean_filename = re.sub(r'[^\w\s]', '', dirty_filename) # Remove special characters and replace spaces with underscores - clean_filename = clean_filename.replace(' ', '_') - return clean_filename - -class DatabaseWriter: - def __init__(self, queue): - self.queue = queue - self.processing_hashes = set() # Set to store the hashes if everything that is currently being processed in the queue (to avoid duplicates of the same task being added to the queue) - - def _get_hash_from_operation(self, operation): - attr_name = { - TextEmbedding: 'text_hash', - DocumentEmbedding: 'file_hash', - Document: 'document_hash', - TokenLevelEmbedding: 'token_hash', - TokenLevelEmbeddingBundle: 'input_text_hash', - TokenLevelEmbeddingBundleCombinedFeatureVector: 'combined_feature_vector_hash', - AudioTranscript: 'audio_file_hash' - }.get(type(operation)) - hash_value = getattr(operation, attr_name, None) - llm_model_name = getattr(operation, 'llm_model_name', None) - return f"{hash_value}_{llm_model_name}" if hash_value and llm_model_name else None - - async def initialize_processing_hashes(self, chunk_size=1000): - start_time = datetime.utcnow() - async with AsyncSessionLocal() as session: - queries = [ - (select(TextEmbedding.text_hash, TextEmbedding.llm_model_name), True), - (select(DocumentEmbedding.file_hash, DocumentEmbedding.llm_model_name), True), - (select(Document.document_hash, Document.llm_model_name), True), - (select(TokenLevelEmbedding.token_hash, TokenLevelEmbedding.llm_model_name), True), - (select(TokenLevelEmbeddingBundle.input_text_hash, TokenLevelEmbeddingBundle.llm_model_name), True), - (select(TokenLevelEmbeddingBundleCombinedFeatureVector.combined_feature_vector_hash, TokenLevelEmbeddingBundleCombinedFeatureVector.llm_model_name), True), - (select(AudioTranscript.audio_file_hash), False) - ] - for query, has_llm in queries: - offset = 0 - while True: - result = await session.execute(query.limit(chunk_size).offset(offset)) - rows = result.fetchall() - if not rows: - break - for row in rows: - if has_llm: - hash_with_model = f"{row[0]}_{row[1]}" - else: - hash_with_model = row[0] - self.processing_hashes.add(hash_with_model) - offset += chunk_size - end_time = datetime.utcnow() - total_time = (end_time - start_time).total_seconds() - if len(self.processing_hashes) > 0: - logger.info(f"Finished initializing set of input hash/llm_model_name combinations that are either currently being processed or have already been processed. Set size: {len(self.processing_hashes)}; Took {total_time} seconds, for an average of {total_time / len(self.processing_hashes)} seconds per hash.") - - async def _handle_integrity_error(self, e, write_operation, session): - unique_constraint_msg = { - TextEmbedding: "token_embeddings.token_hash, token_embeddings.llm_model_name", - DocumentEmbedding: "document_embeddings.file_hash, document_embeddings.llm_model_name", - Document: "documents.document_hash, documents.llm_model_name", - TokenLevelEmbedding: "token_level_embeddings.token_hash, token_level_embeddings.llm_model_name", - TokenLevelEmbeddingBundle: "token_level_embedding_bundles.input_text_hash, token_level_embedding_bundles.llm_model_name", - AudioTranscript: "audio_transcripts.audio_file_hash" - }.get(type(write_operation)) - if unique_constraint_msg and unique_constraint_msg in str(e): - logger.warning(f"Embedding already exists in the database for given input and llm_model_name: {e}") - await session.rollback() - else: - raise - - async def dedicated_db_writer(self): - while True: - write_operations_batch = await self.queue.get() - async with AsyncSessionLocal() as session: - try: - for write_operation in write_operations_batch: - session.add(write_operation) - await session.flush() # Flush to get the IDs - await session.commit() - for write_operation in write_operations_batch: - hash_to_remove = self._get_hash_from_operation(write_operation) - if hash_to_remove is not None and hash_to_remove in self.processing_hashes: - self.processing_hashes.remove(hash_to_remove) - except IntegrityError as e: - await self._handle_integrity_error(e, write_operation, session) - except SQLAlchemyError as e: - logger.error(f"Database error: {e}") - await session.rollback() - except Exception as e: - tb = traceback.format_exc() - logger.error(f"Unexpected error: {e}\n{tb}") - await session.rollback() - self.queue.task_done() - - async def enqueue_write(self, write_operations): - write_operations = [op for op in write_operations if self._get_hash_from_operation(op) not in self.processing_hashes] # Filter out write operations for hashes that are already being processed - if not write_operations: # If there are no write operations left after filtering, return early - return - for op in write_operations: # Add the hashes of the write operations to the set - hash_value = self._get_hash_from_operation(op) - if hash_value: - self.processing_hashes.add(hash_value) - await self.queue.put(write_operations) - - -async def execute_with_retry(func, *args, **kwargs): - retries = 0 - while retries < MAX_RETRIES: - try: - return await func(*args, **kwargs) - except OperationalError as e: - if 'database is locked' in str(e): - retries += 1 - sleep_time = RETRY_DELAY_BASE_SECONDS * (2 ** retries) + (random.random() * JITTER_FACTOR) # Implementing exponential backoff with jitter - logger.warning(f"Database is locked. Retrying ({retries}/{MAX_RETRIES})... Waiting for {sleep_time} seconds") - await asyncio.sleep(sleep_time) - else: - raise - raise OperationalError("Database is locked after multiple retries") - -async def initialize_db(): - logger.info("Initializing database, creating tables, and setting SQLite PRAGMAs...") - list_of_sqlite_pragma_strings = ["PRAGMA journal_mode=WAL;", "PRAGMA synchronous = NORMAL;", "PRAGMA cache_size = -1048576;", "PRAGMA busy_timeout = 2000;", "PRAGMA wal_autocheckpoint = 100;"] - list_of_sqlite_pragma_justification_strings = ["Set SQLite to use Write-Ahead Logging (WAL) mode (from default DELETE mode) so that reads and writes can occur simultaneously", - "Set synchronous mode to NORMAL (from FULL) so that writes are not blocked by reads", - "Set cache size to 1GB (from default 2MB) so that more data can be cached in memory and not read from disk; to make this 256MB, set it to -262144 instead", - "Increase the busy timeout to 2 seconds so that the database waits", - "Set the WAL autocheckpoint to 100 (from default 1000) so that the WAL file is checkpointed more frequently"] - assert(len(list_of_sqlite_pragma_strings) == len(list_of_sqlite_pragma_justification_strings)) - async with engine.begin() as conn: - for pragma_string in list_of_sqlite_pragma_strings: - await conn.execute(sql_text(pragma_string)) - logger.info(f"Executed SQLite PRAGMA: {pragma_string}") - logger.info(f"Justification: {list_of_sqlite_pragma_justification_strings[list_of_sqlite_pragma_strings.index(pragma_string)]}") - await conn.run_sync(Base.metadata.create_all) # Create tables if they don't exist - logger.info("Database initialization completed.") - -def get_db_writer() -> DatabaseWriter: - return db_writer # Return the existing DatabaseWriter instance - -def check_that_user_has_required_permissions_to_manage_ramdisks(): - try: # Try to run a harmless command with sudo to test if the user has password-less sudo permissions - result = subprocess.run(["sudo", "ls"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - if "password" in result.stderr.lower(): - raise PermissionError("Password required for sudo") - logger.info("User has sufficient permissions to manage RAM Disks.") - return True - except (PermissionError, subprocess.CalledProcessError) as e: - logger.info("Sorry, current user does not have sufficient permissions to manage RAM Disks! Disabling RAM Disks for now...") - logger.debug(f"Permission check error detail: {e}") - return False - -def setup_ramdisk(): - cmd_check = f"sudo mount | grep {RAMDISK_PATH}" # Check if RAM disk already exists at the path - result = subprocess.run(cmd_check, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') - if RAMDISK_PATH in result: - logger.info(f"RAM Disk already set up at {RAMDISK_PATH}. Skipping setup.") - return - total_ram_gb = psutil.virtual_memory().total / (1024 ** 3) - free_ram_gb = psutil.virtual_memory().free / (1024 ** 3) - buffer_gb = 2 # buffer to ensure we don't use all the free RAM - ramdisk_size_gb = max(min(RAMDISK_SIZE_IN_GB, free_ram_gb - buffer_gb), 0.1) - ramdisk_size_mb = int(ramdisk_size_gb * 1024) - ramdisk_size_str = f"{ramdisk_size_mb}M" - logger.info(f"Total RAM: {total_ram_gb}G") - logger.info(f"Free RAM: {free_ram_gb}G") - logger.info(f"Calculated RAM Disk Size: {ramdisk_size_gb}G") - if RAMDISK_SIZE_IN_GB > total_ram_gb: - raise ValueError(f"Cannot allocate {RAMDISK_SIZE_IN_GB}G for RAM Disk. Total system RAM is {total_ram_gb:.2f}G.") - logger.info("Setting up RAM Disk...") - os.makedirs(RAMDISK_PATH, exist_ok=True) - mount_command = ["sudo", "mount", "-t", "tmpfs", "-o", f"size={ramdisk_size_str}", "tmpfs", RAMDISK_PATH] - subprocess.run(mount_command, check=True) - logger.info(f"RAM Disk set up at {RAMDISK_PATH} with size {ramdisk_size_gb}G") - -def copy_models_to_ramdisk(models_directory, ramdisk_directory): - total_size = sum(os.path.getsize(os.path.join(models_directory, model)) for model in os.listdir(models_directory)) - free_ram = psutil.virtual_memory().free - if total_size > free_ram: - logger.warning(f"Not enough space on RAM Disk. Required: {total_size}, Available: {free_ram}. Rebuilding RAM Disk.") - clear_ramdisk() - free_ram = psutil.virtual_memory().free # Recompute the available RAM after clearing the RAM disk - if total_size > free_ram: - logger.error(f"Still not enough space on RAM Disk even after clearing. Required: {total_size}, Available: {free_ram}.") - raise ValueError("Not enough RAM space to copy models.") - setup_ramdisk() - os.makedirs(ramdisk_directory, exist_ok=True) - for model in os.listdir(models_directory): - shutil.copyfile(os.path.join(models_directory, model), os.path.join(ramdisk_directory, model)) - logger.info(f"Copied model {model} to RAM Disk at {os.path.join(ramdisk_directory, model)}") - -def clear_ramdisk(): - while True: - cmd_check = f"sudo mount | grep {RAMDISK_PATH}" - result = subprocess.run(cmd_check, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') - if RAMDISK_PATH not in result: - break # Exit the loop if the RAMDISK_PATH is not in the mount list - cmd_umount = f"sudo umount -l {RAMDISK_PATH}" - subprocess.run(cmd_umount, shell=True, check=True) - logger.info(f"Cleared RAM Disk at {RAMDISK_PATH}") - -async def build_faiss_indexes(): - global faiss_indexes, token_faiss_indexes, associated_texts_by_model - faiss_indexes = {} - token_faiss_indexes = {} # Separate FAISS indexes for token-level embeddings - associated_texts_by_model = defaultdict(list) # Create a dictionary to store associated texts by model name - async with AsyncSessionLocal() as session: - result = await session.execute(sql_text("SELECT llm_model_name, text, embedding_json FROM embeddings")) # Query regular embeddings - token_result = await session.execute(sql_text("SELECT llm_model_name, token, token_level_embedding_json FROM token_level_embeddings")) # Query token-level embeddings - embeddings_by_model = defaultdict(list) - token_embeddings_by_model = defaultdict(list) - for row in result.fetchall(): # Process regular embeddings - llm_model_name = row[0] - associated_texts_by_model[llm_model_name].append(row[1]) # Store the associated text by model name - embeddings_by_model[llm_model_name].append((row[1], json.loads(row[2]))) - for row in token_result.fetchall(): # Process token-level embeddings - llm_model_name = row[0] - token_embeddings_by_model[llm_model_name].append(json.loads(row[2])) - for llm_model_name, embeddings in embeddings_by_model.items(): - logger.info(f"Building Faiss index over embeddings for model {llm_model_name}...") - embeddings_array = np.array([e[1] for e in embeddings]).astype('float32') - if embeddings_array.size == 0: - logger.error(f"No embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!") - continue - logger.info(f"Loaded {len(embeddings_array)} embeddings for model {llm_model_name}.") - logger.info(f"Embedding dimension for model {llm_model_name}: {embeddings_array.shape[1]}") - logger.info(f"Normalizing {len(embeddings_array)} embeddings for model {llm_model_name}...") - faiss.normalize_L2(embeddings_array) # Normalize the vectors for cosine similarity - faiss_index = faiss.IndexFlatIP(embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity - faiss_index.add(embeddings_array) - logger.info(f"Faiss index built for model {llm_model_name}.") - faiss_indexes[llm_model_name] = faiss_index # Store the index by model name - for llm_model_name, token_embeddings in token_embeddings_by_model.items(): - token_embeddings_array = np.array(token_embeddings).astype('float32') - if token_embeddings_array.size == 0: - logger.error(f"No token-level embeddings were loaded from the database for model {llm_model_name}, so nothing to build the Faiss index with!") - continue - logger.info(f"Normalizing {len(token_embeddings_array)} token-level embeddings for model {llm_model_name}...") - faiss.normalize_L2(token_embeddings_array) # Normalize the vectors for cosine similarity - token_faiss_index = faiss.IndexFlatIP(token_embeddings_array.shape[1]) # Use IndexFlatIP for cosine similarity - token_faiss_index.add(token_embeddings_array) - logger.info(f"Token-level Faiss index built for model {llm_model_name}.") - token_faiss_indexes[llm_model_name] = token_faiss_index # Store the token-level index by model name - return faiss_indexes, token_faiss_indexes, associated_texts_by_model - -class JSONAggregator: - def __init__(self): - self.completions = [] - self.aggregate_result = None - - @staticmethod - def weighted_vote(values, weights): - tally = defaultdict(float) - for v, w in zip(values, weights): - tally[v] += w - return max(tally, key=tally.get) - - @staticmethod - def flatten_json(json_obj, parent_key='', sep='->'): - items = {} - for k, v in json_obj.items(): - new_key = f"{parent_key}{sep}{k}" if parent_key else k - if isinstance(v, dict): - items.update(JSONAggregator.flatten_json(v, new_key, sep=sep)) - else: - items[new_key] = v - return items - - @staticmethod - def get_value_by_path(json_obj, path, sep='->'): - keys = path.split(sep) - item = json_obj - for k in keys: - item = item[k] - return item - - @staticmethod - def set_value_by_path(json_obj, path, value, sep='->'): - keys = path.split(sep) - item = json_obj - for k in keys[:-1]: - item = item.setdefault(k, {}) - item[keys[-1]] = value - - def calculate_path_weights(self): - all_paths = [] - for j in self.completions: - all_paths += list(self.flatten_json(j).keys()) - path_weights = defaultdict(float) - for path in all_paths: - path_weights[path] += 1.0 - return path_weights - - def aggregate(self): - path_weights = self.calculate_path_weights() - aggregate = {} - for path, weight in path_weights.items(): - values = [self.get_value_by_path(j, path) for j in self.completions if path in self.flatten_json(j)] - weights = [weight] * len(values) - aggregate_value = self.weighted_vote(values, weights) - self.set_value_by_path(aggregate, path, aggregate_value) - self.aggregate_result = aggregate - -class FakeUploadFile: - def __init__(self, filename: str, content: Any, content_type: str = 'text/plain'): - self.filename = filename - self.content_type = content_type - self.file = io.BytesIO(content) - def read(self, size: int = -1) -> bytes: - return self.file.read(size) - def seek(self, offset: int, whence: int = 0) -> int: - return self.file.seek(offset, whence) - def tell(self) -> int: - return self.file.tell() - -async def get_transcript_from_db(audio_file_hash: str): - return await execute_with_retry(_get_transcript_from_db, audio_file_hash) - -async def _get_transcript_from_db(audio_file_hash: str) -> Optional[dict]: - async with AsyncSessionLocal() as session: - result = await session.execute( - sql_text("SELECT * FROM audio_transcripts WHERE audio_file_hash=:audio_file_hash"), - {"audio_file_hash": audio_file_hash}, - ) - row = result.fetchone() - if row: - try: - segments_json = json.loads(row.segments_json) - combined_transcript_text_list_of_metadata_dicts = json.loads(row.combined_transcript_text_list_of_metadata_dicts) - info_json = json.loads(row.info_json) - if hasattr(info_json, '__dict__'): - info_json = vars(info_json) - except json.JSONDecodeError as e: - raise ValueError(f"JSON Decode Error: {e}") - if not isinstance(segments_json, list) or not isinstance(combined_transcript_text_list_of_metadata_dicts, list) or not isinstance(info_json, dict): - logger.error(f"Type of segments_json: {type(segments_json)}, Value: {segments_json}") - logger.error(f"Type of combined_transcript_text_list_of_metadata_dicts: {type(combined_transcript_text_list_of_metadata_dicts)}, Value: {combined_transcript_text_list_of_metadata_dicts}") - logger.error(f"Type of info_json: {type(info_json)}, Value: {info_json}") - raise ValueError("Deserialized JSON does not match the expected format.") - audio_transcript_response = { - "id": row.id, - "audio_file_name": row.audio_file_name, - "audio_file_size_mb": row.audio_file_size_mb, - "segments_json": segments_json, - "combined_transcript_text": row.combined_transcript_text, - "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts, - "info_json": info_json, - "ip_address": row.ip_address, - "request_time": row.request_time, - "response_time": row.response_time, - "total_time": row.total_time, - "url_to_download_zip_file_of_embeddings": "" - } - return AudioTranscriptResponse(**audio_transcript_response) - return None - -async def save_transcript_to_db(audio_file_hash, audio_file_name, audio_file_size_mb, transcript_segments, info, ip_address, request_time, response_time, total_time, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts): - existing_transcript = await get_transcript_from_db(audio_file_hash) - if existing_transcript: - return existing_transcript - audio_transcript = AudioTranscript( - audio_file_hash=audio_file_hash, - audio_file_name=audio_file_name, - audio_file_size_mb=audio_file_size_mb, - segments_json=json.dumps(transcript_segments), - combined_transcript_text=combined_transcript_text, - combined_transcript_text_list_of_metadata_dicts=json.dumps(combined_transcript_text_list_of_metadata_dicts), - info_json=json.dumps(info), - ip_address=ip_address, - request_time=request_time, - response_time=response_time, - total_time=total_time - ) - await db_writer.enqueue_write([audio_transcript]) - -def normalize_logprobs(avg_logprob, min_logprob, max_logprob): - range_logprob = max_logprob - min_logprob - return (avg_logprob - min_logprob) / range_logprob if range_logprob != 0 else 0.5 - -def remove_pagination_breaks(text: str) -> str: - text = re.sub(r'-(\n)(?=[a-z])', '', text) # Remove hyphens at the end of lines when the word continues on the next line - text = re.sub(r'(?<=\w)(? dict: - request_time = datetime.utcnow() - ip_address = req.client.host if req else "127.0.0.1" - file_contents = await file.read() - audio_file_hash = sha3_256(file_contents).hexdigest() - file.file.seek(0) # Reset file pointer after read - existing_audio_transcript = await get_transcript_from_db(audio_file_hash) - if existing_audio_transcript: - return existing_audio_transcript - current_position = file.file.tell() - file.file.seek(0, os.SEEK_END) - audio_file_size_mb = file.file.tell() / (1024 * 1024) - file.file.seek(current_position) - with tempfile.NamedTemporaryFile(delete=False) as tmp_file: - shutil.copyfileobj(file.file, tmp_file) - audio_file_name = tmp_file.name - segment_details, info, combined_transcript_text, combined_transcript_text_list_of_metadata_dicts, request_time, response_time, total_time, download_url = await compute_transcript_with_whisper_from_audio_func(audio_file_hash, audio_file_name, file.filename, audio_file_size_mb, ip_address, req, compute_embeddings_for_resulting_transcript_document, llm_model_name) - audio_transcript_response = { - "audio_file_hash": audio_file_hash, - "audio_file_name": file.filename, - "audio_file_size_mb": audio_file_size_mb, - "segments_json": segment_details, - "combined_transcript_text": combined_transcript_text, - "combined_transcript_text_list_of_metadata_dicts": combined_transcript_text_list_of_metadata_dicts, - "info_json": info, - "ip_address": ip_address, - "request_time": request_time, - "response_time": response_time, - "total_time": total_time, - "url_to_download_zip_file_of_embeddings": download_url if compute_embeddings_for_resulting_transcript_document else "" - } - os.remove(audio_file_name) - return AudioTranscriptResponse(**audio_transcript_response) - - -# Core embedding functions start here: - -def download_models() -> Tuple[List[str], List[Dict[str, str]]]: - download_status = [] - json_path = os.path.join(BASE_DIRECTORY, "model_urls.json") - if not os.path.exists(json_path): - initial_model_urls = [ - 'https://huggingface.co/TheBloke/Yarn-Llama-2-13B-128K-GGUF/resolve/main/yarn-llama-2-13b-128k.Q4_K_M.gguf', - 'https://huggingface.co/TheBloke/Yarn-Llama-2-7B-128K-GGUF/resolve/main/yarn-llama-2-7b-128k.Q4_K_M.gguf', - 'https://huggingface.co/TheBloke/openchat_v3.2_super-GGUF/resolve/main/openchat_v3.2_super.Q4_K_M.gguf', - 'https://huggingface.co/TheBloke/Phind-CodeLlama-34B-Python-v1-GGUF/resolve/main/phind-codellama-34b-python-v1.Q4_K_M.gguf', - 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q6_K.gguf' - ] - with open(json_path, "w") as f: - json.dump(initial_model_urls, f) - with open(json_path, "r") as f: - list_of_model_download_urls = json.load(f) - model_names = [os.path.basename(url) for url in list_of_model_download_urls] - current_file_path = os.path.abspath(__file__) - base_dir = os.path.dirname(current_file_path) - models_dir = os.path.join(base_dir, 'models') - logger.info("Checking models directory...") - if USE_RAMDISK: - ramdisk_models_dir = os.path.join(RAMDISK_PATH, 'models') - if not os.path.exists(RAMDISK_PATH): - setup_ramdisk() - if all(os.path.exists(os.path.join(ramdisk_models_dir, llm_model_name)) for llm_model_name in model_names): - logger.info("Models found in RAM Disk.") - for url in list_of_model_download_urls: - download_status.append({"url": url, "status": "success", "message": "Model found in RAM Disk."}) - return model_names, download_status - if not os.path.exists(models_dir): - os.makedirs(models_dir) - logger.info(f"Created models directory: {models_dir}") - else: - logger.info(f"Models directory exists: {models_dir}") - for url, model_name_with_extension in zip(list_of_model_download_urls, model_names): - status = {"url": url, "status": "success", "message": "File already exists."} - filename = os.path.join(models_dir, model_name_with_extension) - if not os.path.exists(filename): - logger.info(f"Downloading model {model_name_with_extension} from {url}...") - urllib.request.urlretrieve(url, filename) - file_size = os.path.getsize(filename) / (1024 * 1024) # Convert bytes to MB - if file_size < 100: - os.remove(filename) - status["status"] = "failure" - status["message"] = "Downloaded file is too small, probably not a valid model file." - else: - logger.info(f"Downloaded: {filename}") - else: - logger.info(f"File already exists: {filename}") - download_status.append(status) - if USE_RAMDISK: - copy_models_to_ramdisk(models_dir, ramdisk_models_dir) - logger.info("Model downloads completed.") - return model_names, download_status - -def add_model_url(new_url: str) -> str: - corrected_url = new_url - if '/blob/main/' in new_url: - corrected_url = new_url.replace('/blob/main/', '/resolve/main/') - json_path = os.path.join(BASE_DIRECTORY, "model_urls.json") - with open(json_path, "r") as f: - existing_urls = json.load(f) - if corrected_url not in existing_urls: - logger.info(f"Model URL not found in database. Adding {new_url} now...") - existing_urls.append(corrected_url) - with open(json_path, "w") as f: - json.dump(existing_urls, f) - logger.info(f"Model URL added: {new_url}") - else: - logger.info("Model URL already exists.") - return corrected_url - -async def get_embedding_from_db(text: str, llm_model_name: str): - text_hash = sha3_256(text.encode('utf-8')).hexdigest() # Compute the hash - return await execute_with_retry(_get_embedding_from_db, text_hash, llm_model_name) - -async def _get_embedding_from_db(text_hash: str, llm_model_name: str) -> Optional[dict]: - async with AsyncSessionLocal() as session: - result = await session.execute( - sql_text("SELECT embedding_json FROM embeddings WHERE text_hash=:text_hash AND llm_model_name=:llm_model_name"), - {"text_hash": text_hash, "llm_model_name": llm_model_name}, - ) - row = result.fetchone() - if row: - embedding_json = row[0] - logger.info(f"Embedding found in database for text hash '{text_hash}' using model '{llm_model_name}'") - return json.loads(embedding_json) - return None - -async def get_or_compute_embedding(request: EmbeddingRequest, req: Request = None, client_ip: str = None, document_file_hash: str = None) -> dict: - request_time = datetime.utcnow() # Capture request time as datetime object - ip_address = client_ip or (req.client.host if req else "localhost") # If client_ip is provided, use it; otherwise, try to get from req; if not available, default to "localhost" - logger.info(f"Received request for embedding for '{request.text}' using model '{request.llm_model_name}' from IP address '{ip_address}'") - embedding_list = await get_embedding_from_db(request.text, request.llm_model_name) # Check if embedding exists in the database - if embedding_list is not None: - response_time = datetime.utcnow() # Capture response time as datetime object - total_time = (response_time - request_time).total_seconds() # Calculate time taken in seconds - logger.info(f"Embedding found in database for '{request.text}' using model '{request.llm_model_name}'; returning in {total_time:.4f} seconds") - return {"embedding": embedding_list} - model = load_model(request.llm_model_name) - embedding_list = calculate_sentence_embedding(model, request.text) # Compute the embedding if not in the database - if embedding_list is None: - logger.error(f"Could not calculate the embedding for the given text: '{request.text}' using model '{request.llm_model_name}!'") - raise HTTPException(status_code=400, detail="Could not calculate the embedding for the given text") - embedding_json = json.dumps(embedding_list) # Serialize the numpy array to JSON and save to the database - response_time = datetime.utcnow() # Capture response time as datetime object - total_time = (response_time - request_time).total_seconds() # Calculate total time using datetime objects - word_length_of_input_text = len(request.text.split()) - if word_length_of_input_text > 0: - logger.info(f"Embedding calculated for '{request.text}' using model '{request.llm_model_name}' in {total_time} seconds, or an average of {total_time/word_length_of_input_text :.2f} seconds per word. Now saving to database...") - await save_embedding_to_db(request.text, request.llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash) - return {"embedding": embedding_list} - -async def save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None): - existing_embedding = await get_embedding_from_db(text, llm_model_name) # Check if the embedding already exists - if existing_embedding is not None: - return existing_embedding - return await execute_with_retry(_save_embedding_to_db, text, llm_model_name, embedding_json, ip_address, request_time, response_time, total_time, document_file_hash) - -async def _save_embedding_to_db(text: str, llm_model_name: str, embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, total_time: float, document_file_hash: str = None): - existing_embedding = await get_embedding_from_db(text, llm_model_name) - if existing_embedding: - return existing_embedding - embedding = TextEmbedding( - text=text, - llm_model_name=llm_model_name, - embedding_json=embedding_json, - ip_address=ip_address, - request_time=request_time, - response_time=response_time, - total_time=total_time, - document_file_hash=document_file_hash - ) - await db_writer.enqueue_write([embedding]) # Enqueue the write operation using the db_writer instance - -def load_model(llm_model_name: str, raise_http_exception: bool = True): - try: - models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') - if llm_model_name in embedding_model_cache: - return embedding_model_cache[llm_model_name] - matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) - if not matching_files: - logger.error(f"No model file found matching: {llm_model_name}") - raise FileNotFoundError - matching_files.sort(key=os.path.getmtime, reverse=True) - model_file_path = matching_files[0] - model_instance = LlamaCppEmbeddings(model_path=model_file_path, use_mlock=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS) - model_instance.client.verbose = False - embedding_model_cache[llm_model_name] = model_instance - return model_instance - except TypeError as e: - logger.error(f"TypeError occurred while loading the model: {e}") - raise - except Exception as e: - logger.error(f"Exception occurred while loading the model: {e}") - if raise_http_exception: - raise HTTPException(status_code=404, detail="Model file not found") - else: - raise FileNotFoundError(f"No model file found matching: {llm_model_name}") - -def load_token_level_embedding_model(llm_model_name: str, raise_http_exception: bool = True): - try: - if llm_model_name in token_level_embedding_model_cache: # Check if the model is already loaded in the cache - return token_level_embedding_model_cache[llm_model_name] - models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path - matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files - if not matching_files: - logger.error(f"No model file found matching: {llm_model_name}") - raise FileNotFoundError - matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first) - model_file_path = matching_files[0] - model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model - token_level_embedding_model_cache[llm_model_name] = model_instance # Cache the loaded model - return model_instance - except TypeError as e: - logger.error(f"TypeError occurred while loading the model: {e}") - raise - except Exception as e: - logger.error(f"Exception occurred while loading the model: {e}") - if raise_http_exception: - raise HTTPException(status_code=404, detail="Model file not found") - else: - raise FileNotFoundError(f"No model file found matching: {llm_model_name}") - -async def compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings) -> List[float]: - start_time = datetime.utcnow() - logger.info("Extracting token-level embeddings from the bundle") - parsed_df = pd.read_json(token_level_embeddings) # Parse the json_content back to a DataFrame - token_level_embeddings = list(parsed_df['embedding']) - embeddings = np.array(token_level_embeddings) # Convert the list of embeddings to a NumPy array - logger.info(f"Computing column-wise means/mins/maxes/std_devs of the embeddings... (shape: {embeddings.shape})") - assert(len(embeddings) > 0) - means = np.mean(embeddings, axis=0) - mins = np.min(embeddings, axis=0) - maxes = np.max(embeddings, axis=0) - stds = np.std(embeddings, axis=0) - logger.info("Concatenating the computed statistics to form the combined feature vector") - combined_feature_vector = np.concatenate([means, mins, maxes, stds]) - end_time = datetime.utcnow() - total_time = (end_time - start_time).total_seconds() - logger.info(f"Computed the token-level embedding bundle's combined feature vector computed in {total_time: .2f} seconds.") - return combined_feature_vector.tolist() - -async def get_or_compute_token_level_embedding_bundle_combined_feature_vector(token_level_embedding_bundle_id, token_level_embeddings, db_writer: DatabaseWriter) -> List[float]: - request_time = datetime.utcnow() - logger.info(f"Checking for existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}") - async with AsyncSessionLocal() as session: - result = await session.execute( - select(TokenLevelEmbeddingBundleCombinedFeatureVector) - .filter(TokenLevelEmbeddingBundleCombinedFeatureVector.token_level_embedding_bundle_id == token_level_embedding_bundle_id) - ) - existing_combined_feature_vector = result.scalar_one_or_none() - if existing_combined_feature_vector: - response_time = datetime.utcnow() - total_time = (response_time - request_time).total_seconds() - logger.info(f"Found existing combined feature vector for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Returning cached result in {total_time:.2f} seconds.") - return json.loads(existing_combined_feature_vector.combined_feature_vector_json) # Parse the JSON string into a list - logger.info(f"No cached combined feature_vector found for token-level embedding bundle ID: {token_level_embedding_bundle_id}. Computing now...") - combined_feature_vector = await compute_token_level_embedding_bundle_combined_feature_vector(token_level_embeddings) - combined_feature_vector_db_object = TokenLevelEmbeddingBundleCombinedFeatureVector( - token_level_embedding_bundle_id=token_level_embedding_bundle_id, - combined_feature_vector_json=json.dumps(combined_feature_vector) # Convert the list to a JSON string - ) - logger.info(f"Writing combined feature vector for database write for token-level embedding bundle ID: {token_level_embedding_bundle_id} to the database...") - await db_writer.enqueue_write([combined_feature_vector_db_object]) - return combined_feature_vector - -async def calculate_token_level_embeddings(text: str, llm_model_name: str, client_ip: str, token_level_embedding_bundle_id: int) -> List[np.array]: - request_time = datetime.utcnow() - logger.info(f"Starting token-level embedding calculation for text: '{text}' using model: '{llm_model_name}'") - logger.info(f"Loading model: '{llm_model_name}'") - llm = load_token_level_embedding_model(llm_model_name) # Assuming this method returns an instance of the Llama class - token_embeddings = [] - tokens = text.split() # Simple whitespace tokenizer; can be replaced with a more advanced one if needed - logger.info(f"Tokenized text into {len(tokens)} tokens") - for idx, token in enumerate(tokens, start=1): - try: # Check if the embedding is already available in the database - existing_embedding = await get_token_level_embedding_from_db(token, llm_model_name) - if existing_embedding is not None: - token_embeddings.append(np.array(existing_embedding)) - logger.info(f"Embedding retrieved from database for token '{token}'") - continue - logger.info(f"Processing token {idx} of {len(tokens)}: '{token}'") - token_embedding = llm.embed(token) - token_embedding_array = np.array(token_embedding) - token_embeddings.append(token_embedding_array) - response_time = datetime.utcnow() - token_level_embedding_json = json.dumps(token_embedding_array.tolist()) - await store_token_level_embeddings_in_db(token, llm_model_name, token_level_embedding_json, client_ip, request_time, response_time, token_level_embedding_bundle_id) - except RuntimeError as e: - logger.error(f"Failed to calculate embedding for token '{token}': {e}") - logger.info(f"Completed token embedding calculation for all tokens in text: '{text}'") - return token_embeddings - -async def get_token_level_embedding_from_db(token: str, llm_model_name: str) -> Optional[List[float]]: - token_hash = sha3_256(token.encode('utf-8')).hexdigest() # Compute the hash - async with AsyncSessionLocal() as session: - result = await session.execute( - sql_text("SELECT token_level_embedding_json FROM token_level_embeddings WHERE token_hash=:token_hash AND llm_model_name=:llm_model_name"), - {"token_hash": token_hash, "llm_model_name": llm_model_name}, - ) - row = result.fetchone() - if row: - embedding_json = row[0] - logger.info(f"Embedding found in database for token hash '{token_hash}' using model '{llm_model_name}'") - return json.loads(embedding_json) - return None - -async def store_token_level_embeddings_in_db(token: str, llm_model_name: str, token_level_embedding_json: str, ip_address: str, request_time: datetime, response_time: datetime, token_level_embedding_bundle_id: int): - total_time = (response_time - request_time).total_seconds() - embedding = TokenLevelEmbedding( - token=token, - llm_model_name=llm_model_name, - token_level_embedding_json=token_level_embedding_json, - ip_address=ip_address, - request_time=request_time, - response_time=response_time, - total_time=total_time, - token_level_embedding_bundle_id=token_level_embedding_bundle_id - ) - await db_writer.enqueue_write([embedding]) # Enqueue the write operation for the token-level embedding - -def calculate_sentence_embedding(llama: Llama, text: str) -> np.array: - sentence_embedding = None - retry_count = 0 - while sentence_embedding is None and retry_count < 3: - try: - if retry_count > 0: - logger.info(f"Attempting again calculate sentence embedding. Attempt number {retry_count + 1}") - sentence_embedding = llama.embed_query(text) - except TypeError as e: - logger.error(f"TypeError in calculate_sentence_embedding: {e}") - raise - except Exception as e: - logger.error(f"Exception in calculate_sentence_embedding: {e}") - text = text[:-int(len(text) * 0.1)] - retry_count += 1 - logger.info(f"Trimming sentence due to too many tokens. New length: {len(text)}") - if sentence_embedding is None: - logger.error("Failed to calculate sentence embedding after multiple attempts") - return sentence_embedding - -async def compute_embeddings_for_document(strings: list, llm_model_name: str, client_ip: str, document_file_hash: str) -> List[Tuple[str, np.array]]: - results = [] - if USE_PARALLEL_INFERENCE_QUEUE: - logger.info(f"Using parallel inference queue to compute embeddings for {len(strings)} strings") - start_time = time.perf_counter() # Record the start time - semaphore = asyncio.Semaphore(MAX_CONCURRENT_PARALLEL_INFERENCE_TASKS) - async def compute_embedding(text): # Define a function to compute the embedding for a given text - try: - async with semaphore: # Acquire a semaphore slot - request = EmbeddingRequest(text=text, llm_model_name=llm_model_name) - embedding = await get_embedding_vector_for_string(request, client_ip=client_ip, document_file_hash=document_file_hash) - return text, embedding["embedding"] - except Exception as e: - logger.error(f"Error computing embedding for text '{text}': {e}") - return text, None - results = await asyncio.gather(*[compute_embedding(s) for s in strings]) # Use asyncio.gather to run the tasks concurrently - end_time = time.perf_counter() # Record the end time - duration = end_time - start_time - if len(strings) > 0: - logger.info(f"Parallel inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string") - else: # Compute embeddings sequentially - logger.info(f"Using sequential inference to compute embeddings for {len(strings)} strings") - start_time = time.perf_counter() # Record the start time - for s in strings: - embedding_request = EmbeddingRequest(text=s, llm_model_name=llm_model_name) - embedding = await get_embedding_vector_for_string(embedding_request, client_ip=client_ip, document_file_hash=document_file_hash) - results.append((s, embedding["embedding"])) - end_time = time.perf_counter() # Record the end time - duration = end_time - start_time - if len(strings) > 0: - logger.info(f"Sequential inference task for {len(strings)} strings completed in {duration:.2f} seconds; {duration / len(strings):.2f} seconds per string") - filtered_results = [(text, embedding) for text, embedding in results if embedding is not None] # Filter out results with None embeddings (applicable to parallel processing) and return - return filtered_results - -async def parse_submitted_document_file_into_sentence_strings_func(temp_file_path: str, mime_type: str): - strings = [] - if mime_type.startswith('text/'): - with open(temp_file_path, 'r') as buffer: - content = buffer.read() - else: - try: - content = textract.process(temp_file_path).decode('utf-8') - except UnicodeDecodeError: - try: - content = textract.process(temp_file_path).decode('unicode_escape') - except Exception as e: - logger.error(f"Error while processing file: {e}, mime_type: {mime_type}") - raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}") - except Exception as e: - logger.error(f"Error while processing file: {e}, mime_type: {mime_type}") - raise HTTPException(status_code=400, detail=f"Unsupported file type or error: {e}") - sentences = sophisticated_sentence_splitter(content) - if len(sentences) == 0 and temp_file_path.lower().endswith('.pdf'): - logger.info("No sentences found, attempting OCR using Tesseract.") - try: - content = textract.process(temp_file_path, method='tesseract').decode('utf-8') - sentences = sophisticated_sentence_splitter(content) - except Exception as e: - logger.error(f"Error while processing file with OCR: {e}") - raise HTTPException(status_code=400, detail=f"OCR failed: {e}") - if len(sentences) == 0: - logger.info("No sentences found in the document") - raise HTTPException(status_code=400, detail="No sentences found in the document") - logger.info(f"Extracted {len(sentences)} sentences from the document") - strings = [s.strip() for s in sentences if len(s.strip()) > MINIMUM_STRING_LENGTH_FOR_DOCUMENT_EMBEDDING] - return strings - -async def _get_document_from_db(file_hash: str): - async with AsyncSessionLocal() as session: - result = await session.execute(select(Document).filter(Document.document_hash == file_hash)) - return result.scalar_one_or_none() - -async def store_document_embeddings_in_db(file: File, file_hash: str, original_file_content: bytes, json_content: bytes, results: List[Tuple[str, np.array]], llm_model_name: str, client_ip: str, request_time: datetime): - document = await _get_document_from_db(file_hash) # First, check if a Document with the same hash already exists - if not document: # If not, create a new Document object - document = Document(document_hash=file_hash, llm_model_name=llm_model_name) - await db_writer.enqueue_write([document]) - document_embedding = DocumentEmbedding( - filename=file.filename, - mimetype=file.content_type, - file_hash=file_hash, - llm_model_name=llm_model_name, - file_data=original_file_content, - document_embedding_results_json=json.loads(json_content.decode()), - ip_address=client_ip, - request_time=request_time, - response_time=datetime.utcnow(), - total_time=(datetime.utcnow() - request_time).total_seconds() - ) - document.document_embeddings.append(document_embedding) # Associate it with the Document - document.update_hash() # This will trigger the SQLAlchemy event to update the document_hash - await db_writer.enqueue_write([document, document_embedding]) # Enqueue the write operation for the document embedding - write_operations = [] # Collect text embeddings to write - logger.info(f"Storing {len(results)} text embeddings in database") - for text, embedding in results: - embedding_entry = await _get_embedding_from_db(text, llm_model_name) - if not embedding_entry: - embedding_entry = TextEmbedding( - text=text, - llm_model_name=llm_model_name, - embedding_json=json.dumps(embedding), - ip_address=client_ip, - request_time=request_time, - response_time=datetime.utcnow(), - total_time=(datetime.utcnow() - request_time).total_seconds(), - document_file_hash=file_hash # Link it to the DocumentEmbedding via file_hash - ) - else: - write_operations.append(embedding_entry) - await db_writer.enqueue_write(write_operations) # Enqueue the write operation for text embeddings - -def load_text_completion_model(llm_model_name: str, raise_http_exception: bool = True): - try: - if llm_model_name in text_completion_model_cache: # Check if the model is already loaded in the cache - return text_completion_model_cache[llm_model_name] - models_dir = os.path.join(RAMDISK_PATH, 'models') if USE_RAMDISK else os.path.join(BASE_DIRECTORY, 'models') # Determine the model directory path - matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*")) # Search for matching model files - if not matching_files: - logger.error(f"No model file found matching: {llm_model_name}") - raise FileNotFoundError - matching_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first) - model_file_path = matching_files[0] - model_instance = Llama(model_path=model_file_path, embedding=True, n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS, verbose=False) # Load the model - text_completion_model_cache[llm_model_name] = model_instance # Cache the loaded model - return model_instance - except TypeError as e: - logger.error(f"TypeError occurred while loading the model: {e}") - raise - except Exception as e: - logger.error(f"Exception occurred while loading the model: {e}") - if raise_http_exception: - raise HTTPException(status_code=404, detail="Model file not found") - else: - raise FileNotFoundError(f"No model file found matching: {llm_model_name}") - -async def generate_completion_from_llm(request: TextCompletionRequest, req: Request = None, client_ip: str = None) -> List[TextCompletionResponse]: - request_time = datetime.utcnow() - logger.info(f"Starting text completion calculation using model: '{request.llm_model_name}'for input prompt: '{request.input_prompt}'") - logger.info(f"Loading model: '{request.llm_model_name}'") - llm = load_text_completion_model(request.llm_model_name) - logger.info(f"Done loading model: '{request.llm_model_name}'") - list_of_llm_outputs = [] - if request.grammar_file_string != "": - list_of_grammar_files = glob.glob("./grammar_files/*.gbnf") - matching_grammar_files = [x for x in list_of_grammar_files if request.grammar_file_string in x] - if len(matching_grammar_files) == 0: - logger.error(f"No grammar file found matching: {request.grammar_file_string}") - raise FileNotFoundError - matching_grammar_files.sort(key=os.path.getmtime, reverse=True) # Sort the files based on modification time (recently modified files first) - grammar_file_path = matching_grammar_files[0] - logger.info(f"Loading selected grammar file: '{grammar_file_path}'") - llama_grammar = LlamaGrammar.from_file(grammar_file_path) - for ii in range(request.number_of_completions_to_generate): - logger.info(f"Generating completion {ii+1} of {request.number_of_completions_to_generate} with model {request.llm_model_name} for input prompt: '{request.input_prompt}'") - output = llm(prompt=request.input_prompt, grammar=llama_grammar, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature) - list_of_llm_outputs.append(output) - else: - for ii in range(request.number_of_completions_to_generate): - output = llm(prompt=request.input_prompt, max_tokens=request.number_of_tokens_to_generate, temperature=request.temperature) - list_of_llm_outputs.append(output) - response_time = datetime.utcnow() - total_time_per_completion = ((response_time - request_time).total_seconds()) / request.number_of_completions_to_generate - list_of_responses = [] - for idx, current_completion_output in enumerate(list_of_llm_outputs): - generated_text = current_completion_output['choices'][0]['text'] - if request.grammar_file_string == 'json': - generated_text = generated_text.encode('unicode_escape').decode() - llm_model_usage_json = json.dumps(current_completion_output['usage']) - logger.info(f"Completed text completion {idx} in an average of {total_time_per_completion:.2f} seconds for input prompt: '{request.input_prompt}'; Beginning of generated text: \n'{generated_text[:100]}'") - response = TextCompletionResponse(input_prompt = request.input_prompt, - llm_model_name = request.llm_model_name, - grammar_file_string = request.grammar_file_string, - number_of_tokens_to_generate = request.number_of_tokens_to_generate, - number_of_completions_to_generate = request.number_of_completions_to_generate, - time_taken_in_seconds = float(total_time_per_completion), - generated_text = generated_text, - llm_model_usage_json = llm_model_usage_json) - list_of_responses.append(response) - return list_of_responses @app.exception_handler(SQLAlchemyError) @@ -1144,7 +78,6 @@ async def general_exception_handler(request: Request, exc: Exception) -> JSONRes logger.exception(exc) return JSONResponse(status_code=500, content={"message": "An unexpected error occurred"}) -#FastAPI Endpoints start here: @app.get("/", include_in_schema=False) async def custom_swagger_ui_html(): @@ -1164,7 +97,7 @@ async def custom_swagger_ui_html(): ### Example Response: ```json { - "model_names": ["yarn-llama-2-7b-128k", "yarn-llama-2-13b-128k", "openchat_v3.2_super", "phind-codellama-34b-python-v1", "my_super_custom_model"] + "model_names": ["yarn-llama-2-7b-128k", "openchat_v3.2_super", "mistral-7b-instruct-v0.1", "my_super_custom_model"] } ```""", response_description="A JSON object containing the list of available model names.") @@ -1283,15 +216,23 @@ async def get_all_stored_documents(req: Request, token: str = None) -> AllDocume async def add_new_model(model_url: str, token: str = None) -> Dict[str, Any]: if USE_SECURITY_TOKEN and (token is None or token != SECURITY_TOKEN): raise HTTPException(status_code=403, detail="Unauthorized") - decoded_model_url = unquote(model_url) - if not decoded_model_url.endswith('.gguf'): - return {"status": "error", "message": "Model URL must point to a .gguf file."} - corrected_model_url = add_model_url(decoded_model_url) - _, download_status = download_models() - status_dict = {status["url"]: status for status in download_status} - if corrected_model_url in status_dict: - return {"status": status_dict[corrected_model_url]["status"], "message": status_dict[corrected_model_url]["message"]} - return {"status": "unknown", "message": "Unexpected error."} + unique_id = f"add_model_{hash(model_url)}" # Generate a unique lock ID based on the model_url + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + decoded_model_url = unquote(model_url) + if not decoded_model_url.endswith('.gguf'): + return {"status": "error", "message": "Model URL must point to a .gguf file."} + corrected_model_url = add_model_url(decoded_model_url) + _, download_status = download_models() + status_dict = {status["url"]: status for status in download_status} + if corrected_model_url in status_dict: + return {"status": status_dict[corrected_model_url]["status"], "message": status_dict[corrected_model_url]["message"]} + return {"status": "unknown", "message": "Unexpected error."} + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing", "message": "Another worker is already processing this model URL."} @@ -1594,44 +535,53 @@ async def compute_similarity_between_strings(request: SimilarityRequest, req: Re response_description="A JSON object containing the query text along with the most similar strings and similarity scores.") async def search_stored_embeddings_with_query_string_for_semantic_similarity(request: SemanticSearchRequest, req: Request, token: str = None) -> SemanticSearchResponse: global faiss_indexes, token_faiss_indexes, associated_texts_by_model - faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() - request_time = datetime.utcnow() - llm_model_name = request.llm_model_name - num_results = request.number_of_most_similar_strings_to_return - total_entries = len(associated_texts_by_model[llm_model_name]) # Get the total number of entries for the model - num_results = min(num_results, total_entries) # Ensure num_results doesn't exceed the total number of entries - logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}") - if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN): - raise HTTPException(status_code=403, detail="Unauthorized") - try: - logger.info(f"Computing embedding for input text: {request.query_text}") - embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=request.llm_model_name) - embedding_response = await get_embedding_vector_for_string(embedding_request, req) - input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1) - faiss.normalize_L2(input_embedding) # Normalize the input vector for cosine similarity - logger.info(f"Computed embedding for input text: {request.query_text}") - faiss_index = faiss_indexes.get(llm_model_name) # Retrieve the correct FAISS index for the llm_model_name - if faiss_index is None: - raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}") - logger.info("Searching for the most similar string in the FAISS index") - similarities, indices = faiss_index.search(input_embedding.reshape(1, -1), num_results) # Search for num_results similar strings - results = [] # Create an empty list to store the results - for ii in range(num_results): - similarity = float(similarities[0][ii]) # Convert numpy.float32 to native float - most_similar_text = associated_texts_by_model[llm_model_name][indices[0][ii]] - if most_similar_text != request.query_text: # Don't return the query text as a result - results.append({"search_result_text": most_similar_text, "similarity_to_query_text": similarity}) - response_time = datetime.utcnow() - total_time = (response_time - request_time).total_seconds() - logger.info(f"Finished searching for the most similar string in the FAISS index in {total_time} seconds. Found {len(results)} results, returning the top {num_results}.") - logger.info(f"Found most similar strings for query string {request.query_text}: {results}") - return {"query_text": request.query_text, "results": results} # Return the response matching the SemanticSearchResponse model - except Exception as e: - logger.error(f"An error occurred while processing the request: {e}") - logger.error(traceback.format_exc()) # Print the traceback - raise HTTPException(status_code=500, detail="Internal Server Error") - - + unique_id = f"semantic_search_{request.query_text}_{request.llm_model_name}_{request.number_of_most_similar_strings_to_return}" # Unique ID for this operation + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() + request_time = datetime.utcnow() + llm_model_name = request.llm_model_name + num_results = request.number_of_most_similar_strings_to_return + total_entries = len(associated_texts_by_model[llm_model_name]) # Get the total number of entries for the model + num_results = min(num_results, total_entries) # Ensure num_results doesn't exceed the total number of entries + logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}") + if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN): + raise HTTPException(status_code=403, detail="Unauthorized") + try: + logger.info(f"Computing embedding for input text: {request.query_text}") + embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=request.llm_model_name) + embedding_response = await get_embedding_vector_for_string(embedding_request, req) + input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1) + faiss.normalize_L2(input_embedding) # Normalize the input vector for cosine similarity + logger.info(f"Computed embedding for input text: {request.query_text}") + faiss_index = faiss_indexes.get(llm_model_name) # Retrieve the correct FAISS index for the llm_model_name + if faiss_index is None: + raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}") + logger.info("Searching for the most similar string in the FAISS index") + similarities, indices = faiss_index.search(input_embedding.reshape(1, -1), num_results) # Search for num_results similar strings + results = [] # Create an empty list to store the results + for ii in range(num_results): + similarity = float(similarities[0][ii]) # Convert numpy.float32 to native float + most_similar_text = associated_texts_by_model[llm_model_name][indices[0][ii]] + if most_similar_text != request.query_text: # Don't return the query text as a result + results.append({"search_result_text": most_similar_text, "similarity_to_query_text": similarity}) + response_time = datetime.utcnow() + total_time = (response_time - request_time).total_seconds() + logger.info(f"Finished searching for the most similar string in the FAISS index in {total_time} seconds. Found {len(results)} results, returning the top {num_results}.") + logger.info(f"Found most similar strings for query string {request.query_text}: {results}") + return {"query_text": request.query_text, "results": results} # Return the response matching the SemanticSearchResponse model + except Exception as e: + logger.error(f"An error occurred while processing the request: {e}") + logger.error(traceback.format_exc()) # Print the traceback + raise HTTPException(status_code=500, detail="Internal Server Error") + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing"} + + + @app.post("/advanced_search_stored_embeddings_with_query_string_for_semantic_similarity/", response_model=AdvancedSemanticSearchResponse, summary="Advanced Semantic Search with Two-Step Similarity Measures", @@ -1676,52 +626,60 @@ async def search_stored_embeddings_with_query_string_for_semantic_similarity(req response_description="A JSON object containing the query text and the most similar strings, along with their similarity scores for multiple measures.") async def advanced_search_stored_embeddings_with_query_string_for_semantic_similarity(request: AdvancedSemanticSearchRequest, req: Request, token: str = None) -> AdvancedSemanticSearchResponse: global faiss_indexes, token_faiss_indexes, associated_texts_by_model - faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() - request_time = datetime.utcnow() - llm_model_name = request.llm_model_name - total_entries = len(associated_texts_by_model[llm_model_name]) - num_results = max([1, int((1 - request.similarity_filter_percentage) * total_entries)]) - logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}") - if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN): - raise HTTPException(status_code=403, detail="Unauthorized") - try: - logger.info(f"Computing embedding for input text: {request.query_text}") - embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=llm_model_name) - embedding_response = await get_embedding_vector_for_string(embedding_request, req) - input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1) - faiss.normalize_L2(input_embedding) - logger.info(f"Computed embedding for input text: {request.query_text}") - faiss_index = faiss_indexes.get(llm_model_name) - if faiss_index is None: - raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}") - _, indices = faiss_index.search(input_embedding, num_results) - filtered_indices = indices[0] - similarity_results = [] - for idx in filtered_indices: - associated_text = associated_texts_by_model[llm_model_name][idx] - embedding_request = EmbeddingRequest(text=associated_text, llm_model_name=llm_model_name) - embedding_response = await get_embedding_vector_for_string(embedding_request, req) - filtered_embedding = np.array(embedding_response["embedding"]) - params = { - "vector_1": input_embedding.tolist()[0], - "vector_2": filtered_embedding.tolist(), - "similarity_measure": "all" - } - similarity_stats_str = fvs.py_compute_vector_similarity_stats(json.dumps(params)) - similarity_stats_json = json.loads(similarity_stats_str) - similarity_results.append({ - "search_result_text": associated_text, - "similarity_to_query_text": similarity_stats_json - }) - num_to_return = request.number_of_most_similar_strings_to_return if request.number_of_most_similar_strings_to_return is not None else len(similarity_results) - results = sorted(similarity_results, key=lambda x: x["similarity_to_query_text"]["hoeffding_d"], reverse=True)[:num_to_return] - response_time = datetime.utcnow() - total_time = (response_time - request_time).total_seconds() - logger.info(f"Finished advanced search in {total_time} seconds. Found {len(results)} results.") - return {"query_text": request.query_text, "results": results} - except Exception as e: - logger.error(f"An error occurred while processing the request: {e}") - raise HTTPException(status_code=500, detail="Internal Server Error") + unique_id = f"advanced_semantic_search_{request.query_text}_{request.llm_model_name}_{request.similarity_filter_percentage}_{request.number_of_most_similar_strings_to_return}" + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() + request_time = datetime.utcnow() + llm_model_name = request.llm_model_name + total_entries = len(associated_texts_by_model[llm_model_name]) + num_results = max([1, int((1 - request.similarity_filter_percentage) * total_entries)]) + logger.info(f"Received request to find {num_results} most similar strings for query text: `{request.query_text}` using model: {llm_model_name}") + if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN): + raise HTTPException(status_code=403, detail="Unauthorized") + try: + logger.info(f"Computing embedding for input text: {request.query_text}") + embedding_request = EmbeddingRequest(text=request.query_text, llm_model_name=llm_model_name) + embedding_response = await get_embedding_vector_for_string(embedding_request, req) + input_embedding = np.array(embedding_response["embedding"]).astype('float32').reshape(1, -1) + faiss.normalize_L2(input_embedding) + logger.info(f"Computed embedding for input text: {request.query_text}") + faiss_index = faiss_indexes.get(llm_model_name) + if faiss_index is None: + raise HTTPException(status_code=400, detail=f"No FAISS index found for model: {llm_model_name}") + _, indices = faiss_index.search(input_embedding, num_results) + filtered_indices = indices[0] + similarity_results = [] + for idx in filtered_indices: + associated_text = associated_texts_by_model[llm_model_name][idx] + embedding_request = EmbeddingRequest(text=associated_text, llm_model_name=llm_model_name) + embedding_response = await get_embedding_vector_for_string(embedding_request, req) + filtered_embedding = np.array(embedding_response["embedding"]) + params = { + "vector_1": input_embedding.tolist()[0], + "vector_2": filtered_embedding.tolist(), + "similarity_measure": "all" + } + similarity_stats_str = fvs.py_compute_vector_similarity_stats(json.dumps(params)) + similarity_stats_json = json.loads(similarity_stats_str) + similarity_results.append({ + "search_result_text": associated_text, + "similarity_to_query_text": similarity_stats_json + }) + num_to_return = request.number_of_most_similar_strings_to_return if request.number_of_most_similar_strings_to_return is not None else len(similarity_results) + results = sorted(similarity_results, key=lambda x: x["similarity_to_query_text"]["hoeffding_d"], reverse=True)[:num_to_return] + response_time = datetime.utcnow() + total_time = (response_time - request_time).total_seconds() + logger.info(f"Finished advanced search in {total_time} seconds. Found {len(results)} results.") + return {"query_text": request.query_text, "results": results} + except Exception as e: + logger.error(f"An error occurred while processing the request: {e}") + raise HTTPException(status_code=500, detail="Internal Server Error") + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing"} @@ -1775,41 +733,50 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...), hash_obj.update(chunk) file_hash = hash_obj.hexdigest() logger.info(f"SHA3-256 hash of submitted file: {file_hash}") - async with AsyncSessionLocal() as session: # Check if the document has been processed before - result = await session.execute(select(DocumentEmbedding).filter(DocumentEmbedding.file_hash == file_hash, DocumentEmbedding.llm_model_name == llm_model_name)) - existing_document_embedding = result.scalar_one_or_none() - if existing_document_embedding: # If the document has been processed before, return the existing result - logger.info(f"Document {file.filename} has been processed before, returning existing result") - json_content = json.dumps(existing_document_embedding.document_embedding_results_json).encode() - else: # If the document has not been processed, continue processing - mime = Magic(mime=True) - mime_type = mime.from_file(temp_file_path) - logger.info(f"Received request to extract embeddings for document {file.filename} with MIME type: {mime_type} and size: {os.path.getsize(temp_file_path)} bytes from IP address: {client_ip}") - strings = await parse_submitted_document_file_into_sentence_strings_func(temp_file_path, mime_type) - results = await compute_embeddings_for_document(strings, llm_model_name, client_ip, file_hash) # Compute the embeddings and json_content for new documents - df = pd.DataFrame(results, columns=['text', 'embedding']) - json_content = df.to_json(orient=json_format or 'records').encode() - with open(temp_file_path, 'rb') as file_buffer: # Store the results in the database - original_file_content = file_buffer.read() - await store_document_embeddings_in_db(file, file_hash, original_file_content, json_content, results, llm_model_name, client_ip, request_time) - overall_total_time = (datetime.utcnow() - request_time).total_seconds() - logger.info(f"Done getting all embeddings for document {file.filename} containing {len(strings)} with model {llm_model_name}") - json_content_length = len(json_content) - if len(json_content) > 0: - logger.info(f"The response took {overall_total_time} seconds to generate, or {overall_total_time / (len(strings)/1000.0)} seconds per thousand input tokens and {overall_total_time / (float(json_content_length)/1000000.0)} seconds per million output characters.") - if send_back_json_or_zip_file == 'json': # Assume 'json' response should be sent back - logger.info(f"Returning JSON response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}") - return JSONResponse(content=json.loads(json_content.decode())) # Decode the content and parse it as JSON - else: # Assume 'zip' file should be sent back - original_filename_without_extension, _ = os.path.splitext(file.filename) - json_file_path = f"/tmp/{original_filename_without_extension}.json" - with open(json_file_path, 'wb') as json_file: # Write the JSON content as bytes - json_file.write(json_content) - zip_file_path = f"/tmp/{original_filename_without_extension}.zip" - with zipfile.ZipFile(zip_file_path, 'w') as zipf: - zipf.write(json_file_path, os.path.basename(json_file_path)) - logger.info(f"Returning ZIP response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}") - return FileResponse(zip_file_path, headers={"Content-Disposition": f"attachment; filename={original_filename_without_extension}.zip"}) + unique_id = f"document_embedding_{file_hash}_{llm_model_name}" + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + async with AsyncSessionLocal() as session: # Check if the document has been processed before + result = await session.execute(select(DocumentEmbedding).filter(DocumentEmbedding.file_hash == file_hash, DocumentEmbedding.llm_model_name == llm_model_name)) + existing_document_embedding = result.scalar_one_or_none() + if existing_document_embedding: # If the document has been processed before, return the existing result + logger.info(f"Document {file.filename} has been processed before, returning existing result") + json_content = json.dumps(existing_document_embedding.document_embedding_results_json).encode() + else: # If the document has not been processed, continue processing + mime = Magic(mime=True) + mime_type = mime.from_file(temp_file_path) + logger.info(f"Received request to extract embeddings for document {file.filename} with MIME type: {mime_type} and size: {os.path.getsize(temp_file_path)} bytes from IP address: {client_ip}") + strings = await parse_submitted_document_file_into_sentence_strings_func(temp_file_path, mime_type) + results = await compute_embeddings_for_document(strings, llm_model_name, client_ip, file_hash) # Compute the embeddings and json_content for new documents + df = pd.DataFrame(results, columns=['text', 'embedding']) + json_content = df.to_json(orient=json_format or 'records').encode() + with open(temp_file_path, 'rb') as file_buffer: # Store the results in the database + original_file_content = file_buffer.read() + await store_document_embeddings_in_db(file, file_hash, original_file_content, json_content, results, llm_model_name, client_ip, request_time) + overall_total_time = (datetime.utcnow() - request_time).total_seconds() + logger.info(f"Done getting all embeddings for document {file.filename} containing {len(strings)} with model {llm_model_name}") + json_content_length = len(json_content) + if len(json_content) > 0: + logger.info(f"The response took {overall_total_time} seconds to generate, or {overall_total_time / (len(strings)/1000.0)} seconds per thousand input tokens and {overall_total_time / (float(json_content_length)/1000000.0)} seconds per million output characters.") + if send_back_json_or_zip_file == 'json': # Assume 'json' response should be sent back + logger.info(f"Returning JSON response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}") + return JSONResponse(content=json.loads(json_content.decode())) # Decode the content and parse it as JSON + else: # Assume 'zip' file should be sent back + original_filename_without_extension, _ = os.path.splitext(file.filename) + json_file_path = f"/tmp/{original_filename_without_extension}.json" + with open(json_file_path, 'wb') as json_file: # Write the JSON content as bytes + json_file.write(json_content) + zip_file_path = f"/tmp/{original_filename_without_extension}.zip" + with zipfile.ZipFile(zip_file_path, 'w') as zipf: + zipf.write(json_file_path, os.path.basename(json_file_path)) + logger.info(f"Returning ZIP response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}") + return FileResponse(zip_file_path, headers={"Content-Disposition": f"attachment; filename={original_filename_without_extension}.zip"}) + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing"} + @app.post("/get_text_completions_from_input_prompt/", @@ -1833,7 +800,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...), ```json { "input_prompt": "The Kings of France in the 17th Century:", - "llm_model_name": "phind-codellama-34b-python-v1", + "llm_model_name": "mistral-7b-instruct-v0.1", "temperature": 0.95, "grammar_file_string": "json", "number_of_tokens_to_generate": 500, @@ -1849,7 +816,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...), [ { "input_prompt": "The Kings of France in the 17th Century:", - "llm_model_name": "phind-codellama-34b-python-v1", + "llm_model_name": "mistral-7b-instruct-v0.1", "grammar_file_string": "json", "number_of_tokens_to_generate": 500, "number_of_completions_to_generate": 3, @@ -1859,7 +826,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...), }, { "input_prompt": "The Kings of France in the 17th Century:", - "llm_model_name": "phind-codellama-34b-python-v1", + "llm_model_name": "mistral-7b-instruct-v0.1", "grammar_file_string": "json", "number_of_tokens_to_generate": 500, "number_of_completions_to_generate": 3, @@ -1869,7 +836,7 @@ async def get_all_embedding_vectors_for_document(file: UploadFile = File(...), }, { "input_prompt": "The Kings of France in the 17th Century:", - "llm_model_name": "phind-codellama-34b-python-v1", + "llm_model_name": "mistral-7b-instruct-v0.1", "grammar_file_string": "json", "number_of_tokens_to_generate": 500, "number_of_completions_to_generate": 3, @@ -1884,7 +851,15 @@ async def get_text_completions_from_input_prompt(request: TextCompletionRequest, logger.warning(f"Unauthorized request from client IP {client_ip}") raise HTTPException(status_code=403, detail="Unauthorized") try: - return await generate_completion_from_llm(request, req, client_ip) + unique_id = f"text_completion_{hash(request.input_prompt)}_{request.llm_model_name}" + lock = await shared_resources.lock_manager.lock(unique_id) + if lock.valid: + try: + return await generate_completion_from_llm(request, req, client_ip) + finally: + await shared_resources.lock_manager.unlock(lock) + else: + return {"status": "already processing"} except Exception as e: logger.error(f"An error occurred while processing the request: {e}") logger.error(traceback.format_exc()) # Print the traceback @@ -1936,24 +911,7 @@ async def clear_ramdisk_endpoint(token: str = None): @app.on_event("startup") async def startup_event(): - global db_writer, faiss_indexes, token_faiss_indexes, associated_texts_by_model - await initialize_db() - queue = asyncio.Queue() - db_writer = DatabaseWriter(queue) - await db_writer.initialize_processing_hashes() - asyncio.create_task(db_writer.dedicated_db_writer()) - global USE_RAMDISK - if USE_RAMDISK and not check_that_user_has_required_permissions_to_manage_ramdisks(): - USE_RAMDISK = False - elif USE_RAMDISK: - setup_ramdisk() - list_of_downloaded_model_names, download_status = download_models() - for llm_model_name in list_of_downloaded_model_names: - try: - load_model(llm_model_name, raise_http_exception=False) - except FileNotFoundError as e: - logger.error(e) - faiss_indexes, token_faiss_indexes, associated_texts_by_model = await build_faiss_indexes() + await initialize_globals() @app.get("/download/{file_name}") @@ -1985,4 +943,4 @@ def show_logs_default(): if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT) + uvicorn.run("swiss_army_llama:app", **option) diff --git a/uvicorn_config.py b/uvicorn_config.py new file mode 100644 index 0000000..ac4f89f --- /dev/null +++ b/uvicorn_config.py @@ -0,0 +1,9 @@ +from decouple import config +SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int) +UVICORN_NUMBER_OF_WORKERS = config("UVICORN_NUMBER_OF_WORKERS", default=3, cast=int) + +option = { + "host": "0.0.0.0", + "port": SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT, + "workers": UVICORN_NUMBER_OF_WORKERS +}