Skip to content

Commit

Permalink
Merge pull request #445 from skier233/fixtimestampbug
Browse files Browse the repository at this point in the history
[AI Tagger] Fix zip issues and timestamp issues
  • Loading branch information
Maista6969 authored Oct 1, 2024
2 parents b02d5c2 + 628d0e7 commit b14e196
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 8 deletions.
2 changes: 1 addition & 1 deletion plugins/AITagger/ai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def process_images_async(image_paths, threshold=config.IMAGE_THRESHOLD, re
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
return await post_api_async(session, 'process_images/', {"paths": image_paths, "threshold": threshold, "return_confidence": return_confidence})

async def process_video_async(video_path, frame_interval=config.FRAME_INTERVAL,threshold=config.AI_VIDEO_THRESHOLD, return_confidence=True ,vr_video=False):
async def process_video_async(video_path, vr_video=False, frame_interval=config.FRAME_INTERVAL,threshold=config.AI_VIDEO_THRESHOLD, return_confidence=True):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=config.SERVER_TIMEOUT)) as session:
return await post_api_async(session, 'process_video/', {"path": video_path, "frame_interval": frame_interval, "threshold": threshold, "return_confidence": return_confidence, "vr_video": vr_video})

Expand Down
2 changes: 1 addition & 1 deletion plugins/AITagger/ai_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def __tag_scene(scene):
vr_video = media_handler.is_vr_scene(scene.get('tags'))
if vr_video:
log.info(f"Processing VR video {scenePath}")
server_result = await ai_server.process_video_async(mutated_path, vr_video)
server_result = await ai_server.process_video_async(video_path=mutated_path, vr_video=vr_video)
if server_result is None:
log.error("Server returned no results")
media_handler.add_error_scene(sceneId)
Expand Down
2 changes: 1 addition & 1 deletion plugins/AITagger/ai_tagger.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: AI Tagger
description: Tag videos and Images with Locally hosted AI using Skier's Patreon AI models
version: 1.7
version: 1.8
url: https://github.com/stashapp/CommunityScripts/tree/main/plugins/AITagger
exec:
- python
Expand Down
25 changes: 20 additions & 5 deletions plugins/AITagger/media_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def initialize(connection):
# ----------------- Tag Methods -----------------


tag_categories = ["actions", "bodyparts"]
tag_categories = ["actions", "bodyparts", "bdsm", "clothing", "describingperson", "environment", "describingbody", "describingimage", "describingscene", "sextoys"]

def get_all_tags_from_server_result(result):
alltags = []
Expand Down Expand Up @@ -91,7 +91,11 @@ def remove_tagme_tags_from_images(image_ids):
def add_tags_to_image(image_id, tag_ids):
stash.update_images({"ids": [image_id], "tag_ids": {"ids": tag_ids, "mode": "ADD"}})

worker_counter = 0

def get_image_paths_and_ids(images):
global worker_counter
counter_updated = False
imagePaths = []
imageIds = []
temp_files = []
Expand All @@ -100,14 +104,25 @@ def get_image_paths_and_ids(images):
imagePath = image['files'][0]['path']
imageId = image['id']
if '.zip' in imagePath:
if not counter_updated:
worker_counter += 1
counter_updated = True
zip_index = imagePath.index('.zip') + 4
zip_path, img_path = imagePath[:zip_index], imagePath[zip_index+1:].replace('\\', '/')

# Create a unique temporary directory for this worker
temp_dir = os.path.join(config.temp_image_dir, f"worker_{worker_counter}")
os.makedirs(temp_dir, exist_ok=True)

temp_path = os.path.join(temp_dir, img_path)
os.makedirs(os.path.dirname(temp_path), exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
temp_path = os.path.join(config.temp_image_dir, img_path)
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
zip_ref.extract(img_path, config.temp_image_dir)
zip_ref.extract(img_path, temp_dir)
imagePath = os.path.abspath(os.path.normpath(temp_path))
temp_files.append(imagePath)

temp_files.append(temp_path)
temp_files.append(temp_dir) # Ensure the temp directory is also added to temp_files
imagePaths.append(imagePath)
imageIds.append(imageId)
except IndexError:
Expand Down

0 comments on commit b14e196

Please sign in to comment.