Skip to content

Commit

Permalink
feat: add fal tool for image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
hwzhuhao committed Oct 11, 2024
1 parent 1c1e008 commit 326547a
Show file tree
Hide file tree
Showing 13 changed files with 738 additions and 23 deletions.
7 changes: 7 additions & 0 deletions api/core/tools/provider/builtin/fal/_assets/icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions api/core/tools/provider/builtin/fal/fal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import requests

from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController


class FalProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
url = "https://queue.fal.run/fal-ai/flux/schnell"
headers = {
"Content-Type": "application/json",
"Authorization": f"Key {credentials.get('fal_api_key')}",
}
data = {"prompt": "cute girl, blue eyes, white hair, anime style."}
response = requests.post(url, headers=headers, data=data)
if response.status_code != 200:
raise ToolProviderCredentialValidationError("Fal API key is invalid")
21 changes: 21 additions & 0 deletions api/core/tools/provider/builtin/fal/fal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
identity:
author: zhuhao
name: fal
label:
en_US: Fal
zh_CN: Fal
description:
en_US: The image generation API provided by fal.ai includes Flux, AuraFlow, Stable Diffusion and Kolors models.
zh_CN: Fal提供的图片生成API, 包含Flux,AuraFlow,Stable Diffusion和Kolors模型。
icon: icon.svg
tags:
- image
credentials_for_provider:
fal_api_key:
type: secret-input
required: true
label:
en_US: Fal API Key
placeholder:
en_US: Please input your Fal API key
url: https://fal.ai/dashboard/keys
35 changes: 35 additions & 0 deletions api/core/tools/provider/builtin/fal/tools/auraflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Union

import fal_client

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool


class AuraFlowTool(BuiltinTool):
"""
A tool for generating image via Fal.ai AuraFlow model
"""

def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
api_key = self.runtime.credentials.get("fal_api_key", "")
if not api_key:
return self.create_text_message("Please input fal api key")
client = fal_client.SyncClient(key=api_key)
payload = {
"prompt": tool_parameters.get("prompt"),
"num_images": tool_parameters.get("num_images", 1),
"seed": tool_parameters.get("seed"),
"guidance_scale": tool_parameters.get("guidance_scale", 3.5),
"num_inference_steps": tool_parameters.get("num_inference_steps", 50),
"expand_prompt": tool_parameters.get("expand_prompt", True),
}
handler = client.submit("fal-ai/aura-flow", arguments=payload)
request_id = handler.request_id
res = client.result("fal-ai/aura-flow", request_id)
result = []
for image in res.get("images", []):
result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value))
return result
78 changes: 78 additions & 0 deletions api/core/tools/provider/builtin/fal/tools/auraflow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
identity:
name: auraflow
author: zhuhao
label:
en_US: AuraFlow
icon: icon.svg
description:
human:
en_US: Generate image via Fal.ai AuraFlow model.
llm: This tool is used to generate image from prompt via Fal.ai AuraFlow model.
parameters:
- name: prompt
type: string
required: true
label:
en_US: prompt
zh_Hans: 提示词
human_description:
en_US: The text prompt used to generate the image.
zh_Hans: 建议用英文的生成图片提示词以获得更好的生成效果。
llm_description: this prompt text will be used to generate image.
form: llm
- name: num_inference_steps
type: number
required: true
default: 28
min: 1
max: 100
label:
en_US: Num Inference Steps
zh_Hans: 生成图片的步数
form: form
human_description:
en_US: The number of inference steps to perform. More steps produce higher quality but take longer.
zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。
- name: seed
type: number
min: 0
max: 9999999999
label:
en_US: Seed
zh_Hans: 种子
human_description:
en_US: The same seed and prompt can produce similar images.
zh_Hans: 相同的种子和提示可以产生相似的图像。
form: form
- name: guidance_scale
type: number
required: false
default: 3.5
label:
en_US: Guidance Scale
zh_Hans: 引导缩放比例
form: form
human_description:
en_US: The Guidance scale is a measure of how close you want the model to stick to your prompt when looking for a related image to you.
zh_Hans: 无分类器引导比例是衡量在寻找相关图像展示时最贴近提示的一个尺度.
- name: num_images
type: number
default: 1
label:
en_US: Image Number
zh_Hans: 图像数量
human_description:
en_US: The number of images to generate
zh_Hans: 生成图像的数量
form: form
- name: expand_prompt
type: boolean
default: true
label:
en_US: prompt expansion
zh_Hans: 是否进行提示词扩展
human_description:
en_US: Whether to perform prompt expansion
zh_Hans: 是否进行提示词扩展
llm_description: Perform prompt expansion if true.
form: llm
40 changes: 40 additions & 0 deletions api/core/tools/provider/builtin/fal/tools/flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any, Union

import fal_client

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

FLUX_MODEL = {"schnell": "fal-ai/flux/schnell", "dev": "fal-ai/flux/dev", "pro": ""}


class FluxTool(BuiltinTool):
"""
A tool for generating image via Fal.ai Flux model
"""

