Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Help branch patch for video file instead of path #13

Merged
merged 6 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 39 additions & 23 deletions runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,11 @@ def frames_compactor(
def video_shredder(video_data, is_file_path=True) -> np.ndarray:
"""
Extract frames from a video file or in-memory video data and return them as a NumPy array.

Args:
video_data (str or BytesIO): Path to the input video file or in-memory video data.
is_file_path (bool): Indicates if video_data is a file path (True) or in-memory data (False).

video_data (str or BytesIO): Path to the input video file or in-memory video data.
is_file_path (bool): Indicates if video_data is a file path (True) or in-memory data (False).
Returns:
np.ndarray: Array of frames with shape (num_frames, height, width, channels).
np.ndarray: Array of frames with shape (num_frames, height, width, channels).
"""
if is_file_path:
# Handle file-based video input
Expand All @@ -195,32 +193,50 @@ def video_shredder(video_data, is_file_path=True) -> np.ndarray:
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
temp_file.write(video_data)
temp_file_path = temp_file.name

# Open the temporary video file
video_capture = cv2.VideoCapture(temp_file_path)
# Open the temporary video file
video_capture = cv2.VideoCapture(temp_file_path)

if not video_capture.isOpened():
raise ValueError("Error opening video data")

# Get the video frame rate
fps = video_capture.get(cv2.CAP_PROP_FPS)

# Get the video frame count
frame_count = video_capture.get(cv2.CAP_PROP_FRAME_COUNT)

# Create a list to store the extracted frames
frames = []
success, frame = video_capture.read()

while success:
frames.append(frame)
# Extract frames based on the video frame rate and timing
for i in range(int(frame_count)):
# Get the current video frame timestamp
timestamp = i / fps

# Set the current position of the video capture
video_capture.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)

# Extract the frame at the current timestamp
success, frame = video_capture.read()

# Add the extracted frame to the list of frames
frames.append(frame)

# Release the video capture
video_capture.release()

# Delete the temporary file if it was created
if not is_file_path:
os.remove(temp_file_path)

# Convert list of frames to a NumPy array
# Convert the list of frames to a NumPy array
frames_array = np.array(frames)

print(f"Extracted {frames_array.shape[0]} frames from video in shape of {frames_array.shape}")

return frames_array


class SafetyChecker:
"""Checks images for unsafe or inappropriate content using a pretrained model.

Expand Down Expand Up @@ -281,13 +297,11 @@ class DirectoryReader:
def __init__(self, dir: str):
self.paths = sorted(
glob.glob(os.path.join(dir, "*")),
key=lambda x: int(os.path.basename(x).split(".")[0]),
key=lambda x: (int(os.path.basename(x).split(".")[0]), x)
)
self.nb_frames = len(self.paths)
self.idx = 0

assert self.nb_frames > 0, "no frames found in directory"

first_img = Image.open(self.paths[0])
self.height = first_img.height
self.width = first_img.width
Expand All @@ -301,14 +315,16 @@ def reset(self):
def get_frame(self):
if self.idx >= self.nb_frames:
return None

path = self.paths[self.idx]
self.idx += 1

img = Image.open(path)
transforms = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])

return transforms(img)
try:
img = Image.open(path)
transforms = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
frame = transforms(img)
self.idx += 1
return frame
except Exception as e:
logger.error(f"Error reading frame {self.idx}: {e}")
return None

class DirectoryWriter:
def __init__(self, dir: str):
Expand Down
34 changes: 22 additions & 12 deletions runner/app/routes/frame_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
async def frame_interpolation(
model_id: Annotated[str, Form()] = "",
video: Annotated[UploadFile, File()]=None,
video: Annotated[UploadFile, File()]= None,
inter_frames: Annotated[int, Form()] = 2,
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
Expand All @@ -64,11 +64,13 @@ async def frame_interpolation(
# Extract frames from video
video_data = await video.read()
frames = video_shredder(video_data, is_file_path=False)

# Save frames to temporary directory
for i, frame in enumerate(frames):
frame_path = os.path.join(temp_input_dir, f"{i}.png")
cv2.imwrite(frame_path, frame)
try:
frame_path = os.path.join(temp_input_dir, f"{i}.png")
cv2.imwrite(frame_path, frame)
except Exception as e:
logger.error(f"Error saving frame {i}: {e}")

# Create DirectoryReader and DirectoryWriter
reader = DirectoryReader(temp_input_dir)
Expand All @@ -81,10 +83,18 @@ async def frame_interpolation(

# Collect output frames
output_frames = []
for frame_path in sorted(glob.glob(os.path.join(temp_output_dir, "*.png"))):
frame = Image.open(frame_path)
output_frames.append(frame)
# Wrap output frames in a list of batches (with a single batch in this case)
path_to_file = sorted(
glob.glob(os.path.join(temp_output_dir, "*")),
key=lambda x: (int(os.path.basename(x).split(".")[0]), x)
)
for frame_in_path in path_to_file:
try:
frame = Image.open(frame_in_path)
output_frames.append(frame)
except Exception as e:
logger.error(f"Error reading frame {frame_in_path}: {e}")

# Wrap output frames in a list of batches (with a single batch in this case)
output_images = [[{"url": image_to_data_url(frame), "seed": 0, "nsfw": False} for frame in output_frames]]

except Exception as e:
Expand All @@ -96,11 +106,11 @@ async def frame_interpolation(
)
finally:
# Clean up temporary directories
for file_path in glob.glob(os.path.join(temp_input_dir, "*")):
os.remove(file_path)
for path_to_file in glob.glob(os.path.join(temp_input_dir, "*")):
os.remove(path_to_file)
os.rmdir(temp_input_dir)
for file_path in glob.glob(os.path.join(temp_output_dir, "*")):
os.remove(file_path)
for path_to_file in glob.glob(os.path.join(temp_output_dir, "*")):
os.remove(path_to_file)
os.rmdir(temp_output_dir)

return {"frames": output_images}
73 changes: 0 additions & 73 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -332,79 +332,6 @@
]
}
},
"/llm-generate": {
"post": {
"summary": "Llm Generate",
"operationId": "llm_generate",
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"$ref": "#/components/schemas/Body_llm_generate_llm_generate_post"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/LlmResponse"
}
}
}
},
"400": {
"description": "Bad Request",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"401": {
"description": "Unauthorized",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
},
"security": [
{
"HTTPBearer": []
}
]
}
},
"/frame-interpolation": {
"post": {
"summary": "Frame Interpolation",
Expand Down
57 changes: 27 additions & 30 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading