Skip to content

Commit

Permalink
support controlent_detect command
Browse files Browse the repository at this point in the history
  • Loading branch information
SpenserCai committed Aug 20, 2023
1 parent 71a1e4c commit f47e7e3
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 28 deletions.
3 changes: 2 additions & 1 deletion dbot/slash_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* @Date: 2023-08-16 22:10:00
* @version:
* @LastEditors: SpenserCai
* @LastEditTime: 2023-08-19 18:35:33
* @LastEditTime: 2023-08-20 13:42:49
* @Description: file content
*/
package dbot
Expand Down Expand Up @@ -42,4 +42,5 @@ func (dbot *DiscordBot) GenerateCommandList() {
dbot.AppCommand = append(dbot.AppCommand, slash_handler.SlashHandler{}.RembgOptions())
dbot.AppCommand = append(dbot.AppCommand, slash_handler.SlashHandler{}.ExtraSingleOptions())
dbot.AppCommand = append(dbot.AppCommand, slash_handler.SlashHandler{}.PngInfoOptions())
dbot.AppCommand = append(dbot.AppCommand, slash_handler.SlashHandler{}.ControlnetDetectOptions())
}
202 changes: 202 additions & 0 deletions dbot/slash_handler/controlnet_detect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* @Author: SpenserCai
* @Date: 2023-08-20 12:45:58
* @version:
* @LastEditors: SpenserCai
* @LastEditTime: 2023-08-20 14:52:44
* @Description: file content
*/

package slash_handler

import (
"fmt"
"log"
"strings"

"github.com/SpenserCai/sd-webui-discord/cluster"
"github.com/SpenserCai/sd-webui-discord/global"
"github.com/SpenserCai/sd-webui-discord/utils"

"github.com/SpenserCai/sd-webui-go/intersvc"
"github.com/bwmarrin/discordgo"
)

func (shdl SlashHandler) controlnetModuleChoice() []*discordgo.ApplicationCommandOptionChoice {
exclued := []string{
"clip_vision",
"t2ia_color_grid",
"pidinet",
"pidinet_safe",
"t2ia_sketch_pidi",
"scribble_pidinet",
"scribble_xdog",
"scribble_hed",
"normal_bae",
"lineart_realistic",
"lineart_coarse",
"lineart_anime",
"pidinet",
"pidinet_safe",
"pidinet_sketch",
"pidinet_scribble",
"inpaint_global_harmonious",
"inpaint_only",
"inpaint_only+lama",
"normal_map",
"invert",
"shuffle",
"tile_colorfix",
"tile_colorfix+sharp",
"reference_adain+attn",
"mediapipe_face",
}
choices := []*discordgo.ApplicationCommandOptionChoice{}
modulesvc := &intersvc.ControlnetModuleList{}
modulesvc.Action(global.ClusterManager.GetNodeAuto().StableClient)
if modulesvc.Error != nil {
log.Println(modulesvc.Error)
return choices
}
model_list := modulesvc.GetResponse().ModuleList
for _, model := range model_list {
choices = append(choices, &discordgo.ApplicationCommandOptionChoice{
Name: model,
Value: model,
})
}
newChoices := []*discordgo.ApplicationCommandOptionChoice{}
for _, choice := range choices {
exclu := false
for _, ex := range exclued {
if strings.Contains(choice.Name, ex) {
exclu = true
break
}
}
if !exclu {
newChoices = append(newChoices, choice)
}
}
return newChoices
}

func (shdl SlashHandler) ControlnetDetectOptions() *discordgo.ApplicationCommand {
return &discordgo.ApplicationCommand{
Name: "controlnet_detect",
Description: "Remove background from image",
Options: []*discordgo.ApplicationCommandOption{
{
Type: discordgo.ApplicationCommandOptionString,
Name: "image_url",
Description: "The url of the images,split by ','",
Required: true,
},
{
Type: discordgo.ApplicationCommandOptionString,
Name: "module",
Description: "The module to use",
Required: true,
Choices: shdl.controlnetModuleChoice(),
},
{
Type: discordgo.ApplicationCommandOptionInteger,
Name: "processor_res",
Description: "The resolution of the processor",
Required: false,
},
{
Type: discordgo.ApplicationCommandOptionNumber,
Name: "threshold_a",
Description: "The threshold of the processor",
Required: false,
},
{
Type: discordgo.ApplicationCommandOptionNumber,
Name: "threshold_b",
Description: "The threshold of the processor",
Required: false,
},
},
}
}

