Skip to content

Commit

Permalink
Added alternative way of supplying binary files to various endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
Dicklesworthstone committed May 20, 2024
1 parent aff224a commit 1199b8d
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 60 deletions.
27 changes: 26 additions & 1 deletion service_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,4 +708,29 @@ def convert_document_to_sentences_func(file_path: str, mime_type: str) -> Dict[s
"total_input_file_size_in_bytes": total_input_file_size_in_bytes,
"total_text_size_in_characters": total_text_size_in_characters
}
return result
return result

async def download_file(url: str, expected_size: int, expected_hash: str) -> str:
temp_file = tempfile.NamedTemporaryFile(delete=False)
temp_file_path = temp_file.name
hash_obj = hashlib.sha3_256()
downloaded_size = 0
async with httpx.AsyncClient() as client:
async with client.stream("GET", url) as response:
if response.status_code != 200:
raise HTTPException(status_code=400, detail="Failed to download file")
async for chunk in response.aiter_bytes():
downloaded_size += len(chunk)
if downloaded_size > expected_size:
os.remove(temp_file_path)
raise HTTPException(status_code=400, detail="Downloaded file size exceeds expected size")
temp_file.write(chunk)
hash_obj.update(chunk)
temp_file.close()
if downloaded_size != expected_size:
os.remove(temp_file_path)
raise HTTPException(status_code=400, detail="Downloaded file size does not match expected size")
if hash_obj.hexdigest() != expected_hash:
os.remove(temp_file_path)
raise HTTPException(status_code=400, detail="File hash mismatch")
return temp_file_path
173 changes: 114 additions & 59 deletions swiss_army_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from embeddings_data_models import EmbeddingRequest, SemanticSearchRequest, AdvancedSemanticSearchRequest, SimilarityRequest, TextCompletionRequest, AddGrammarRequest
from embeddings_data_models import EmbeddingResponse, SemanticSearchResponse, AdvancedSemanticSearchResponse, SimilarityResponse, AllStringsResponse, AllDocumentsResponse, TextCompletionResponse, AddGrammarResponse
from embeddings_data_models import ShowLogsIncrementalModel
from service_functions import get_or_compute_embedding, get_or_compute_transcript, add_model_url, get_or_compute_token_level_embedding_bundle_combined_feature_vector, calculate_token_level_embeddings
from service_functions import get_or_compute_embedding, get_or_compute_transcript, add_model_url, get_or_compute_token_level_embedding_bundle_combined_feature_vector, calculate_token_level_embeddings, download_file
from service_functions import parse_submitted_document_file_into_sentence_strings_func, compute_embeddings_for_document, store_document_embeddings_in_db, generate_completion_from_llm, validate_bnf_grammar_func, convert_document_to_sentences_func
from grammar_builder import GrammarBuilder
from log_viewer_functions import show_logs_incremental_func, show_logs_func
Expand Down Expand Up @@ -765,6 +765,9 @@ async def advanced_search_stored_embeddings_with_query_string_for_semantic_simil
### Parameters:
- `file`: The uploaded document file (either plain text, .doc/.docx, PDF, etc.).
- `url`: URL of the document file to download.
- `hash`: SHA3-256 hash of the document file to verify integrity.
- `size`: Size of the document file in bytes to verify completeness.
- `llm_model_name`: The model used to calculate embeddings (optional).
- `json_format`: The format of the JSON response (optional, see details below).
- `send_back_json_or_zip_file`: Whether to return a JSON file or a ZIP file containing the embeddings file (optional, defaults to `zip`).
Expand All @@ -786,77 +789,93 @@ async def advanced_search_stored_embeddings_with_query_string_for_semantic_simil
- MS Word: Submit a `.doc` or `.docx` file.
- PDF: Submit a `.pdf` file.""",
response_description="Either a ZIP file containing the embeddings JSON file or a direct JSON response, depending on the value of `send_back_json_or_zip_file`.")
async def get_all_embedding_vectors_for_document(file: UploadFile = File(...),
llm_model_name: str = "bge-m3-q8_0",
json_format: str = 'records',
corpus_identifier_string: Optional[str] = None,
token: str = None,
send_back_json_or_zip_file: str = 'zip',
req: Request = None) -> Response:
async def get_all_embedding_vectors_for_document(
file: UploadFile = File(None),
url: str = Form(None),
hash: str = Form(None),
size: int = Form(None),
llm_model_name: str = "bge-m3-q8_0",
json_format: str = 'records',
corpus_identifier_string: Optional[str] = None,
token: str = None,
send_back_json_or_zip_file: str = 'zip',
req: Request = None
):
client_ip = req.client.host if req else "localhost"
request_time = datetime.utcnow()
if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN): raise HTTPException(status_code=403, detail="Unauthorized") # noqa: E701
_, extension = os.path.splitext(file.filename)
temp_file = tempfile.NamedTemporaryFile(suffix=extension, delete=False)
temp_file_path = temp_file.name
with open(temp_file_path, 'wb') as buffer:
chunk_size = 1024
chunk = await file.read(chunk_size)
while chunk:
buffer.write(chunk)

if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN):
raise HTTPException(status_code=403, detail="Unauthorized")

if file:
_, extension = os.path.splitext(file.filename)
temp_file_path = tempfile.NamedTemporaryFile(suffix=extension, delete=False).name
with open(temp_file_path, 'wb') as buffer:
chunk_size = 1024
chunk = await file.read(chunk_size)
hash_obj = sha3_256()
while chunk:
buffer.write(chunk)
chunk = await file.read(chunk_size)
elif url and hash and size:
temp_file_path = await download_file(url, size, hash)
else:
raise HTTPException(status_code=400, detail="Invalid input. Provide either a file or URL with hash and size.")

hash_obj = hashlib.sha3_256()
with open(temp_file_path, 'rb') as buffer:
for chunk in iter(lambda: buffer.read(chunk_size), b''):
hash_obj.update(chunk)
file_hash = hash_obj.hexdigest()
logger.info(f"SHA3-256 hash of submitted file: {file_hash}")

if corpus_identifier_string is None:
corpus_identifier_string = file_hash

unique_id = f"document_embedding_{file_hash}_{llm_model_name}"
lock = await shared_resources.lock_manager.lock(unique_id)
if lock.valid:
try:
async with AsyncSessionLocal() as session: # Check if the document has been processed before
lock = await shared_resources.lock_manager.lock(unique_id)

if lock.valid:
try:
async with AsyncSession() as session:
result = await session.execute(select(DocumentEmbedding).filter(DocumentEmbedding.file_hash == file_hash, DocumentEmbedding.llm_model_name == llm_model_name))
existing_document_embedding = result.scalar_one_or_none()
if existing_document_embedding: # If the document has been processed before, return the existing result
logger.info(f"Document {file.filename} has been processed before, returning existing result")
if existing_document_embedding:
logger.info(f"Document {file.filename if file else url} has been processed before, returning existing result")
json_content = json.dumps(existing_document_embedding.document_embedding_results_json).encode()
else: # If the document has not been processed, continue processing
else:
mime = Magic(mime=True)
mime_type = mime.from_file(temp_file_path)
logger.info(f"Received request to extract embeddings for document {file.filename} with MIME type: {mime_type} and size: {os.path.getsize(temp_file_path)} bytes from IP address: {client_ip}")
mime_type = mime.from_file(temp_file_path)
logger.info(f"Received request to extract embeddings for document {file.filename if file else url} with MIME type: {mime_type} and size: {os.path.getsize(temp_file_path)} bytes from IP address: {client_ip}")
strings = await parse_submitted_document_file_into_sentence_strings_func(temp_file_path, mime_type)
results = await compute_embeddings_for_document(strings, llm_model_name, client_ip, file_hash) # Compute the embeddings and json_content for new documents
results = await compute_embeddings_for_document(strings, llm_model_name, client_ip, file_hash)
df = pd.DataFrame(results, columns=['text', 'embedding'])
json_content = df.to_json(orient=json_format or 'records').encode()
with open(temp_file_path, 'rb') as file_buffer: # Store the results in the database
with open(temp_file_path, 'rb') as file_buffer:
original_file_content = file_buffer.read()
await store_document_embeddings_in_db(file, file_hash, original_file_content, json_content, results, llm_model_name, client_ip, request_time, corpus_identifier_string)
overall_total_time = (datetime.utcnow() - request_time).total_seconds()
logger.info(f"Done getting all embeddings for document {file.filename} containing {len(strings)} with model {llm_model_name}")
logger.info(f"Done getting all embeddings for document {file.filename if file else url} containing {len(strings)} with model {llm_model_name}")
json_content_length = len(json_content)
if len(json_content) > 0:
if json_content_length > 0:
logger.info(f"The response took {overall_total_time} seconds to generate, or {overall_total_time / (len(strings)/1000.0)} seconds per thousand input tokens and {overall_total_time / (float(json_content_length)/1000000.0)} seconds per million output characters.")
if send_back_json_or_zip_file == 'json': # Assume 'json' response should be sent back
logger.info(f"Returning JSON response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}")
return JSONResponse(content=json.loads(json_content.decode())) # Decode the content and parse it as JSON
else: # Assume 'zip' file should be sent back
original_filename_without_extension, _ = os.path.splitext(file.filename)
if send_back_json_or_zip_file == 'json':
logger.info(f"Returning JSON response for document {file.filename if file else url} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}")
return JSONResponse(content=json.loads(json_content.decode()))
else:
original_filename_without_extension, _ = os.path.splitext(file.filename if file else os.path.basename(url))
json_file_path = f"/tmp/{original_filename_without_extension}.json"
with open(json_file_path, 'wb') as json_file: # Write the JSON content as bytes
with open(json_file_path, 'wb') as json_file:
json_file.write(json_content)
zip_file_path = f"/tmp/{original_filename_without_extension}.zip"
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
zipf.write(json_file_path, os.path.basename(json_file_path))
logger.info(f"Returning ZIP response for document {file.filename} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}")
logger.info(f"Returning ZIP response for document {file.filename if file else url} containing {len(strings)} with model {llm_model_name}; first 100 characters out of {json_content_length} total of JSON response: {json_content[:100]}")
return FileResponse(zip_file_path, headers={"Content-Disposition": f"attachment; filename={original_filename_without_extension}.zip"})
finally:
await shared_resources.lock_manager.unlock(lock)
else:
return {"status": "already processing"}



@app.post("/get_text_completions_from_input_prompt/",
Expand Down Expand Up @@ -1057,6 +1076,9 @@ async def turn_pydantic_model_description_into_bnf_grammar_for_llm(
### Parameters:
- `file`: The uploaded audio file.
- `url`: URL of the audio file to download.
- `hash`: SHA3-256 hash of the audio file to verify integrity.
- `size`: Size of the audio file in bytes to verify completeness.
- `compute_embeddings_for_resulting_transcript_document`: Boolean to indicate if document embeddings should be computed (optional, defaults to True).
- `llm_model_name`: The language model used for computing embeddings (optional, defaults to the default model name).
- `req`: HTTP Request object for additional request metadata (optional).
Expand All @@ -1075,20 +1097,39 @@ async def turn_pydantic_model_description_into_bnf_grammar_for_llm(
- Unauthorized requests are logged and result in a 403 status.
- All other errors result in a 500 status and are logged with their tracebacks.""",
response_description="A JSON object containing the complete transcription details, computational times, and an optional URL for downloading a ZIP file of the document embeddings.")
async def compute_transcript_with_whisper_from_audio(file: UploadFile,
compute_embeddings_for_resulting_transcript_document: Optional[bool] = True,
llm_model_name: Optional[str] = DEFAULT_MODEL_NAME,
corpus_identifier_string: Optional[str] = None,
req: Request = None,
token: str = None,
client_ip: str = None):
async def compute_transcript_with_whisper_from_audio(
file: UploadFile = File(None),
url: str = Form(None),
hash: str = Form(None),
size: int = Form(None),
compute_embeddings_for_resulting_transcript_document: Optional[bool] = True,
llm_model_name: Optional[str] = DEFAULT_MODEL_NAME,
corpus_identifier_string: Optional[str] = None,
req: Request = None,
token: str = None,
client_ip: str = None
):
if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN):
logger.warning(f"Unauthorized request from client_ip {client_ip}")
raise HTTPException(status_code=403, detail="Unauthorized")
if file:
temp_file_path = tempfile.NamedTemporaryFile(delete=False).name
with open(temp_file_path, 'wb') as buffer:
chunk_size = 1024
chunk = await file.read(chunk_size)
while chunk:
buffer.write(chunk)
chunk = await file.read(chunk_size)
elif url and hash and size:
temp_file_path = await download_file(url, size, hash)
else:
raise HTTPException(status_code=400, detail="Invalid input. Provide either a file or URL with hash and size.")
try:
audio_transcript = await get_or_compute_transcript(file, compute_embeddings_for_resulting_transcript_document, llm_model_name, req, corpus_identifier_string)
return audio_transcript
audio_transcript = await get_or_compute_transcript(temp_file_path, compute_embeddings_for_resulting_transcript_document, llm_model_name, req, corpus_identifier_string)
os.remove(temp_file_path)
return JSONResponse(content=audio_transcript)
except Exception as e:
os.remove(temp_file_path)
logger.error(f"An error occurred while processing the request: {e}")
logger.error(traceback.format_exc()) # Print the traceback
raise HTTPException(status_code=500, detail="Internal Server Error")
Expand Down Expand Up @@ -1181,7 +1222,7 @@ async def clear_ramdisk_endpoint(token: str = None):
### Security:
If a security token is required by the application configuration, you must provide a valid `token` to access this endpoint. Unauthorized access will result in a 403 status code.""",
response_description="The ZIP file that was requested, or a status code indicating an error.")
async def download_file(file_name: str, token: str = None):
async def download_file_endpoint(file_name: str, token: str = None):
if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN):
raise HTTPException(status_code=403, detail="Unauthorized")
decoded_file_name = unquote(file_name)
Expand Down Expand Up @@ -1252,6 +1293,9 @@ def show_logs_default():
### Parameters:
- `file`: The uploaded document file (supports plain text, .doc/.docx, PDF files, images using Tesseract OCR, and many other file types supported by the textract library).
- `url`: URL of the file to download.
- `hash`: SHA3-256 hash of the file to verify integrity.
- `size`: Size of the file in bytes to verify completeness.
- `token`: Security token (optional).
### Response:
Expand All @@ -1277,20 +1321,31 @@ def show_logs_default():
```""",
response_description="A JSON object containing the sentences extracted from the document and various statistics."
)
async def convert_document_to_sentences(file: UploadFile = File(...), token: str = None):
async def convert_document_to_sentences(
file: UploadFile = File(None),
url: str = Form(None),
hash: str = Form(None),
size: int = Form(None),
token: str = Form(None)
):
if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN):
raise HTTPException(status_code=403, detail="Unauthorized")
_, extension = os.path.splitext(file.filename)
temp_file = tempfile.NamedTemporaryFile(suffix=extension, delete=False)
temp_file_path = temp_file.name
with open(temp_file_path, 'wb') as buffer:
chunk_size = 1024
chunk = await file.read(chunk_size)
while chunk:
buffer.write(chunk)
if file:
_, extension = os.path.splitext(file.filename)
temp_file = tempfile.NamedTemporaryFile(suffix=extension, delete=False)
temp_file_path = temp_file.name
with open(temp_file_path, 'wb') as buffer:
chunk_size = 1024
chunk = await file.read(chunk_size)
while chunk:
buffer.write(chunk)
chunk = await file.read(chunk_size)
elif url and hash and size:
temp_file_path = await download_file(url, size, hash)
else:
raise HTTPException(status_code=400, detail="Invalid input. Provide either a file or URL with hash and size.")
mime = Magic(mime=True)
mime_type = mime.from_file(temp_file_path)
result = convert_document_to_sentences_func(temp_file_path, mime_type)
os.remove(temp_file_path)
return JSONResponse(content=result)
return JSONResponse(content=result)

0 comments on commit 1199b8d

Please sign in to comment.