Skip to content

Commit

Permalink
Add OpenAI batch job support
Browse files Browse the repository at this point in the history
And many other little changes and fixes
  • Loading branch information
rehanzo committed Jun 23, 2024
1 parent e981256 commit a5c6a31
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 37 deletions.
35 changes: 25 additions & 10 deletions VisioNomicon/args_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def parse_cli_args():
action="store_true",
help="If error retries limit is reached, map file to original name instead of returning an error",
)
parser.add_argument(
"-b",
"--create-batch",
action="store_true",
help="Create batch job through OpenAI API",
)
parser.add_argument(
"-B",
"--retrieve-batch",
action="store_true",
help="Retrieve batch job output through OpenAI API. Run this 24 hours after creating the batch job.",
)

# if flag with value, equals value
# if flag with no value, equals const value
Expand Down Expand Up @@ -103,35 +115,38 @@ def parse_cli_args():
parser.error("-u/--undo must not be used with any other arguments.")
####################################################################################

if args.files is not None and len(args.files) == 0:
if args.files == NO_VAL:
parser.error("-f/--files requires a value")

if args.output is not None and args.execute is not None:
if args.output and args.execute:
parser.error(
"instead of using -o/--output along with -x/--execute, use -ox/--mapex"
)

if args.mapex is not None:
if args.output is not None or args.execute is not None:
if args.mapex:
if args.output or args.execute:
parser.error(
"-ox/--mapex should be used without -o/--output or -x/--execute"
)

args.output = args.mapex
args.execute = args.mapex

if args.output is not None and args.files is None:
if args.output and not args.files:
parser.error("-o/--output must be used with -f/--files")

if args.template is None:
if args.create_batch and not args.files:
parser.error("-b/--create-batch must be used with -f/--files")

if args.template == NO_VAL:
parser.error("used -t/--template with no value")

supported_ext = [".png", ".jpeg", ".jpg", ".webp", ".gif"]

#
# get absolute paths where we need them
#
if args.files is not None:
if args.files:
args.files = [os.path.abspath(path) for path in args.files]
clean_paths = args.files.copy()

Expand All @@ -148,13 +163,13 @@ def parse_cli_args():
parser.error("Filetype {} not supported".format(image_ext))
args.files = clean_paths

if args.output is not None and args.output != NO_VAL:
if args.output and args.output != NO_VAL:
args.output = os.path.abspath(args.output)

if args.execute is not None and args.execute != NO_VAL:
if args.execute and args.execute != NO_VAL:
args.execute = os.path.abspath(args.execute)

if args.undo is not None and args.undo != NO_VAL:
if args.undo and args.undo != NO_VAL:
args.undo = os.path.abspath(args.undo)

return args
110 changes: 102 additions & 8 deletions VisioNomicon/gpt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from openai import OpenAI
import openai
import json
import io
from pathlib import Path
import os
import requests
import base64
import sys
from constants import API_KEY, NAMING_PROMPT, MODEL

API_KEY = ""
RETRIEVED_JSON = {}


def set_api_key():
Expand All @@ -16,6 +19,98 @@ def set_api_key():
API_KEY = os.environ.get("OPENAI_API_KEY") if API_KEY == "" else API_KEY


def batch(filepaths: list[str], base64_strs: list[str], template: str, data_dir: str):
batch_reqs = []
for filepath, base64_str in zip(filepaths, base64_strs):
batch_reqs.append(
{
"custom_id": filepath,
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": MODEL,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": NAMING_PROMPT.format(template=template),
},
{
"type": "image_url",
"image_url": {
"url": base64_str,
"detail": "auto",
},
},
],
}
],
"temperature": 0.7,
},
}
)

set_api_key()
bytes_buffer = io.BytesIO()
# write to bytes buffer
# doing this to avoid having to write file to disk then pull back from disk to send
for entry in batch_reqs:
json_line = json.dumps(entry) + "\n"
bytes_buffer.write(json_line.encode("utf-8"))

# reset buffer position to prepare to send
bytes_buffer.seek(0)

file_upload_response = openai.files.create(file=bytes_buffer, purpose="batch")

