diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index 1c8fae060..3cc41c68f 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -2,7 +2,7 @@ from furl import furl import requests -from loguru import logger + from daras_ai_v2.exceptions import UserError from daras_ai_v2.functional import flatmap_parallel from daras_ai_v2.exceptions import raise_for_status @@ -64,88 +64,36 @@ def gdrive_list_urls_of_files_in_folder(f: furl, max_depth: int = 4) -> list[str def gdrive_download(f: furl, mime_type: str, export_links: dict) -> tuple[bytes, str]: from googleapiclient import discovery + from googleapiclient.http import MediaIoBaseDownload # get drive file id file_id = url_to_gdrive_file_id(f) # get metadata service = discovery.build("drive", "v3") - request, mime_type = service_request(service, file_id, f, mime_type) - file_bytes, mime_type = download_blob_file_content( - service, request, file_id, f, mime_type, export_links + if f.host != "drive.google.com": + # export google docs to appropriate type + export_mime_type, _ = docs_export_mimetype(f) + if f_url_export := export_links.get(export_mime_type, None): + r = requests.get(f_url_export) + file_bytes = r.content + raise_for_status(r) + return file_bytes, export_mime_type + + request = service.files().get_media( + fileId=file_id, + supportsAllDrives=True, ) - - return file_bytes, mime_type - - -def service_request( - service, file_id: str, f: furl, mime_type: str, retried_request=False -) -> tuple[any, str]: - # get files in drive directly - if f.host == "drive.google.com" or retried_request: - request = service.files().get_media( - fileId=file_id, - supportsAllDrives=True, - ) - # export google docs to appropriate type - else: - mime_type, _ = docs_export_mimetype(f) - request = service.files().export_media( - fileId=file_id, - mimeType=mime_type, - ) - return request, mime_type - - -def download_blob_file_content( - service, request, file_id: str, f: furl, mime_type: str, export_links: dict -) -> tuple[bytes, str]: - from googleapiclient.http import MediaIoBaseDownload - from googleapiclient.errors import HttpError - # download file = io.BytesIO() downloader = MediaIoBaseDownload(file, request) + done = False + while done is False: + _, done = downloader.next_chunk() + # print(f"Download {int(status.progress() * 100)}%") + file_bytes = file.getvalue() - if ( - mime_type - == "application/vnd.openxmlformats-officedocument.presentationml.presentation" - ): - # logger.debug(f"Downloading {str(f)!r} using export links") - f_url_export = export_links.get(mime_type, None) - if f_url_export: - - f_bytes = download_from_exportlinks(f_url_export) - else: - request = service.files().get_media( - fileId=file_id, - supportsAllDrives=True, - ) - downloader = MediaIoBaseDownload(file, request) - - done = False - while done is False: - _, done = downloader.next_chunk() - # print(f"Download {int(status.progress() * 100)}%") - f_bytes = file.getvalue() - - else: - done = False - while done is False: - _, done = downloader.next_chunk() - # print(f"Download {int(status.progress() * 100)}%") - f_bytes = file.getvalue() - - return f_bytes, mime_type - - -def download_from_exportlinks(f: furl) -> bytes: - try: - r = requests.get(f) - f_bytes = r.content - except requests.exceptions.RequestException as e: - raise_for_status(e) - return f_bytes + return file_bytes, mime_type def docs_export_mimetype(f: furl) -> tuple[str, str]: diff --git a/daras_ai_v2/glossary.py b/daras_ai_v2/glossary.py index 87618b87c..77c252173 100644 --- a/daras_ai_v2/glossary.py +++ b/daras_ai_v2/glossary.py @@ -15,7 +15,7 @@ def validate_glossary_document(document: str): metadata = doc_url_to_file_metadata(document) f_bytes, mime_type = download_content_bytes( - f_url=document, mime_type=metadata.mime_type + f_url=document, mime_type=metadata.mime_type, export_links=metadata.export_links ) df = tabular_bytes_to_str_df( f_name=metadata.name, f_bytes=f_bytes, mime_type=mime_type diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py index 25c99a2dc..4eb7d654e 100644 --- a/recipes/BulkRunner.py +++ b/recipes/BulkRunner.py @@ -612,7 +612,7 @@ def get_columns(files: list[str]) -> list[str]: def read_df_any(f_url: str) -> "pd.DataFrame": file_meta = doc_url_to_file_metadata(f_url) f_bytes, mime_type = download_content_bytes( - f_url=f_url, mime_type=file_meta.mime_type + f_url=f_url, mime_type=file_meta.mime_type, export_links=file_meta.export_links ) df = tabular_bytes_to_any_df( f_name=file_meta.name, f_bytes=f_bytes, mime_type=mime_type diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 23cf89bfe..0fa063379 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -475,7 +475,9 @@ def process_source( elif is_video: f = furl(webpage_url) if is_gdrive_url(f): - f_bytes, _ = gdrive_download(f, doc_meta.mime_type) + f_bytes, _ = gdrive_download( + f, doc_meta.mime_type, doc_meta.export_links + ) webpage_url = upload_file_from_bytes( doc_meta.name, f_bytes, content_type=doc_meta.mime_type )