From a5c6a31c6415c525cb9a549a84016a8d834c1779 Mon Sep 17 00:00:00 2001 From: Rehan Date: Sat, 22 Jun 2024 21:04:49 -0400 Subject: [PATCH] Add OpenAI batch job support And many other little changes and fixes --- VisioNomicon/args_handler.py | 35 +++++++---- VisioNomicon/gpt.py | 110 ++++++++++++++++++++++++++++++++--- VisioNomicon/main.py | 62 ++++++++++++++------ 3 files changed, 170 insertions(+), 37 deletions(-) diff --git a/VisioNomicon/args_handler.py b/VisioNomicon/args_handler.py index 23c5f6b..8d508dd 100644 --- a/VisioNomicon/args_handler.py +++ b/VisioNomicon/args_handler.py @@ -75,6 +75,18 @@ def parse_cli_args(): action="store_true", help="If error retries limit is reached, map file to original name instead of returning an error", ) + parser.add_argument( + "-b", + "--create-batch", + action="store_true", + help="Create batch job through OpenAI API", + ) + parser.add_argument( + "-B", + "--retrieve-batch", + action="store_true", + help="Retrieve batch job output through OpenAI API. Run this 24 hours after creating the batch job.", + ) # if flag with value, equals value # if flag with no value, equals const value @@ -103,16 +115,16 @@ def parse_cli_args(): parser.error("-u/--undo must not be used with any other arguments.") #################################################################################### - if args.files is not None and len(args.files) == 0: + if args.files == NO_VAL: parser.error("-f/--files requires a value") - if args.output is not None and args.execute is not None: + if args.output and args.execute: parser.error( "instead of using -o/--output along with -x/--execute, use -ox/--mapex" ) - if args.mapex is not None: - if args.output is not None or args.execute is not None: + if args.mapex: + if args.output or args.execute: parser.error( "-ox/--mapex should be used without -o/--output or -x/--execute" ) @@ -120,10 +132,13 @@ def parse_cli_args(): args.output = args.mapex args.execute = args.mapex - if args.output is not None and args.files is None: + if args.output and not args.files: parser.error("-o/--output must be used with -f/--files") - if args.template is None: + if args.create_batch and not args.files: + parser.error("-b/--create-batch must be used with -f/--files") + + if args.template == NO_VAL: parser.error("used -t/--template with no value") supported_ext = [".png", ".jpeg", ".jpg", ".webp", ".gif"] @@ -131,7 +146,7 @@ def parse_cli_args(): # # get absolute paths where we need them # - if args.files is not None: + if args.files: args.files = [os.path.abspath(path) for path in args.files] clean_paths = args.files.copy() @@ -148,13 +163,13 @@ def parse_cli_args(): parser.error("Filetype {} not supported".format(image_ext)) args.files = clean_paths - if args.output is not None and args.output != NO_VAL: + if args.output and args.output != NO_VAL: args.output = os.path.abspath(args.output) - if args.execute is not None and args.execute != NO_VAL: + if args.execute and args.execute != NO_VAL: args.execute = os.path.abspath(args.execute) - if args.undo is not None and args.undo != NO_VAL: + if args.undo and args.undo != NO_VAL: args.undo = os.path.abspath(args.undo) return args diff --git a/VisioNomicon/gpt.py b/VisioNomicon/gpt.py index c29268c..1c6868d 100644 --- a/VisioNomicon/gpt.py +++ b/VisioNomicon/gpt.py @@ -1,11 +1,14 @@ -from openai import OpenAI +import openai +import json +import io from pathlib import Path import os import requests import base64 import sys +from constants import API_KEY, NAMING_PROMPT, MODEL -API_KEY = "" +RETRIEVED_JSON = {} def set_api_key(): @@ -16,6 +19,98 @@ def set_api_key(): API_KEY = os.environ.get("OPENAI_API_KEY") if API_KEY == "" else API_KEY +def batch(filepaths: list[str], base64_strs: list[str], template: str, data_dir: str): + batch_reqs = [] + for filepath, base64_str in zip(filepaths, base64_strs): + batch_reqs.append( + { + "custom_id": filepath, + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": MODEL, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": NAMING_PROMPT.format(template=template), + }, + { + "type": "image_url", + "image_url": { + "url": base64_str, + "detail": "auto", + }, + }, + ], + } + ], + "temperature": 0.7, + }, + } + ) + + set_api_key() + bytes_buffer = io.BytesIO() + # write to bytes buffer + # doing this to avoid having to write file to disk then pull back from disk to send + for entry in batch_reqs: + json_line = json.dumps(entry) + "\n" + bytes_buffer.write(json_line.encode("utf-8")) + + # reset buffer position to prepare to send + bytes_buffer.seek(0) + + file_upload_response = openai.files.create(file=bytes_buffer, purpose="batch") + + # create batch request from uploaded requests file + # only 24h completion window is available for now + batch = openai.batches.create( + input_file_id=file_upload_response.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + # write batch id to file to retrieve later + with open(f"{data_dir}/batch_id", "w") as file: + file.write(batch.id) + + +def image_to_name_retrieve(image_path: str) -> str: + global RETRIEVED_JSON + + if not RETRIEVED_JSON: + # get file_id for completed responses + file_id = "" + # get batch id from file + data_dir = os.environ.get("XDG_DATA_HOME") + data_dir = ( + data_dir if data_dir else os.path.abspath("~/.local/share") + ) + "/visionomicon/" + with open(f"{data_dir}/batch_id", "r") as f: + file_id = openai.batches.retrieve(f.read()).output_file_id + + # could occur if batch not complete yet + if file_id is None: + print("Error during batch retrieval, maybe the job isn't complete yet.") + sys.exit() + + try: + # get responses in a json str + response_str = openai.files.content(file_id).content.decode("utf-8") + # output file for responses may be expired or deleted + except openai.NotFoundError: + print("Error during batch retrieval, output file could not be retrieved.") + sys.exit() + # each response in own json + response_jsons = [json.loads(s) for s in response_str.split("\n") if s.strip()] + RETRIEVED_JSON = {s["custom_id"]: s for s in response_jsons} + return RETRIEVED_JSON[image_path]["response"]["body"]["choices"][0]["message"][ + "content" + ].strip() + + def image_to_name(image_path: str, args) -> str: template: str = args.template @@ -34,14 +129,14 @@ def encode_image(image_path: str): headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"} payload = { - "model": "gpt-4o", + "model": MODEL, "messages": [ { "role": "user", "content": [ { "type": "text", - "text": f"Generate a filename for an image by analyzing its content and utilizing a user-provided template. Placeholders enclosed in square brackets (e.g., [Subject], [Color], [Action]) will be used, which represent specific elements to be incorporated in the filename. Replace the placeholders accurately and succinctly with terms pulled from the image content, removing the brackets in the final filename. For instance, if the template reads '[MainSubject]_in_[Setting]', the filename might be 'Cat_in_Garden'. Construct the filename omitting the file extension and any other text. Assure that every placeholder is filled with precise, image-derived information, conforming to typical filename length restrictions. The given template is '{template}'.", + "text": NAMING_PROMPT.format(template=template), }, { "type": "image_url", @@ -64,7 +159,7 @@ def encode_image(image_path: str): try: return response_json["choices"][0]["message"]["content"] - except: + except KeyError: print("OpenAI Unexpected Response:", response_json["error"]["message"]) i < args.error_retries and print("retrying...\n") @@ -79,10 +174,9 @@ def encode_image(image_path: str): def name_validation(name: str, template: str): set_api_key() - client = OpenAI() - completion = client.chat.completions.create( - model="gpt-4-1106-preview", + completion = openai.chat.completions.create( + model=MODEL, messages=[ { "role": "system", diff --git a/VisioNomicon/main.py b/VisioNomicon/main.py index 6c63534..7c42315 100644 --- a/VisioNomicon/main.py +++ b/VisioNomicon/main.py @@ -1,30 +1,34 @@ import os +import sys +import base64 import json import copy import glob -from VisioNomicon.args_handler import * -from VisioNomicon.gpt import * +from constants import get_data_dir +from args_handler import parse_cli_args, NO_VAL +from gpt import ( + image_to_name, + name_validation, + image_to_name_retrieve, + batch, +) from datetime import datetime -DATA_DIR = "" - def main(): - # get data dir - global DATA_DIR - DATA_DIR = ( - os.environ.get("XDG_DATA_HOME") - if "XDG_DATA_HOME" in os.environ - else os.path.abspath("~/.local/share") - ) + "/visionomicon/" - # make data dir if doesn't exist - not os.path.exists(DATA_DIR) and os.makedirs(DATA_DIR) + data_dir = get_data_dir() + if not os.path.exists(data_dir): + os.makedirs(data_dir) args = parse_cli_args() + if args.create_batch: + create_batch(args) + print("Batch job created.") + return 0 # if creating mapping - if args.files is not None: + if args.files: new_filepaths: list[str] = generate_mapping(args) # have new and old, put them together into a json and save @@ -33,7 +37,7 @@ def main(): # if executing or undoing if args.undo or args.execute: rel_mapping_fp = args.execute if args.execute else args.undo - rename_from_mapping(rel_mapping_fp, args.undo is not None) + rename_from_mapping(rel_mapping_fp, args.undo) def rename_from_mapping(rel_mapping_fp: str, undo: bool = False): @@ -61,7 +65,7 @@ def get_mapping_name(cli_fp: str): return cli_fp else: # Join the directory with the file pattern - file_pattern = os.path.join(DATA_DIR, "*.json") + file_pattern = os.path.join(get_data_dir(), "*.json") # Get list of files matching the file pattern files = glob.glob(file_pattern) @@ -84,11 +88,23 @@ def save_mapping(args, new_filepaths: list[str]): def generate_mapping_name(args) -> str: return ( args.output - if args.output != NO_VAL - else DATA_DIR + datetime.now().strftime("mapping-%Y-%m-%d-%H-%M-%S.json") + if args.output and args.output != NO_VAL + else get_data_dir() + datetime.now().strftime("mapping-%Y-%m-%d-%H-%M-%S.json") ) +def create_batch(args): + base64_strs = [] + for fp in args.files: + _, image_ext = os.path.splitext(fp) + with open(fp, "rb") as image_file: + base64_strs.append( + f"data:image/{image_ext};base64,{base64.b64encode(image_file.read()).decode("utf-8")}" + ) + + batch(args.files, base64_strs, args.template, get_data_dir()) + + def generate_mapping(args) -> list[str]: og_filepaths: list[str] = args.files new_filepaths: list[str] = copy.deepcopy(og_filepaths) @@ -97,11 +113,19 @@ def generate_mapping(args) -> list[str]: slicepoint = new_filepaths[i].rindex("/") + 1 new_filepaths[i] = new_filepaths[i][:slicepoint] + new_fp = "" + new_filename = "" + new_name = "" + image_ext = "" for i in range(len(og_filepaths)): image_path = og_filepaths[i] for j in range(args.validation_retries + 1): print("Generating name...") - new_name = image_to_name(image_path, args) + new_name = ( + image_to_name_retrieve(image_path) + if args.retrieve_batch + else image_to_name(image_path, args) + ) print("Generated name {}".format(new_name)) _, image_ext = os.path.splitext(image_path)