# create batch request from uploaded requests file
# only 24h completion window is available for now
batch = openai.batches.create(
input_file_id=file_upload_response.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
# write batch id to file to retrieve later
with open(f"{data_dir}/batch_id", "w") as file:
file.write(batch.id)


def image_to_name_retrieve(image_path: str) -> str:
global RETRIEVED_JSON

if not RETRIEVED_JSON:
# get file_id for completed responses
file_id = ""
# get batch id from file
data_dir = os.environ.get("XDG_DATA_HOME")
data_dir = (
data_dir if data_dir else os.path.abspath("~/.local/share")
) + "/visionomicon/"
with open(f"{data_dir}/batch_id", "r") as f:
file_id = openai.batches.retrieve(f.read()).output_file_id

# could occur if batch not complete yet
if file_id is None:
print("Error during batch retrieval, maybe the job isn't complete yet.")
sys.exit()

try:
# get responses in a json str
response_str = openai.files.content(file_id).content.decode("utf-8")
# output file for responses may be expired or deleted
except openai.NotFoundError:
print("Error during batch retrieval, output file could not be retrieved.")
sys.exit()
# each response in own json
response_jsons = [json.loads(s) for s in response_str.split("\n") if s.strip()]
RETRIEVED_JSON = {s["custom_id"]: s for s in response_jsons}
return RETRIEVED_JSON[image_path]["response"]["body"]["choices"][0]["message"][
"content"
].strip()


def image_to_name(image_path: str, args) -> str:
template: str = args.template

Expand All @@ -34,14 +129,14 @@ def encode_image(image_path: str):
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"}

payload = {
"model": "gpt-4o",
"model": MODEL,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": f"Generate a filename for an image by analyzing its content and utilizing a user-provided template. Placeholders enclosed in square brackets (e.g., [Subject], [Color], [Action]) will be used, which represent specific elements to be incorporated in the filename. Replace the placeholders accurately and succinctly with terms pulled from the image content, removing the brackets in the final filename. For instance, if the template reads '[MainSubject]_in_[Setting]', the filename might be 'Cat_in_Garden'. Construct the filename omitting the file extension and any other text. Assure that every placeholder is filled with precise, image-derived information, conforming to typical filename length restrictions. The given template is '{template}'.",
"text": NAMING_PROMPT.format(template=template),
},
{
"type": "image_url",
Expand All @@ -64,7 +159,7 @@ def encode_image(image_path: str):

try:
return response_json["choices"][0]["message"]["content"]
except:
except KeyError:
print("OpenAI Unexpected Response:", response_json["error"]["message"])
i < args.error_retries and print("retrying...\n")

Expand All @@ -79,10 +174,9 @@ def encode_image(image_path: str):

def name_validation(name: str, template: str):
set_api_key()
client = OpenAI()

completion = client.chat.completions.create(
model="gpt-4-1106-preview",
completion = openai.chat.completions.create(
model=MODEL,
messages=[
{
"role": "system",
Expand Down
62 changes: 43 additions & 19 deletions VisioNomicon/main.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
import os
import sys
import base64
import json
import copy
import glob
from VisioNomicon.args_handler import *
from VisioNomicon.gpt import *
from constants import get_data_dir
from args_handler import parse_cli_args, NO_VAL
from gpt import (
image_to_name,
name_validation,
image_to_name_retrieve,
batch,
)
from datetime import datetime

DATA_DIR = ""


def main():
# get data dir
global DATA_DIR
DATA_DIR = (
os.environ.get("XDG_DATA_HOME")
if "XDG_DATA_HOME" in os.environ
else os.path.abspath("~/.local/share")
) + "/visionomicon/"

# make data dir if doesn't exist
not os.path.exists(DATA_DIR) and os.makedirs(DATA_DIR)
data_dir = get_data_dir()
if not os.path.exists(data_dir):
os.makedirs(data_dir)

args = parse_cli_args()
if args.create_batch:
create_batch(args)
print("Batch job created.")
return 0

# if creating mapping
if args.files is not None:
if args.files:
new_filepaths: list[str] = generate_mapping(args)

# have new and old, put them together into a json and save
Expand All @@ -33,7 +37,7 @@ def main():
# if executing or undoing
if args.undo or args.execute:
rel_mapping_fp = args.execute if args.execute else args.undo
rename_from_mapping(rel_mapping_fp, args.undo is not None)
rename_from_mapping(rel_mapping_fp, args.undo)


def rename_from_mapping(rel_mapping_fp: str, undo: bool = False):
Expand Down Expand Up @@ -61,7 +65,7 @@ def get_mapping_name(cli_fp: str):
return cli_fp
else:
# Join the directory with the file pattern
file_pattern = os.path.join(DATA_DIR, "*.json")
file_pattern = os.path.join(get_data_dir(), "*.json")

# Get list of files matching the file pattern
files = glob.glob(file_pattern)
Expand All @@ -84,11 +88,23 @@ def save_mapping(args, new_filepaths: list[str]):
def generate_mapping_name(args) -> str:
return (
args.output
if args.output != NO_VAL
else DATA_DIR + datetime.now().strftime("mapping-%Y-%m-%d-%H-%M-%S.json")
if args.output and args.output != NO_VAL
else get_data_dir() + datetime.now().strftime("mapping-%Y-%m-%d-%H-%M-%S.json")
)


def create_batch(args):
base64_strs = []
for fp in args.files:
_, image_ext = os.path.splitext(fp)
with open(fp, "rb") as image_file:
base64_strs.append(
f"data:image/{image_ext};base64,{base64.b64encode(image_file.read()).decode("utf-8")}"
)

batch(args.files, base64_strs, args.template, get_data_dir())


def generate_mapping(args) -> list[str]:
og_filepaths: list[str] = args.files
new_filepaths: list[str] = copy.deepcopy(og_filepaths)
Expand All @@ -97,11 +113,19 @@ def generate_mapping(args) -> list[str]:
slicepoint = new_filepaths[i].rindex("/") + 1
new_filepaths[i] = new_filepaths[i][:slicepoint]

new_fp = ""
new_filename = ""
new_name = ""
image_ext = ""
for i in range(len(og_filepaths)):
image_path = og_filepaths[i]
for j in range(args.validation_retries + 1):
print("Generating name...")
new_name = image_to_name(image_path, args)
new_name = (
image_to_name_retrieve(image_path)
if args.retrieve_batch
else image_to_name(image_path, args)
)
print("Generated name {}".format(new_name))

_, image_ext = os.path.splitext(image_path)
Expand Down

0 comments on commit a5c6a31

Please sign in to comment.