def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
api_key = self.runtime.credentials.get("fal_api_key", "")
if not api_key:
return self.create_text_message("Please input fal api key")
client = fal_client.SyncClient(key=api_key)
model = tool_parameters.get("model", "schnell")
model_name = FLUX_MODEL.get(model)
payload = {
"prompt": tool_parameters.get("prompt"),
"image_size": tool_parameters.get("image_size", "landscape_4_3"),
"num_images": tool_parameters.get("num_images", 1),
"seed": tool_parameters.get("seed"),
"num_inference_steps": tool_parameters.get("num_inference_steps", 4),
"sync_mode": tool_parameters.get("sync_mode", True),
"enable_safety_checker": tool_parameters.get("enable_safety_checker", True),
}
handler = client.submit(model_name, arguments=payload)
request_id = handler.request_id
res = client.result(model_name, request_id)
result = []
for image in res.get("images", []):
result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value))
return result
132 changes: 132 additions & 0 deletions api/core/tools/provider/builtin/fal/tools/flux.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
identity:
name: flux
author: zhuhao
label:
en_US: Flux
icon: icon.svg
description:
human:
en_US: Generate image via Fal.ai flux model.
llm: This tool is used to generate image from prompt via Fal.ai flux model.
parameters:
- name: model
type: select
required: true
options:
- value: schnell
label:
en_US: Flux.1-schnell
- value: dev
label:
en_US: Flux.1-dev
- value: pro
label:
en_US: Flux.1-pro
default: schnell
label:
en_US: Choose Image Model
zh_Hans: 选择生成图片的模型
form: form
- name: prompt
type: string
required: true
label:
en_US: prompt
zh_Hans: 提示词
human_description:
en_US: The text prompt used to generate the image.
zh_Hans: 建议用英文的生成图片提示词以获得更好的生成效果。
llm_description: this prompt text will be used to generate image.
form: llm
- name: image_size
type: select
required: true
options:
- value: square_hd
label:
en_US: square_hd
- value: square
label:
en_US: square
- value: portrait_4_3
label:
en_US: portrait_4_3
- value: portrait_16_9
label:
en_US: portrait_16_9
- value: landscape_4_3
label:
en_US: landscape_4_3
- value: landscape_16_9
label:
en_US: landscape_16_9
default: landscape_4_3
label:
en_US: The size of the generated image
zh_Hans: 选择生成的图片大小
form: form
- name: num_inference_steps
type: number
required: true
default: 28
min: 1
max: 100
label:
en_US: Num Inference Steps
zh_Hans: 生成图片的步数
form: form
human_description:
en_US: The number of inference steps to perform. More steps produce higher quality but take longer.
zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。
- name: seed
type: number
min: 0
max: 9999999999
label:
en_US: Seed
zh_Hans: 种子
human_description:
en_US: The same seed and prompt can produce similar images.
zh_Hans: 相同的种子和提示可以产生相似的图像。
form: form
- name: guidance_scale
type: number
required: false
default: 3.5
label:
en_US: Guidance Scale
zh_Hans: 引导缩放比例
form: form
human_description:
en_US: The Guidance scale is a measure of how close you want the model to stick to your prompt when looking for a related image to you.
zh_Hans: 无分类器引导比例是衡量在寻找相关图像展示时最贴近提示的一个尺度.
- name: num_images
type: number
default: 1
label:
en_US: Image Number
zh_Hans: 图像数量
human_description:
en_US: The number of images to generate
zh_Hans: 生成图像的数量
form: form
- name: sync_mode
type: boolean
default: false
label:
en_US: Sync Mode
zh_Hans: 同步模式
human_description:
en_US: The function will wait for the image to be generated and uploaded before returning the response if true
zh_Hans: 为真时该函数将在返回响应之前等待图像生成并上传.
form: form
- name: enable_safety_checker
type: boolean
default: true
label:
en_US: Enable Safety Checker
zh_Hans: 开启安全检测
human_description:
en_US: The safety checker will be enabled if true
zh_Hans: 为真时开启安全检测.
form: form
38 changes: 38 additions & 0 deletions api/core/tools/provider/builtin/fal/tools/kolors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Any, Union

import fal_client

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool


class AuraFlowTool(BuiltinTool):
"""
A tool for generating image via Fal.ai AuraFlow model
"""

def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
api_key = self.runtime.credentials.get("fal_api_key", "")
if not api_key:
return self.create_text_message("Please input fal api key")
client = fal_client.SyncClient(key=api_key)
payload = {
"prompt": tool_parameters.get("prompt"),
"negative_prompt": tool_parameters.get("negative_prompt", ""),
"image_size": tool_parameters.get("image_size", "square_hd"),
"num_images": tool_parameters.get("num_images", 1),
"seed": tool_parameters.get("seed"),
"guidance_scale": tool_parameters.get("guidance_scale", 5),
"num_inference_steps": tool_parameters.get("num_inference_steps", 50),
"sync_mode": tool_parameters.get("sync_mode", True),
"enable_safety_checker": tool_parameters.get("enable_safety_checker", True),
}
handler = client.submit("fal-ai/kolors", arguments=payload)
request_id = handler.request_id
res = client.result("fal-ai/kolors", request_id)
result = []
for image in res.get("images", []):
result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value))
return result
Loading

0 comments on commit 326547a

Please sign in to comment.