tie connections to a session id + prudp socket address

This commit is contained in:
Maple 2026-04-11 20:02:10 +02:00
commit dd4015b2c4

View file

@ -73,7 +73,7 @@ pub struct Server<C: Crypto> {
param: ProxyStartupParam, param: ProxyStartupParam,
socket: UdpSocket, socket: UdpSocket,
crypto: C, crypto: C,
connections: RwLock<HashMap<PRUDPSockAddr, Arc<Connection<C::Instance>>>>, connections: RwLock<HashMap<(PRUDPSockAddr, u8), Arc<Connection<C::Instance>>>>,
} }
impl<C: Crypto> Server<C> { impl<C: Crypto> Server<C> {
@ -167,8 +167,9 @@ impl<C: Crypto> Server<C> {
self.clone().send_data_packet(conn.clone(), &data).await; self.clone().send_data_packet(conn.clone(), &data).await;
} }
} }
async fn timeout_thread(self: Arc<Self>, conn: Arc<Connection<C::Instance>>) { async fn timeout_thread(self: Arc<Self>, conn: Weak<Connection<C::Instance>>) {
loop { loop {
let Some(conn) = conn.upgrade() else { break };
sleep(Duration::from_secs(3)); sleep(Duration::from_secs(3));
let mut inner = conn.inner.lock().await; let mut inner = conn.inner.lock().await;
@ -214,7 +215,7 @@ impl<C: Crypto> Server<C> {
drop(inner); drop(inner);
let mut conns = self.connections.write().await; let mut conns = self.connections.write().await;
conns.remove(&conn.addr); conns.remove(&(conn.addr, conn.session_id));
drop(conns); drop(conns);
break; break;
} }
@ -284,10 +285,10 @@ impl<C: Crypto> Server<C> {
}); });
let mut conns = self.connections.write().await; let mut conns = self.connections.write().await;
if conns.contains_key(&addr) { if conns.contains_key(&(addr, header.session_id)) {
error!("client already connected but tried to connect again"); error!("client already connected but tried to connect again");
} }
conns.insert(addr, conn.clone()); conns.insert((addr, header.session_id), conn.clone());
drop(conns); drop(conns);
spawn({ spawn({
@ -297,7 +298,7 @@ impl<C: Crypto> Server<C> {
}); });
spawn({ spawn({
let this = self.clone(); let this = self.clone();
let conn = conn.clone(); let conn = Arc::downgrade(&conn);
this.timeout_thread(conn) this.timeout_thread(conn)
}); });
@ -325,7 +326,7 @@ impl<C: Crypto> Server<C> {
return; return;
}; };
let Some(res) = self.get_connection(addr).await else { let Some(res) = self.get_connection((addr, header.session_id)).await else {
warn!("data packet on inactive connection from: {:?}", addr); warn!("data packet on inactive connection from: {:?}", addr);
return; return;
}; };
@ -372,7 +373,7 @@ impl<C: Crypto> Server<C> {
info!("got ping"); info!("got ping");
let header = packet.header().unwrap(); let header = packet.header().unwrap();
let Some(conn) = self.get_connection(addr).await else { let Some(conn) = self.get_connection((addr, header.session_id)).await else {
warn!("ping on inactive connection: {:?}", addr); warn!("ping on inactive connection: {:?}", addr);
return; return;
}; };
@ -398,7 +399,7 @@ impl<C: Crypto> Server<C> {
info!("got disconnect"); info!("got disconnect");
let header = packet.header().unwrap(); let header = packet.header().unwrap();
let Some(conn) = self.get_connection(addr).await else { let Some(conn) = self.get_connection((addr, header.session_id)).await else {
warn!("ping on inactive connection: {:?}", addr); warn!("ping on inactive connection: {:?}", addr);
return; return;
}; };
@ -415,14 +416,17 @@ impl<C: Crypto> Server<C> {
drop(inner); drop(inner);
let mut conns = self.connections.write().await; let mut conns = self.connections.write().await;
conns.remove(&addr); conns.remove(&(addr, header.session_id));
drop(conns); drop(conns);
self.socket.send_to(&packet, addr.regular_socket_addr).await; self.socket.send_to(&packet, addr.regular_socket_addr).await;
self.socket.send_to(&packet, addr.regular_socket_addr).await; self.socket.send_to(&packet, addr.regular_socket_addr).await;
self.socket.send_to(&packet, addr.regular_socket_addr).await; self.socket.send_to(&packet, addr.regular_socket_addr).await;
} }
async fn get_connection(&self, addr: PRUDPSockAddr) -> Option<Arc<Connection<C::Instance>>> { async fn get_connection(
&self,
addr: (PRUDPSockAddr, u8),
) -> Option<Arc<Connection<C::Instance>>> {
let rd = self.connections.read().await; let rd = self.connections.read().await;
let res = rd.get(&addr).cloned(); let res = rd.get(&addr).cloned();
drop(rd); drop(rd);
@ -444,7 +448,7 @@ impl<C: Crypto> Server<C> {
let addr = PRUDPSockAddr::new(SocketAddr::V4(addr), header.source); let addr = PRUDPSockAddr::new(SocketAddr::V4(addr), header.source);
if let Some(conn) = self.get_connection(addr).await { if let Some(conn) = self.get_connection((addr, header.session_id)).await {
let mut inner = conn.inner.lock().await; let mut inner = conn.inner.lock().await;
inner.last_action = Instant::now(); inner.last_action = Instant::now();
drop(inner); drop(inner);