//! WebSocket Relay Server //! //! Routes frames from source→receivers and inputs from receivers→source. //! Roles are determined by the connection path: //! - `/source` — the frame producer (DreamStack renderer) //! - `/stream` — frame consumers (thin receivers) //! //! The relay is intentionally dumb — it just forwards bytes. //! The protocol semantics live in the source and receiver. use std::net::SocketAddr; use std::sync::Arc; use futures_util::{SinkExt, StreamExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{broadcast, mpsc, RwLock}; use tokio_tungstenite::tungstenite::Message; /// Relay server configuration. pub struct RelayConfig { /// Address to bind to. pub addr: SocketAddr, /// Maximum number of receivers. pub max_receivers: usize, /// Frame broadcast channel capacity. pub frame_buffer_size: usize, } impl Default for RelayConfig { fn default() -> Self { Self { addr: "0.0.0.0:9100".parse().unwrap(), max_receivers: 64, frame_buffer_size: 16, } } } /// Stats tracked by the relay. #[derive(Debug, Default, Clone)] pub struct RelayStats { pub frames_relayed: u64, pub bytes_relayed: u64, pub inputs_relayed: u64, pub connected_receivers: usize, pub source_connected: bool, } /// Shared relay state. struct RelayState { /// Broadcast channel: source → all receivers (frames) frame_tx: broadcast::Sender>, /// Channel: receivers → source (input events) input_tx: mpsc::Sender>, input_rx: Option>>, /// Live stats stats: RelayStats, } /// Run the WebSocket relay server. pub async fn run_relay(config: RelayConfig) -> Result<(), Box> { let listener = TcpListener::bind(&config.addr).await?; eprintln!("╔══════════════════════════════════════════════════╗"); eprintln!("║ DreamStack Bitstream Relay v0.1.0 ║"); eprintln!("║ ║"); eprintln!("║ Source: ws://{}/source ║", config.addr); eprintln!("║ Receiver: ws://{}/stream ║", config.addr); eprintln!("╚══════════════════════════════════════════════════╝"); let (frame_tx, _) = broadcast::channel(config.frame_buffer_size); let (input_tx, input_rx) = mpsc::channel(256); let state = Arc::new(RwLock::new(RelayState { frame_tx, input_tx, input_rx: Some(input_rx), stats: RelayStats::default(), })); while let Ok((stream, addr)) = listener.accept().await { let state = state.clone(); tokio::spawn(handle_connection(stream, addr, state)); } Ok(()) } async fn handle_connection( stream: TcpStream, addr: SocketAddr, state: Arc>, ) { // Peek at the HTTP upgrade request to determine role let ws_stream = match tokio_tungstenite::accept_hdr_async( stream, |_req: &tokio_tungstenite::tungstenite::handshake::server::Request, res: tokio_tungstenite::tungstenite::handshake::server::Response| { // We'll extract the path from the URI later via a different mechanism Ok(res) }, ) .await { Ok(ws) => ws, Err(e) => { eprintln!("[relay] WebSocket handshake failed from {}: {}", addr, e); return; } }; // For simplicity in the PoC, first connection = source, subsequent = receivers. // A production version would parse the URI path. let is_source = { let s = state.read().await; !s.stats.source_connected }; if is_source { eprintln!("[relay] Source connected: {}", addr); handle_source(ws_stream, addr, state).await; } else { eprintln!("[relay] Receiver connected: {}", addr); handle_receiver(ws_stream, addr, state).await; } } async fn handle_source( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, state: Arc>, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Mark source as connected and take the input_rx let input_rx = { let mut s = state.write().await; s.stats.source_connected = true; s.input_rx.take() }; let frame_tx = { let s = state.read().await; s.frame_tx.clone() }; // Forward input events from receivers → source if let Some(mut input_rx) = input_rx { let state_clone = state.clone(); tokio::spawn(async move { while let Some(input_bytes) = input_rx.recv().await { let msg = Message::Binary(input_bytes.into()); if ws_sink.send(msg).await.is_err() { break; } let mut s = state_clone.write().await; s.stats.inputs_relayed += 1; } }); } // Receive frames from source → broadcast to receivers while let Some(Ok(msg)) = ws_source.next().await { if let Message::Binary(data) = msg { let data_vec: Vec = data.into(); { let mut s = state.write().await; s.stats.frames_relayed += 1; s.stats.bytes_relayed += data_vec.len() as u64; } // Broadcast to all receivers (ignore send errors = no receivers) let _ = frame_tx.send(data_vec); } } // Source disconnected eprintln!("[relay] Source disconnected: {}", addr); let mut s = state.write().await; s.stats.source_connected = false; } async fn handle_receiver( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, state: Arc>, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Subscribe to frame broadcast let mut frame_rx = { let mut s = state.write().await; s.stats.connected_receivers += 1; s.frame_tx.subscribe() }; let input_tx = { let s = state.read().await; s.input_tx.clone() }; // Forward frames from broadcast → this receiver let _state_clone = state.clone(); let addr_clone = addr; let send_task = tokio::spawn(async move { loop { match frame_rx.recv().await { Ok(frame_bytes) => { let msg = Message::Binary(frame_bytes.into()); if ws_sink.send(msg).await.is_err() { break; } } Err(broadcast::error::RecvError::Lagged(n)) => { eprintln!("[relay] Receiver {} lagged by {} frames", addr_clone, n); } Err(_) => break, } } }); // Forward input events from this receiver → source while let Some(Ok(msg)) = ws_source.next().await { if let Message::Binary(data) = msg { let _ = input_tx.send(data.into()).await; } } // Receiver disconnected send_task.abort(); eprintln!("[relay] Receiver disconnected: {}", addr); let mut s = state.write().await; s.stats.connected_receivers = s.stats.connected_receivers.saturating_sub(1); } // ─── Tests ─── #[cfg(test)] mod tests { use super::*; #[test] fn default_config() { let config = RelayConfig::default(); assert_eq!(config.addr, "0.0.0.0:9100".parse::().unwrap()); assert_eq!(config.max_receivers, 64); assert_eq!(config.frame_buffer_size, 16); } #[test] fn stats_default() { let stats = RelayStats::default(); assert_eq!(stats.frames_relayed, 0); assert_eq!(stats.bytes_relayed, 0); assert!(!stats.source_connected); } }