diff --git a/src/prudp/packet.rs b/src/prudp/packet.rs index 1637584..f70160f 100644 --- a/src/prudp/packet.rs +++ b/src/prudp/packet.rs @@ -87,7 +87,7 @@ impl Debug for TypesFlags { #[repr(transparent)] #[derive(PartialEq, Eq, Copy, Clone, Pod, Zeroable, SwapEndian, Hash)] -pub struct VirtualPort(u8); +pub struct VirtualPort(pub(crate) u8); impl VirtualPort { #[inline] @@ -201,6 +201,16 @@ impl PacketOption{ Ok(()) } + + fn write_size(&self) -> u8 { + match self { + SupportedFunctions(_) => 2 + 4, + ConnectionSignature(_) => 2 + 16, + FragmentId(_) => 2 + 1, + InitialSequenceId(_) => 2 + 2, + MaximumSubstreamId(_) => 2 + 1, + } + } } #[derive(Debug, Clone)] @@ -374,6 +384,11 @@ impl PRUDPPacket { self.packet_signature = self.calculate_signature_value(access_key, session_key, connection_signature); } + pub fn set_sizes(&mut self){ + self.header.packet_specific_size = self.options.iter().map(|o| o.write_size()).sum(); + self.header.payload_size = self.payload.len() as u16; + } + pub fn base_response_packet(&self) -> Self { Self { header: PRUDPHeader { @@ -426,7 +441,14 @@ mod test { let buf = vec![0; option_id.option_type_size() as usize]; - PacketOption::from(option_id, &buf).unwrap(); + let opt = PacketOption::from(option_id, &buf).unwrap(); + { + let mut write_buf = vec![]; + + opt.write_to_stream(&mut write_buf).unwrap(); + + assert_eq!(write_buf.len() as u8, opt.write_size()) + } } diff --git a/src/prudp/socket.rs b/src/prudp/socket.rs index 5da71d5..837357a 100644 --- a/src/prudp/socket.rs +++ b/src/prudp/socket.rs @@ -19,8 +19,6 @@ use crate::prudp::router::{Error, Router}; use crate::prudp::sockaddr::PRUDPSockAddr; - - /// PRUDP Socket for accepting connections to then send and recieve data from those clients pub struct Socket(Arc, Arc, Receiver); @@ -30,7 +28,7 @@ pub struct SocketImpl { socket: Arc, access_key: &'static str, connections: RwLock>>>, - connection_creation_sender: Sender + connection_creation_sender: Sender, } #[derive(Debug)] @@ -39,11 +37,10 @@ pub struct Connection { id: u64, signature: [u8; 16], server_signature: [u8; 16], - session_id: u8 + session_id: u8, } - impl Socket { pub async fn new(router: Arc, port: VirtualPort, access_key: &'static str) -> Result { trace!("creating socket on router at {} on virtual port {:?}", router.get_own_address(), port); @@ -58,12 +55,12 @@ impl Socket { Ok(Self(socket, router, recv)) } - pub async fn accept(&mut self) -> Option{ + pub async fn accept(&mut self) -> Option { self.2.recv().await } } -impl Drop for Socket{ +impl Drop for Socket { fn drop(&mut self) { { let router = self.1.clone(); @@ -79,7 +76,7 @@ impl Drop for Socket{ } } -impl Deref for Socket{ +impl Deref for Socket { type Target = SocketImpl; fn deref(&self) -> &Self::Target { &self.0 @@ -87,7 +84,6 @@ impl Deref for Socket{ } - impl SocketImpl { fn new(router: &Router, connection_creation_sender: Sender, port: VirtualPort, access_key: &'static str) -> Self { SocketImpl { @@ -95,7 +91,7 @@ impl SocketImpl { virtual_port: port, connections: Default::default(), access_key, - connection_creation_sender + connection_creation_sender, } } @@ -119,7 +115,7 @@ impl SocketImpl { id: random(), signature: [0; 16], server_signature: [0; 16], - session_id: 0 + session_id: 0, }))); } drop(conn); @@ -166,18 +162,20 @@ impl SocketImpl { response_packet.options.push(ConnectionSignature(conn.signature)); - for options in &packet.options{ - match options{ + for options in &packet.options { + match options { SupportedFunctions(functions) => { response_packet.options.push(SupportedFunctions(*functions)) } MaximumSubstreamId(max_substream) => { response_packet.options.push(MaximumSubstreamId(*max_substream)) - }, - _ => {/* ??? */} + } + _ => { /* ??? */ } } } + response_packet.set_sizes(); + response_packet.calculate_and_assign_signature(self.access_key, None, None); let mut vec = Vec::new(); @@ -202,14 +200,14 @@ impl SocketImpl { response_packet.options.push(ConnectionSignature(Default::default())); - for option in &packet.options{ - match option { + for option in &packet.options { + match option { MaximumSubstreamId(max_substream) => response_packet.options.push(MaximumSubstreamId(*max_substream)), SupportedFunctions(funcs) => response_packet.options.push(SupportedFunctions(*funcs)), ConnectionSignature(sig) => { conn.server_signature = *sig - }, - _ => {/* ? */} + } + _ => { /* ? */ } } } @@ -219,10 +217,12 @@ impl SocketImpl { // todo: implement something to do secure servers - if conn.server_signature == <[u8;16] as Default>::default(){ + if conn.server_signature == <[u8; 16] as Default>::default() { error!("didn't get connection signature from client") } + response_packet.set_sizes(); + response_packet.calculate_and_assign_signature(self.access_key, None, Some(conn.server_signature)); let mut vec = Vec::new(); @@ -233,6 +233,49 @@ impl SocketImpl { } _ => unimplemented!("unimplemented packet type: {}", packet.header.types_and_flags.get_types()) } - } } + +#[cfg(test)] +mod test { + use std::io::Cursor; + use std::net::{Ipv4Addr, SocketAddrV4}; + use std::sync::Arc; + use tokio::net::UdpSocket; + use tokio::sync::mpsc::channel; + use crate::prudp::packet::{PRUDPPacket, VirtualPort}; + use crate::prudp::sockaddr::PRUDPSockAddr; + use crate::prudp::socket::SocketImpl; + + #[tokio::test] + async fn test_connect() { + let packet_1 = [234, 208, 1, 27, 0, 0, 175, 161, 192, 0, 0, 0, 0, 0, 36, 21, 233, 179, 203, 154, 57, 222, 219, 9, 21, 2, 29, 172, 56, 92, 0, 4, 4, 1, 0, 0, 1, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 1, 0]; + let packet_2 = [234, 208, 1, 31, 0, 0, 175, 161, 225, 0, 249, 0, 1, 0, 40, 168, 31, 138, 58, 193, 30, 134, 3, 232, 205, 245, 28, 155, 193, 198, 0, 4, 0, 0, 0, 0, 1, 16, 211, 240, 113, 188, 227, 114, 114, 30, 157, 179, 246, 55, 233, 240, 44, 197, 3, 2, 247, 244, 4, 1, 0]; + + let packet_1 = PRUDPPacket::new(&mut Cursor::new(packet_1)).unwrap(); + let packet_2 = PRUDPPacket::new(&mut Cursor::new(packet_2)).unwrap(); + + + + + let (send, recv) = channel(100); + + let sock = SocketImpl{ + connections: Default::default(), + access_key: "6f599f81", + virtual_port: VirtualPort(0), + socket: Arc::new(UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 10000)).await.unwrap()), + connection_creation_sender: send + }; + println!("sent: {:?}", packet_1); + sock.process_packet(PRUDPSockAddr{ + virtual_port: VirtualPort(0), + regular_socket_addr: SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2469) + }, &packet_1).await; + println!("sent: {:?}", packet_2); + sock.process_packet(PRUDPSockAddr{ + virtual_port: VirtualPort(0), + regular_socket_addr: SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2469) + }, &packet_2).await; + } +} \ No newline at end of file