diff --git a/src/gateway/client/dispatch.rs b/src/gateway/client/dispatch.rs index bf6aabd2b8..fba9e54b90 100644 --- a/src/gateway/client/dispatch.rs +++ b/src/gateway/client/dispatch.rs @@ -1,7 +1,6 @@ -#[cfg(feature = "framework")] use std::sync::Arc; -use super::event_handler::InternalEventHandler; +use super::event_handler::{EventHandler, RawEventHandler}; use super::{Context, FullEvent}; #[cfg(feature = "cache")] use crate::cache::{Cache, CacheUpdate}; @@ -48,18 +47,10 @@ pub(crate) async fn dispatch_model( event: Event, context: Context, #[cfg(feature = "framework")] framework: Option>, - event_handler: Option, + event_handler: Option>, + raw_event_handler: Option>, ) { - let (handler, raw_handler) = match event_handler { - Some(InternalEventHandler::Normal(handler)) => (Some(handler), None), - Some(InternalEventHandler::Both { - raw, - normal, - }) => (Some(normal), Some(raw)), - Some(InternalEventHandler::Raw(raw_handler)) => (None, Some(raw_handler)), - None => (None, None), - }; - if let Some(raw_handler) = raw_handler { + if let Some(raw_handler) = raw_event_handler { raw_handler.raw_event(context.clone(), &event).await; } @@ -78,7 +69,7 @@ pub(crate) async fn dispatch_model( framework.dispatch(&context, &full_event).await; } - if let Some(handler) = handler { + if let Some(handler) = event_handler { if let Some(extra_event) = extra_event { extra_event.dispatch(context.clone(), &*handler).await; } diff --git a/src/gateway/client/event_handler.rs b/src/gateway/client/event_handler.rs index 1e39d8be1e..3430930963 100644 --- a/src/gateway/client/event_handler.rs +++ b/src/gateway/client/event_handler.rs @@ -1,7 +1,6 @@ use std::collections::VecDeque; #[cfg(feature = "cache")] use std::num::NonZeroU16; -use std::sync::Arc; use async_trait::async_trait; use strum::{EnumCount, IntoStaticStr, VariantNames}; @@ -535,10 +534,3 @@ pub trait RawEventHandler: Send + Sync { true } } - -#[derive(Clone)] -pub enum InternalEventHandler { - Raw(Arc), - Normal(Arc), - Both { raw: Arc, normal: Arc }, -} diff --git a/src/gateway/client/mod.rs b/src/gateway/client/mod.rs index d3f225c36b..41094164ef 100644 --- a/src/gateway/client/mod.rs +++ b/src/gateway/client/mod.rs @@ -38,7 +38,7 @@ use futures::StreamExt as _; use tracing::debug; pub use self::context::Context; -pub use self::event_handler::{EventHandler, FullEvent, InternalEventHandler, RawEventHandler}; +pub use self::event_handler::{EventHandler, FullEvent, RawEventHandler}; #[cfg(feature = "cache")] use crate::cache::Cache; #[cfg(feature = "cache")] @@ -289,18 +289,8 @@ impl IntoFuture for ClientBuilder { let presence = self.presence; let http = self.http; - let event_handler = match (self.event_handler, self.raw_event_handler) { - (Some(normal), Some(raw)) => Some(InternalEventHandler::Both { - normal, - raw, - }), - (Some(h), None) => Some(InternalEventHandler::Normal(h)), - (None, Some(h)) => Some(InternalEventHandler::Raw(h)), - (None, None) => None, - }; - if let Some(ratelimiter) = &http.ratelimiter { - if let Some(InternalEventHandler::Normal(event_handler)) = &event_handler { + if let Some(event_handler) = &self.event_handler { let event_handler = Arc::clone(event_handler); ratelimiter.set_ratelimit_callback(Box::new(move |info| { let event_handler = Arc::clone(&event_handler); @@ -334,7 +324,8 @@ impl IntoFuture for ClientBuilder { let framework_cell = Arc::new(OnceLock::new()); let (shard_manager, shard_manager_ret_value) = ShardManager::new(ShardManagerOptions { data: Arc::clone(&data), - event_handler, + event_handler: self.event_handler, + raw_event_handler: self.raw_event_handler, #[cfg(feature = "framework")] framework: Arc::clone(&framework_cell), #[cfg(feature = "voice")] diff --git a/src/gateway/sharding/shard_manager.rs b/src/gateway/sharding/shard_manager.rs index 5bdfd07ef9..dba609ea3a 100644 --- a/src/gateway/sharding/shard_manager.rs +++ b/src/gateway/sharding/shard_manager.rs @@ -16,7 +16,7 @@ use super::{ShardId, ShardQueue, ShardQueuer, ShardQueuerMessage, ShardRunnerInf use crate::cache::Cache; #[cfg(feature = "framework")] use crate::framework::Framework; -use crate::gateway::client::InternalEventHandler; +use crate::gateway::client::{EventHandler, RawEventHandler}; #[cfg(feature = "voice")] use crate::gateway::VoiceGatewayManager; use crate::gateway::{ConnectionStage, GatewayError, PresenceData}; @@ -49,7 +49,7 @@ use crate::model::gateway::GatewayIntents; /// use std::env; /// use std::sync::{Arc, OnceLock}; /// -/// use serenity::gateway::client::{EventHandler, InternalEventHandler, RawEventHandler}; +/// use serenity::gateway::client::EventHandler; /// use serenity::gateway::{ShardManager, ShardManagerOptions}; /// use serenity::http::Http; /// use serenity::model::gateway::GatewayIntents; @@ -66,12 +66,13 @@ use crate::model::gateway::GatewayIntents; /// let data = Arc::new(()); /// let shard_total = gateway_info.shards; /// let ws_url = Arc::from(gateway_info.url); -/// let event_handler = Arc::new(Handler) as Arc; +/// let event_handler = Arc::new(Handler); /// let max_concurrency = std::num::NonZeroU16::MIN; /// /// ShardManager::new(ShardManagerOptions { /// data, -/// event_handler: Some(InternalEventHandler::Normal(event_handler)), +/// event_handler: Some(event_handler), +/// raw_event_handler: None, /// framework: Arc::new(OnceLock::new()), /// # #[cfg(feature = "voice")] /// # voice_manager: None, @@ -128,6 +129,7 @@ impl ShardManager { let mut shard_queuer = ShardQueuer { data: opt.data, event_handler: opt.event_handler, + raw_event_handler: opt.raw_event_handler, #[cfg(feature = "framework")] framework: opt.framework, last_start: None, @@ -356,7 +358,8 @@ impl Drop for ShardManager { pub struct ShardManagerOptions { pub data: Arc, - pub event_handler: Option, + pub event_handler: Option>, + pub raw_event_handler: Option>, #[cfg(feature = "framework")] pub framework: Arc>>, #[cfg(feature = "voice")] diff --git a/src/gateway/sharding/shard_queuer.rs b/src/gateway/sharding/shard_queuer.rs index dc9a4802fb..61896ef978 100644 --- a/src/gateway/sharding/shard_queuer.rs +++ b/src/gateway/sharding/shard_queuer.rs @@ -22,7 +22,7 @@ use super::{ use crate::cache::Cache; #[cfg(feature = "framework")] use crate::framework::Framework; -use crate::gateway::client::InternalEventHandler; +use crate::gateway::client::{EventHandler, RawEventHandler}; #[cfg(feature = "voice")] use crate::gateway::VoiceGatewayManager; use crate::gateway::{ConnectionStage, PresenceData, Shard, ShardRunnerMessage}; @@ -42,11 +42,10 @@ pub struct ShardQueuer { /// /// [`Client::data`]: crate::Client::data pub data: Arc, - /// A reference to [`EventHandler`] or [`RawEventHandler`]. - /// - /// [`EventHandler`]: crate::gateway::client::EventHandler - /// [`RawEventHandler`]: crate::gateway::client::RawEventHandler - pub event_handler: Option, + /// A reference to an [`EventHandler`]. + pub event_handler: Option>, + /// A reference to a [`RawEventHandler`]. + pub raw_event_handler: Option>, /// A copy of the framework #[cfg(feature = "framework")] pub framework: Arc>>, @@ -223,6 +222,7 @@ impl ShardQueuer { let mut runner = ShardRunner::new(ShardRunnerOptions { data: Arc::clone(&self.data), event_handler: self.event_handler.clone(), + raw_event_handler: self.raw_event_handler.clone(), #[cfg(feature = "framework")] framework: self.framework.get().cloned(), manager: Arc::clone(&self.manager), diff --git a/src/gateway/sharding/shard_runner.rs b/src/gateway/sharding/shard_runner.rs index fd7277e196..c8000c559f 100644 --- a/src/gateway/sharding/shard_runner.rs +++ b/src/gateway/sharding/shard_runner.rs @@ -16,7 +16,7 @@ use crate::cache::Cache; #[cfg(feature = "framework")] use crate::framework::Framework; use crate::gateway::client::dispatch::dispatch_model; -use crate::gateway::client::{Context, InternalEventHandler}; +use crate::gateway::client::{Context, EventHandler, RawEventHandler}; #[cfg(feature = "voice")] use crate::gateway::VoiceGatewayManager; use crate::gateway::{ActivityData, ChunkGuildFilter, GatewayError}; @@ -30,7 +30,8 @@ use crate::model::user::OnlineStatus; /// A runner for managing a [`Shard`] and its respective WebSocket client. pub struct ShardRunner { data: Arc, - event_handler: Option, + event_handler: Option>, + raw_event_handler: Option>, #[cfg(feature = "framework")] framework: Option>, manager: Arc, @@ -58,6 +59,7 @@ impl ShardRunner { runner_tx: tx, data: opt.data, event_handler: opt.event_handler, + raw_event_handler: opt.raw_event_handler, #[cfg(feature = "framework")] framework: opt.framework, manager: opt.manager, @@ -120,7 +122,7 @@ impl ShardRunner { if post != pre { self.update_manager().await; - if let Some(InternalEventHandler::Normal(event_handler)) = &self.event_handler { + if let Some(event_handler) = &self.event_handler { let event_handler = Arc::clone(event_handler); let context = self.make_context(); let event = ShardStageUpdateEvent { @@ -173,21 +175,14 @@ impl ShardRunner { if let Some(event) = event { let context = self.make_context(); - let can_dispatch = match &self.event_handler { - Some(InternalEventHandler::Normal(handler)) => { - handler.filter_event(&context, &event) - }, - Some(InternalEventHandler::Raw(handler)) => { - handler.filter_event(&context, &event) - }, - Some(InternalEventHandler::Both { - raw, - normal, - }) => { - raw.filter_event(&context, &event) && normal.filter_event(&context, &event) - }, - None => true, - }; + let can_dispatch = self + .event_handler + .as_ref() + .map_or(true, |handler| handler.filter_event(&context, &event)) + && self + .raw_event_handler + .as_ref() + .map_or(true, |handler| handler.filter_event(&context, &event)); if can_dispatch { #[cfg(feature = "collector")] @@ -214,6 +209,7 @@ impl ShardRunner { #[cfg(feature = "framework")] self.framework.clone(), self.event_handler.clone(), + self.raw_event_handler.clone(), ), ); } @@ -509,7 +505,8 @@ impl ShardRunner { /// Options to be passed to [`ShardRunner::new`]. pub struct ShardRunnerOptions { pub data: Arc, - pub event_handler: Option, + pub event_handler: Option>, + pub raw_event_handler: Option>, #[cfg(feature = "framework")] pub framework: Option>, pub manager: Arc,