Skip to content

Commit

Permalink
add progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
skier233 committed May 30, 2024
1 parent 7c05319 commit ec140c9
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions plugins/AITagger/ai_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import csv
from typing import Any

# plugins don't start in the right directory, let's switch to the local directory
os.chdir(os.path.dirname(os.path.realpath(__file__)))

# ----------------- Setup -----------------
os.chdir(os.path.dirname(os.path.realpath(__file__)))

def install(package):
try:
Expand Down Expand Up @@ -67,6 +65,8 @@ def install(package):
min_durations = {}
required_durations = {}
semaphore = asyncio.Semaphore(config.CONCURRENT_TASK_LIMIT)
progress = 0
increment = 0.0

# ----------------- Main Execution -----------------

Expand Down Expand Up @@ -126,17 +126,21 @@ async def run(json_input, output):
# ----------------- High Level Calls -----------------

async def tag_images():
global increment
images = stash.find_images(f={"tags": {"value":tagme_tag_id, "modifier":"INCLUDES"}}, fragment="id files {path}")
if images:
image_batches = [images[i:i + config.IMAGE_REQUEST_BATCH_SIZE] for i in range(0, len(images), config.IMAGE_REQUEST_BATCH_SIZE)]
increment = 1.0 / len(image_batches)
tasks = [__tag_images(batch) for batch in image_batches]
await asyncio.gather(*tasks)
else:
log.info("No images to tag")


async def tag_scenes():
global increment
scenes = stash.find_scenes(f={"tags": {"value":tagme_tag_id, "modifier":"INCLUDES"}}, fragment="id files {path}")
increment = 1.0 / len(scenes)
if scenes:
tasks = [__tag_scene(scene) for scene in scenes]
await asyncio.gather(*tasks)
Expand All @@ -155,12 +159,15 @@ async def __tag_images(images):
#TODO
try:
server_results = ImageResult(**await process_images_async(imagePaths))
except pydantic.ValidationError as e:
process_server_image_results(server_results, imageIds)
except Exception as e:
log.error(f"Failed to process images: {e}")
add_error_images(imageIds)
stash.update_images({"ids": imageIds, "tag_ids": {"ids": [tagme_tag_id], "mode": "REMOVE"}})
return
process_server_image_results(server_results, imageIds)
finally:
increment_progress()



def process_server_image_results(server_results, imageIds):
results = server_results.result
Expand Down Expand Up @@ -197,12 +204,15 @@ async def __tag_scene(scene):
sceneId = scene['id']
try:
server_result = VideoResult(**await process_video_async(scenePath))
except pydantic.ValidationError as e:
process_server_video_result(server_result, sceneId, scenePath)
except Exception as e:
log.error(f"Failed to process video: {e}")
add_error_scene(sceneId)
stash.update_scenes({"ids": [sceneId], "tag_ids": {"ids": [tagme_tag_id], "mode": "REMOVE"}})
return
process_server_video_result(server_result, sceneId, scenePath)
finally:
increment_progress()


def process_server_video_result(server_result, sceneId, scenePath):
results = server_result.result
Expand Down Expand Up @@ -347,4 +357,9 @@ def parse_csv(file_path):
max_gaps[server_tag] = max_gap
required_durations[server_tag] = required_duration

def increment_progress():
global progress
global increment
progress += increment
log.progress(progress)
asyncio.run(main())

0 comments on commit ec140c9

Please sign in to comment.