From d53349c26409e3bfb8436bbde5c92d30134b3b89 Mon Sep 17 00:00:00 2001 From: DJMrTV Date: Tue, 21 Jan 2025 20:10:58 +0100 Subject: [PATCH] refactor for new naming, new api and async --- Cargo.lock | 155 +++++++++++++++++++++++++++++++ Cargo.toml | 3 + src/main.rs | 36 +++---- src/prudp/connection.rs | 7 ++ src/prudp/endpoint.rs | 85 ----------------- src/prudp/mod.rs | 7 +- src/prudp/packet.rs | 175 +++++++++++++++++++++++++++++----- src/prudp/router.rs | 173 ++++++++++++++++++++++++++++++++++ src/prudp/server.rs | 121 ------------------------ src/prudp/socket.rs | 201 ++++++++++++++++++++++++++++++++++++++++ 10 files changed, 717 insertions(+), 246 deletions(-) create mode 100644 src/prudp/connection.rs delete mode 100644 src/prudp/endpoint.rs create mode 100644 src/prudp/router.rs delete mode 100644 src/prudp/server.rs create mode 100644 src/prudp/socket.rs diff --git a/Cargo.lock b/Cargo.lock index 88e20d1..8e8a43c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,21 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "addr2line" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "aho-corasick" version = "1.1.3" @@ -63,6 +78,21 @@ dependencies = [ "paste", ] +[[package]] +name = "backtrace" +version = "0.3.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets", +] + [[package]] name = "bindgen" version = "0.69.5" @@ -92,6 +122,15 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -219,6 +258,17 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + [[package]] name = "dotenv" version = "0.15.0" @@ -289,12 +339,27 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" + [[package]] name = "glob" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.11" @@ -410,6 +475,16 @@ version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + [[package]] name = "memchr" version = "2.7.4" @@ -422,6 +497,26 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" +dependencies = [ + "adler2", +] + +[[package]] +name = "mio" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +dependencies = [ + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys 0.52.0", +] + [[package]] name = "nom" version = "7.1.3" @@ -456,6 +551,15 @@ dependencies = [ "libc", ] +[[package]] +name = "object" +version = "0.36.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -468,6 +572,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + [[package]] name = "powerfmt" version = "0.2.0" @@ -595,6 +705,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -690,6 +806,16 @@ dependencies = [ "time", ] +[[package]] +name = "socket2" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "spin" version = "0.9.8" @@ -704,13 +830,16 @@ dependencies = [ "bytemuck", "chrono", "dotenv", + "hmac", "log", + "md-5", "once_cell", "rand", "rc4", "rustls", "simplelog", "thiserror", + "tokio", "v_byte_macros", ] @@ -804,6 +933,32 @@ dependencies = [ "time-core", ] +[[package]] +name = "tokio" +version = "1.43.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" +dependencies = [ + "backtrace", + "libc", + "mio", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys 0.52.0", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/Cargo.toml b/Cargo.toml index c3e391c..af41ebd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,6 @@ log = "0.4.25" anyhow = "1.0.95" rand = "0.9.0-beta.3" rustls = "^0.23.21" +hmac = "0.12.1" +md-5 = "^0.10.6" +tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread", "net", "sync"] } diff --git a/src/main.rs b/src/main.rs index 1bbf293..8c95b7d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,9 +6,9 @@ use chrono::Local; use log::{info, trace}; use once_cell::sync::Lazy; use simplelog::{ColorChoice, CombinedLogger, Config, LevelFilter, TerminalMode, TermLogger, WriteLogger}; -use crate::prudp::endpoint::Endpoint; +use crate::prudp::socket::{Socket, SocketImpl}; use crate::prudp::packet::VirtualPort; -use crate::prudp::server::NexServer; +use crate::prudp::router::Router; mod endianness; mod prudp; @@ -27,7 +27,8 @@ static OWN_IP: Lazy = Lazy::new(||{ .expect("no public ip specified") }); -fn main() { +#[tokio::main] +async fn main() { CombinedLogger::init( vec![ TermLogger::new(LevelFilter::Info, Config::default(), TerminalMode::Mixed, ColorChoice::Auto), @@ -40,23 +41,26 @@ fn main() { dotenv::dotenv().ok(); + start_servers().await; +} + +async fn start_servers(){ info!("starting auth server on {}:{}", *OWN_IP, *AUTH_SERVER_PORT); - let (auth_server, auth_server_join_handle) = - NexServer::new(SocketAddrV4::new(*OWN_IP, *AUTH_SERVER_PORT)) - .expect("unable to startauth server"); + let auth_server_router = + Router::new(SocketAddrV4::new(*OWN_IP, *AUTH_SERVER_PORT)).await + .expect("unable to startauth server"); info!("setting up endpoints"); - let auth_endpoints = vec![ - Endpoint::new(auth_server.socket.try_clone().unwrap(), VirtualPort::new(1,10)) - ]; + let mut socket = + Socket::new( + auth_server_router.clone(), + VirtualPort::new(1,10), + "6f599f81" + ).await.expect("unable to create socket"); - auth_server.endpoints.set(auth_endpoints) - .expect("endpoints were somehow set by something else???"); - - - trace!("joining auth server"); - - auth_server_join_handle.join().unwrap(); + let Some(connection) = socket.accept().await else { + return; + }; } diff --git a/src/prudp/connection.rs b/src/prudp/connection.rs new file mode 100644 index 0000000..2ee80e4 --- /dev/null +++ b/src/prudp/connection.rs @@ -0,0 +1,7 @@ +use tokio::sync::mpsc::Receiver; + +//struct Connection(Arc, Receiver<>); + +struct ConnectionImpl{ + +} \ No newline at end of file diff --git a/src/prudp/endpoint.rs b/src/prudp/endpoint.rs deleted file mode 100644 index dcada53..0000000 --- a/src/prudp/endpoint.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::collections::HashMap; -use std::net::UdpSocket; -use std::sync::{Arc, RwLock}; -use log::{error, info}; -use rand::random; -use crate::prudp::packet::{flags, PRUDPPacket, types, VirtualPort}; -use crate::prudp::sockaddr::PRUDPSockAddr; - -#[derive(Debug)] -pub struct Endpoint{ - virtual_port: VirtualPort, - socket: UdpSocket, - connections: RwLock> -} - -#[derive(Debug)] -pub struct Connection{ - sock_addr: PRUDPSockAddr, - id: u64 -} - -impl Endpoint{ - pub fn new(socket: UdpSocket, port: VirtualPort) -> Self{ - Self{ - socket, - virtual_port: port, - connections: Default::default() - } - } - - pub fn get_virual_port(&self) -> VirtualPort{ - self.virtual_port - } - - pub fn process_packet(&self, connection: PRUDPSockAddr, packet: &PRUDPPacket){ - info!("recieved packet on endpoint"); - - let conn = self.connections.read().expect("poison"); - - if !conn.contains_key(&connection){ - drop(conn); - - let mut conn = self.connections.write().expect("poison"); - //only insert if we STILL dont have the connection preventing double insertion - if !conn.contains_key(&connection) { - conn.insert(connection, Connection { - sock_addr: connection, - id: random() - }); - } - drop(conn); - } else { - drop(conn); - } - - let conn = self.connections.read().expect("poison"); - - let Some(conn) = conn.get(&connection) else { - error!("connection is still not present after making sure connection is present, giving up."); - return; - }; - - if //((packet.header.types_and_flags.get_flags() & flags::NEED_ACK) != 0) || - ((packet.header.types_and_flags.get_flags() & flags::ACK) != 0) || - ((packet.header.types_and_flags.get_flags() & flags::RELIABLE) != 0) || - ((packet.header.types_and_flags.get_flags() & flags::MULTI_ACK) != 0) { - let copy = packet.header.types_and_flags; - - unimplemented!("{:?}", copy) - } - - - match packet.header.types_and_flags.get_types() { - types::SYN => { - // reset heartbeat? - let response_header = packet.base_response_header(); - - - } - _ => unimplemented!() - } - - - } -} \ No newline at end of file diff --git a/src/prudp/mod.rs b/src/prudp/mod.rs index c17f1fe..d47b899 100644 --- a/src/prudp/mod.rs +++ b/src/prudp/mod.rs @@ -1,5 +1,6 @@ pub mod packet; -pub mod server; -pub mod endpoint; +pub mod router; +pub mod socket; mod auth_module; -mod sockaddr; \ No newline at end of file +mod sockaddr; +mod connection; \ No newline at end of file diff --git a/src/prudp/packet.rs b/src/prudp/packet.rs index 606e738..f267bc7 100644 --- a/src/prudp/packet.rs +++ b/src/prudp/packet.rs @@ -1,15 +1,20 @@ use std::fmt::{Debug, Formatter}; use std::hint::unreachable_unchecked; use std::io; -use std::io::{Cursor, ErrorKind, Read, Seek}; +use std::io::{Cursor, ErrorKind, Read, Seek, Write}; use std::net::SocketAddrV4; use bytemuck::{Pod, Zeroable}; +use hmac::{Hmac, Mac}; use log::{error, warn}; +use md5::{Md5, Digest}; use thiserror::Error; use v_byte_macros::{EnumTryInto, SwapEndian}; use crate::endianness::{IS_BIG_ENDIAN, IS_LITTLE_ENDIAN, ReadExtensions}; +use crate::prudp::packet::PacketOption::{ConnectionSignature, FragmentId, InitialSequenceId, MaximumSubstreamId, SupportedFunctions}; use crate::prudp::sockaddr::PRUDPSockAddr; +type Md5Hmac = Hmac; + #[derive(Error, Debug)] pub enum Error { #[error("{0}")] @@ -34,21 +39,29 @@ pub type Result = std::result::Result; pub struct TypesFlags(u16); impl TypesFlags { - pub fn get_types(self) -> u8 { + pub const fn get_types(self) -> u8 { (self.0 & 0x000F) as u8 } - pub fn get_flags(self) -> u16 { + pub const fn get_flags(self) -> u16 { (self.0 & 0xFFF0) >> 4 } - pub fn types(self, val: u8) -> Self { + pub const fn types(self, val: u8) -> Self { Self((self.0 & 0xFFF0) | (val as u16 & 0x000F)) } - pub fn flags(self, val: u16) -> Self { + pub const fn flags(self, val: u16) -> Self { Self((self.0 & 0x000F) | ((val << 4) & 0xFFF0)) } + + pub const fn set_flag(&mut self, val: u16){ + self.0 |= (val & 0xFFF) << 4; + } + + pub const fn set_types(&mut self, val: u8){ + self.0 |= val as u16 & 0x0F; + } } pub mod flags { @@ -116,7 +129,7 @@ impl Debug for VirtualPort { } } -#[repr(C, packed)] +#[repr(C)] #[derive(Debug, Copy, Clone, Pod, Zeroable, SwapEndian)] pub struct PRUDPHeader { magic: [u8; 2], @@ -136,11 +149,64 @@ enum PacketSpecificData { E = 0x10 } +#[derive(Debug, Clone)] +pub enum PacketOption{ + SupportedFunctions(u32), + ConnectionSignature([u8; 16]), + FragmentId(u8), + InitialSequenceId(u16), + MaximumSubstreamId(u8) +} + +impl PacketOption{ + fn from(option_id: OptionId, option_data: &[u8]) -> io::Result{ + let mut data_cursor = Cursor::new(option_data); + let val = match option_id.into(){ + 0 => SupportedFunctions(data_cursor.read_struct(IS_BIG_ENDIAN)?), + 1 => ConnectionSignature(data_cursor.read_struct(IS_BIG_ENDIAN)?), + 2 => FragmentId(data_cursor.read_struct(IS_BIG_ENDIAN)?), + 3 => InitialSequenceId(data_cursor.read_struct(IS_BIG_ENDIAN)?), + 4 => MaximumSubstreamId(data_cursor.read_struct(IS_BIG_ENDIAN)?), + _ => unsafe{ unreachable_unchecked() } + }; + + Ok(val) + } + + fn write_to_stream(&self, stream: &mut impl Write) -> io::Result<()> { + match self { + SupportedFunctions(v) => { + stream.write_all(&[0, size_of_val(v) as u8])?; + stream.write_all(&v.to_le_bytes())?; + } + ConnectionSignature(v) => { + stream.write_all(&[1, size_of_val(v) as u8])?; + stream.write_all(v)?; + } + FragmentId(v) => { + stream.write_all(&[2, size_of_val(v) as u8])?; + stream.write_all(&v.to_le_bytes())?; + } + InitialSequenceId(v) => { + stream.write_all(&[3, size_of_val(v) as u8])?; + stream.write_all(&v.to_le_bytes())?; + } + MaximumSubstreamId(v) => { + stream.write_all(&[4, size_of_val(v) as u8])?; + stream.write_all(&v.to_le_bytes())?; + } + } + + Ok(()) + } +} + #[derive(Debug, Clone)] pub struct PRUDPPacket { pub header: PRUDPHeader, + pub packet_signature: [u8; 16], pub payload: Vec, - pub options: Vec<(u8, Vec)>, + pub options: Vec, } #[derive(Copy, Clone, Debug)] @@ -190,7 +256,7 @@ impl PRUDPPacket { } //discard it for now - let _: [u8; 16] = reader.read_struct(IS_BIG_ENDIAN)?; + let packet_signature: [u8; 16] = reader.read_struct(IS_BIG_ENDIAN)?; assert_eq!(reader.stream_position().ok(), Some(14 + 16)); @@ -236,7 +302,7 @@ impl PRUDPPacket { break; } - options.push((option_id.into(), option_data)); + options.push(PacketOption::from(option_id, &option_data)?); } @@ -244,8 +310,11 @@ impl PRUDPPacket { reader.read_exact(&mut payload)?; + + Ok(Self { header, + packet_signature, payload, options, }) @@ -258,20 +327,84 @@ impl PRUDPPacket { } } - pub fn base_response_header(&self) -> PRUDPHeader { - PRUDPHeader { - magic: [0xEA, 0xD0], - types_and_flags: TypesFlags(0), - destination_port: self.header.source_port, - source_port: self.header.destination_port, - payload_size: 0, - version: 1, - packet_specific_size: 0, - sequence_id: 0, - session_id: 0, - substream_id: 0, + fn generate_options_bytes(&self) -> Vec{ + let mut vec = Vec::new(); + for option in &self.options{ + option.write_to_stream(&mut vec).expect("vec should always automatically be able to extend"); } + + vec + } + + pub fn calculate_signature_value(&self, access_key: &str, session_key: Option<[u8; 32]>, connection_signature: Option<[u8; 16]>) -> [u8; 16]{ + let access_key_bytes = access_key.as_bytes(); + let access_key_sum: u32 = access_key_bytes.iter().map(|v| *v as u32).sum(); + let access_key_sum_bytes: [u8; 4] = access_key_sum.to_le_bytes(); + + let header_data: [u8; 8] = bytemuck::bytes_of(&self.header)[0x8..].try_into().unwrap(); + + let option_bytes = self.generate_options_bytes(); + + let mut md5 = md5::Md5::default(); + + md5.update(access_key_bytes); + let key = md5.finalize(); + + let mut hmac = Md5Hmac::new_from_slice(&key).expect("fuck"); + + hmac.write(&header_data).expect("error during hmac calculation"); + if let Some(session_key) = session_key { + hmac.write(&session_key).expect("error during hmac calculation"); + } + hmac.write(&access_key_sum_bytes).expect("error during hmac calculation"); + if let Some(connection_signature) = connection_signature { + hmac.write(&connection_signature).expect("error during hmac calculation"); + } + + hmac.write(&option_bytes).expect("error during hmac calculation"); + + hmac.write_all(&self.payload).expect("error during hmac calculation"); + + hmac.finalize().into_bytes()[0..16].try_into().expect("invalid hmac size") + } + + pub fn calculate_and_assign_signature(&mut self, access_key: &str, session_key: Option<[u8; 32]>, connection_signature: Option<[u8; 16]>){ + self.packet_signature = self.calculate_signature_value(access_key, session_key, connection_signature); + } + + pub fn base_response_packet(&self) -> Self { + Self { + header: PRUDPHeader { + magic: [0xEA, 0xD0], + types_and_flags: TypesFlags(0), + destination_port: self.header.source_port, + source_port: self.header.destination_port, + payload_size: 0, + version: 1, + packet_specific_size: 0, + sequence_id: 0, + session_id: 0, + substream_id: 0, + + }, + packet_signature: [0; 16], + payload: Default::default(), + options: Default::default() + } + } + + pub fn write_to(&self, writer: &mut impl Write) -> io::Result<()>{ + writer.write_all(bytemuck::bytes_of(&self.header))?; + writer.write_all(&self.packet_signature)?; + + for option in &self.options{ + option.write_to_stream(writer)?; + } + + writer.write_all(&self.payload)?; + + Ok(()) } } diff --git a/src/prudp/router.rs b/src/prudp/router.rs new file mode 100644 index 0000000..6db9a84 --- /dev/null +++ b/src/prudp/router.rs @@ -0,0 +1,173 @@ +use std::{env, io, thread}; +use std::cell::OnceCell; +use std::io::Cursor; +use std::marker::PhantomData; +use tokio::net::UdpSocket; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::net::SocketAddr::V4; +use std::ops::{Deref, DerefMut}; +use std::sync::{Arc, OnceLock}; +use std::sync::atomic::{AtomicBool, Ordering}; +use tokio::task::JoinHandle; +use once_cell::sync::Lazy; +use log::{error, info, trace, warn}; +use thiserror::Error; +use tokio::sync::RwLock; +use crate::prudp::auth_module::AuthModule; +use crate::prudp::socket::{Socket, SocketImpl}; +use crate::prudp::packet::{PRUDPPacket, VirtualPort}; +use crate::prudp::router::Error::VirtualPortTaken; +use crate::prudp::sockaddr::PRUDPSockAddr; + +static SERVER_DATAGRAMS: Lazy = Lazy::new(||{ + env::var("SERVER_DATAGRAM_COUNT").ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(1) +}); + +pub struct Router { + endpoints: RwLock<[Option>; 16]>, + running: AtomicBool, + socket: Arc, + //pub auth_module: Arc + _no_outside_construction: PhantomData<()> +} +#[derive(Debug, Error)] +pub enum Error{ + #[error("tried to register socket to a port which is already taken (port: {0})")] + VirtualPortTaken(u8) +} + + +impl Router { + fn process_prudp_packet(&self, packet: &PRUDPPacket){ + + } + async fn process_prudp_packets<'a>(&self, socket: &'a UdpSocket, addr: SocketAddrV4, udp_message: &[u8]){ + let mut stream = Cursor::new(udp_message); + + while stream.position() as usize != udp_message.len() { + let packet = match PRUDPPacket::new(&mut stream){ + Ok(p) => p, + Err(e) => { + error!("Somebody({}) is fucking with the servers or their connection is bad", addr); + break; + }, + }; + + trace!("got valid prudp packet from someone({}): \n{:?}", addr, packet); + + let connection = packet.source_sockaddr(addr); + + let endpoints = self.endpoints.read().await; + + let Some(endpoint) = endpoints[packet.header.destination_port.get_port_number() as usize].as_ref() else { + error!("connection to invalid endpoint({}) attempted by {}", packet.header.destination_port.get_port_number(), connection.regular_socket_addr); + continue; + }; + + let endpoint = endpoint.clone(); + + // Dont keep the locked structure for too long + drop(endpoints); + + trace!("sending packet to endpoint"); + + endpoint.process_packet(connection, &packet).await; + } + } + + async fn server_thread_send_entry(self: Arc, socket: Arc){ + info!("starting datagram thread"); + + while self.running.load(Ordering::Relaxed) { + // yes we actually allow the max udp to be read lol + let mut msg_buffer = vec![0u8; 65507]; + + let (len, addr) = socket.recv_from(&mut msg_buffer) + .await.expect("Datagram thread crashed due to unexpected error from recv_from"); + + let V4(addr) = addr else { + error!("somehow got ipv6 packet...? ignoring"); + continue; + }; + + let current_msg = &msg_buffer[0..len]; + info!("attempting to process message"); + + self.process_prudp_packets(&socket, addr, current_msg).await; + } + } + + pub async fn new(addr: SocketAddrV4) -> io::Result>{ + trace!("starting router on {}", addr); + + let socket = Arc::new(UdpSocket::bind(addr).await?); + + let own_impl = Router { + endpoints: Default::default(), + running: AtomicBool::new(true), + socket: socket.clone(), + _no_outside_construction: Default::default() + }; + + let arc = Arc::new(own_impl); + + + { + let socket = socket.clone(); + let server= arc.clone(); + + tokio::spawn(async { + server.server_thread_send_entry(socket).await; + }); + } + + { + let socket = socket.clone(); + let server= arc.clone(); + + tokio::spawn(async { + //server thread sender entry + // todo: make this run in the socket cause that makes more sense + //server.server_thread_recieve_entry(socket).await; + }); + } + + + Ok(arc) + } + + pub fn get_udp_socket(&self) -> Arc{ + self.socket.clone() + } + + // This will remove a socket from the router, this renders all instances of that socket unable + // to recieve any more data making the error out on trying to for example recieve connections + pub async fn remove_socket(&self, virtual_port: VirtualPort){ + self.endpoints.write().await[virtual_port.get_port_number() as usize] = None; + } + + // returns Some(()) i + pub(crate) async fn add_socket(&self, socket: Arc) -> Result<(), Error>{ + let mut endpoints = self.endpoints.write().await; + + let idx = socket.get_virual_port().get_port_number() as usize; + + if endpoints[idx].is_none() { + endpoints[idx] = Some(socket); + } else { + return Err(VirtualPortTaken(idx as u8)); + } + + Ok(()) + } + + pub fn get_own_address(&self) -> SocketAddrV4{ + match self.socket.local_addr().expect("unable to get socket address"){ + SocketAddr::V4(v4) => v4, + _ => unreachable!() + } + } +} + diff --git a/src/prudp/server.rs b/src/prudp/server.rs deleted file mode 100644 index 44f4500..0000000 --- a/src/prudp/server.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::{env, io, thread}; -use std::cell::OnceCell; -use std::io::Cursor; -use std::marker::PhantomData; -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, TcpStream, UdpSocket}; -use std::net::SocketAddr::V4; -use std::ops::{Deref, DerefMut}; -use std::sync::{Arc, Mutex, OnceLock, RwLock}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::thread::JoinHandle; -use once_cell::sync::Lazy; -use log::{error, info, trace, warn}; -use crate::prudp::auth_module::AuthModule; -use crate::prudp::endpoint::Endpoint; -use crate::prudp::packet::{PRUDPPacket, VirtualPort}; -use crate::prudp::sockaddr::PRUDPSockAddr; - -static SERVER_DATAGRAMS: Lazy = Lazy::new(||{ - env::var("SERVER_DATAGRAM_COUNT").ok() - .and_then(|s| s.parse().ok()) - .unwrap_or(1) -}); - -pub struct NexServer{ - pub endpoints: OnceLock>, - pub socket: UdpSocket, - pub running: AtomicBool, - //pub auth_module: Arc - _no_outside_construction: PhantomData<()> -} - - -impl NexServer{ - fn process_prudp_packet(&self, packet: &PRUDPPacket){ - - } - fn process_prudp_packets<'a>(&self, socket: &'a UdpSocket, addr: SocketAddrV4, udp_message: &[u8]){ - let mut stream = Cursor::new(udp_message); - - while stream.position() as usize != udp_message.len() { - let packet = match PRUDPPacket::new(&mut stream){ - Ok(p) => p, - Err(e) => { - error!("Somebody({}) is fucking with the servers or their connection is bad", addr); - break; - }, - }; - - trace!("got valid prudp packet from someone({}): \n{:?}", addr, packet); - - let connection = packet.source_sockaddr(addr); - - let Some(endpoints) = self.endpoints.get() else{ - warn!("Got a message: ignoring because the server is still starting or the endpoints havent been set up"); - break; - }; - - let Some(endpoint) = endpoints.iter().find(|e|{ - e.get_virual_port().get_port_number() == packet.header.destination_port.get_port_number() - }) else { - error!("connection to invalid endpoint({}) attempted by {}", packet.header.destination_port.get_port_number(), connection.regular_socket_addr); - continue; - }; - - trace!("sending packet to endpoint"); - - endpoint.process_packet(connection, &packet); - } - } - - fn server_thread_entry(self: Arc, socket: UdpSocket){ - info!("starting datagram thread"); - - while self.running.load(Ordering::Relaxed) { - // yes we actually allow the max udp to be read lol - let mut msg_buffer = vec![0u8; 65507]; - - let (len, addr) = socket.recv_from(&mut msg_buffer) - .expect("Datagram thread crashed due to unexpected error from recv_from"); - - let V4(addr) = addr else { - error!("somehow got ipv6 packet...? ignoring"); - continue; - }; - - let current_msg = &msg_buffer[0..len]; - info!("attempting to process message"); - self.process_prudp_packets(&socket, addr, current_msg); - } - } - - pub fn new(addr: SocketAddrV4) -> io::Result<(Arc, JoinHandle<()>)>{ - let socket = UdpSocket::bind(addr)?; - - let own_impl = NexServer{ - endpoints: Default::default(), - running: AtomicBool::new(true), - socket: socket.try_clone().unwrap(), - _no_outside_construction: Default::default() - }; - - let arc = Arc::new(own_impl); - - let mut thread = None; - - for _ in 0..*SERVER_DATAGRAMS { - let socket = socket.try_clone().unwrap(); - let server= arc.clone(); - - thread = Some(thread::spawn(move || { - server.server_thread_entry(socket); - })); - } - - let thread = thread.expect("cannot have less than 1 thread for a server"); - - - Ok((arc, thread)) - } -} - diff --git a/src/prudp/socket.rs b/src/prudp/socket.rs new file mode 100644 index 0000000..3c07a52 --- /dev/null +++ b/src/prudp/socket.rs @@ -0,0 +1,201 @@ +use std::array; +use std::collections::HashMap; +use std::io::Write; +use std::ops::Deref; +use tokio::net::UdpSocket; +use std::sync::{Arc}; +use tokio::sync::{Mutex, RwLock}; +use hmac::{Hmac, Mac}; +use log::{error, info, trace}; +use rand::random; +use rc4::consts::U256; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use crate::prudp::packet::{flags, PacketOption, PRUDPPacket, types, VirtualPort}; +use crate::prudp::packet::PacketOption::{MaximumSubstreamId, SupportedFunctions}; +use crate::prudp::packet::types::SYN; +use crate::prudp::router::{Error, Router}; +use crate::prudp::sockaddr::PRUDPSockAddr; + + +type Md5Hmac = Hmac; + +/// PRUDP Socket for accepting connections to then send and recieve data from those clients +pub struct Socket(Arc, Arc, Receiver); + +#[derive(Debug)] +pub struct SocketImpl { + virtual_port: VirtualPort, + socket: Arc, + access_key: &'static str, + connections: RwLock>>>, + connection_creation_sender: Sender +} + +#[derive(Debug)] +pub struct Connection { + sock_addr: PRUDPSockAddr, + id: u64, + signature: [u8; 16], +} + + + +impl Socket { + pub async fn new(router: Arc, port: VirtualPort, access_key: &'static str) -> Result { + trace!("creating socket on router at {} on virtual port {:?}", router.get_own_address(), port); + let (send, recv) = channel(20); + + let socket = Arc::new( + SocketImpl::new(&router, send, port, access_key) + ); + + router.add_socket(socket.clone()).await?; + + Ok(Self(socket, router, recv)) + } + + pub async fn accept(&mut self) -> Option{ + self.2.recv().await + } +} + +impl Drop for Socket{ + fn drop(&mut self) { + { + let router = self.1.clone(); + + let virtual_port = self.virtual_port; + trace!("socket dropped socket will be removed from router soon"); + // it's not that important to remove it immediately so we can delay the deletion a bit if needed + tokio::spawn(async move { + router.remove_socket(virtual_port).await; + trace!("socket removed from router successfully"); + }); + } + } +} + +impl Deref for Socket{ + type Target = SocketImpl; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl SocketImpl { + fn new(router: &Router, connection_creation_sender: Sender, port: VirtualPort, access_key: &'static str) -> Self { + SocketImpl { + socket: router.get_udp_socket(), + virtual_port: port, + connections: Default::default(), + access_key, + connection_creation_sender + } + } + + pub fn get_virual_port(&self) -> VirtualPort { + self.virtual_port + } + + pub async fn process_packet(&self, connection: PRUDPSockAddr, packet: &PRUDPPacket) { + info!("recieved packet on endpoint"); + + let conn = self.connections.read().await; + + if !conn.contains_key(&connection) { + drop(conn); + + let mut conn = self.connections.write().await; + //only insert if we STILL dont have the connection preventing double insertion + if !conn.contains_key(&connection) { + conn.insert(connection, Arc::new(Mutex::new(Connection { + sock_addr: connection, + id: random(), + signature: [0; 16], + }))); + } + drop(conn); + } else { + drop(conn); + } + + let connections = self.connections.read().await; + + let Some(conn) = connections.get(&connection) else { + error!("connection is still not present after making sure connection is present, giving up."); + return; + }; + + let conn = conn.clone(); + + // dont keep holding the connections list unnescesarily + drop(connections); + + let mut conn = conn.lock().await; + + if //((packet.header.types_and_flags.get_flags() & flags::NEED_ACK) != 0) || + ((packet.header.types_and_flags.get_flags() & flags::ACK) != 0) || + ((packet.header.types_and_flags.get_flags() & flags::RELIABLE) != 0) || + ((packet.header.types_and_flags.get_flags() & flags::MULTI_ACK) != 0) { + let copy = packet.header.types_and_flags; + + unimplemented!("{:?}", copy) + } + + + match packet.header.types_and_flags.get_types() { + types::SYN => { + // reset heartbeat? + let mut response_packet = packet.base_response_packet(); + + response_packet.header.types_and_flags.set_types(SYN); + response_packet.header.types_and_flags.set_flag(flags::ACK); + response_packet.header.types_and_flags.set_flag(flags::HAS_SIZE); + + let mut hmac = Md5Hmac::new_from_slice(&[0; 16]).expect("fuck"); + + let mut data = connection.regular_socket_addr.ip().octets().to_vec(); + data.extend_from_slice(&connection.regular_socket_addr.port().to_be_bytes()); + + hmac.write_all(&data).expect("figuring this out was complete ass"); + let result: [u8; 16] = hmac.finalize().into_bytes()[0..16].try_into().expect("fuck"); + + conn.signature = result; + + response_packet.options.push(PacketOption::ConnectionSignature(result)); + + response_packet.calculate_and_assign_signature(self.access_key, None, None); + + for options in &packet.options{ + match options{ + SupportedFunctions(functions) => { + response_packet.options.push(SupportedFunctions(*functions)) + } + MaximumSubstreamId(max_substream) => { + response_packet.options.push(MaximumSubstreamId(*max_substream)) + }, + _ => {/* ??? */} + } + } + + let mut vec = Vec::new(); + + response_packet.write_to(&mut vec).expect("somehow failed to convert backet to bytes"); + + self.socket.send_to(&vec, connection.regular_socket_addr).await.expect("failed to send data back"); + } + _ => unimplemented!("unimplemented packet type: {}", packet.header.types_and_flags.get_types()) + } + } +} + +#[cfg(test)] +mod test { + use hmac::Mac; + use crate::prudp::socket::Md5Hmac; + + #[test] + fn fuck() { + let hmac = Md5Hmac::new_from_slice(&[0; 16]).expect("fuck"); + } +} \ No newline at end of file