diff --git a/prudpv0/src/server.rs b/prudpv0/src/server.rs index 63fefe7..4921b3f 100644 --- a/prudpv0/src/server.rs +++ b/prudpv0/src/server.rs @@ -73,7 +73,7 @@ pub struct Server { param: ProxyStartupParam, socket: UdpSocket, crypto: C, - connections: RwLock>>>, + connections: RwLock>>>, } impl Server { @@ -167,8 +167,9 @@ impl Server { self.clone().send_data_packet(conn.clone(), &data).await; } } - async fn timeout_thread(self: Arc, conn: Arc>) { + async fn timeout_thread(self: Arc, conn: Weak>) { loop { + let Some(conn) = conn.upgrade() else { break }; sleep(Duration::from_secs(3)); let mut inner = conn.inner.lock().await; @@ -214,7 +215,7 @@ impl Server { drop(inner); let mut conns = self.connections.write().await; - conns.remove(&conn.addr); + conns.remove(&(conn.addr, conn.session_id)); drop(conns); break; } @@ -284,10 +285,10 @@ impl Server { }); let mut conns = self.connections.write().await; - if conns.contains_key(&addr) { + if conns.contains_key(&(addr, header.session_id)) { error!("client already connected but tried to connect again"); } - conns.insert(addr, conn.clone()); + conns.insert((addr, header.session_id), conn.clone()); drop(conns); spawn({ @@ -297,7 +298,7 @@ impl Server { }); spawn({ let this = self.clone(); - let conn = conn.clone(); + let conn = Arc::downgrade(&conn); this.timeout_thread(conn) }); @@ -325,7 +326,7 @@ impl Server { return; }; - let Some(res) = self.get_connection(addr).await else { + let Some(res) = self.get_connection((addr, header.session_id)).await else { warn!("data packet on inactive connection from: {:?}", addr); return; }; @@ -372,7 +373,7 @@ impl Server { info!("got ping"); let header = packet.header().unwrap(); - let Some(conn) = self.get_connection(addr).await else { + let Some(conn) = self.get_connection((addr, header.session_id)).await else { warn!("ping on inactive connection: {:?}", addr); return; }; @@ -398,7 +399,7 @@ impl Server { info!("got disconnect"); let header = packet.header().unwrap(); - let Some(conn) = self.get_connection(addr).await else { + let Some(conn) = self.get_connection((addr, header.session_id)).await else { warn!("ping on inactive connection: {:?}", addr); return; }; @@ -415,14 +416,17 @@ impl Server { drop(inner); let mut conns = self.connections.write().await; - conns.remove(&addr); + conns.remove(&(addr, header.session_id)); drop(conns); self.socket.send_to(&packet, addr.regular_socket_addr).await; self.socket.send_to(&packet, addr.regular_socket_addr).await; self.socket.send_to(&packet, addr.regular_socket_addr).await; } - async fn get_connection(&self, addr: PRUDPSockAddr) -> Option>> { + async fn get_connection( + &self, + addr: (PRUDPSockAddr, u8), + ) -> Option>> { let rd = self.connections.read().await; let res = rd.get(&addr).cloned(); drop(rd); @@ -444,7 +448,7 @@ impl Server { let addr = PRUDPSockAddr::new(SocketAddr::V4(addr), header.source); - if let Some(conn) = self.get_connection(addr).await { + if let Some(conn) = self.get_connection((addr, header.session_id)).await { let mut inner = conn.inner.lock().await; inner.last_action = Instant::now(); drop(inner);