diff --git a/src/channel.rs b/src/channel.rs index 377443c..36816e4 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -32,7 +32,7 @@ use crate::{ events::{FetchAllUserChannelPermissions, SetUserChannelPermissions}, Persistence, }, - server::Server, + server::{response::IntoProtocol, Server}, }; #[derive(Copy, Clone)] @@ -441,7 +441,7 @@ impl Handler for Channel { } // send the channel's topic to the joining user - for message in ChannelTopic::new(self).into_messages(self.name.to_string(), true) { + for message in ChannelTopic::new(self, true).into_messages(&self.name) { msg.client.do_send(Broadcast { message, span: Span::current(), @@ -497,8 +497,7 @@ impl Handler for Channel { }); for (client, connection) in &self.clients { - for message in ChannelTopic::new(self).into_messages(connection.nick.to_string(), false) - { + for message in ChannelTopic::new(self, false).into_messages(&connection.nick) { client.do_send(Broadcast { message, span: Span::current(), @@ -569,7 +568,7 @@ impl Handler for Channel { #[instrument(parent = &msg.span, skip_all)] fn handle(&mut self, msg: ChannelFetchTopic, _ctx: &mut Self::Context) -> Self::Result { - MessageResult(ChannelTopic::new(self)) + MessageResult(ChannelTopic::new(self, msg.skip_on_none)) } } diff --git a/src/channel/response.rs b/src/channel/response.rs index c8a4bc1..5fc927b 100644 --- a/src/channel/response.rs +++ b/src/channel/response.rs @@ -4,25 +4,29 @@ use itertools::Itertools; use crate::{ channel::{permissions::Permission, Channel, CurrentChannelTopic}, connection::InitiatedConnection, + server::response::IntoProtocol, SERVER_NAME, }; pub struct ChannelTopic { pub channel_name: String, pub topic: Option, + pub skip_on_none: bool, } impl ChannelTopic { #[must_use] - pub fn new(channel: &Channel) -> Self { + pub fn new(channel: &Channel, skip_on_none: bool) -> Self { Self { channel_name: channel.name.to_string(), topic: channel.topic.clone(), + skip_on_none, } } +} - #[must_use] - pub fn into_messages(self, for_user: String, skip_on_none: bool) -> Vec { +impl IntoProtocol for ChannelTopic { + fn into_messages(self, for_user: &str) -> Vec { if let Some(topic) = self.topic { vec![ Message { @@ -43,7 +47,7 @@ impl ChannelTopic { command: Command::Response( Response::RPL_TOPICWHOTIME, vec![ - for_user, + for_user.to_string(), self.channel_name.to_string(), topic.set_by, topic.set_time.timestamp().to_string(), @@ -51,13 +55,17 @@ impl ChannelTopic { ), }, ] - } else if !skip_on_none { + } else if !self.skip_on_none { vec![Message { tags: None, prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())), command: Command::Response( Response::RPL_NOTOPIC, - vec![for_user, self.channel_name, "No topic is set".to_string()], + vec![ + for_user.to_string(), + self.channel_name, + "No topic is set".to_string(), + ], ), }] } else { @@ -83,9 +91,10 @@ impl ChannelWhoList { .collect(), } } +} - #[must_use] - pub fn into_messages(self, for_user: &str) -> Vec { +impl IntoProtocol for ChannelWhoList { + fn into_messages(self, for_user: &str) -> Vec { let mut out = Vec::with_capacity(self.nick_list.len()); for (perm, conn) in self.nick_list { @@ -233,18 +242,17 @@ pub enum ChannelJoinRejectionReason { Banned, } -impl ChannelJoinRejectionReason { - #[must_use] - pub fn into_message(self) -> Message { +impl IntoProtocol for ChannelJoinRejectionReason { + fn into_messages(self, for_user: &str) -> Vec { match self { - Self::Banned => Message { + Self::Banned => vec![Message { tags: None, prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())), command: Command::Response( Response::ERR_BANNEDFROMCHAN, - vec!["Cannot join channel (+b)".to_string()], + vec![for_user.to_string(), "Cannot join channel (+b)".to_string()], ), - }, + }], } } } diff --git a/src/client.rs b/src/client.rs index 7afc711..b00f71c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -36,7 +36,7 @@ use crate::{ Persistence, }, server::{ - response::{NoSuchNick, WhoList}, + response::{IntoProtocol, WhoList}, Server, }, SERVER_NAME, @@ -159,10 +159,9 @@ impl Client { ctx: &mut Context, channel: &Addr, message: M, - map: impl FnOnce(M::Result, &Self) -> Vec + 'static, ) where M: actix::Message + Send + 'static, - M::Result: Send, + M::Result: Send + IntoProtocol, Channel: Handler, ::Context: ToEnvelope, { @@ -170,21 +169,17 @@ impl Client { .send(message) .into_actor(self) .map(move |result, ref mut this, _ctx| { - for message in (map)(result.unwrap(), this) { + for message in result.unwrap().into_messages(&this.connection.nick) { this.writer.write(message); } }); ctx.spawn(fut); } - fn server_send_map_write( - &self, - ctx: &mut Context, - message: M, - map: impl FnOnce(M::Result, &Self) -> Vec + 'static, - ) where + fn server_send_map_write(&self, ctx: &mut Context, message: M) + where M: actix::Message + Send + 'static, - M::Result: Send, + M::Result: Send + IntoProtocol, Server: Handler, ::Context: ToEnvelope, { @@ -193,7 +188,7 @@ impl Client { .send(message) .into_actor(self) .map(move |result, ref mut this, _ctx| { - for message in (map)(result.unwrap(), this) { + for message in result.unwrap().into_messages(&this.connection.nick) { this.writer.write(message); } }); @@ -304,7 +299,7 @@ impl Handler for Client { fn handle(&mut self, _msg: ForceDisconnect, ctx: &mut Self::Context) -> Self::Result { ctx.stop(); - MessageResult(true) + MessageResult(Ok(())) } } @@ -450,7 +445,9 @@ impl Handler for Client { Ok(v) => v, Err(error) => { error!(?error, "User failed to join channel"); - this.writer.write(error.into_message()); + for m in error.into_messages(&this.connection.nick) { + this.writer.write(m); + } continue; } }; @@ -718,8 +715,10 @@ impl StreamHandler> for Client { self.channel_send_map_write( ctx, channel, - ChannelFetchTopic { span }, - |res, this| res.into_messages(this.connection.nick.to_string(), false), + ChannelFetchTopic { + span, + skip_on_none: false, + }, ); } } @@ -740,9 +739,7 @@ impl StreamHandler> for Client { } Command::LIST(_, _) => { let span = Span::current(); - self.server_send_map_write(ctx, ChannelList { span }, |res, this| { - res.into_messages(this.connection.nick.to_string()) - }); + self.server_send_map_write(ctx, ChannelList { span }); } Command::INVITE(nick, channel) => { let Some(channel) = self.channels.get(&channel) else { @@ -800,15 +797,11 @@ impl StreamHandler> for Client { } Command::MOTD(_) => { let span = Span::current(); - self.server_send_map_write(ctx, ServerFetchMotd { span }, |res, this| { - res.into_messages(this.connection.nick.to_string()) - }); + self.server_send_map_write(ctx, ServerFetchMotd { span }); } Command::LUSERS(_, _) => { let span = Span::current(); - self.server_send_map_write(ctx, ServerListUsers { span }, |res, this| { - res.into_messages(&this.connection.nick) - }); + self.server_send_map_write(ctx, ServerListUsers { span }); } Command::VERSION(_) => { self.writer.write(Message { @@ -843,9 +836,7 @@ impl StreamHandler> for Client { } Command::ADMIN(_) => { let span = Span::current(); - self.server_send_map_write(ctx, ServerAdminInfo { span }, |res, this| { - res.into_messages(&this.connection.nick) - }); + self.server_send_map_write(ctx, ServerAdminInfo { span }); } Command::INFO(_) => { static INFO_STR: &str = include_str!("../text/info.txt"); @@ -874,15 +865,11 @@ impl StreamHandler> for Client { } Command::WHO(Some(query), _) => { let span = Span::current(); - self.server_send_map_write(ctx, FetchWhoList { span, query }, |res, this| { - res.into_messages(&this.connection.nick) - }); + self.server_send_map_write(ctx, FetchWhoList { span, query }); } Command::WHOIS(Some(query), _) => { let span = Span::current(); - self.server_send_map_write(ctx, FetchWhois { span, query }, |res, this| { - res.into_messages(&this.connection.nick) - }); + self.server_send_map_write(ctx, FetchWhois { span, query }); } Command::WHOWAS(_, _, _) => {} Command::KILL(nick, comment) => { @@ -936,16 +923,9 @@ impl StreamHandler> for Client { ctx, ForceDisconnect { span, - user: user.to_string(), + user, comment, }, - move |res, this| { - if res { - vec![] - } else { - NoSuchNick { nick: user }.into_messages(&this.connection.nick) - } - }, ); } Command::AUTHENTICATE(_) => { diff --git a/src/messages.rs b/src/messages.rs index ab1aec6..dcacfbc 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -7,6 +7,7 @@ use crate::{ channel::Channel, client::Client, connection::{InitiatedConnection, UserId}, + server::response::NoSuchNick, }; /// Sent when a user is connecting to the server. @@ -38,7 +39,7 @@ pub struct KillUser { } #[derive(Message, Clone)] -#[rtype(result = "bool")] +#[rtype(result = "Result<(), NoSuchNick>")] pub struct ForceDisconnect { pub span: Span, pub user: String, @@ -151,6 +152,7 @@ pub struct FetchUserPermission { #[rtype(result = "super::channel::response::ChannelTopic")] pub struct ChannelFetchTopic { pub span: Span, + pub skip_on_none: bool, } /// Retrieves the WHO list for the channel. diff --git a/src/server.rs b/src/server.rs index c3c386d..c762b0c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -31,7 +31,7 @@ use crate::{ UserNickChangeInternal, Wallops, }, persistence::Persistence, - server::response::{AdminInfo, ListUsers, Motd, WhoList, Whois}, + server::response::{AdminInfo, IntoProtocol, ListUsers, Motd, NoSuchNick, WhoList, Whois}, SERVER_NAME, }; @@ -124,7 +124,7 @@ impl Handler for Server { }); } - for message in Motd::new(self).into_messages(msg.connection.nick.clone()) { + for message in Motd::new(self).into_messages(&msg.connection.nick) { msg.handle.do_send(Broadcast { span: Span::current(), message, @@ -313,9 +313,9 @@ impl Handler for Server { fn handle(&mut self, msg: ForceDisconnect, _ctx: &mut Self::Context) -> Self::Result { if let Some((handle, _)) = self.clients.iter().find(|(_, v)| v.nick == msg.user) { handle.do_send(msg); - MessageResult(true) + MessageResult(Ok(())) } else { - MessageResult(false) + MessageResult(Err(NoSuchNick { nick: msg.user })) } } } @@ -371,6 +371,7 @@ impl Handler for Server { .map(|channel| { let fetch_topic = channel.send(ChannelFetchTopic { span: Span::current(), + skip_on_none: true, }); let fetch_members = channel.send(ChannelMemberList { diff --git a/src/server/response.rs b/src/server/response.rs index 3f75c04..8c38ce7 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -11,9 +11,8 @@ pub struct Whois { pub channels: Vec<(Permission, String)>, } -impl Whois { - #[must_use] - pub fn into_messages(self, for_user: &str) -> Vec { +impl IntoProtocol for Whois { + fn into_messages(self, for_user: &str) -> Vec { macro_rules! msg { ($response:ident, $($payload:expr),*) => { @@ -88,9 +87,8 @@ pub struct NoSuchNick { pub nick: String, } -impl NoSuchNick { - #[must_use] - pub fn into_messages(self, for_user: &str) -> Vec { +impl IntoProtocol for NoSuchNick { + fn into_messages(self, for_user: &str) -> Vec { vec![Message { tags: None, prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())), @@ -108,9 +106,8 @@ pub struct WhoList { pub query: String, } -impl WhoList { - #[must_use] - pub fn into_messages(self, for_user: &str) -> Vec { +impl IntoProtocol for WhoList { + fn into_messages(self, for_user: &str) -> Vec { let mut out: Vec<_> = self .list .into_iter() @@ -140,9 +137,8 @@ pub struct AdminInfo { pub email: String, } -impl AdminInfo { - #[must_use] - pub fn into_messages(self, for_user: &str) -> Vec { +impl IntoProtocol for AdminInfo { + fn into_messages(self, for_user: &str) -> Vec { macro_rules! msg { ($response:ident, $($payload:expr),*) => { @@ -177,9 +173,9 @@ pub struct ListUsers { pub channels_formed: usize, } -impl ListUsers { +impl IntoProtocol for ListUsers { #[must_use] - pub fn into_messages(self, for_user: &str) -> Vec { + fn into_messages(self, for_user: &str) -> Vec { macro_rules! msg { ($response:ident, $($payload:expr),*) => { @@ -253,11 +249,15 @@ impl Motd { motd: server.config.motd.clone(), } } +} +impl IntoProtocol for Motd { #[must_use] - pub fn into_messages(self, for_user: String) -> Vec { + fn into_messages(self, for_user: &str) -> Vec { + let mut out = Vec::new(); + if let Some(motd) = self.motd { - let mut motd_messages = vec![Message { + out.push(Message { tags: None, prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())), command: Command::Response( @@ -267,9 +267,9 @@ impl Motd { format!("- {SERVER_NAME} Message of the day -"), ], ), - }]; + }); - motd_messages.extend(motd.trim().split('\n').map(|v| Message { + out.extend(motd.trim().split('\n').map(|v| Message { tags: None, prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())), command: Command::Response( @@ -278,26 +278,26 @@ impl Motd { ), })); - motd_messages.push(Message { + out.push(Message { tags: None, prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())), command: Command::Response( Response::RPL_ENDOFMOTD, - vec![for_user, "End of /MOTD command.".to_string()], + vec![for_user.to_string(), "End of /MOTD command.".to_string()], ), }); - - motd_messages } else { - vec![Message { + out.push(Message { tags: None, prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())), command: Command::Response( Response::ERR_NOMOTD, - vec![for_user, "MOTD File is missing".to_string()], + vec![for_user.to_string(), "MOTD File is missing".to_string()], ), - }] + }); } + + out } } @@ -306,9 +306,9 @@ pub struct ChannelList { pub members: Vec, } -impl ChannelList { +impl IntoProtocol for ChannelList { #[must_use] - pub fn into_messages(self, for_user: String) -> Vec { + fn into_messages(self, for_user: &str) -> Vec { let mut messages = Vec::with_capacity(self.members.len() + 2); messages.push(Message { @@ -345,7 +345,7 @@ impl ChannelList { prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())), command: Command::Response( Response::RPL_LISTEND, - vec![for_user, "End of /LIST".to_string()], + vec![for_user.to_string(), "End of /LIST".to_string()], ), }); @@ -358,3 +358,27 @@ pub struct ChannelListItem { pub client_count: usize, pub topic: Option, } + +pub trait IntoProtocol { + #[must_use] + fn into_messages(self, for_user: &str) -> Vec; +} + +impl IntoProtocol for () { + fn into_messages(self, _for_user: &str) -> Vec { + vec![] + } +} + +impl IntoProtocol for Result +where + T: IntoProtocol, + E: IntoProtocol, +{ + fn into_messages(self, for_user: &str) -> Vec { + match self { + Ok(v) => v.into_messages(for_user), + Err(e) => e.into_messages(for_user), + } + } +}