diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py index 9de0b0c2..7a4a23c7 100644 --- a/runner/app/pipelines/utils/utils.py +++ b/runner/app/pipelines/utils/utils.py @@ -178,13 +178,11 @@ def frames_compactor( def video_shredder(video_data, is_file_path=True) -> np.ndarray: """ Extract frames from a video file or in-memory video data and return them as a NumPy array. - Args: - video_data (str or BytesIO): Path to the input video file or in-memory video data. - is_file_path (bool): Indicates if video_data is a file path (True) or in-memory data (False). - + video_data (str or BytesIO): Path to the input video file or in-memory video data. + is_file_path (bool): Indicates if video_data is a file path (True) or in-memory data (False). Returns: - np.ndarray: Array of frames with shape (num_frames, height, width, channels). + np.ndarray: Array of frames with shape (num_frames, height, width, channels). """ if is_file_path: # Handle file-based video input @@ -195,32 +193,50 @@ def video_shredder(video_data, is_file_path=True) -> np.ndarray: with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file: temp_file.write(video_data) temp_file_path = temp_file.name - - # Open the temporary video file - video_capture = cv2.VideoCapture(temp_file_path) + # Open the temporary video file + video_capture = cv2.VideoCapture(temp_file_path) if not video_capture.isOpened(): raise ValueError("Error opening video data") + # Get the video frame rate + fps = video_capture.get(cv2.CAP_PROP_FPS) + + # Get the video frame count + frame_count = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) + + # Create a list to store the extracted frames frames = [] - success, frame = video_capture.read() - while success: - frames.append(frame) + # Extract frames based on the video frame rate and timing + for i in range(int(frame_count)): + # Get the current video frame timestamp + timestamp = i / fps + + # Set the current position of the video capture + video_capture.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) + + # Extract the frame at the current timestamp success, frame = video_capture.read() + + # Add the extracted frame to the list of frames + frames.append(frame) + # Release the video capture video_capture.release() - + # Delete the temporary file if it was created if not is_file_path: os.remove(temp_file_path) - # Convert list of frames to a NumPy array + # Convert the list of frames to a NumPy array frames_array = np.array(frames) + print(f"Extracted {frames_array.shape[0]} frames from video in shape of {frames_array.shape}") return frames_array + class SafetyChecker: """Checks images for unsafe or inappropriate content using a pretrained model. @@ -281,13 +297,11 @@ class DirectoryReader: def __init__(self, dir: str): self.paths = sorted( glob.glob(os.path.join(dir, "*")), - key=lambda x: int(os.path.basename(x).split(".")[0]), + key=lambda x: (int(os.path.basename(x).split(".")[0]), x) ) self.nb_frames = len(self.paths) self.idx = 0 - assert self.nb_frames > 0, "no frames found in directory" - first_img = Image.open(self.paths[0]) self.height = first_img.height self.width = first_img.width @@ -301,14 +315,16 @@ def reset(self): def get_frame(self): if self.idx >= self.nb_frames: return None - path = self.paths[self.idx] - self.idx += 1 - - img = Image.open(path) - transforms = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) - - return transforms(img) + try: + img = Image.open(path) + transforms = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) + frame = transforms(img) + self.idx += 1 + return frame + except Exception as e: + logger.error(f"Error reading frame {self.idx}: {e}") + return None class DirectoryWriter: def __init__(self, dir: str): diff --git a/runner/app/routes/frame_interpolation.py b/runner/app/routes/frame_interpolation.py index d11458e9..a5bcdb83 100644 --- a/runner/app/routes/frame_interpolation.py +++ b/runner/app/routes/frame_interpolation.py @@ -37,7 +37,7 @@ ) async def frame_interpolation( model_id: Annotated[str, Form()] = "", - video: Annotated[UploadFile, File()]=None, + video: Annotated[UploadFile, File()]= None, inter_frames: Annotated[int, Form()] = 2, token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): @@ -64,11 +64,13 @@ async def frame_interpolation( # Extract frames from video video_data = await video.read() frames = video_shredder(video_data, is_file_path=False) - # Save frames to temporary directory for i, frame in enumerate(frames): - frame_path = os.path.join(temp_input_dir, f"{i}.png") - cv2.imwrite(frame_path, frame) + try: + frame_path = os.path.join(temp_input_dir, f"{i}.png") + cv2.imwrite(frame_path, frame) + except Exception as e: + logger.error(f"Error saving frame {i}: {e}") # Create DirectoryReader and DirectoryWriter reader = DirectoryReader(temp_input_dir) @@ -81,10 +83,18 @@ async def frame_interpolation( # Collect output frames output_frames = [] - for frame_path in sorted(glob.glob(os.path.join(temp_output_dir, "*.png"))): - frame = Image.open(frame_path) - output_frames.append(frame) - # Wrap output frames in a list of batches (with a single batch in this case) + path_to_file = sorted( + glob.glob(os.path.join(temp_output_dir, "*")), + key=lambda x: (int(os.path.basename(x).split(".")[0]), x) + ) + for frame_in_path in path_to_file: + try: + frame = Image.open(frame_in_path) + output_frames.append(frame) + except Exception as e: + logger.error(f"Error reading frame {frame_in_path}: {e}") + + # Wrap output frames in a list of batches (with a single batch in this case) output_images = [[{"url": image_to_data_url(frame), "seed": 0, "nsfw": False} for frame in output_frames]] except Exception as e: @@ -96,11 +106,11 @@ async def frame_interpolation( ) finally: # Clean up temporary directories - for file_path in glob.glob(os.path.join(temp_input_dir, "*")): - os.remove(file_path) + for path_to_file in glob.glob(os.path.join(temp_input_dir, "*")): + os.remove(path_to_file) os.rmdir(temp_input_dir) - for file_path in glob.glob(os.path.join(temp_output_dir, "*")): - os.remove(file_path) + for path_to_file in glob.glob(os.path.join(temp_output_dir, "*")): + os.remove(path_to_file) os.rmdir(temp_output_dir) return {"frames": output_images} diff --git a/runner/openapi.json b/runner/openapi.json index e55aa404..f3ee57dc 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -332,79 +332,6 @@ ] } }, - "/llm-generate": { - "post": { - "summary": "Llm Generate", - "operationId": "llm_generate", - "requestBody": { - "content": { - "application/x-www-form-urlencoded": { - "schema": { - "$ref": "#/components/schemas/Body_llm_generate_llm_generate_post" - } - } - }, - "required": true - }, - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/LlmResponse" - } - } - } - }, - "400": { - "description": "Bad Request", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPError" - } - } - } - }, - "401": { - "description": "Unauthorized", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPError" - } - } - } - }, - "500": { - "description": "Internal Server Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPError" - } - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - }, - "security": [ - { - "HTTPBearer": [] - } - ] - } - }, "/frame-interpolation": { "post": { "summary": "Frame Interpolation", diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 85b4b590..cd81856c 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -1659,36 +1659,33 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xaX2/juBH/KgTbRye2c5du4bdNercXNLsXrJ3rwyIwGGls8yKRKkkldgN/94JD/aH+", - "RTaSuEDqJ1vScOY3w/kNh5SeaSDjRAoQRtPJM9XBCmKGfz/fXP2ilFT2f6JkAspwwCexXtofw00EdEK/", - "6iUdULNJ7IU2iosl3W4HVMG/U64gpJMfOORuUAwpdBfj5P2fEBi6HdALGW7mLA25nBs5N7A2tatEatME", - "hTL2z0KqmBk6ofdcMLWhnlUUaUAd0FiGEM15aIeHsGBpZMd7I79aAXIV9vrpUHie7uZNVxgWisUw58KA", - "SmTEDJei9V57SFDGyeuKa2clvisrQ351MgUMO3IJ6pWxGdBHHkLPpPyBIm1xrYRwr0h0hZPHbAk28u5P", - "7bI9iMuUh0wEMNcBs3C8KHw6PS9RfsnkyBTlCggije9dJNHKy8G4QpGWQDqEL2AZ+1hQDelH9Kq5FbBk", - "hj/CPFEyTkynjm+ZHLlxcm2q0tjNgZ4noNoUjj19aUzQQU1uQDW0eqmLasUCFGDMDCRVGoxHo5raXJhM", - "UbhNaQkuH9ntl2YLMJt5sILgoWLZqBRK01MUI5coVqi5lzICJlAPQOhbnNrrNnDaKBBLs6oYG53+3bOV", - "SzTSoVbJktwrl7b1krYDlXpZiNWhftnOwkVt6v5Wwvm1Y6JWwJerahadf/LG/eaetw19DVNfxalYYhG7", - "T4MHMHUl47NPvhYrSS5QsqLNJ4DkGuYsXc47EmPkrQTfrDD5nC5Jd470c+rsfH9KHZwmTzyshWI8Ovu5", - "tPQvfN4cWaNIDzO607uLGVEUz5cgQDED1Yt2VsRsPTfyAYSu9GNsTWbu7psv6HtVv402EM9r3eIU75LW", - "pnFADcSJ9ThV4A+aebd3LFz1aemJbdeUpAmutcVvR6f1vyoYfXT8dP6hVrj91qjWuWuZ6N9ms5uOrU4I", - "hvHI/vurggWd0L8Myw3TMNstDYvtTB1gNtwDVtrqAPIHi3iIvWwvJG4g1n3Y6vq8xvofTlMBhCnFNpXW", - "uw1QG25gkVld5klQxasNM2k1K+nv/6R+S4ICL20BfAMt9pFb30EnUmjoYKfeOWJfIeTMj5PrNtvi1FgN", - "tD/XVVgtuK+juBu18p7kGpvKvNKJJX+e6iqT3EpAbvUuhFKefk+d55MPucUjF7uGL0IvnnxQ3+z1q5bw", - "VEW+3K2KevfmKcpopxEReX454C0ezWBtuicpWKXiYffUQnE/tS7d+Hpq2YVwbaor4Nr0emicUAbK867i", - "RIeTM4n5esMUc4681z64bMx3aMX/z7eo5x9th1q03nv22s2urpmzLYndu5pGMqiwl4nN7ws6+fHciNVz", - "A+KdR+RrGaCZFirXj0tB645W0N0oRREzmdm7fdS3fjhTmaQXqR1WcDyF6y5z5RliEag9l9J6ecs377WD", - "x/alNTN/Vzs1fKmiuUrbcGS3smrtxKANixPfVQ/3rHjeA934gtaY54TD2ACPdApSxc1mauPokNtW7AKY", - "AlWc0yMH3a1CycqYhG6tDi4W0lFaB4onmJwT+lkQliQRd9lKjCQqFeTzFUl4AhEXbjLypOaPkAAo+/x7", - "KgQaegSlna7R6fh0ZKMlExAs4XRCf8JbA5ows0LYQzztPjHyJA99voOSuKPjUlyF+dn8TGbzYSMI2tgu", - "HldZKQwIHBWnkeEJU2Zot1onITOsfG/Rl467HcZvq3NoKyHecMmGXp2NRjVcXlCHf2obnl1BVdZmtF2d", - "sWkaBKD1Io1IKTagP78hhHJT0mL/goXku5sPZ3d8GLu3gqVmJRX/D4RoePzTYQxnzpJfhOFmQ2ZSkmum", - "li7qZ2dvCqKxO2vCKUVIsYM7P9Tk4+shwSIyBfUIipTb3LxE4VrpF6cfd9u7AdVpHDO1yZlNZpIgt+3Q", - "YcuLm+7KgEvEVUX2VQWCRVG2wveWir1eOm1rxw82nO9ZN6pL4LFwdBeOI2f35SxyjlRJh8xd4UEM7geh", - "havunIa+Y9b7J0G75vzWdy2DiN7ghs72JsX5bXsJwk1Gttd45+5kh/dqB+5PqqdYxzpzrDNvV2fchwoz", - "6U5LaqQsPh55kZT59yMHIWX3K70Dk/K4+B9J+e6kdNRCUkZRfJK/MO2m5HUUf8mFXmKk7/v65Onp6QSZ", - "maoIRCBDd5S4Bz973u0emJv+K5IjM4/MfDtmXkcxKQiGvDSwNjs0sN5Z+c7E3P8Uq3oaf2xTj7z7ILyz", - "yV3rUrNvSropd5sJvG9n2vqJy5F5R+Z9EOblLNq6UVaNxkFVS8WLqstIpiG5lHGcCm425Asz8MQ2NPtg", - "BF+P6clwGCpg8cnSPT2NsuGngR1Ot3fb/wYAAP//yh4IQxgzAAA=", + "H4sIAAAAAAAC/+xZW2/buBL+KwTPeXRiO21ODvyWZHsxtpegdrsPRWAw0thmK5FakkrrDfzfFxzqQslS", + "ZMONF8j6yRcNZ765fENy9EADGSdSgDCajh6oDpYQM/x6eTN+pZRU9nuiZALKcMAnsV7YD8NNBHRE3+sF", + "7VGzSuwPbRQXC7pe96iCP1OuIKSjr7jktlcsKXQX6+TdNwgMXffolQxXM5aGXM6MnBn4aWq/EqnNJiiU", + "sV/mUsXM0BG944KpFfWsosgG1B6NZQjRjId2eQhzlkZ2vbfyvRUg47DTT4fC83Q7b9rCMFcshhkXBlQi", + "I2a4FI3/NYcEZZy8rrh2VuIbWxny2skUMOzKBag9Y9Oj9zyEjqR8QZGmuFZCuFMk2sLJY7YAG3n3pfaz", + "OYiLlIdMBDDTAbNwvChcnJ6XKN9kcmSCcgUEkcZ3LpJo5fFgjFGkIZAO4SNYhj4WVEO6Ee2VWwELZvg9", + "zBIl48S06viQyZEbJ9ekKo1dDvQsAdWkcOjpS2OCDmpyA2pDq1e6qFbMQQHGzEBSpcFwMKipzYXJBIWb", + "lJbg8pXtfmk2B7OaBUsIvlcsG5VCaXqCYuQaxQo1d1JGwATqAQh9ixP7uwmcNgrEwiwrxgan//ds5RIb", + "5VDrZEnulSvbekvbgkqdLMTuUP/ZzMJ5LXX/K+G8bknUEvhiWa2i8wtv3Vv3vGnpPkzdi1OxxCZ2lwbf", + "wdSVDM8ufC1WklyhZEWbTwDJNcxYupi1FMbA2wk+WGFymS5Ie410c+rsfHdKHZwmP3hYC8VwcPaytPQH", + "Pt9cWaNIBzPay7uNGWmCjb34bNnW/6nq7Mr9xfmzaqe7NcTG3DUk+u10etNyrg7BMB7Zb/9VMKcj+p9+", + "eTrvZ0fzfnF2rgPMlnvASlstQL6wiId4cOqExA3EugtbXZ93ivvNaSqAMKXYqnLOawLUhBtYZJbXeRFU", + "8WrDTFqtSvrxd+rvfyjw2HnTN9BgH7n1CXQihYYWduqtI/YeQs78OLmjTVOcNlqP9nNdhdWA21nawCv0", + "/IdPhg/2917dNVWRL/dZRZ3XphRltNOIiDzPHPAGj6bw07QnIlim4vv2iUBxPxHXbn09ET1qr22+gxZG", + "p4fGCWWgPO8qTrQ4OZWY3RummHPkqa4o5Zlpi1PSv/z2cP7cLg/FqWjHY1DmVK2mqzXbUNide08kgwp7", + "mVh9nNPR14eNWD1sQLz1iPxOBmimgcr1SRZo3XJwcn+UooiZTO2/XdS3fjhTmaQXqS32OxyQtLe5crxT", + "BGrHjafe3vJ7VW0m1LwRZeZvawOdxzqa67QbjmzXVq2dGLRhceK76uGeFs87oBtf0BrznHAYN8AjnYJU", + "cbOa2Dg65PbgcgVMgSpGqMhB91ehZGlMQtdWBxdz6SitA8UTLM4RvRSEJUnEXbUSI4lKBbkck4QnEHHh", + "kpEXNb+HBEDZ559SIdDQPSjtdA1Oh6cDGy2ZgGAJpyP6Av/q0YSZJcLu4yDyxMiTPPT5fcOmBUGMw3xs", + "OpVZPmwEQRt75sVdVgoDAlfFaWR4wpTp24vJScgMK0fKXeW43Zx0Xc2h7YT4hys29OpsMKjh8oLa/6Zt", + "eLYFVdmb0XY1Y5M0CEDreRqRUqxHX/5CCOURvsH+FQvJJ5cPZ3d4GLufBUvNUir+F4RoePjiMIYzZ8kr", + "YbhZkamU5B1TCxf1s7NfCmLjLrMJpxQhxX3n/FDJx8m9YBGZgLoHRcpLYd6icK/0m9PX2/Vtj+o0jpla", + "5cwmU0mQ23ZpH/v5SWWm3t4ZcIsYV2T3ahAsirIdvrNV7PQ+YF27rNtwPmXfqG6Bx8bR3jiOnN2Vs8g5", + "UiUdMneJYwu8D0IDV91Ugz5h1ftzk21rfu27lkFEb/BCZ88mxbSzuQXhJSO7azzx6WSLVx4HPp9UZz7H", + "PnPsM7+uz7h3yFPppiU1Uhbv9R8lZf5q/yCkbH/bcmBSHjf/IymfnJSOWkhKezveYqP0ZnKPUnK/23J1", + "6nfcDo/MeybMs8Vd2w2zN73tlPucCTztDtj44vnIvCPzngnzchat3SqrRuOiqqViIH4dyTQk1zKOU8HN", + "irxhBn6wFc1eTOMYXo/6/VABi08W7ulplC0/Dexyur5d/x0AAP//O5cYMRstAAA=", } // GetSwagger returns the content of the embedded swagger specification file