diff --git a/main.go b/main.go index 8899431..2654e50 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "golang.org/x/exp/slices" "github.com/diamondburned/arikawa/v3/api" + "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/gateway" "github.com/diamondburned/arikawa/v3/session" "github.com/joho/godotenv" @@ -34,19 +35,21 @@ func main() { // Parse the command line arguments. flag.Parse() s := session.New("Bot " + token) - s.AddHandler(func(c *gateway.MessageCreateEvent) { + + // Add a handler for the message create and update events. + handle := func(m discord.Message) { // Check if the message is in one of the specified channel IDs. - if !slices.Contains(flag.Args(), c.ChannelID.String()) { + if !slices.Contains(flag.Args(), m.ChannelID.String()) { return } // Check if the message has attachments or embeds. - if len(c.Message.Attachments) > 0 || containsEmbeds(c) { + if len(m.Attachments) > 0 || containsEmbeds(m) { return } // Send a DM to the user. - channel, err := s.CreatePrivateChannel(c.Author.ID) + channel, err := s.CreatePrivateChannel(m.Author.ID) if err != nil { log.Println("Failed to create private channel:", err) return @@ -59,15 +62,23 @@ func main() { } if _, err := s.SendMessageComplex(channel.ID, api.SendMessageData{ - Content: c.Message.Content, + Content: m.Content, }); err != nil { log.Println("Failed to send message:", err) } // Delete the message. - if err := s.DeleteMessage(c.ChannelID, c.ID, "No attachments"); err != nil { + if err := s.DeleteMessage(m.ChannelID, m.ID, "No attachments"); err != nil { log.Println("Failed to delete message:", err) } + } + + s.AddHandler(func(c *gateway.MessageCreateEvent) { + handle(c.Message) + }) + + s.AddHandler(func(c *gateway.MessageUpdateEvent) { + handle(c.Message) }) // Add the needed Gateway intents. @@ -91,10 +102,10 @@ func main() { } // containsEmbeds checks if the message contains embeds regular expressions. -func containsEmbeds(c *gateway.MessageCreateEvent) bool { - if len(c.Message.Embeds) > 0 { +func containsEmbeds(m discord.Message) bool { + if len(m.Embeds) > 0 { return true } - return embedRegex.MatchString(c.Message.Content) + return embedRegex.MatchString(m.Content) }