diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index fd13217a..7fde5ee3 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -1,12 +1,18 @@ import argparse +import copy import json import os -import copy import yaml from app.main import app, use_route_names_as_operation_ids -from app.routes import health, image_to_image, image_to_video, text_to_image, upscale, audio_to_text - +from app.routes import ( + audio_to_text, + health, + image_to_image, + image_to_video, + text_to_image, + upscale, +) from fastapi.openapi.utils import get_openapi # Specify Endpoints for OpenAPI schema generation. diff --git a/runner/openapi.json b/runner/openapi.json index bf4e35f7..4345e565 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -488,7 +488,7 @@ "num_inference_steps": { "type": "integer", "title": "Num Inference Steps", - "default": 25 + "default": 100 }, "num_images_per_prompt": { "type": "integer", @@ -589,7 +589,7 @@ "num_inference_steps": { "type": "integer", "title": "Num Inference Steps", - "default": 50 + "default": 75 } }, "type": "object", diff --git a/worker/docker.go b/worker/docker.go index 8d7f97e0..25951c3d 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -177,6 +177,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo }, ExposedPorts: nat.PortSet{ containerPort: struct{}{}, + "5678/tcp": struct{}{}, }, Labels: map[string]string{ containerCreatorLabel: containerCreator, @@ -205,6 +206,12 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo HostPort: containerHostPort, }, }, + "5678/tcp": []nat.PortBinding{ + { + HostIP: "0.0.0.0", + HostPort: "5678", + }, + }, }, } diff --git a/worker/multipart.go b/worker/multipart.go index de4fbd80..865b9114 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -75,6 +75,11 @@ func NewImageToImageMultipartWriter(w io.Writer, req ImageToImageMultipartReques return nil, err } } + if req.NumInferenceSteps != nil { + if err := mw.WriteField("num_inference_steps", strconv.Itoa(*req.NumInferenceSteps)); err != nil { + return nil, err + } + } if err := mw.Close(); err != nil { return nil, err @@ -142,6 +147,11 @@ func NewImageToVideoMultipartWriter(w io.Writer, req ImageToVideoMultipartReques return nil, err } } + if req.NumInferenceSteps != nil { + if err := mw.WriteField("num_inference_steps", strconv.Itoa(*req.NumInferenceSteps)); err != nil { + return nil, err + } + } if err := mw.Close(); err != nil { return nil, err @@ -187,6 +197,11 @@ func NewUpscaleMultipartWriter(w io.Writer, req UpscaleMultipartRequestBody) (*m return nil, err } } + if req.NumInferenceSteps != nil { + if err := mw.WriteField("num_inference_steps", strconv.Itoa(*req.NumInferenceSteps)); err != nil { + return nil, err + } + } if err := mw.Close(); err != nil { return nil, err diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 6cd2d1b0..0dbe8036 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -1481,31 +1481,32 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xZ227bOBN+FYL/f+nEjttsF75Lst022B6C2u1eFIHBSGObrURqeUjrDfzuCw5liTpV", - "DtJ4gayvbFnDmW8O33BI39FIppkUIIymkzuqoxWkDL+eXV2+VEoq9z1TMgNlOOCbVC/dh+EmATqhb/WS", - "DqhZZ+5BG8XFkm42A6rgL8sVxHTyGZdcD4olhe5inbz5ApGhmwE9l/F6zmzM5dzIuYHvpvaUSW2aoFDG", - "fVlIlTJDJ/SGC6bWNLCKIg2oA5rKGJI5j93yGBbMJm59sPKtEyCXca+fHkXg6W7edIWBp2wJTtR/qT22", - "B2JpecxEBHMdMQchcOnF8WmJ7FUuR6YoV0AQNr0B5SCglR+H9BJFWkLqEf4Ay0mIBdWQfkQPSNSAClgy", - "w29hnimZZqZTx7tcjlx5uTZVNvU50PMMVJvCk0CfTQk6qMkVqIZWLgwsvXuoVixAAcbMQKYrSsenNa1b", - "WTJF2TadJbbtym63NFuAWc+jFURfK4aNslCanqIYuUCxQs2NlAkwgXoA4tDi1D23gdNGgViaVcXY6PjX", - "wNZWolENNeZlW6981dYpuAOTekl4y2OQ9cd2Ei5qmfulhPN7R6JWwJerahGdvgjWvfbv25Y+hKgPolQq", - "DZdifmOjr2DqSk7GL0ItTpKco2RFW1j/kmuYM7ucdxTGaBwQwAmTM7sk3TXyOJTaO02+8bgWipPR+Hlp", - "6U9831xZo0gPM7rLu4sZNsO+Xny2c+Ffq86+3J+OnlQ7vV9DbM1dS6Jfz2ZXHXNgDIbxxH37v4IFndD/", - "DctpcpiPksNi1qsDzJcHwEpbHUA+sYTHzHWSXkjcQKr7sNX1bUosv3lNBRCmFFujDyHauoI23MASs7rY", - "FkEVrzbM2GpV0vd/0HD/Q4G2ubPcGEoDLfaRWx9AZ1Jo6GCn3jlibyHmLIyTn2za4tRoPTrMdRVWC25v", - "qYFX6MW3kAzv3PODuqtVSSj3USW9Y75FGe01IqLAMw+8xaMZfDfdiYhWVnzdPREoHibiwq+vJ2JA3TEj", - "dNDB6PXQeKEcVOBdxYkOJ2cSs3vFFPOOPNYJpZyZdpiS/uOHhye22wVT0T3HoNypWk1Xa7alsHv3nkRG", - "FfYysX6/oJPPd41Y3TUgXgdEfiMjNNNC5frNC2jdMTj5H0pRxExm7tc+6js/vKlcMojUDvvdJzc3dre5", - "hWJpbb+558ZTb2/bc5VX3LMR5eZDlyp4WxzynbbhyG5t1dlJQRuWZqGrAe5Z8b4HugkFnbHACY+xAR7p", - "FFnFzXrq4uiRu8HlHJgCVVz5IQf9T4WSlTEZ3TgdXCykp7SOFM+wOCf0TBCWZQn31UqMJMoKcnZJMp5B", - "woVPxrao+S1kAMq9/2CFQEO3oLTXNTo+OR65aMkMBMs4ndBn+NOAZsysEPYQL86OjDzahn573nBpQRCX", - "8faabybzfLgIgjZu5sVdVgoDAlelNjE8Y8oM3cHkKGaGlVegfeW4273epppD1wnxB19s6NV4NKrhCoI6", - "/KJdeHYFVdmb0XY1Y1MbRaD1wiakFBvQ5z8RQjnCt9g/ZzH54PPh7Z7sx+5HwaxZScX/hhgNnzzbj+Hc", - "WfJSGG7WZCYlecPU0kd9PP6pIBpnmSacUoQU553TfSX/UhhQgiVkCuoWFCkPhdsWhXtl2Jw+X2+uB1Tb", - "NGVqvWU2mUmC3HZLhys8/OBUCS29wJ+N6CNyLjx97Uq5TehUDhG9wbHQdbjizqS9xeGokk8sj9zjdrg4", - "3XOXq54cD22uu80dOsx9O4z/I2om/ZmrRkq8Ee0lJc6T+yJl953tnklZnaIPpDyQ8hFI6amFpHQz9g4b", - "ZXCy/yElHzZzV+8ODtvhgXlPhHmuuGu7Yf5/UTflPuYCj7sDtv59dWDegXlPhHlbFm38KqdG46KqpeJa", - "7SKRNiYXMk2t4GZNXjED39ia5n9v4WWengyHsQKWHi392+MkX34cueV0c735JwAA//+y67pMESgAAA==", + "H4sIAAAAAAAC/+xZ227bOBN+FYL/f+nEhzabhe+SbLcNtoegdrsXRWAw0thmK5FaHtJ6A7/7gkNZomQp", + "cpDGC2R9Zcsaznxz+IZD+o5GMs2kAGE0Hd9RHS0hZfj17OrylVJSue+ZkhkowwHfpHrhPgw3CdAxfacX", + "tEfNKnMP2iguFnS97lEFf1muIKbjL7jkulcsKXQX6+TNV4gMXffouYxXM2ZjLmdGzgz8MLWnTGqzDQpl", + "3Je5VCkzdExvuGBqRQOrKLIFtUdTGUMy47FbHsOc2cStD1a+cwLkMu7006MIPN3Nm7Yw8JQtwIn6L7XH", + "5kAsLI+ZiGCmI+YgBC6dHp+UyF7ncmSCcgUEYdMbUA4CWrk/pJco0hBSj/AeLMMQC6oh3YgekageFbBg", + "ht/CLFMyzUyrjve5HLnyck2qbOpzoGcZqCaFw0CfTQk6qMkVqC2tXBhYePdQrZiDAoyZgUxXlQ4GNbUb", + "YTJB4SalJbjNyna/NJuDWc2iJUTfKpaNslCanqAYuUCxQs2NlAkwgXoA4tDixD03gdNGgViYZcXY4PjX", + "wNZGYqscatTLNl75sq1zcAcqdbLwlscg64/NLJzXUvdLCef3lkQtgS+W1So6OQ3WvfHvm5Y+hqmP4lQq", + "DZdidmOjb2DqSoaj01CLkyTnKFnRFhJAcg0zZhezlsIYjAICOGFyZhekvUa6OTU6eTil9k6T7zyuhWI4", + "GL0sLf2J77dX1ijSwYz28m5jhs2wsRefzVz416qzK/enJ8+qnT6sITbmriHRb6bTq5ZBMAbDeOK+/V/B", + "nI7p//rlONnPZ8l+MezVAebLA2ClrRYgn1nCY+Y6SSckbiDVXdjq+tYllt+8pgIIU4qt0IcQbV1BE25g", + "iVlebIqgilcbZmy1KumHP2i4/6FA0+BZbgylgQb7yK2PoDMpNLSwU+8csXcQcxbGyY82TXHaaj06zHUV", + "VgNub2kLr9Dz7yEZ3rvnR3VXq5JQ7pNKOud8izLaa0REgWceeINHU/hh2hMRLa34tnsiUDxMxIVfX09E", + "j7pzRuigg9HpofFCOajAu4oTLU5OJWb3iinmHXmqI0o5M+0wJf3HTw8nz+3wUExFDxyDcqdqNV2t2YbC", + "7tx7EhlV2MvE6sOcjr/cbcXqbgvidUDktzJCMw1Url+9gNYtg5P/oRRFzGTqfu2ivvPDm8olg0jtsN99", + "dnNje5ubK5bW9psHbjz19rY5V3nFHRtRbj50qYK3wSHfabcc2a2tOjspaMPSLHQ1wD0t3ndAN6GgMxY4", + "4TFugUc6RVZxs5q4OHrkbnA5B6ZAFXd+yEH/U6FkaUxG104HF3PpKa0jxTMszjE9E4RlWcJ9tRIjibKC", + "nF2SjGeQcOGTsSlqfgsZgHLvP1oh0NAtKO11DY6HxwMXLZmBYBmnY/oCf+rRjJklwu7jzdmRkUeb0G/O", + "Gy4tCOIy3tzzTWWeDxdB0MbNvLjLSmFA4KrUJoZnTJm+O5gcxcyw8g60qxx3u9hbV3PoOiH+4IsNvRoN", + "BjVcQVD7X7ULz66gKnsz2q5mbGKjCLSe24SUYj368idCKEf4BvvnLCYffT683eF+7H4SzJqlVPxviNHw", + "8MV+DOfOklfCcLMiUynJW6YWPuqj0U8FsXWW2YZTipDivHOyr+RfCgNKsIRMQN2CIuWhcNOicK8Mm9OX", + "6/V1j2qbpkytNswmU0mQ225pf4mHH5wqoaEX+LMRfULOhaevXSm3Dp3KIaI3OBa6DlfcmTS3OBxV8onl", + "iXvcDhene+5y1ZPjoc21t7lDh3loh/H/RE2lP3PVSIk3op2kxHlyX6Rsv7PdMymrU/SBlAdSPgEpPbWQ", + "lG7G3mGjDE7291LycTN39e7gsB0emPdMmOeKu7Yb5v8XtVPuUy7wtDtg499XB+YdmPdMmLdh0dqvcmo0", + "LqpaKq7VLhJpY3Ih09QKblbkNTPwna1o/vcWXubpcb8fK2Dp0cK/PU7y5ceRW07X1+t/AgAA///2pVcb", + "EigAAA==", } // GetSwagger returns the content of the embedded swagger specification file