From 8627064c1058277d492814f566ed84c6c837f321 Mon Sep 17 00:00:00 2001 From: Dicklesworthstone Date: Tue, 21 May 2024 14:47:05 -0400 Subject: [PATCH] Fix --- swiss_army_llama.py | 47 ++++++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/swiss_army_llama.py b/swiss_army_llama.py index 3da564e..d10ef2a 100644 --- a/swiss_army_llama.py +++ b/swiss_army_llama.py @@ -765,7 +765,7 @@ async def advanced_search_stored_embeddings_with_query_string_for_semantic_simil return {"status": "already processing"} - + @app.post("/get_all_embedding_vectors_for_document/", summary="Get Embeddings for a Document", description="""Extract text embeddings for a document. This endpoint supports plain text, .doc/.docx (MS Word), PDF files, images (using Tesseract OCR), and many other file types supported by the textract library. @@ -813,16 +813,23 @@ async def get_all_embedding_vectors_for_document( if USE_SECURITY_TOKEN and use_hardcoded_security_token and (token is None or token != SECURITY_TOKEN): raise HTTPException(status_code=403, detail="Unauthorized") if file: - _, extension = os.path.splitext(file.filename) - temp_file_path = tempfile.NamedTemporaryFile(suffix=extension, delete=False).name + input_data_binary = await file.read() + result = magika.identify_bytes(input_data_binary) + detected_data_type = result.output.ct_label + temp_file = tempfile.NamedTemporaryFile(suffix=f".{detected_data_type}", delete=False) + temp_file_path = temp_file.name + logger.info(f"Temp file path: {temp_file_path}") with open(temp_file_path, 'wb') as buffer: - chunk_size = 1024 - chunk = await file.read(chunk_size) - while chunk: - buffer.write(chunk) - chunk = await file.read(chunk_size) + buffer.write(input_data_binary) elif url and hash and size: temp_file_path = await download_file(url, size, hash) + with open(temp_file_path, 'rb') as file: + input_data_binary = file.read() + result = magika.identify_bytes(input_data_binary) + detected_data_type = result.output.ct_label + new_temp_file_path = temp_file_path + f".{detected_data_type}" + os.rename(temp_file_path, new_temp_file_path) + temp_file_path = new_temp_file_path else: raise HTTPException(status_code=400, detail="Invalid input. Provide either a file or URL with hash and size.") @@ -830,7 +837,7 @@ async def get_all_embedding_vectors_for_document( # Verify file integrity hash_obj = sha3_256() with open(temp_file_path, 'rb') as buffer: - for chunk in iter(lambda: buffer.read(chunk_size), b''): + for chunk in iter(lambda: buffer.read(1024), b''): hash_obj.update(chunk) file_hash = hash_obj.hexdigest() logger.info(f"SHA3-256 hash of submitted file: {file_hash}") @@ -1140,17 +1147,26 @@ async def compute_transcript_with_whisper_from_audio( logger.warning(f"Unauthorized request from client_ip {client_ip}") raise HTTPException(status_code=403, detail="Unauthorized") if file: - temp_file_path = tempfile.NamedTemporaryFile(delete=False).name + input_data_binary = await file.read() + result = magika.identify_bytes(input_data_binary) + detected_data_type = result.output.ct_label + temp_file = tempfile.NamedTemporaryFile(suffix=f".{detected_data_type}", delete=False) + temp_file_path = temp_file.name + logger.info(f"Temp file path: {temp_file_path}") with open(temp_file_path, 'wb') as buffer: - chunk_size = 1024 - chunk = await file.read(chunk_size) - while chunk: - buffer.write(chunk) - chunk = await file.read(chunk_size) + buffer.write(input_data_binary) elif url and hash and size: temp_file_path = await download_file(url, size, hash) + with open(temp_file_path, 'rb') as file: + input_data_binary = file.read() + result = magika.identify_bytes(input_data_binary) + detected_data_type = result.output.ct_label + new_temp_file_path = temp_file_path + f".{detected_data_type}" + os.rename(temp_file_path, new_temp_file_path) + temp_file_path = new_temp_file_path else: raise HTTPException(status_code=400, detail="Invalid input. Provide either a file or URL with hash and size.") + audio_file_size_mb = os.path.getsize(temp_file_path) / (1024 * 1024) input_data = { "file_size_mb": audio_file_size_mb, @@ -1170,6 +1186,7 @@ async def compute_transcript_with_whisper_from_audio( end_resource_monitoring(context) + @app.post("/add_new_grammar_definition_file/", response_model=AddGrammarResponse, summary="Add or Update a Grammar Definition File",