func (shdl SlashHandler) ControlnetDetectSetOptions(dsOpt []*discordgo.ApplicationCommandInteractionDataOption, opt *intersvc.ControlnetDetectRequest) {
opt.ControlnetProcessorRes = func() *int64 { v := int64(512); return &v }()
opt.ControlnetThresholda = func() *float64 { v := float64(64.0); return &v }()
opt.ControlnetThresholdb = func() *float64 { v := float64(64.0); return &v }()
for _, v := range dsOpt {
switch v.Name {
case "image_url":
imgUrls := strings.Split(v.StringValue(), ",")
imgs := []string{}
for _, imgUrl := range imgUrls {
img, _ := utils.GetImageBase64(imgUrl)
imgs = append(imgs, img)
}
opt.ControlnetInputImages = imgs
case "module":
opt.ControlnetModule = func() *string { v := v.StringValue(); return &v }()
case "processor_res":
opt.ControlnetProcessorRes = func() *int64 { v := v.IntValue(); return &v }()
case "threshold_a":
opt.ControlnetThresholda = func() *float64 { v := v.FloatValue(); return &v }()
case "threshold_b":
opt.ControlnetThresholdb = func() *float64 { v := v.FloatValue(); return &v }()
}
}
}

func (shdl SlashHandler) ControlnetDetectAction(s *discordgo.Session, i *discordgo.InteractionCreate, opt *intersvc.ControlnetDetectRequest, node *cluster.ClusterNode) {
msg, err := shdl.SendStateMessage("Running", s, i)
if err != nil {
log.Println(err)
return
}
if len(opt.ControlnetInputImages) > 4 {
s.FollowupMessageEdit(i.Interaction, msg.ID, &discordgo.WebhookEdit{
Content: func() *string { v := "Too many images, please input less than 4 images"; return &v }(),
})
return
}
controlnet_detect := &intersvc.ControlnetDetect{RequestItem: opt}
controlnet_detect.Action(node.StableClient)
if controlnet_detect.Error != nil {
s.FollowupMessageEdit(i.Interaction, msg.ID, &discordgo.WebhookEdit{
Content: func() *string { v := controlnet_detect.Error.Error(); return &v }(),
})
} else {
files := make([]*discordgo.File, 0)
for n, img := range controlnet_detect.GetResponse().Images {
imageReader, err := utils.GetImageReaderByBase64(img)
if err != nil {
s.FollowupMessageEdit(i.Interaction, msg.ID, &discordgo.WebhookEdit{
Content: func() *string { v := err.Error(); return &v }(),
})
return
}
files = append(files, &discordgo.File{
Name: fmt.Sprintf("result_%d.png", n),
Reader: imageReader,
ContentType: "image/png",
})
}
s.FollowupMessageEdit(i.Interaction, msg.ID, &discordgo.WebhookEdit{
Content: func() *string { v := "Success"; return &v }(),
Files: files,
})
}
}

func (shdl SlashHandler) ControlnetDetectCommandHandler(s *discordgo.Session, i *discordgo.InteractionCreate) {
option := &intersvc.ControlnetDetectRequest{}
shdl.ReportCommandInfo(s, i)
node := global.ClusterManager.GetNodeAuto()
action := func() (map[string]interface{}, error) {
shdl.ControlnetDetectSetOptions(i.ApplicationCommandData().Options, option)
shdl.ControlnetDetectAction(s, i, option, node)
return nil, nil
}
callback := func() {}
node.ActionQueue.AddTask(shdl.GenerateTaskID(i), action, callback)
}
20 changes: 11 additions & 9 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@ module github.com/SpenserCai/sd-webui-discord
go 1.19

require (
github.com/SpenserCai/sd-webui-go v0.2.3
github.com/SpenserCai/sd-webui-go v0.2.4
github.com/bwmarrin/discordgo v0.27.1
golang.org/x/text v0.3.7
golang.org/x/text v0.7.0
)

require (
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/logr v1.2.4 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/analysis v0.21.4 // indirect
github.com/go-openapi/errors v0.20.4 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/jsonreference v0.20.0 // indirect
github.com/go-openapi/jsonpointer v0.20.0 // indirect
github.com/go-openapi/jsonreference v0.20.2 // indirect
github.com/go-openapi/loads v0.21.2 // indirect
github.com/go-openapi/runtime v0.26.0 // indirect
github.com/go-openapi/spec v0.20.8 // indirect
github.com/go-openapi/spec v0.20.9 // indirect
github.com/go-openapi/strfmt v0.21.7 // indirect
github.com/go-openapi/swag v0.22.4 // indirect
github.com/go-openapi/validate v0.22.1 // indirect
Expand All @@ -28,9 +28,11 @@ require (
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect
go.mongodb.org/mongo-driver v1.11.3 // indirect
go.opentelemetry.io/otel v1.14.0 // indirect
go.opentelemetry.io/otel/trace v1.14.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
go.mongodb.org/mongo-driver v1.12.1 // indirect
go.opentelemetry.io/otel v1.16.0 // indirect
go.opentelemetry.io/otel/metric v1.16.0 // indirect
go.opentelemetry.io/otel/trace v1.16.0 // indirect
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
golang.org/x/sys v0.5.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
Expand Down
Loading

0 comments on commit f47e7e3

Please sign in to comment.