diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 1652f28..c4a2c8b 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,7 +1,6 @@ extern crate proc_macro; use proc_macro2::TokenTree; -use quote::__private::ext::RepToTokensExt; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput, Data}; @@ -23,7 +22,7 @@ pub fn rmc_serialize(input: TokenStream) -> TokenStream { panic!("rmc struct type MUST be a struct"); }; - /// generate base data + // generate base data let serialize_base_content = { let mut serialize_content = quote! {}; @@ -87,7 +86,7 @@ pub fn rmc_serialize(input: TokenStream) -> TokenStream { } }; - /// generate base with extends stuff + // generate base with extends stuff let serialize_base_content = if let Some(attr) = struct_attr{ let tokens = attr.tokens.clone(); diff --git a/src/kerberos/mod.rs b/src/kerberos/mod.rs index dac31fb..cc48eaf 100644 --- a/src/kerberos/mod.rs +++ b/src/kerberos/mod.rs @@ -25,9 +25,9 @@ pub fn derive_key(pid: u32, password: [u8; 16]) -> [u8; 16]{ key } -#[derive(Pod, Zeroable, Copy, Clone)] +#[derive(Pod, Zeroable, Copy, Clone, Debug, Eq, PartialEq)] #[repr(transparent)] -pub struct KerberosDateTime(u64); +pub struct KerberosDateTime(pub u64); impl KerberosDateTime{ pub fn new(second: u64, minute: u64, hour: u64, day: u64, month: u64, year:u64 ) -> Self { diff --git a/src/main.rs b/src/main.rs index e3d3448..04423f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] use std::{env, fs}; +use std::collections::BTreeMap; use std::fs::File; use std::net::{Ipv4Addr, SocketAddrV4}; use std::sync::Arc; @@ -11,7 +12,7 @@ use once_cell::sync::Lazy; use rc4::{KeyInit, Rc4, StreamCipher}; use rc4::consts::U5; use simplelog::{ColorChoice, CombinedLogger, Config, LevelFilter, TerminalMode, TermLogger, WriteLogger}; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, RwLock}; use tokio::task::JoinHandle; use crate::nex::account::Account; use crate::protocols::{auth, block_if_maintenance}; @@ -151,7 +152,7 @@ async fn start_auth_server() -> AuthServer{ }), Box::new(move |packet, socket, connection|{ let rmcserver = rmcserver.clone(); - Box::pin(async move { rmcserver.process_message(packet, &socket, connection).await; }) + Box::pin(async move { rmcserver.process_message(packet, socket, connection).await; }) }) ).await.expect("unable to create socket"); @@ -177,13 +178,16 @@ async fn start_secure_server() -> SecureServer{ info!("setting up endpoints"); - let matchmake_data = Arc::new(Mutex::new( - MatchmakeData{} + let matchmake_data = Arc::new(RwLock::new( + MatchmakeData{ + matchmake_sessions: BTreeMap::new() + } )); let rmcserver = RMCProtocolServer::new(Box::new([ Box::new(block_if_maintenance), Box::new(protocols::secure::bound_protocol()), + Box::new(protocols::matchmake::bound_protocol(matchmake_data.clone())), Box::new(protocols::matchmake_extension::bound_protocol(matchmake_data)) ])); @@ -218,7 +222,7 @@ async fn start_secure_server() -> SecureServer{ }), Box::new(move |packet, socket, connection|{ let rmcserver = rmcserver.clone(); - Box::pin(async move { rmcserver.process_message(packet, &socket, connection).await; }) + Box::pin(async move { rmcserver.process_message(packet, socket, connection).await; }) }) ).await.expect("unable to create socket"); diff --git a/src/protocols/auth/method_login.rs b/src/protocols/auth/method_login.rs index afd578d..8e20a0c 100644 --- a/src/protocols/auth/method_login.rs +++ b/src/protocols/auth/method_login.rs @@ -1,8 +1,10 @@ use std::io::Cursor; +use std::sync::Arc; use log::error; +use tokio::sync::Mutex; use crate::nex::account::Account; use crate::protocols::auth::AuthProtocolConfig; -use crate::prudp::socket::ConnectionData; +use crate::prudp::socket::{ConnectionData, SocketData}; use crate::rmc::message::RMCMessage; use crate::rmc::response::{ErrorCode, RMCResponseResult}; use crate::rmc::structures::RmcSerialize; @@ -13,7 +15,7 @@ pub async fn login(rmcmessage: &RMCMessage, _name: &str) -> RMCResponseResult{ rmcmessage.error_result_with_code(ErrorCode::Core_NotImplemented) } -pub async fn login_raw_params(rmcmessage: &RMCMessage, _: &mut ConnectionData, data: AuthProtocolConfig) -> RMCResponseResult{ +pub async fn login_raw_params(rmcmessage: &RMCMessage, _: &Arc, _: &Arc>, data: AuthProtocolConfig) -> RMCResponseResult{ let mut reader = Cursor::new(&rmcmessage.rest_of_data); let Ok(str) = String::deserialize(&mut reader) else { diff --git a/src/protocols/auth/method_login_ex.rs b/src/protocols/auth/method_login_ex.rs index 8dfbafa..61ed0c7 100644 --- a/src/protocols/auth/method_login_ex.rs +++ b/src/protocols/auth/method_login_ex.rs @@ -1,13 +1,15 @@ use std::io::{Cursor, Write}; +use std::sync::Arc; use bytemuck::bytes_of; use hex::encode; use log::{error}; +use tokio::sync::Mutex; use crate::grpc::account; use crate::kerberos::KerberosDateTime; use crate::nex::account::Account; use crate::protocols::auth::AuthProtocolConfig; use crate::protocols::auth::ticket_generation::generate_ticket; -use crate::prudp::socket::ConnectionData; +use crate::prudp::socket::{ConnectionData, SocketData}; use crate::rmc; use crate::rmc::message::RMCMessage; use crate::rmc::response::{ErrorCode, RMCResponseResult}; @@ -51,7 +53,7 @@ pub async fn login_ex(rmcmessage: &RMCMessage, proto_data: AuthProtocolConfig, p return rmcmessage.success_with_data(response); } -pub async fn login_ex_raw_params(rmcmessage: &RMCMessage, _: &mut ConnectionData, data: AuthProtocolConfig) -> RMCResponseResult{ +pub async fn login_ex_raw_params(rmcmessage: &RMCMessage, _: &Arc, _: &Arc>, data: AuthProtocolConfig) -> RMCResponseResult{ let mut reader = Cursor::new(&rmcmessage.rest_of_data); let Ok(str) = String::deserialize(&mut reader) else { diff --git a/src/protocols/auth/method_request_ticket.rs b/src/protocols/auth/method_request_ticket.rs index c8113cb..6a04041 100644 --- a/src/protocols/auth/method_request_ticket.rs +++ b/src/protocols/auth/method_request_ticket.rs @@ -1,11 +1,13 @@ use std::io::Cursor; +use std::sync::Arc; use log::error; +use tokio::sync::Mutex; use crate::endianness::{IS_BIG_ENDIAN, ReadExtensions}; use crate::grpc::account; use crate::protocols::auth::{AuthProtocolConfig, get_login_data_by_pid}; use crate::protocols::auth::method_login_ex::login_ex; use crate::protocols::auth::ticket_generation::generate_ticket; -use crate::prudp::socket::ConnectionData; +use crate::prudp::socket::{ConnectionData, SocketData}; use crate::rmc::message::RMCMessage; use crate::rmc::response::{ErrorCode, RMCResponseResult}; use crate::rmc::response::ErrorCode::Core_Unknown; @@ -39,7 +41,7 @@ pub async fn request_ticket(rmcmessage: &RMCMessage, data: AuthProtocolConfig, s rmcmessage.success_with_data(response) } -pub async fn request_ticket_raw_params(rmcmessage: &RMCMessage, _: &mut ConnectionData, data: AuthProtocolConfig) -> RMCResponseResult{ +pub async fn request_ticket_raw_params(rmcmessage: &RMCMessage, _: &Arc, _: &Arc>, data: AuthProtocolConfig) -> RMCResponseResult{ let mut reader = Cursor::new(&rmcmessage.rest_of_data); let Ok(source_pid) = reader.read_struct(IS_BIG_ENDIAN) else { diff --git a/src/protocols/matchmake/method_unregister_gathering.rs b/src/protocols/matchmake/method_unregister_gathering.rs new file mode 100644 index 0000000..28f9c5b --- /dev/null +++ b/src/protocols/matchmake/method_unregister_gathering.rs @@ -0,0 +1,37 @@ +use std::io::Cursor; +use std::sync::Arc; +use log::info; +use tokio::sync::{Mutex, RwLock}; +use crate::protocols::matchmake_common::MatchmakeData; +use crate::prudp::socket::{ConnectionData, SocketData}; +use crate::rmc::message::RMCMessage; +use crate::rmc::response::{ErrorCode, RMCResponseResult}; +use crate::rmc::structures::qresult::QResult; +use crate::rmc::structures::RmcSerialize; + +pub async fn unregister_gathering(rmcmessage: &RMCMessage, gid: u32, data: Arc>) -> RMCResponseResult{ + let mut rd = data.write().await; + + rd.matchmake_sessions.remove(&gid); + + let result = QResult::success(ErrorCode::Core_Unknown); + + let mut response = Vec::new(); + + result.serialize(&mut response).expect("aaa"); + + rmcmessage.success_with_data(response) +} + +pub async fn unregister_gathering_raw_params(rmcmessage: &RMCMessage, _: &Arc, _: &Arc>, data: Arc>) -> RMCResponseResult{ + let mut reader = Cursor::new(&rmcmessage.rest_of_data); + + let Ok(gid) = u32::deserialize(&mut reader) else { + return rmcmessage.error_result_with_code(ErrorCode::Core_InvalidArgument); + }; + + + + + unregister_gathering(rmcmessage, gid, data).await +} \ No newline at end of file diff --git a/src/protocols/matchmake/mod.rs b/src/protocols/matchmake/mod.rs new file mode 100644 index 0000000..cc85e23 --- /dev/null +++ b/src/protocols/matchmake/mod.rs @@ -0,0 +1,13 @@ +mod method_unregister_gathering; + +use std::sync::Arc; +use tokio::sync::RwLock; +use crate::define_protocol; +use crate::protocols::matchmake::method_unregister_gathering::unregister_gathering_raw_params; +use crate::protocols::matchmake_common::MatchmakeData; + +define_protocol!{ + 21(matchmake_data: Arc>) => { + 2 => unregister_gathering_raw_params + } +} \ No newline at end of file diff --git a/src/protocols/matchmake_common/mod.rs b/src/protocols/matchmake_common/mod.rs index ed88475..ce725e2 100644 --- a/src/protocols/matchmake_common/mod.rs +++ b/src/protocols/matchmake_common/mod.rs @@ -1,4 +1,63 @@ -pub struct MatchmakeData{ +use std::collections::{BTreeMap, HashMap}; +use std::sync::Arc; +use log::error; +use tokio::sync::Mutex; +use crate::protocols::notification::Notification; +use crate::prudp::socket::{ConnectionData, SocketData}; +use crate::rmc::structures::matchmake::MatchmakeSession; +pub struct ExtendedMatchmakeSession{ + pub session: MatchmakeSession, + pub connected_players: Vec>>, } +pub struct MatchmakeData{ + pub(crate) matchmake_sessions: BTreeMap>> +} + +impl ExtendedMatchmakeSession{ + pub async fn add_player(&mut self, socket: &SocketData, conn: Arc>, join_msg: String) { + let Some(pid) = conn.lock().await.active_connection_data.as_ref() + .map(|c| + c.active_secure_connection_data.as_ref() + .map(|c| c.pid + ) + ).flatten() else { + error!("tried to add player without secure connection"); + return + }; + + self.connected_players.push(conn); + + + for conn in &self.connected_players{ + let Some(other_pid) = conn.lock().await.active_connection_data.as_ref() + .map(|c| + c.active_secure_connection_data.as_ref() + .map(|c| c.pid + ) + ).flatten() else { + error!("tried to send connection notification to player secure connection"); + return + }; + + let mut conn = conn.lock().await; + + conn.send_notification(socket, Notification{ + pid_source: pid, + notif_type: 3001, + param_1: self.session.gathering.self_gid, + param_2: other_pid, + str_param: join_msg.clone(), + }).await; + + + } + } +} + +impl MatchmakeData { + pub async fn try_find_session_with_criteria(&self, ) -> Option>>{ + None + } +} \ No newline at end of file diff --git a/src/protocols/matchmake_extension/method_auto_matchmake_with_param_postpone.rs b/src/protocols/matchmake_extension/method_auto_matchmake_with_param_postpone.rs index fd90c5c..47da00b 100644 --- a/src/protocols/matchmake_extension/method_auto_matchmake_with_param_postpone.rs +++ b/src/protocols/matchmake_extension/method_auto_matchmake_with_param_postpone.rs @@ -1,9 +1,89 @@ +use std::io::Cursor; use std::sync::Arc; -use tokio::sync::Mutex; -use crate::protocols::matchmake_common::MatchmakeData; -use crate::prudp::socket::ConnectionData; +use rand::random; +use tokio::sync::{Mutex, RwLock}; +use crate::protocols::matchmake_common::{ExtendedMatchmakeSession, MatchmakeData}; +use crate::prudp::socket::{ConnectionData, SocketData}; use crate::rmc::message::RMCMessage; +use crate::rmc::response::{ErrorCode, RMCResponseResult}; +use crate::rmc::structures::matchmake::{AutoMatchmakeParam, MatchmakeSession}; +use crate::rmc::structures::RmcSerialize; -pub async fn auto_matchmake_with_param_postpone_raw_params(rmcmessage: &RMCMessage, _: &mut ConnectionData, data: Arc>){ + +pub async fn auto_matchmake_with_param_postpone( + rmcmessage: &RMCMessage, + conn: &Arc>, + socket: &Arc, + mm_data: Arc>, + auto_matchmake_param: AutoMatchmakeParam +) -> RMCResponseResult{ + println!("auto_matchmake_with_param_postpone: {:?}", auto_matchmake_param); + let locked_conn = conn.lock().await; + let Some(secure_conn) = + locked_conn.active_connection_data.as_ref().map(|a| a.active_secure_connection_data.as_ref()).flatten() else { + return rmcmessage.error_result_with_code(ErrorCode::Core_Exception); + }; + + let pid = secure_conn.pid; + + drop(locked_conn); + + let mm_data_read = mm_data.read().await; + //todo: there is a bit of a race condition here, i dont have any idea on how to fix it though... + let session = if let Some(session) = mm_data_read.try_find_session_with_criteria().await{ + session + } else { + // drop it first so that we dont cause a deadlock, also drop it right here so we dont hold + // up anything else unnescesarily + drop(mm_data_read); + + let gid = random(); + + let mut matchmake_session = auto_matchmake_param.matchmake_session.clone(); + matchmake_session.gathering.self_gid = gid; + matchmake_session.gathering.host_pid = pid; + matchmake_session.gathering.owner_pid = pid; + + + + let mut mm_data = mm_data.write().await; + + let session = Arc::new(Mutex::new(ExtendedMatchmakeSession{ + session: matchmake_session.clone(), + connected_players: Vec::new() + })); + + mm_data.matchmake_sessions.insert(gid, session.clone()); + + session + }; + + let mut session = session.lock().await; + + //todo: refactor so that this works + session.add_player(socket, conn.clone(), auto_matchmake_param.join_message).await; + + let mut response = Vec::new(); + + session.session.serialize(&mut response).expect("unable to serialize matchmake session"); + + rmcmessage.success_with_data(response) +} + +pub async fn auto_matchmake_with_param_postpone_raw_params( + rmcmessage: &RMCMessage, + socket: &Arc, + connection_data: &Arc>, + data: Arc> +) -> RMCResponseResult{ + let mut reader = Cursor::new(&rmcmessage.rest_of_data); + + let Ok(matchmake_param) = AutoMatchmakeParam::deserialize(&mut reader) else { + return rmcmessage.error_result_with_code(ErrorCode::Core_InvalidArgument); + }; + + + + auto_matchmake_with_param_postpone(rmcmessage, connection_data, socket, data, matchmake_param).await } \ No newline at end of file diff --git a/src/protocols/matchmake_extension/method_get_playing_session.rs b/src/protocols/matchmake_extension/method_get_playing_session.rs index c82f730..ff784d9 100644 --- a/src/protocols/matchmake_extension/method_get_playing_session.rs +++ b/src/protocols/matchmake_extension/method_get_playing_session.rs @@ -1,15 +1,16 @@ use std::io::Cursor; use std::sync::Arc; -use tokio::sync::Mutex; +use log::info; +use tokio::sync::{Mutex, RwLock}; use crate::protocols::matchmake_common::MatchmakeData; -use crate::prudp::socket::ConnectionData; +use crate::prudp::socket::{ConnectionData, SocketData}; use crate::rmc::message::RMCMessage; use crate::rmc::response::{ErrorCode, RMCResponseResult}; use crate::rmc::structures::RmcSerialize; type PIDList = Vec; -async fn get_playing_session(rmcmessage: &RMCMessage, data: Arc>) -> RMCResponseResult { +async fn get_playing_session(rmcmessage: &RMCMessage, data: Arc>) -> RMCResponseResult { //todo: propperly implement this let cheeseburger = PIDList::new(); @@ -21,12 +22,14 @@ async fn get_playing_session(rmcmessage: &RMCMessage, data: Arc>) -> RMCResponseResult{ +pub async fn get_playing_session_raw_params(rmcmessage: &RMCMessage, _: &Arc, _: &Arc>, data: Arc>) -> RMCResponseResult{ let mut reader = Cursor::new(&rmcmessage.rest_of_data); let Ok(list) = PIDList::deserialize(&mut reader) else { return rmcmessage.error_result_with_code(ErrorCode::FPD_FriendNotExists); }; + info!("get_playing_session got called with {:?}", list); + get_playing_session(rmcmessage, data).await } \ No newline at end of file diff --git a/src/protocols/matchmake_extension/mod.rs b/src/protocols/matchmake_extension/mod.rs index 47dbc31..080db7b 100644 --- a/src/protocols/matchmake_extension/mod.rs +++ b/src/protocols/matchmake_extension/mod.rs @@ -2,13 +2,15 @@ mod method_get_playing_session; mod method_auto_matchmake_with_param_postpone; use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, RwLock}; use crate::define_protocol; use crate::protocols::matchmake_common::MatchmakeData; use method_get_playing_session::get_playing_session_raw_params; +use method_auto_matchmake_with_param_postpone::auto_matchmake_with_param_postpone_raw_params; define_protocol!{ - 109(matchmake_data: Arc>) => { - 16 => get_playing_session_raw_params + 109(matchmake_data: Arc>) => { + 16 => get_playing_session_raw_params, + 40 => auto_matchmake_with_param_postpone_raw_params } } \ No newline at end of file diff --git a/src/protocols/mod.rs b/src/protocols/mod.rs index 12397aa..9f0f188 100644 --- a/src/protocols/mod.rs +++ b/src/protocols/mod.rs @@ -1,10 +1,12 @@ use std::env; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; use log::warn; use once_cell::sync::Lazy; +use tokio::sync::Mutex; use crate::grpc; -use crate::prudp::socket::ConnectionData; +use crate::prudp::socket::{ConnectionData, SocketData}; use crate::rmc::message::RMCMessage; use crate::rmc::response::{ErrorCode, RMCResponse}; @@ -14,7 +16,8 @@ pub mod server; pub mod secure; pub mod matchmake_extension; pub mod matchmake_common; - +pub mod matchmake; +mod notification; static IS_MAINTENANCE: Lazy = Lazy::new(|| { env::var("IS_MAINTENANCE") @@ -30,8 +33,10 @@ static BYPASS_LEVEL: Lazy = Lazy::new(|| { }); -pub fn block_if_maintenance<'a>(rmcmessage: &'a RMCMessage, conn: &'a mut ConnectionData) -> Pin> + Send + 'a)>> { +pub fn block_if_maintenance<'a>(rmcmessage: &'a RMCMessage, _: &'a Arc , conn: &'a Arc>) -> Pin> + Send + 'a)>> { Box::pin(async move { + let mut conn = conn.lock().await; + if let Some(active_conn) = conn.active_connection_data.as_ref() { if let Some(secure_conn) = active_conn.active_secure_connection_data.as_ref() { if let Ok(mut client) = grpc::account::Client::new().await { @@ -62,7 +67,7 @@ pub fn block_if_maintenance<'a>(rmcmessage: &'a RMCMessage, conn: &'a mut Connec macro_rules! define_protocol { ($id:literal ($($varname:ident : $ty:ty),*) => {$($func_id:literal => $func:path),*} ) => { #[allow(unused_parens)] - async fn protocol (rmcmessage: &crate::RMCMessage, connection: &mut crate::protocols::ConnectionData, $($varname : $ty),*) -> Option{ + async fn protocol (rmcmessage: &crate::RMCMessage, socket: &::std::sync::Arc, connection: &::std::sync::Arc<::tokio::sync::Mutex>, $($varname : $ty),*) -> Option{ if rmcmessage.protocol_id != $id{ return None; } @@ -71,7 +76,7 @@ macro_rules! define_protocol { let response_result = match rmcmessage.method_id{ $( - $func_id => $func ( rmcmessage, connection, self_data).await, + $func_id => $func ( rmcmessage, socket, connection, self_data).await, )* _ => { log::error!("invalid method id sent to protocol {}: {:?}", $id, rmcmessage.method_id); @@ -90,10 +95,10 @@ macro_rules! define_protocol { }) } #[allow(unused_parens)] - pub fn bound_protocol($($varname : $ty,)*) -> Box Fn(&'message_lifetime crate::RMCMessage, &'message_lifetime mut crate::protocols::ConnectionData) + pub fn bound_protocol($($varname : $ty,)*) -> Box Fn(&'message_lifetime crate::RMCMessage, &'message_lifetime ::std::sync::Arc, &'message_lifetime ::std::sync::Arc<::tokio::sync::Mutex>) -> ::std::pin::Pin> + Send + 'message_lifetime>> + Send + Sync>{ Box::new( - move |v, cd| { + move |v, s, cd| { Box::pin({ $( let $varname = $varname.clone(); @@ -103,7 +108,7 @@ macro_rules! define_protocol { $( let $varname = $varname.clone(); )* - protocol(v, cd, $($varname,)*).await + protocol(v, s, cd, $($varname,)*).await } }) } diff --git a/src/protocols/notification/mod.rs b/src/protocols/notification/mod.rs new file mode 100644 index 0000000..7c96090 --- /dev/null +++ b/src/protocols/notification/mod.rs @@ -0,0 +1,48 @@ +use macros::RmcSerialize; +use rand::random; +use crate::prudp::packet::{PRUDPHeader, PRUDPPacket, TypesFlags}; +use crate::prudp::packet::flags::{NEED_ACK, RELIABLE}; +use crate::prudp::packet::types::DATA; +use crate::prudp::socket::{ConnectionData, SocketData}; +use crate::rmc::message::RMCMessage; +use crate::rmc::structures::RmcSerialize; + +#[derive(RmcSerialize)] +#[rmc_struct(0)] +pub struct Notification{ + pub pid_source: u32, + pub notif_type: u32, + pub param_1: u32, + pub param_2: u32, + pub str_param: String, +} + +impl ConnectionData{ + pub async fn send_notification(&mut self, socket: &SocketData, notif: Notification){ + + let mut data = Vec::new(); + + notif.serialize(&mut data).expect("unable to write"); + + let message = RMCMessage{ + protocol_id: 0xE, + method_id: 1, + call_id: random(), + rest_of_data: data + }; + + let prudp_packet = PRUDPPacket{ + header: PRUDPHeader{ + types_and_flags: TypesFlags::default().types(DATA).flags(NEED_ACK | RELIABLE), + source_port: socket.get_virual_port(), + destination_port: self.sock_addr.virtual_port, + ..Default::default() + }, + options: Vec::new(), + payload: message.to_data(), + packet_signature: [0;16] + }; + + self.finish_and_send_packet_to(socket, prudp_packet).await; + } +} \ No newline at end of file diff --git a/src/protocols/secure/method_register.rs b/src/protocols/secure/method_register.rs index e50439d..eb99421 100644 --- a/src/protocols/secure/method_register.rs +++ b/src/protocols/secure/method_register.rs @@ -1,8 +1,10 @@ use std::io::{Cursor, Write}; +use std::sync::Arc; use bytemuck::bytes_of; use log::{error, warn}; +use tokio::sync::Mutex; use crate::protocols::auth::AuthProtocolConfig; -use crate::prudp::socket::ConnectionData; +use crate::prudp::socket::{ConnectionData, SocketData}; use crate::prudp::station_url::{nat_types, StationUrl}; use crate::prudp::station_url::Type::PRUDPS; use crate::prudp::station_url::UrlOptions::{Address, NatFiltering, NatMapping, NatType, Port, PrincipalID, RVConnectionID}; @@ -14,9 +16,9 @@ use crate::rmc::structures::RmcSerialize; type StringList = Vec; -pub async fn register(rmcmessage: &RMCMessage, station_urls: Vec, conn_data: &mut ConnectionData) -> RMCResponseResult{ - - let Some(active_connection_data) = conn_data.active_connection_data.as_ref() else { +pub async fn register(rmcmessage: &RMCMessage, station_urls: Vec, conn_data: &Arc>) -> RMCResponseResult{ + let locked = conn_data.lock().await; + let Some(active_connection_data) = locked.active_connection_data.as_ref() else { return rmcmessage.error_result_with_code(ErrorCode::RendezVous_NotAuthenticated) }; @@ -28,8 +30,8 @@ pub async fn register(rmcmessage: &RMCMessage, station_urls: Vec, co url_type: PRUDPS, options: vec![ RVConnectionID(active_connection_data.connection_id), - Address(*conn_data.sock_addr.regular_socket_addr.ip()), - Port(conn_data.sock_addr.regular_socket_addr.port()), + Address(*locked.sock_addr.regular_socket_addr.ip()), + Port(locked.sock_addr.regular_socket_addr.port()), NatFiltering(0), NatMapping(0), NatType(nat_types::BEHIND_NAT), @@ -50,7 +52,7 @@ pub async fn register(rmcmessage: &RMCMessage, station_urls: Vec, co rmcmessage.success_with_data(response) } -pub async fn register_raw_params(rmcmessage: &RMCMessage, conn_data: &mut ConnectionData, _: ()) -> RMCResponseResult{ +pub async fn register_raw_params(rmcmessage: &RMCMessage, _: &Arc, conn_data: &Arc>, _: ()) -> RMCResponseResult{ let mut reader = Cursor::new(&rmcmessage.rest_of_data); let Ok(station_urls) = StringList::deserialize(&mut reader) else { diff --git a/src/protocols/secure/method_send_report.rs b/src/protocols/secure/method_send_report.rs index 0104855..6b8b44f 100644 --- a/src/protocols/secure/method_send_report.rs +++ b/src/protocols/secure/method_send_report.rs @@ -1,8 +1,10 @@ use std::io::Cursor; +use std::sync::Arc; use log::error; +use tokio::sync::Mutex; use crate::endianness::{IS_BIG_ENDIAN, ReadExtensions}; use crate::protocols::secure::method_register::register; -use crate::prudp::socket::ConnectionData; +use crate::prudp::socket::{ConnectionData, SocketData}; use crate::prudp::station_url::StationUrl; use crate::rmc::message::RMCMessage; use crate::rmc::response::{ErrorCode, RMCResponseResult}; @@ -20,7 +22,7 @@ pub async fn send_report(rmcmessage: &RMCMessage, report_id: u32, data: Vec) return rmcmessage.success_with_data(Vec::new()); } -pub async fn send_report_raw_params(rmcmessage: &RMCMessage, conn_data: &mut ConnectionData, _: ()) -> RMCResponseResult{ +pub async fn send_report_raw_params(rmcmessage: &RMCMessage, _: &Arc, conn_data: &Arc>, _: ()) -> RMCResponseResult{ let mut reader = Cursor::new(&rmcmessage.rest_of_data); let Ok(error_id) = reader.read_struct(IS_BIG_ENDIAN) else { diff --git a/src/protocols/server.rs b/src/protocols/server.rs index 4aabf81..9927937 100644 --- a/src/protocols/server.rs +++ b/src/protocols/server.rs @@ -3,13 +3,14 @@ use std::io::Cursor; use std::pin::Pin; use std::sync::Arc; use log::error; +use tokio::sync::Mutex; use crate::prudp::packet::PRUDPPacket; use crate::prudp::socket::{ConnectionData, SocketData}; use crate::rmc::message::RMCMessage; use crate::rmc::response::{RMCResponse, RMCResponseResult, send_response}; use crate::rmc::response::ErrorCode::Core_NotImplemented; -type ContainedProtocolList = Box<[Box Fn(&'a RMCMessage, &'a mut ConnectionData) -> Pin> + Send + 'a>> + Send + Sync>]>; +type ContainedProtocolList = Box<[Box Fn(&'a RMCMessage, &'a Arc, &'a Arc>) -> Pin> + Send + 'a>> + Send + Sync>]>; pub struct RMCProtocolServer(ContainedProtocolList); @@ -18,27 +19,33 @@ impl RMCProtocolServer{ Arc::new(Self(protocols)) } - pub async fn process_message(&self, packet: PRUDPPacket, socket: &SocketData, connection: &mut ConnectionData){ + pub async fn process_message(&self, packet: PRUDPPacket, socket: Arc, connection: Arc>){ let Ok(rmc) = RMCMessage::new(&mut Cursor::new(&packet.payload)) else { error!("error reading rmc message"); return; }; + println!("got rmc message {},{}", rmc.protocol_id, rmc.method_id); + for proto in &self.0 { - if let Some(response) = proto(&rmc, connection).await { - send_response(&packet, &socket, connection, response).await; + if let Some(response) = proto(&rmc, &socket, &connection).await { + + let mut locked = connection.lock().await; + send_response(&packet, &socket, &mut locked, response).await; + drop(locked); return; } } error!("tried to send message to unimplemented protocol {} with method id {}", rmc.protocol_id, rmc.method_id); - - send_response(&packet, &socket, connection, RMCResponse{ + let mut locked = connection.lock().await; + send_response(&packet, &socket, &mut locked, RMCResponse{ protocol_id: rmc.protocol_id as u8, response_result: RMCResponseResult::Error { call_id: rmc.call_id, error_code: Core_NotImplemented } }).await; + } } \ No newline at end of file diff --git a/src/prudp/socket.rs b/src/prudp/socket.rs index 27c88c9..f32528b 100644 --- a/src/prudp/socket.rs +++ b/src/prudp/socket.rs @@ -27,7 +27,7 @@ pub struct Socket { type OnConnectHandlerFn = Box Pin, Vec, Option)>> + Send>> + Send + Sync>; -type OnDataHandlerFn = Box Fn(PRUDPPacket, Arc, &'a mut MutexGuard<'_, ConnectionData>) -> Pin + 'a + Send>> + Send + Sync>; +type OnDataHandlerFn = Box, Arc>) -> Pin + Send>> + Send + Sync>; pub struct ActiveSecureConnectionData { pub(crate) pid: u32, @@ -38,7 +38,7 @@ pub struct SocketData { virtual_port: VirtualPort, pub socket: Arc, pub access_key: &'static str, - connections: RwLock>>>, + connections: RwLock>, Arc>)>>, on_connect_handler: OnConnectHandlerFn, on_data_handler: OnDataHandlerFn, @@ -69,6 +69,8 @@ pub struct ConnectionData { } + + impl Socket { pub async fn new( router: Arc, @@ -146,14 +148,14 @@ impl SocketData { let mut conn = self.connections.write().await; //only insert if we STILL dont have the connection preventing double insertion if !conn.contains_key(&client_address) { - conn.insert(client_address, Arc::new(Mutex::new(ConnectionData { + conn.insert(client_address, (Arc::new(Mutex::new(ConnectionData { sock_addr: client_address, id: random(), signature: [0; 16], server_signature: [0; 16], active_connection_data: None, - }))); + })), Arc::new(Mutex::new(())))); } drop(conn); } else { @@ -172,20 +174,24 @@ impl SocketData { // dont keep holding the connections list unnescesarily drop(connections); - let mut connection = conn.lock().await; + let mut connection = conn.0.lock().await; + //let _mutual_exclusion_packet_handeling_mtx = conn.1.lock().await; if (packet.header.types_and_flags.get_flags() & ACK) != 0 { //todo: handle acknowledgements and resending packets propperly + println!("got ack"); return; } if (packet.header.types_and_flags.get_flags() & MULTI_ACK) != 0 { + println!("got ack"); return; } match packet.header.types_and_flags.get_types() { SYN => { + println!("got syn"); // reset heartbeat? let mut response_packet = packet.base_response_packet(); @@ -220,6 +226,7 @@ impl SocketData { self.socket.send_to(&vec, client_address.regular_socket_addr).await.expect("failed to send data back"); } CONNECT => { + println!("got connect"); let Some(MaximumSubstreamId(max_substream)) = packet.options.iter().find(|v| matches!(v, MaximumSubstreamId(_))) else { return; }; @@ -328,13 +335,18 @@ impl SocketData { self.socket.send_to(&vec, client_address.regular_socket_addr).await.expect("failed to send data back"); } - + drop(connection); while let Some(mut packet) = { - connection.active_connection_data.as_mut().map(|a| + let mut locked = conn.0.lock().await; + + let packet = locked.active_connection_data.as_mut().map(|a| a.reliable_client_queue .front() .is_some_and(|v| v.header.sequence_id == a.reliable_client_counter) - .then(|| a.reliable_client_queue.pop_front())).flatten().flatten() + .then(|| a.reliable_client_queue.pop_front())).flatten().flatten(); + + drop(locked); + packet } { if packet.options.iter().any(|v| match v{ PacketOption::FragmentId(f) => *f != 0, @@ -343,7 +355,9 @@ impl SocketData { error!("fragmented packets are unsupported right now") } - let active_connection = connection.active_connection_data.as_mut() + let mut locked = conn.0.lock().await; + + let active_connection = locked.active_connection_data.as_mut() .expect("we litterally just recieved a packet which requires the connection to be active, failing this should be impossible"); active_connection.reliable_client_counter = active_connection.reliable_client_counter.overflowing_add(1).0; @@ -354,9 +368,10 @@ impl SocketData { stream.apply_keystream(&mut packet.payload); + drop(locked); // we cant divert this off to another thread we HAVE to process it now to keep order - (self.on_data_handler)(packet, self.clone(), &mut connection).await; + (self.on_data_handler)(packet, self.clone(), conn.0.clone()).await; // ignored for now } } else { @@ -397,7 +412,7 @@ impl SocketData { } } DISCONNECT => { - + println!("got disconnect"); let Some(active_connection) = &connection.active_connection_data else { return; }; @@ -466,6 +481,8 @@ impl ConnectionData{ error!("unable to send packet to destination: {}", e); } } + + } #[cfg(test)] diff --git a/src/rmc/message.rs b/src/rmc/message.rs index 7607672..68598d2 100644 --- a/src/rmc/message.rs +++ b/src/rmc/message.rs @@ -1,10 +1,11 @@ use std::io; -use std::io::{Read, Seek}; +use std::io::{Read, Seek, Write}; +use bytemuck::bytes_of; use log::error; use crate::endianness::{IS_BIG_ENDIAN, ReadExtensions}; use crate::rmc::response::{ErrorCode, RMCResponseResult}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct RMCMessage{ pub protocol_id: u16, pub call_id: u32, @@ -52,6 +53,23 @@ impl RMCMessage{ }) } + pub fn to_data(&self) -> Vec{ + let size = (1 + 4 + 4 + self.rest_of_data.len()) as u32; + + let mut output = Vec::new(); + + output.write_all(bytes_of(&size)).expect("unable to write size"); + + let proto_id = self.protocol_id as u8; + + output.write_all(bytes_of(&proto_id)).expect("unable to write size"); + + output.write_all(bytes_of(&self.call_id)).expect("unable to write size"); + output.write_all(bytes_of(&self.method_id)).expect("unable to write size"); + + output + } + pub fn error_result_with_code(&self, error_code: ErrorCode) -> RMCResponseResult{ RMCResponseResult::Error { call_id: self.call_id, diff --git a/src/rmc/response.rs b/src/rmc/response.rs index cc82b17..96b82f3 100644 --- a/src/rmc/response.rs +++ b/src/rmc/response.rs @@ -1,6 +1,7 @@ use std::io; use std::io::{Write}; use std::mem::transmute; +use std::time::Duration; use bytemuck::bytes_of; use crate::prudp::packet::{PRUDPPacket}; use crate::prudp::packet::flags::{NEED_ACK, RELIABLE}; @@ -100,6 +101,8 @@ pub async fn send_response(original_packet: &PRUDPPacket, socket: &SocketData, c packet.payload = rmcresponse.to_data(); + //tokio::time::sleep(Duration::from_millis(500)).await; + connection.finish_and_send_packet_to(socket, packet).await; } diff --git a/src/rmc/structures/list.rs b/src/rmc/structures/list.rs index 7b17f63..7f6c97c 100644 --- a/src/rmc/structures/list.rs +++ b/src/rmc/structures/list.rs @@ -6,7 +6,7 @@ use crate::rmc::structures::RmcSerialize; impl RmcSerialize for Vec{ fn serialize(&self, writer: &mut dyn Write) -> crate::rmc::structures::Result<()> { - let u32_len = self.len(); + let u32_len = self.len() as u32; writer.write_all(bytes_of(&u32_len))?; for e in self{ diff --git a/src/rmc/structures/matchmake.rs b/src/rmc/structures/matchmake.rs index a63107b..5406512 100644 --- a/src/rmc/structures/matchmake.rs +++ b/src/rmc/structures/matchmake.rs @@ -4,51 +4,82 @@ use crate::rmc::structures::RmcSerialize; use crate::rmc::structures::variant::Variant; // rmc structure -#[derive(RmcSerialize)] +#[derive(RmcSerialize, Debug, Clone)] #[rmc_struct(0)] -struct Gathering{ - self_gid: u32, - owner_pid: u32, - host_pid: u32, - minimum_participants: u16, - maximum_participants: u16, - participant_policy: u32, - policy_argument: u32, - flags: u32, - state: u32, - description: String +pub struct Gathering { + pub self_gid: u32, + pub owner_pid: u32, + pub host_pid: u32, + pub minimum_participants: u16, + pub maximum_participants: u16, + pub participant_policy: u32, + pub policy_argument: u32, + pub flags: u32, + pub state: u32, + pub description: String, } // rmc structure -#[derive(RmcSerialize)] +#[derive(RmcSerialize, Debug, Clone)] #[rmc_struct(0)] -struct MatchmakeParam{ - params: Vec<(String, Variant)> +pub struct MatchmakeParam { + pub params: Vec<(String, Variant)>, } // rmc structure -#[derive(RmcSerialize)] +#[derive(RmcSerialize, Debug, Clone)] #[rmc_struct(3)] -struct MatchmakeSession{ +pub struct MatchmakeSession { //inherits from #[extends] - gathering: Gathering, + pub gathering: Gathering, - gamemode: u32, - attributes: Vec, - open_participation: bool, - matchmake_system_type: u32, - application_buffer: Vec, - participation_count: u32, - progress_score: u8, - session_key: Vec, - option0: u32, - matchmake_param: MatchmakeParam, - datetime: KerberosDateTime, - user_password: String, - refer_gid: u32, - user_password_enabled: bool, - system_password_enabled: bool + pub gamemode: u32, + pub attributes: Vec, + pub open_participation: bool, + pub matchmake_system_type: u32, + pub application_buffer: Vec, + pub participation_count: u32, + pub progress_score: u8, + pub session_key: Vec, + pub option0: u32, + pub matchmake_param: MatchmakeParam, + pub datetime: KerberosDateTime, + pub user_password: String, + pub refer_gid: u32, + pub user_password_enabled: bool, + pub system_password_enabled: bool, } +#[derive(RmcSerialize, Debug, Clone)] +#[rmc_struct(3)] +pub struct MatchmakeSessionSearchCriteria { + pub attribs: Vec, + pub game_mode: String, + pub minimum_participants: String, + pub maximum_participants: String, + pub matchmake_system_type: String, + pub vacant_only: bool, + pub exclude_locked: bool, + pub exclude_non_host_pid: bool, + pub selection_method: u32, + pub vacant_participants: u16, + pub matchmake_param: MatchmakeParam, + pub exclude_user_password_set: bool, + pub exclude_system_password_set: bool, + pub refer_gid: u32, +} + +#[derive(RmcSerialize, Debug, Clone)] +#[rmc_struct(0)] +pub struct AutoMatchmakeParam { + pub matchmake_session: MatchmakeSession, + pub additional_participants: Vec, + pub gid_for_participation_check: u32, + pub auto_matchmake_option: u32, + pub join_message: String, + pub participation_count: u16, + pub search_criteria: Vec, + pub target_gids: Vec, +} \ No newline at end of file diff --git a/src/rmc/structures/string.rs b/src/rmc/structures/string.rs index afaeb03..f2c63f1 100644 --- a/src/rmc/structures/string.rs +++ b/src/rmc/structures/string.rs @@ -27,7 +27,7 @@ impl RmcSerialize for &str{ panic!("cannot serialize to &str") } fn serialize(&self, writer: &mut dyn Write) -> Result<()> { - let u16_len: u16 = self.len() as u16; + let u16_len: u16 = (self.len() + 1) as u16; writer.write(bytes_of(&u16_len))?; writer.write(self.as_bytes())?; diff --git a/src/rmc/structures/variant.rs b/src/rmc/structures/variant.rs index 04e9981..7472b03 100644 --- a/src/rmc/structures/variant.rs +++ b/src/rmc/structures/variant.rs @@ -3,6 +3,7 @@ use crate::kerberos::KerberosDateTime; use crate::rmc::structures; use crate::rmc::structures::RmcSerialize; +#[derive(Debug, Clone)] pub enum Variant{ None, SInt64(i64),