//! WebSocket Relay Server //! //! Routes frames from source→receivers and inputs from receivers→source. //! //! ## Routing //! //! Connections are routed by WebSocket URI path: //! - `/source` — default source (channel: "default") //! - `/source/{name}` — named source (channel: {name}) //! - `/stream` — default receiver (channel: "default") //! - `/stream/{name}` — named receiver (channel: {name}) //! - `/` — legacy: first connection = source, rest = receivers //! //! Each channel has its own broadcast/input channels and state cache, //! allowing multiple independent streams through a single relay. //! //! ## Features //! - Multi-source: multiple views stream independently //! - Keyframe caching: late-joining receivers get current state instantly //! - Signal state store: caches SignalSync/Diff frames for reconstruction //! - Ping/pong keepalive: detects dead connections //! - Stats tracking: frames, bytes, latency metrics use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Instant; use futures_util::{SinkExt, StreamExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{broadcast, mpsc, RwLock}; use tokio::time::{interval, Duration}; use tokio_tungstenite::tungstenite::Message; use crate::protocol::*; /// Relay server configuration. pub struct RelayConfig { /// Address to bind to. pub addr: SocketAddr, /// Maximum number of receivers per channel. pub max_receivers: usize, /// Frame broadcast channel capacity. pub frame_buffer_size: usize, /// Keepalive interval in seconds. pub keepalive_interval_secs: u64, /// Keepalive timeout in seconds — disconnect after this many seconds without a pong. pub keepalive_timeout_secs: u64, } impl Default for RelayConfig { fn default() -> Self { Self { addr: "0.0.0.0:9100".parse().unwrap(), max_receivers: 64, frame_buffer_size: 16, keepalive_interval_secs: 10, keepalive_timeout_secs: 30, } } } /// Stats tracked by the relay. #[derive(Debug, Default, Clone)] pub struct RelayStats { pub frames_relayed: u64, pub bytes_relayed: u64, pub inputs_relayed: u64, pub connected_receivers: usize, pub source_connected: bool, /// Timestamp of last frame from source (milliseconds since start). pub last_frame_timestamp: u32, /// Total keyframes sent. pub keyframes_sent: u64, /// Total signal diffs sent. pub signal_diffs_sent: u64, /// Uptime in seconds. pub uptime_secs: u64, } /// Cached state for late-joining receivers. #[derive(Debug, Default, Clone)] pub struct StateCache { /// Last pixel keyframe (full RGBA frame). pub last_keyframe: Option>, /// Last signal sync frame (full JSON state). pub last_signal_sync: Option>, /// Accumulated signal diffs since last sync. /// Late-joining receivers get: last_signal_sync + all diffs. pub pending_signal_diffs: Vec>, } impl StateCache { /// Process an incoming frame and update cache. fn process_frame(&mut self, msg: &[u8]) { if msg.len() < HEADER_SIZE { return; } let frame_type = msg[0]; let flags = msg[1]; match FrameType::from_u8(frame_type) { // Cache keyframes (pixel or signal sync) Some(FrameType::Pixels) if flags & FLAG_KEYFRAME != 0 => { self.last_keyframe = Some(msg.to_vec()); } Some(FrameType::SignalSync) => { self.last_signal_sync = Some(msg.to_vec()); self.pending_signal_diffs.clear(); // sync resets diffs } Some(FrameType::SignalDiff) => { self.pending_signal_diffs.push(msg.to_vec()); // Cap accumulated diffs to prevent unbounded memory if self.pending_signal_diffs.len() > 1000 { self.pending_signal_diffs.drain(..500); } } Some(FrameType::Keyframe) => { self.last_keyframe = Some(msg.to_vec()); } _ => {} } } /// Get all messages a late-joining receiver needs to reconstruct current state. fn catchup_messages(&self) -> Vec> { let mut msgs = Vec::new(); // Send last signal sync first (if available) if let Some(ref sync) = self.last_signal_sync { msgs.push(sync.clone()); } // Then all accumulated diffs for diff in &self.pending_signal_diffs { msgs.push(diff.clone()); } // Then last pixel keyframe (if available) if let Some(ref kf) = self.last_keyframe { msgs.push(kf.clone()); } msgs } } /// Per-channel state — each named channel is an independent stream. struct ChannelState { /// Broadcast channel: source → all receivers (frames) frame_tx: broadcast::Sender>, /// Channel: receivers → source (input events) input_tx: mpsc::Sender>, input_rx: Option>>, /// Broadcast channel for WebRTC signaling (SDP/ICE as text) signaling_tx: broadcast::Sender, /// Live stats for this channel stats: RelayStats, /// Cached state for late-joining receivers cache: StateCache, } impl ChannelState { fn new(frame_buffer_size: usize) -> Self { let (frame_tx, _) = broadcast::channel(frame_buffer_size); let (input_tx, input_rx) = mpsc::channel(256); let (signaling_tx, _) = broadcast::channel(64); Self { frame_tx, input_tx, input_rx: Some(input_rx), signaling_tx, stats: RelayStats::default(), cache: StateCache::default(), } } } /// Shared relay state — holds all channels. struct RelayState { /// Named channels: "default", "main", "player1", etc. channels: HashMap>>, /// Frame buffer size for new channels frame_buffer_size: usize, /// Server start time start_time: Instant, } impl RelayState { fn new(frame_buffer_size: usize) -> Self { Self { channels: HashMap::new(), frame_buffer_size, start_time: Instant::now(), } } /// Get or create a channel by name. fn get_or_create_channel(&mut self, name: &str) -> Arc> { self.channels .entry(name.to_string()) .or_insert_with(|| Arc::new(RwLock::new(ChannelState::new(self.frame_buffer_size)))) .clone() } } /// Parsed connection role from the WebSocket URI path. #[derive(Debug, Clone)] enum ConnectionRole { Source(String), // channel name Receiver(String), // channel name Signaling(String), // channel name — WebRTC signaling } /// Parse the WebSocket URI path to determine connection role and channel. fn parse_path(path: &str) -> ConnectionRole { let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect(); match parts.as_slice() { ["source"] => ConnectionRole::Source("default".to_string()), ["source", name] => ConnectionRole::Source(name.to_string()), ["stream"] => ConnectionRole::Receiver("default".to_string()), ["stream", name] => ConnectionRole::Receiver(name.to_string()), ["signal"] => ConnectionRole::Signaling("default".to_string()), ["signal", name] => ConnectionRole::Signaling(name.to_string()), _ => ConnectionRole::Receiver("default".to_string()), // legacy: `/` = receiver } } /// Run the WebSocket relay server. pub async fn run_relay(config: RelayConfig) -> Result<(), Box> { let listener = TcpListener::bind(&config.addr).await?; eprintln!("╔══════════════════════════════════════════════════╗"); eprintln!("║ DreamStack Bitstream Relay v0.4.0 ║"); eprintln!("║ ║"); eprintln!("║ Source: ws://{}/source/{{name}} ║", config.addr); eprintln!("║ Receiver: ws://{}/stream/{{name}} ║", config.addr); eprintln!("║ Signal: ws://{}/signal/{{name}} ║", config.addr); eprintln!("║ ║"); eprintln!("║ Multi-source, WebRTC signaling, keyframe cache ║"); eprintln!("╚══════════════════════════════════════════════════╝"); let state = Arc::new(RwLock::new(RelayState::new(config.frame_buffer_size))); // Background: periodic stats logging { let state = state.clone(); tokio::spawn(async move { let mut tick = interval(Duration::from_secs(30)); loop { tick.tick().await; let s = state.read().await; let uptime = s.start_time.elapsed().as_secs(); for (name, channel) in &s.channels { let cs = channel.read().await; if cs.stats.source_connected || cs.stats.connected_receivers > 0 { eprintln!( "[relay:{name}] up={uptime}s frames={} bytes={} inputs={} receivers={} signal_diffs={} cached={}", cs.stats.frames_relayed, cs.stats.bytes_relayed, cs.stats.inputs_relayed, cs.stats.connected_receivers, cs.stats.signal_diffs_sent, cs.cache.last_signal_sync.is_some(), ); } } } }); } while let Ok((stream, addr)) = listener.accept().await { let state = state.clone(); let keepalive_interval = config.keepalive_interval_secs; let keepalive_timeout = config.keepalive_timeout_secs; tokio::spawn(handle_connection(stream, addr, state, keepalive_interval, keepalive_timeout)); } Ok(()) } async fn handle_connection( stream: TcpStream, addr: SocketAddr, state: Arc>, keepalive_interval: u64, _keepalive_timeout: u64, ) { // Extract the URI path during handshake to determine role let mut role = ConnectionRole::Receiver("default".to_string()); let ws_stream = match tokio_tungstenite::accept_hdr_async( stream, |req: &tokio_tungstenite::tungstenite::handshake::server::Request, res: tokio_tungstenite::tungstenite::handshake::server::Response| { role = parse_path(req.uri().path()); Ok(res) }, ) .await { Ok(ws) => ws, Err(e) => { eprintln!("[relay] WebSocket handshake failed from {}: {}", addr, e); return; } }; // Get or create the channel let (channel, channel_name) = { let mut s = state.write().await; let name = match &role { ConnectionRole::Source(n) | ConnectionRole::Receiver(n) | ConnectionRole::Signaling(n) => n.clone(), }; let ch = s.get_or_create_channel(&name); (ch, name) }; // Legacy fallback: if path was `/` and no source exists, treat as source let role = match role { ConnectionRole::Receiver(ref name) if name == "default" => { let cs = channel.read().await; if !cs.stats.source_connected { ConnectionRole::Source("default".to_string()) } else { role } } _ => role, }; match role { ConnectionRole::Source(ref _name) => { eprintln!("[relay:{channel_name}] Source connected: {addr}"); handle_source(ws_stream, addr, channel, &channel_name, keepalive_interval).await; } ConnectionRole::Receiver(ref _name) => { eprintln!("[relay:{channel_name}] Receiver connected: {addr}"); handle_receiver(ws_stream, addr, channel, &channel_name).await; } ConnectionRole::Signaling(ref _name) => { eprintln!("[relay:{channel_name}] Signaling peer connected: {addr}"); handle_signaling(ws_stream, addr, channel, &channel_name).await; } } } async fn handle_source( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, channel: Arc>, channel_name: &str, keepalive_interval: u64, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Mark source as connected and take the input_rx let input_rx = { let mut cs = channel.write().await; cs.stats.source_connected = true; cs.input_rx.take() }; let frame_tx = { let cs = channel.read().await; cs.frame_tx.clone() }; // Forward input events from receivers → source if let Some(mut input_rx) = input_rx { let channel_clone = channel.clone(); tokio::spawn(async move { while let Some(input_bytes) = input_rx.recv().await { let msg = Message::Binary(input_bytes.into()); if ws_sink.send(msg).await.is_err() { break; } let mut cs = channel_clone.write().await; cs.stats.inputs_relayed += 1; } }); } // Keepalive: periodic pings let channel_ping = channel.clone(); let ping_task = tokio::spawn(async move { let mut tick = interval(Duration::from_secs(keepalive_interval)); let mut seq = 0u16; loop { tick.tick().await; let cs = channel_ping.read().await; if !cs.stats.source_connected { break; } let _ = cs.frame_tx.send(crate::codec::ping(seq, 0)); drop(cs); seq = seq.wrapping_add(1); } }); // Receive frames from source → broadcast to receivers let channel_name_owned = channel_name.to_string(); while let Some(Ok(msg)) = ws_source.next().await { if let Message::Binary(data) = msg { let data_vec: Vec = data.into(); { let mut cs = channel.write().await; cs.stats.frames_relayed += 1; cs.stats.bytes_relayed += data_vec.len() as u64; // Update state cache cs.cache.process_frame(&data_vec); // Track frame-type-specific stats if data_vec.len() >= HEADER_SIZE { match FrameType::from_u8(data_vec[0]) { Some(FrameType::SignalDiff) => cs.stats.signal_diffs_sent += 1, Some(FrameType::Pixels) | Some(FrameType::Keyframe) | Some(FrameType::SignalSync) => cs.stats.keyframes_sent += 1, _ => {} } let ts = u32::from_le_bytes([ data_vec[4], data_vec[5], data_vec[6], data_vec[7], ]); cs.stats.last_frame_timestamp = ts; } } // Broadcast to all receivers on this channel let _ = frame_tx.send(data_vec); } } // Source disconnected ping_task.abort(); eprintln!("[relay:{channel_name_owned}] Source disconnected: {addr}"); let mut cs = channel.write().await; cs.stats.source_connected = false; } async fn handle_receiver( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, channel: Arc>, channel_name: &str, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Subscribe to frame broadcast and get catchup messages let (mut frame_rx, catchup_msgs) = { let mut cs = channel.write().await; cs.stats.connected_receivers += 1; let rx = cs.frame_tx.subscribe(); let catchup = cs.cache.catchup_messages(); (rx, catchup) }; let input_tx = { let cs = channel.read().await; cs.input_tx.clone() }; // Send cached state to late-joining receiver let channel_name_owned = channel_name.to_string(); if !catchup_msgs.is_empty() { eprintln!( "[relay:{channel_name_owned}] Sending {} catchup messages to {addr}", catchup_msgs.len(), ); for msg_bytes in catchup_msgs { let msg = Message::Binary(msg_bytes.into()); if ws_sink.send(msg).await.is_err() { eprintln!("[relay:{channel_name_owned}] Failed to send catchup to {addr}"); let mut cs = channel.write().await; cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1); return; } } } // Forward frames from broadcast → this receiver let send_task = tokio::spawn(async move { loop { match frame_rx.recv().await { Ok(frame_bytes) => { let msg = Message::Binary(frame_bytes.into()); if ws_sink.send(msg).await.is_err() { break; } } Err(broadcast::error::RecvError::Lagged(n)) => { eprintln!("[relay] Receiver {addr} lagged by {n} frames"); } Err(_) => break, } } }); // Forward input events from this receiver → source while let Some(Ok(msg)) = ws_source.next().await { if let Message::Binary(data) = msg { let _ = input_tx.send(data.into()).await; } } // Receiver disconnected send_task.abort(); eprintln!("[relay:{channel_name_owned}] Receiver disconnected: {addr}"); let mut cs = channel.write().await; cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1); } /// Handle a WebRTC signaling connection. /// /// Signaling peers exchange JSON messages (SDP offers/answers, ICE candidates) /// over WebSocket. The relay broadcasts all text messages to all other peers /// on the same channel, enabling peer-to-peer WebRTC setup. async fn handle_signaling( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, channel: Arc>, channel_name: &str, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Subscribe to signaling broadcast let mut sig_rx = { let cs = channel.read().await; cs.signaling_tx.subscribe() }; let sig_tx = { let cs = channel.read().await; cs.signaling_tx.clone() }; // Forward signaling messages from broadcast → this peer let channel_name_owned = channel_name.to_string(); let send_task = tokio::spawn(async move { loop { match sig_rx.recv().await { Ok(msg_text) => { let msg = Message::Text(msg_text.into()); if ws_sink.send(msg).await.is_err() { break; } } Err(broadcast::error::RecvError::Lagged(n)) => { eprintln!("[relay:{channel_name_owned}] Signaling peer lagged by {n} messages"); } Err(_) => break, } } }); // Forward signaling messages from this peer → broadcast to others while let Some(Ok(msg)) = ws_source.next().await { if let Message::Text(text) = msg { let text_str: String = text.into(); let _ = sig_tx.send(text_str); } } send_task.abort(); eprintln!("[relay:{channel_name}] Signaling peer disconnected: {addr}"); } // ─── Tests ─── #[cfg(test)] mod tests { use super::*; #[test] fn default_config() { let config = RelayConfig::default(); assert_eq!(config.addr, "0.0.0.0:9100".parse::().unwrap()); assert_eq!(config.max_receivers, 64); assert_eq!(config.frame_buffer_size, 16); assert_eq!(config.keepalive_interval_secs, 10); } #[test] fn stats_default() { let stats = RelayStats::default(); assert_eq!(stats.frames_relayed, 0); assert_eq!(stats.bytes_relayed, 0); assert!(!stats.source_connected); assert_eq!(stats.keyframes_sent, 0); assert_eq!(stats.signal_diffs_sent, 0); } #[test] fn state_cache_signal_sync() { let mut cache = StateCache::default(); // Simulate a SignalSync frame let sync_json = br#"{"count":0}"#; let sync_msg = crate::codec::signal_sync_frame(0, 100, sync_json); cache.process_frame(&sync_msg); assert!(cache.last_signal_sync.is_some()); assert_eq!(cache.pending_signal_diffs.len(), 0); // Simulate a SignalDiff let diff_json = br#"{"count":1}"#; let diff_msg = crate::codec::signal_diff_frame(1, 200, diff_json); cache.process_frame(&diff_msg); assert_eq!(cache.pending_signal_diffs.len(), 1); // Catchup should contain sync + diff let catchup = cache.catchup_messages(); assert_eq!(catchup.len(), 2); } #[test] fn state_cache_signal_sync_resets_diffs() { let mut cache = StateCache::default(); // Add some diffs for i in 0..5 { let json = format!(r#"{{"count":{}}}"#, i); let msg = crate::codec::signal_diff_frame(i, i as u32 * 100, json.as_bytes()); cache.process_frame(&msg); } assert_eq!(cache.pending_signal_diffs.len(), 5); // A new sync should clear diffs let sync = crate::codec::signal_sync_frame(10, 1000, b"{}"); cache.process_frame(&sync); assert_eq!(cache.pending_signal_diffs.len(), 0); assert!(cache.last_signal_sync.is_some()); } #[test] fn state_cache_keyframe() { let mut cache = StateCache::default(); // Simulate a keyframe pixel frame let kf = crate::codec::pixel_frame(0, 100, 10, 10, &vec![0xFF; 400]); cache.process_frame(&kf); assert!(cache.last_keyframe.is_some()); let catchup = cache.catchup_messages(); assert_eq!(catchup.len(), 1); } #[test] fn state_cache_diff_cap() { let mut cache = StateCache::default(); // Add more than 1000 diffs — should auto-trim for i in 0..1100u16 { let json = format!(r#"{{"n":{}}}"#, i); let msg = crate::codec::signal_diff_frame(i, i as u32, json.as_bytes()); cache.process_frame(&msg); } // Should have been trimmed: 1100 - 500 = 600 remaining after first trim at 1001 assert!(cache.pending_signal_diffs.len() <= 600); } // ─── Path Routing Tests ─── #[test] fn parse_path_source_default() { match parse_path("/source") { ConnectionRole::Source(name) => assert_eq!(name, "default"), _ => panic!("Expected Source"), } } #[test] fn parse_path_source_named() { match parse_path("/source/main") { ConnectionRole::Source(name) => assert_eq!(name, "main"), _ => panic!("Expected Source"), } } #[test] fn parse_path_stream_default() { match parse_path("/stream") { ConnectionRole::Receiver(name) => assert_eq!(name, "default"), _ => panic!("Expected Receiver"), } } #[test] fn parse_path_stream_named() { match parse_path("/stream/player1") { ConnectionRole::Receiver(name) => assert_eq!(name, "player1"), _ => panic!("Expected Receiver"), } } #[test] fn parse_path_signal_default() { match parse_path("/signal") { ConnectionRole::Signaling(name) => assert_eq!(name, "default"), _ => panic!("Expected Signaling"), } } #[test] fn parse_path_signal_named() { match parse_path("/signal/main") { ConnectionRole::Signaling(name) => assert_eq!(name, "main"), _ => panic!("Expected Signaling"), } } #[test] fn parse_path_legacy_root() { match parse_path("/") { ConnectionRole::Receiver(name) => assert_eq!(name, "default"), _ => panic!("Expected Receiver for legacy path"), } } #[test] fn channel_state_creation() { let mut state = RelayState::new(16); let ch1 = state.get_or_create_channel("main"); let ch2 = state.get_or_create_channel("player1"); let ch1_again = state.get_or_create_channel("main"); assert_eq!(state.channels.len(), 2); assert!(Arc::ptr_eq(&ch1, &ch1_again)); assert!(!Arc::ptr_eq(&ch1, &ch2)); } }