changes...

This commit is contained in:
DJMrTV 2025-01-26 12:09:56 +01:00
commit 40ca10651f
18 changed files with 998 additions and 112 deletions

View file

@ -1,74 +1,101 @@
use std::array;
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::io::Write;
use std::ops::Deref;
use std::pin::Pin;
use tokio::net::UdpSocket;
use std::sync::{Arc};
use tokio::sync::{Mutex, RwLock};
use tokio::sync::{Mutex, MutexGuard, RwLock};
use hmac::{Hmac, Mac};
use log::{error, info, trace, warn};
use rand::random;
use rc4::consts::U256;
use rc4::consts::{U256, U5};
use rc4::{Rc4, Rc4Core, StreamCipher};
use rc4::cipher::{KeySizeUser, StreamCipherCoreWrapper};
use rustls::internal::msgs::handshake::SessionId;
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
use crate::prudp::packet::{flags, PacketOption, PRUDPPacket, types, VirtualPort};
use crate::prudp::packet::flags::{ACK, HAS_SIZE, MULTI_ACK, NEED_ACK, RELIABLE};
use crate::prudp::packet::PacketOption::{ConnectionSignature, MaximumSubstreamId, SupportedFunctions};
use crate::prudp::packet::types::{CONNECT, DATA, SYN};
use crate::prudp::packet::types::{CONNECT, DATA, PING, SYN};
use crate::prudp::router::{Error, Router};
use crate::prudp::sockaddr::PRUDPSockAddr;
use rc4::KeyInit;
// due to the way this is designed crashing the router thread causes deadlock, sorry ;-;
// (maybe i will fix that some day)
/// PRUDP Socket for accepting connections to then send and recieve data from those clients
pub struct Socket(Arc<SocketImpl>, Arc<Router>, Receiver<Connection>);
#[derive(Debug)]
pub struct SocketImpl {
virtual_port: VirtualPort,
socket: Arc<UdpSocket>,
access_key: &'static str,
connections: RwLock<HashMap<PRUDPSockAddr, Arc<Mutex<Connection>>>>,
connection_creation_sender: Sender<Connection>,
pub struct Socket {
socket_data: Arc<SocketData>,
router: Arc<Router>,
}
type OnConnectHandlerFn = Box<dyn Fn(PRUDPPacket) -> Pin<Box<dyn Future<Output=(bool, (Box<dyn StreamCipher + Send + Sync>, Box<dyn StreamCipher + Send + Sync>))> + Send + Sync>> + Send + Sync>;
type OnDataHandlerFn = Box<dyn for<'a> Fn(PRUDPPacket, Arc<SocketData>, &'a mut MutexGuard<'_, ConnectionData>) -> Pin<Box<dyn Future<Output=()> + 'a + Send + Sync>> + Send + Sync>;
#[derive(Debug)]
pub struct Connection {
sock_addr: PRUDPSockAddr,
id: u64,
signature: [u8; 16],
server_signature: [u8; 16],
session_id: u8,
reliable_client_counter: u16,
reliable_server_counter: u16,
reliable_client_queue: VecDeque<PRUDPPacket>,
pub struct SocketData {
virtual_port: VirtualPort,
pub socket: Arc<UdpSocket>,
pub access_key: &'static str,
connections: RwLock<HashMap<PRUDPSockAddr, Arc<Mutex<ConnectionData>>>>,
on_connect_handler: OnConnectHandlerFn,
on_data_handler: OnDataHandlerFn,
}
pub struct ActiveConnectionData {
pub reliable_client_counter: u16,
pub reliable_server_counter: u16,
pub reliable_client_queue: VecDeque<PRUDPPacket>,
pub connection_data_channel: Sender<Vec<u8>>,
pub server_encryption: Box<dyn StreamCipher + Send + Sync>,
pub client_decryption: Box<dyn StreamCipher + Send + Sync>,
pub server_session_id: u8,
}
pub struct ConnectionData {
pub sock_addr: PRUDPSockAddr,
pub id: u64,
pub signature: [u8; 16],
pub server_signature: [u8; 16],
pub active_connection_data: Option<ActiveConnectionData>,
}
impl Socket {
pub async fn new(router: Arc<Router>, port: VirtualPort, access_key: &'static str) -> Result<Self, Error> {
pub async fn new(
router: Arc<Router>,
port: VirtualPort,
access_key: &'static str,
on_connection_handler: OnConnectHandlerFn,
on_data_handler: OnDataHandlerFn,
) -> Result<Self, Error> {
trace!("creating socket on router at {} on virtual port {:?}", router.get_own_address(), port);
let (send, recv) = channel(20);
let socket = Arc::new(
SocketImpl::new(&router, send, port, access_key)
let socket_data = Arc::new(
SocketData::new_unbound(&router, port, access_key, on_connection_handler, on_data_handler)
);
router.add_socket(socket.clone()).await?;
router.add_socket(socket_data.clone()).await?;
Ok(Self(socket, router, recv))
}
pub async fn accept(&mut self) -> Option<Connection> {
self.2.recv().await
Ok(Self {
socket_data,
router,
})
}
}
impl Drop for Socket {
fn drop(&mut self) {
{
let router = self.1.clone();
let router = self.router.clone();
let virtual_port = self.virtual_port;
trace!("socket dropped socket will be removed from router soon");
@ -82,21 +109,27 @@ impl Drop for Socket {
}
impl Deref for Socket {
type Target = SocketImpl;
type Target = SocketData;
fn deref(&self) -> &Self::Target {
&self.0
&self.socket_data
}
}
impl SocketImpl {
fn new(router: &Router, connection_creation_sender: Sender<Connection>, port: VirtualPort, access_key: &'static str) -> Self {
SocketImpl {
impl SocketData {
fn new_unbound(router: &Router,
port: VirtualPort,
access_key: &'static str,
on_connect_handler: OnConnectHandlerFn,
on_data_handler: OnDataHandlerFn,
) -> Self {
SocketData {
socket: router.get_udp_socket(),
virtual_port: port,
connections: Default::default(),
access_key,
connection_creation_sender,
on_connect_handler,
on_data_handler,
}
}
@ -104,26 +137,22 @@ impl SocketImpl {
self.virtual_port
}
pub async fn process_packet(&self, connection: PRUDPSockAddr, packet: &PRUDPPacket) {
info!("recieved packet on endpoint");
pub async fn process_packet(self: &Arc<Self>, client_address: PRUDPSockAddr, packet: &PRUDPPacket) {
let conn = self.connections.read().await;
if !conn.contains_key(&connection) {
if !conn.contains_key(&client_address) {
drop(conn);
let mut conn = self.connections.write().await;
//only insert if we STILL dont have the connection preventing double insertion
if !conn.contains_key(&connection) {
conn.insert(connection, Arc::new(Mutex::new(Connection {
sock_addr: connection,
if !conn.contains_key(&client_address) {
conn.insert(client_address, Arc::new(Mutex::new(ConnectionData {
sock_addr: client_address,
id: random(),
signature: [0; 16],
server_signature: [0; 16],
session_id: 0,
reliable_client_queue: VecDeque::new(),
reliable_client_counter: 0,
reliable_server_counter: 0,
active_connection_data: None,
})));
}
drop(conn);
@ -133,7 +162,7 @@ impl SocketImpl {
let connections = self.connections.read().await;
let Some(conn) = connections.get(&connection) else {
let Some(conn) = connections.get(&client_address) else {
error!("connection is still not present after making sure connection is present, giving up.");
return;
};
@ -143,7 +172,7 @@ impl SocketImpl {
// dont keep holding the connections list unnescesarily
drop(connections);
let mut conn = conn.lock().await;
let mut connection = conn.lock().await;
if (packet.header.types_and_flags.get_flags() & ACK) != 0 {
info!("acknowledgement recieved");
@ -152,7 +181,7 @@ impl SocketImpl {
if (packet.header.types_and_flags.get_flags() & MULTI_ACK) != 0 {
info!("acknowledgement recieved");
unimplemented!()
return;
}
@ -166,9 +195,9 @@ impl SocketImpl {
response_packet.header.types_and_flags.set_flag(ACK);
response_packet.header.types_and_flags.set_flag(HAS_SIZE);
conn.signature = connection.calculate_connection_signature();
connection.signature = client_address.calculate_connection_signature();
response_packet.options.push(ConnectionSignature(conn.signature));
response_packet.options.push(ConnectionSignature(connection.signature));
for options in &packet.options {
match options {
@ -190,7 +219,7 @@ impl SocketImpl {
response_packet.write_to(&mut vec).expect("somehow failed to convert backet to bytes");
self.socket.send_to(&vec, connection.regular_socket_addr).await.expect("failed to send data back");
self.socket.send_to(&vec, client_address.regular_socket_addr).await.expect("failed to send data back");
}
CONNECT => {
info!("got connect");
@ -202,18 +231,23 @@ impl SocketImpl {
response_packet.header.types_and_flags.set_flag(HAS_SIZE);
// todo: (or not) sliding windows and stuff
conn.session_id = packet.header.session_id;
response_packet.header.session_id = conn.session_id;
response_packet.header.session_id = packet.header.session_id;
response_packet.header.sequence_id = 1;
response_packet.options.push(ConnectionSignature(Default::default()));
let mut init_seq_id = 0;
for option in &packet.options {
match option {
MaximumSubstreamId(max_substream) => response_packet.options.push(MaximumSubstreamId(*max_substream)),
SupportedFunctions(funcs) => response_packet.options.push(SupportedFunctions(*funcs)),
ConnectionSignature(sig) => {
conn.server_signature = *sig
connection.server_signature = *sig
}
PacketOption::InitialSequenceId(id) => {
init_seq_id = *id;
}
_ => { /* ? */ }
}
@ -225,53 +259,125 @@ impl SocketImpl {
// todo: implement something to do secure servers
if conn.server_signature == <[u8; 16] as Default>::default() {
if connection.server_signature == <[u8; 16] as Default>::default() {
error!("didn't get connection signature from client")
}
response_packet.set_sizes();
response_packet.calculate_and_assign_signature(self.access_key, None, Some(conn.server_signature));
response_packet.calculate_and_assign_signature(self.access_key, None, Some(connection.server_signature));
let mut vec = Vec::new();
response_packet.write_to(&mut vec).expect("somehow failed to convert backet to bytes");
self.socket.send_to(&vec, connection.regular_socket_addr).await.expect("failed to send data back");
self.socket.send_to(&vec, client_address.regular_socket_addr).await.expect("failed to send data back");
let (send, recv) = channel(100);
let (accepted, (client_decryption, server_encryption))
= (self.on_connect_handler)(packet.clone()).await;
if !accepted {
// rejected
return;
}
connection.active_connection_data = Some(ActiveConnectionData {
connection_data_channel: send,
client_decryption,
server_encryption,
reliable_client_queue: VecDeque::new(),
reliable_client_counter: 2,
reliable_server_counter: 1,
server_session_id: packet.header.session_id,
});
}
DATA => {
if (packet.header.types_and_flags.get_flags() & RELIABLE) != 0 {
match conn.reliable_client_queue.binary_search_by_key(&conn.reliable_client_counter, |p| p.header.sequence_id) {
let Some(active_connection) = connection.active_connection_data.as_mut() else {
error!("got data packet on non active connection!");
return;
};
info!("ctr: {}, packet seq: {}", active_connection.reliable_client_counter, packet.header.sequence_id);
match active_connection.reliable_client_queue.binary_search_by_key(&packet.header.sequence_id, |p| p.header.sequence_id) {
Ok(_) => warn!("recieved packet twice"),
Err(position) => conn.reliable_client_queue.insert(position, packet.clone()),
Err(position) => active_connection.reliable_client_queue.insert(position, packet.clone()),
}
if (packet.header.types_and_flags.get_flags() & NEED_ACK) != 0{
if (packet.header.types_and_flags.get_flags() & NEED_ACK) != 0 {
let mut ack = packet.base_acknowledgement_packet();
ack.header.session_id = active_connection.server_session_id;
ack.set_sizes();
ack.calculate_and_assign_signature(self.access_key, None, Some(conn.server_signature));
ack.calculate_and_assign_signature(self.access_key, None, Some(connection.server_signature));
let mut vec = Vec::new();
ack.write_to(&mut vec).expect("somehow failed to convert backet to bytes");
self.socket.send_to(&vec, connection.regular_socket_addr).await.expect("failed to send data back");
self.socket.send_to(&vec, client_address.regular_socket_addr).await.expect("failed to send data back");
}
while let Some(packet) =
conn.reliable_client_queue
while let Some(mut packet) = {
connection.active_connection_data.as_mut().map(|a|
a.reliable_client_queue
.front()
.is_some_and(|v| v.header.sequence_id == conn.reliable_client_counter)
.then(|| conn.reliable_client_queue.pop_front())
.flatten(){
conn.reliable_client_counter = conn.reliable_client_counter.overflowing_add(1).0;
.is_some_and(|v| v.header.sequence_id == a.reliable_client_counter)
.then(|| a.reliable_client_queue.pop_front())).flatten().flatten()
} {
if packet.options.iter().any(|v| match v{
PacketOption::FragmentId(f) => (*f != 0),
_ => false,
}){
error!("fragmented packets are unsupported right now")
}
// ignored
let active_connection = connection.active_connection_data.as_mut()
.expect("we litterally just recieved a packet which requires the connection to be active, failing this should be impossible");
active_connection.reliable_client_counter = active_connection.reliable_client_counter.overflowing_add(1).0;
active_connection.client_decryption.apply_keystream(&mut packet.payload);
// we cant divert this off to another thread we HAVE to process it now to keep order
(self.on_data_handler)(packet, self.clone(), &mut connection).await;
// ignored for now
}
} else {
error!("unreliable packets are unimplemented");
unimplemented!()
}
info!("{:?}", packet);
//info!("{:?}", packet);
}
PING => {
let ConnectionData {
active_connection_data,
server_signature,
..
} = &mut *connection;
if (packet.header.types_and_flags.get_flags() & NEED_ACK) != 0 {
let Some(active_connection) = active_connection_data.as_mut() else {
error!("got data packet on non active connection!");
return;
};
let mut ack = packet.base_acknowledgement_packet();
ack.header.session_id = active_connection.server_session_id;
ack.set_sizes();
ack.calculate_and_assign_signature(self.access_key, None, Some(*server_signature));
let mut vec = Vec::new();
ack.write_to(&mut vec).expect("somehow failed to convert backet to bytes");
self.socket.send_to(&vec, client_address.regular_socket_addr).await.expect("failed to send data back");
}
}
_ => unimplemented!("unimplemented packet type: {}", packet.header.types_and_flags.get_types())
}
}
@ -286,9 +392,9 @@ mod test {
use tokio::sync::mpsc::channel;
use crate::prudp::packet::{PRUDPPacket, VirtualPort};
use crate::prudp::sockaddr::PRUDPSockAddr;
use crate::prudp::socket::SocketImpl;
use crate::prudp::socket::SocketData;
#[tokio::test]
/*#[tokio::test]
async fn test_connect() {
let packet_1 = [234, 208, 1, 27, 0, 0, 175, 161, 192, 0, 0, 0, 0, 0, 36, 21, 233, 179, 203, 154, 57, 222, 219, 9, 21, 2, 29, 172, 56, 92, 0, 4, 4, 1, 0, 0, 1, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 1, 0];
let packet_2 = [234, 208, 1, 31, 0, 0, 175, 161, 225, 0, 249, 0, 1, 0, 40, 168, 31, 138, 58, 193, 30, 134, 3, 232, 205, 245, 28, 155, 193, 198, 0, 4, 0, 0, 0, 0, 1, 16, 211, 240, 113, 188, 227, 114, 114, 30, 157, 179, 246, 55, 233, 240, 44, 197, 3, 2, 247, 244, 4, 1, 0];
@ -299,13 +405,13 @@ mod test {
let (send, recv) = channel(100);
let sock = SocketImpl {
let sock = Arc::new(SocketData {
connections: Default::default(),
access_key: "6f599f81",
virtual_port: VirtualPort(0),
socket: Arc::new(UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 10000)).await.unwrap()),
connection_creation_sender: send,
};
});
println!("sent: {:?}", packet_1);
sock.process_packet(PRUDPSockAddr {
virtual_port: VirtualPort(0),
@ -316,5 +422,5 @@ mod test {
virtual_port: VirtualPort(0),
regular_socket_addr: SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2469),
}, &packet_2).await;
}
}*/
}