diff --git a/src/prudp/packet.rs b/src/prudp/packet.rs index 12bf7e4..e0b6aa2 100644 --- a/src/prudp/packet.rs +++ b/src/prudp/packet.rs @@ -10,6 +10,7 @@ use md5::{Md5, Digest}; use thiserror::Error; use v_byte_macros::{EnumTryInto, SwapEndian}; use crate::endianness::{IS_BIG_ENDIAN, IS_LITTLE_ENDIAN, ReadExtensions}; +use crate::prudp::packet::flags::ACK; use crate::prudp::packet::PacketOption::{ConnectionSignature, FragmentId, InitialSequenceId, MaximumSubstreamId, SupportedFunctions}; use crate::prudp::sockaddr::PRUDPSockAddr; @@ -170,7 +171,7 @@ impl PacketOption{ 2 => FragmentId(data_cursor.read_struct(IS_BIG_ENDIAN)?), 3 => InitialSequenceId(data_cursor.read_struct(IS_BIG_ENDIAN)?), 4 => MaximumSubstreamId(data_cursor.read_struct(IS_BIG_ENDIAN)?), - _ => unsafe{ unreachable_unchecked() } + _ => unreachable!() }; Ok(val) @@ -242,9 +243,7 @@ impl OptionId { 2 => 1, 3 => 2, 4 => 1, - // Getting here would mean that the invariant has been violated, thus this isnt my - // problem lmao - _ => unsafe { unreachable_unchecked() } + _ => unreachable!() } } } @@ -332,6 +331,30 @@ impl PRUDPPacket { }) } + pub fn base_acknowledgement_packet(&self) -> Self{ + let base = self.base_response_packet(); + + let mut flags = self.header.types_and_flags.flags(0); + + flags.set_flag(ACK); + + let options = self.options + .iter() + .filter(|o| matches!(o, FragmentId(_))) + .collect(); + + Self{ + header: PRUDPHeader{ + types_and_flags: flags, + sequence_id: self.header.sequence_id, + substream_id: self.header.substream_id, + ..base.header + }, + options, + ..base + } + } + pub fn source_sockaddr(&self, socket_addr_v4: SocketAddrV4) -> PRUDPSockAddr { PRUDPSockAddr { regular_socket_addr: socket_addr_v4, diff --git a/src/prudp/socket.rs b/src/prudp/socket.rs index e3e1e0c..a474ebb 100644 --- a/src/prudp/socket.rs +++ b/src/prudp/socket.rs @@ -1,18 +1,18 @@ use std::array; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::io::Write; use std::ops::Deref; use tokio::net::UdpSocket; use std::sync::{Arc}; use tokio::sync::{Mutex, RwLock}; use hmac::{Hmac, Mac}; -use log::{error, info, trace}; +use log::{error, info, trace, warn}; use rand::random; use rc4::consts::U256; use rustls::internal::msgs::handshake::SessionId; use tokio::sync::mpsc::{channel, Receiver, Sender}; use crate::prudp::packet::{flags, PacketOption, PRUDPPacket, types, VirtualPort}; -use crate::prudp::packet::flags::{ACK, HAS_SIZE, MULTI_ACK, RELIABLE}; +use crate::prudp::packet::flags::{ACK, HAS_SIZE, MULTI_ACK, NEED_ACK, RELIABLE}; use crate::prudp::packet::PacketOption::{ConnectionSignature, MaximumSubstreamId, SupportedFunctions}; use crate::prudp::packet::types::{CONNECT, DATA, SYN}; use crate::prudp::router::{Error, Router}; @@ -31,6 +31,8 @@ pub struct SocketImpl { connection_creation_sender: Sender, } + + #[derive(Debug)] pub struct Connection { sock_addr: PRUDPSockAddr, @@ -38,6 +40,9 @@ pub struct Connection { signature: [u8; 16], server_signature: [u8; 16], session_id: u8, + reliable_client_counter: u16, + reliable_server_counter: u16, + reliable_client_queue: VecDeque, } @@ -116,6 +121,9 @@ impl SocketImpl { signature: [0; 16], server_signature: [0; 16], session_id: 0, + reliable_client_queue: VecDeque::new(), + reliable_client_counter: 0, + reliable_server_counter: 0, }))); } drop(conn); @@ -147,11 +155,6 @@ impl SocketImpl { unimplemented!() } - if (packet.header.types_and_flags.get_flags() & RELIABLE) != 0 { - error!("unreliable packets are unimplemented"); - unimplemented!() - } - match packet.header.types_and_flags.get_types() { SYN => { @@ -170,7 +173,6 @@ impl SocketImpl { for options in &packet.options { match options { SupportedFunctions(functions) => { - response_packet.options.push(SupportedFunctions(*functions & 0x04)) } MaximumSubstreamId(max_substream) => { @@ -232,12 +234,42 @@ impl SocketImpl { response_packet.calculate_and_assign_signature(self.access_key, None, Some(conn.server_signature)); let mut vec = Vec::new(); - response_packet.write_to(&mut vec).expect("somehow failed to convert backet to bytes"); self.socket.send_to(&vec, connection.regular_socket_addr).await.expect("failed to send data back"); } DATA => { + if (packet.header.types_and_flags.get_flags() & RELIABLE) != 0 { + match conn.reliable_client_queue.binary_search_by_key(&conn.reliable_client_counter, |p| p.header.sequence_id) { + Ok(_) => warn!("recieved packet twice"), + Err(position) => conn.reliable_client_queue.insert(position, packet.clone()), + } + + if (packet.header.types_and_flags.get_flags() & NEED_ACK) != 0{ + let mut ack = packet.base_acknowledgement_packet(); + ack.set_sizes(); + ack.calculate_and_assign_signature(self.access_key, None, Some(conn.server_signature)); + + let mut vec = Vec::new(); + ack.write_to(&mut vec).expect("somehow failed to convert backet to bytes"); + + self.socket.send_to(&vec, connection.regular_socket_addr).await.expect("failed to send data back"); + } + + while let Some(packet) = + conn.reliable_client_queue + .front() + .is_some_and(|v| v.header.sequence_id == conn.reliable_client_counter) + .then(|| conn.reliable_client_queue.pop_front()) + .flatten(){ + conn.reliable_client_counter = conn.reliable_client_counter.overflowing_add(1).0; + + // ignored + } + } else { + error!("unreliable packets are unimplemented"); + unimplemented!() + } info!("{:?}", packet); } _ => unimplemented!("unimplemented packet type: {}", packet.header.types_and_flags.get_types()) @@ -265,26 +297,24 @@ mod test { let packet_2 = PRUDPPacket::new(&mut Cursor::new(packet_2)).unwrap(); - - let (send, recv) = channel(100); - let sock = SocketImpl{ + let sock = SocketImpl { connections: Default::default(), access_key: "6f599f81", virtual_port: VirtualPort(0), socket: Arc::new(UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 10000)).await.unwrap()), - connection_creation_sender: send + connection_creation_sender: send, }; println!("sent: {:?}", packet_1); - sock.process_packet(PRUDPSockAddr{ + sock.process_packet(PRUDPSockAddr { virtual_port: VirtualPort(0), - regular_socket_addr: SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2469) + regular_socket_addr: SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2469), }, &packet_1).await; println!("sent: {:?}", packet_2); - sock.process_packet(PRUDPSockAddr{ + sock.process_packet(PRUDPSockAddr { virtual_port: VirtualPort(0), - regular_socket_addr: SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2469) + regular_socket_addr: SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2469), }, &packet_2).await; } } \ No newline at end of file