From f7fb0febafd295920170682cde4cd4044e6e5ac7 Mon Sep 17 00:00:00 2001 From: igolaizola <11333576+igolaizola@users.noreply.github.com> Date: Wed, 24 Jan 2024 00:02:40 +0100 Subject: [PATCH] Add runway extend feature --- README.md | 7 ++++ cmd/vidai/main.go | 10 +---- pkg/runway/runway.go | 57 +++++++++++++++---------- pkg/runway/runway_test.go | 87 +++++++++++++++++++++++++++++++++++++++ vidai.go | 36 ++++++++-------- 5 files changed, 149 insertions(+), 48 deletions(-) create mode 100644 pkg/runway/runway_test.go diff --git a/README.md b/README.md index c95717f..ba15002 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ This is a CLI tool for [RunwayML Gen-2](https://runwayml.com/) that adds some ex ## 🚀 Features - Generate videos directly from the command line using a text or image prompt. + - Use RunwayML's extend feature to generate longer videos. - Create or extend videos longer than 4 seconds by reusing the last frame of the video as the input for the next generation. - Other handy tools to edit videos, like generating loops or resizing videos. @@ -42,6 +43,12 @@ Generate a video from a text prompt: vidai generate --token RUNWAYML_TOKEN --text "a car in the middle of the road" --output car.mp4 ``` +Generate a video from a image prompt and extend it twice (using RunwayML's extend feature): + +```bash +vidai generate --token RUNWAYML_TOKEN --image car.jpg --output car.mp4 --extend 2 +``` + Extend a video by reusing the last frame twice: ```bash diff --git a/cmd/vidai/main.go b/cmd/vidai/main.go index 84d65a3..76e29c1 100644 --- a/cmd/vidai/main.go +++ b/cmd/vidai/main.go @@ -114,18 +114,12 @@ func newGenerateCommand() *ffcli.Command { return fmt.Errorf("image or text is required") } c := vidai.New(&cfg) - urls, err := c.Generate(ctx, *image, *text, *output, *extend, + u, err := c.Generate(ctx, *image, *text, *output, *extend, *interpolate, *upscale, *watermark) if err != nil { return err } - if len(urls) == 1 { - fmt.Printf("Video URL: %s\n", urls[0]) - } else { - for i, u := range urls { - fmt.Printf("Video URL %d: %s\n", i+1, u) - } - } + fmt.Printf("Video URL: %s\n", u) return nil }, } diff --git a/pkg/runway/runway.go b/pkg/runway/runway.go index 7a3f2f6..4c0264b 100644 --- a/pkg/runway/runway.go +++ b/pkg/runway/runway.go @@ -198,14 +198,18 @@ type createTaskRequest struct { } type gen2Options struct { - Interpolate bool `json:"interpolate"` - Seed int `json:"seed"` - Upscale bool `json:"upscale"` - TextPrompt string `json:"text_prompt"` - Watermark bool `json:"watermark"` - ImagePrompt string `json:"image_prompt"` - InitImage string `json:"init_image"` - Mode string `json:"mode"` + Interpolate bool `json:"interpolate"` + Seed int `json:"seed"` + Upscale bool `json:"upscale"` + TextPrompt string `json:"text_prompt"` + Watermark bool `json:"watermark"` + ImagePrompt string `json:"image_prompt,omitempty"` + InitImage string `json:"init_image,omitempty"` + Mode string `json:"mode"` + InitVideo string `json:"init_video,omitempty"` + MotionScore int `json:"motion_score"` + UseMotionScore bool `json:"use_motion_score"` + UseMotionVectors bool `json:"use_motion_vectors"` } type taskResponse struct { @@ -243,7 +247,7 @@ type artifact struct { ParentAssetGroupId string `json:"parentAssetGroupId"` Filename string `json:"filename"` URL string `json:"url"` - FileSize int `json:"fileSize"` + FileSize string `json:"fileSize"` IsDirectory bool `json:"isDirectory"` PreviewURLs []string `json:"previewUrls"` Private bool `json:"private"` @@ -251,13 +255,13 @@ type artifact struct { Deleted bool `json:"deleted"` Reported bool `json:"reported"` Metadata struct { - FrameRate int `json:"frameRate"` - Duration int `json:"duration"` - Dimensions []int `json:"dimensions"` + FrameRate int `json:"frameRate"` + Duration float32 `json:"duration"` + Dimensions []int `json:"dimensions"` } `json:"metadata"` } -func (c *Client) Generate(ctx context.Context, imageURL, textPrompt string, interpolate, upscale, watermark bool) (string, error) { +func (c *Client) Generate(ctx context.Context, assetURL, textPrompt string, interpolate, upscale, watermark, extend bool) (string, error) { // Load team ID if err := c.loadTeamID(ctx); err != nil { return "", fmt.Errorf("runway: couldn't load team id: %w", err) @@ -266,6 +270,14 @@ func (c *Client) Generate(ctx context.Context, imageURL, textPrompt string, inte // Generate seed seed := rand.Intn(1000000000) + var imageURL string + var videoURL string + if extend { + videoURL = assetURL + } else { + imageURL = assetURL + } + // Create task createReq := &createTaskRequest{ TaskType: "gen2", @@ -279,14 +291,17 @@ func (c *Client) Generate(ctx context.Context, imageURL, textPrompt string, inte }{ Seconds: 4, Gen2Options: gen2Options{ - Interpolate: interpolate, - Seed: seed, - Upscale: upscale, - TextPrompt: textPrompt, - Watermark: watermark, - ImagePrompt: imageURL, - InitImage: imageURL, - Mode: "gen2", + Interpolate: interpolate, + Seed: seed, + Upscale: upscale, + TextPrompt: textPrompt, + Watermark: watermark, + ImagePrompt: imageURL, + InitImage: imageURL, + InitVideo: videoURL, + Mode: "gen2", + UseMotionScore: true, + MotionScore: 22, }, Name: fmt.Sprintf("Gen-2, %d", seed), AssetGroupName: "Gen-2", diff --git a/pkg/runway/runway_test.go b/pkg/runway/runway_test.go new file mode 100644 index 0000000..83272de --- /dev/null +++ b/pkg/runway/runway_test.go @@ -0,0 +1,87 @@ +package runway + +import ( + "encoding/json" + "testing" +) + +func TestUnmarshal(t *testing.T) { + js := `{ + "task": { + "id": "00000000-0000-0000-0000-000000000000", + "name": "Gen-2, 100000", + "image": null, + "createdAt": "2024-01-01T01:01:01.001Z", + "updatedAt": "2024-01-01T01:01:01.001Z", + "taskType": "gen2", + "options": { + "seconds": 4, + "gen2Options": { + "interpolate": true, + "seed": 100000, + "upscale": true, + "text_prompt": "", + "watermark": false, + "image_prompt": "https://a.url.test", + "init_image": "https://a.url.test", + "mode": "gen2", + "motion_score": 22, + "use_motion_score": true, + "use_motion_vectors": false + }, + "name": "Gen-2, 100000", + "assetGroupName": "Gen-2", + "exploreMode": false, + "recordingEnabled": true + }, + "status": "SUCCEEDED", + "error": null, + "progressText": null, + "progressRatio": "1", + "placeInLine": null, + "estimatedTimeToStartSeconds": null, + "artifacts": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "createdAt": "2024-01-01T01:01:01.001Z", + "updatedAt": "2024-01-01T01:01:01.001Z", + "userId": 100000, + "createdBy": 100000, + "taskId": "00000000-0000-0000-0000-000000000000", + "parentAssetGroupId": "00000000-0000-0000-0000-000000000000", + "filename": "Gen-2, 100000.mp4", + "url": "https://a.url.test", + "fileSize": "100000", + "isDirectory": false, + "previewUrls": [ + "https://a.url.test", + "https://a.url.test", + "https://a.url.test", + "https://a.url.test" + ], + "private": true, + "privateInTeam": true, + "deleted": false, + "reported": false, + "metadata": { + "frameRate": 24, + "duration": 8.1, + "dimensions": [ + 2816, + 1536 + ], + "size": { + "width": 2816, + "height": 1536 + } + } + } + ], + "sharedAsset": null + } +}` + var resp taskResponse + if err := json.Unmarshal([]byte(js), &resp); err != nil { + t.Fatal(err) + } +} diff --git a/vidai.go b/vidai.go index ac833dc..116bec7 100644 --- a/vidai.go +++ b/vidai.go @@ -48,10 +48,10 @@ func New(cfg *Config) *Client { // Generate generates a video from an image and a text prompt. func (c *Client) Generate(ctx context.Context, image, text, output string, - extend int, interpolate, upscale, watermark bool) ([]string, error) { + extend int, interpolate, upscale, watermark bool) (string, error) { b, err := os.ReadFile(image) if err != nil { - return nil, fmt.Errorf("vidai: couldn't read image: %w", err) + return "", fmt.Errorf("vidai: couldn't read image: %w", err) } name := filepath.Base(image) @@ -59,12 +59,20 @@ func (c *Client) Generate(ctx context.Context, image, text, output string, if image != "" { imageURL, err = c.client.Upload(ctx, name, b) if err != nil { - return nil, fmt.Errorf("vidai: couldn't upload image: %w", err) + return "", fmt.Errorf("vidai: couldn't upload image: %w", err) } } - videoURL, err := c.client.Generate(ctx, imageURL, text, interpolate, upscale, watermark) + videoURL, err := c.client.Generate(ctx, imageURL, text, interpolate, upscale, watermark, false) if err != nil { - return nil, fmt.Errorf("vidai: couldn't generate video: %w", err) + return "", fmt.Errorf("vidai: couldn't generate video: %w", err) + } + + // Extend video + for i := 0; i < extend; i++ { + videoURL, err = c.client.Generate(ctx, videoURL, "", interpolate, upscale, watermark, true) + if err != nil { + return "", fmt.Errorf("vidai: couldn't extend video: %w", err) + } } // Use temp file if no output is set and we need to extend the video @@ -77,24 +85,14 @@ func (c *Client) Generate(ctx context.Context, image, text, output string, // Download video if videoPath != "" { if err := c.download(ctx, videoURL, videoPath); err != nil { - return nil, fmt.Errorf("vidai: couldn't download video: %w", err) - } - } - - // Extend video - if extend > 0 { - extendURLs, err := c.Extend(ctx, videoPath, output, extend, - interpolate, upscale, watermark) - if err != nil { - return nil, fmt.Errorf("vidai: couldn't extend video: %w", err) + return "", fmt.Errorf("vidai: couldn't download video: %w", err) } - return append([]string{output}, extendURLs...), nil } - return []string{videoURL}, nil + return videoURL, nil } -// Extend extends a video using the last frame of the previous video. +// Extend extends a video using the previous video. func (c *Client) Extend(ctx context.Context, input, output string, n int, interpolate, upscale, watermark bool) ([]string, error) { base := strings.TrimSuffix(filepath.Base(input), filepath.Ext(input)) @@ -133,7 +131,7 @@ func (c *Client) Extend(ctx context.Context, input, output string, n int, if err != nil { return nil, fmt.Errorf("vidai: couldn't upload image: %w", err) } - videoURL, err := c.client.Generate(ctx, imageURL, "", interpolate, upscale, watermark) + videoURL, err := c.client.Generate(ctx, imageURL, "", interpolate, upscale, watermark, false) if err != nil { return nil, fmt.Errorf("vidai: couldn't generate video: %w", err) }