diff --git a/engine/ds-stream/src/relay.rs b/engine/ds-stream/src/relay.rs index 3a6381e..b7e37ef 100644 --- a/engine/ds-stream/src/relay.rs +++ b/engine/ds-stream/src/relay.rs @@ -1,16 +1,27 @@ //! WebSocket Relay Server //! //! Routes frames from source→receivers and inputs from receivers→source. -//! Roles are determined by connection order: -//! - First connection = source (frame producer) -//! - Subsequent connections = receivers (frame consumers) //! -//! Features: +//! ## Routing +//! +//! Connections are routed by WebSocket URI path: +//! - `/source` — default source (channel: "default") +//! - `/source/{name}` — named source (channel: {name}) +//! - `/stream` — default receiver (channel: "default") +//! - `/stream/{name}` — named receiver (channel: {name}) +//! - `/` — legacy: first connection = source, rest = receivers +//! +//! Each channel has its own broadcast/input channels and state cache, +//! allowing multiple independent streams through a single relay. +//! +//! ## Features +//! - Multi-source: multiple views stream independently //! - Keyframe caching: late-joining receivers get current state instantly //! - Signal state store: caches SignalSync/Diff frames for reconstruction //! - Ping/pong keepalive: detects dead connections //! - Stats tracking: frames, bytes, latency metrics +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Instant; @@ -27,7 +38,7 @@ use crate::protocol::*; pub struct RelayConfig { /// Address to bind to. pub addr: SocketAddr, - /// Maximum number of receivers. + /// Maximum number of receivers per channel. pub max_receivers: usize, /// Frame broadcast channel capacity. pub frame_buffer_size: usize, @@ -130,44 +141,93 @@ impl StateCache { } } -/// Shared relay state. -struct RelayState { +/// Per-channel state — each named channel is an independent stream. +struct ChannelState { /// Broadcast channel: source → all receivers (frames) frame_tx: broadcast::Sender>, /// Channel: receivers → source (input events) input_tx: mpsc::Sender>, input_rx: Option>>, - /// Live stats + /// Live stats for this channel stats: RelayStats, /// Cached state for late-joining receivers cache: StateCache, +} + +impl ChannelState { + fn new(frame_buffer_size: usize) -> Self { + let (frame_tx, _) = broadcast::channel(frame_buffer_size); + let (input_tx, input_rx) = mpsc::channel(256); + Self { + frame_tx, + input_tx, + input_rx: Some(input_rx), + stats: RelayStats::default(), + cache: StateCache::default(), + } + } +} + +/// Shared relay state — holds all channels. +struct RelayState { + /// Named channels: "default", "main", "player1", etc. + channels: HashMap>>, + /// Frame buffer size for new channels + frame_buffer_size: usize, /// Server start time start_time: Instant, } +impl RelayState { + fn new(frame_buffer_size: usize) -> Self { + Self { + channels: HashMap::new(), + frame_buffer_size, + start_time: Instant::now(), + } + } + + /// Get or create a channel by name. + fn get_or_create_channel(&mut self, name: &str) -> Arc> { + self.channels + .entry(name.to_string()) + .or_insert_with(|| Arc::new(RwLock::new(ChannelState::new(self.frame_buffer_size)))) + .clone() + } +} + +/// Parsed connection role from the WebSocket URI path. +#[derive(Debug, Clone)] +enum ConnectionRole { + Source(String), // channel name + Receiver(String), // channel name +} + +/// Parse the WebSocket URI path to determine connection role and channel. +fn parse_path(path: &str) -> ConnectionRole { + let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect(); + match parts.as_slice() { + ["source"] => ConnectionRole::Source("default".to_string()), + ["source", name] => ConnectionRole::Source(name.to_string()), + ["stream"] => ConnectionRole::Receiver("default".to_string()), + ["stream", name] => ConnectionRole::Receiver(name.to_string()), + _ => ConnectionRole::Receiver("default".to_string()), // legacy: `/` = receiver + } +} + /// Run the WebSocket relay server. pub async fn run_relay(config: RelayConfig) -> Result<(), Box> { let listener = TcpListener::bind(&config.addr).await?; eprintln!("╔══════════════════════════════════════════════════╗"); - eprintln!("║ DreamStack Bitstream Relay v0.2.0 ║"); + eprintln!("║ DreamStack Bitstream Relay v0.3.0 ║"); eprintln!("║ ║"); - eprintln!("║ Source: ws://{}/source ║", config.addr); - eprintln!("║ Receiver: ws://{}/stream ║", config.addr); + eprintln!("║ Source: ws://{}/source/{{name}} ║", config.addr); + eprintln!("║ Receiver: ws://{}/stream/{{name}} ║", config.addr); eprintln!("║ ║"); - eprintln!("║ Features: keyframe cache, keepalive, RLE ║"); + eprintln!("║ Multi-source, keyframe cache, keepalive, RLE ║"); eprintln!("╚══════════════════════════════════════════════════╝"); - let (frame_tx, _) = broadcast::channel(config.frame_buffer_size); - let (input_tx, input_rx) = mpsc::channel(256); - - let state = Arc::new(RwLock::new(RelayState { - frame_tx, - input_tx, - input_rx: Some(input_rx), - stats: RelayStats::default(), - cache: StateCache::default(), - start_time: Instant::now(), - })); + let state = Arc::new(RwLock::new(RelayState::new(config.frame_buffer_size))); // Background: periodic stats logging { @@ -178,17 +238,19 @@ pub async fn run_relay(config: RelayConfig) -> Result<(), Box 0 { - eprintln!( - "[relay] up={}s frames={} bytes={} inputs={} receivers={} signal_diffs={} cached={}", - uptime, - s.stats.frames_relayed, - s.stats.bytes_relayed, - s.stats.inputs_relayed, - s.stats.connected_receivers, - s.stats.signal_diffs_sent, - s.cache.last_signal_sync.is_some(), - ); + for (name, channel) in &s.channels { + let cs = channel.read().await; + if cs.stats.source_connected || cs.stats.connected_receivers > 0 { + eprintln!( + "[relay:{name}] up={uptime}s frames={} bytes={} inputs={} receivers={} signal_diffs={} cached={}", + cs.stats.frames_relayed, + cs.stats.bytes_relayed, + cs.stats.inputs_relayed, + cs.stats.connected_receivers, + cs.stats.signal_diffs_sent, + cs.cache.last_signal_sync.is_some(), + ); + } } } }); @@ -210,10 +272,14 @@ async fn handle_connection( keepalive_interval: u64, _keepalive_timeout: u64, ) { + // Extract the URI path during handshake to determine role + let mut role = ConnectionRole::Receiver("default".to_string()); + let ws_stream = match tokio_tungstenite::accept_hdr_async( stream, - |_req: &tokio_tungstenite::tungstenite::handshake::server::Request, + |req: &tokio_tungstenite::tungstenite::handshake::server::Request, res: tokio_tungstenite::tungstenite::handshake::server::Response| { + role = parse_path(req.uri().path()); Ok(res) }, ) @@ -226,154 +292,174 @@ async fn handle_connection( } }; - // First connection = source, subsequent = receivers - let is_source = { - let s = state.read().await; - !s.stats.source_connected + // Get or create the channel + let (channel, channel_name) = { + let mut s = state.write().await; + let name = match &role { + ConnectionRole::Source(n) | ConnectionRole::Receiver(n) => n.clone(), + }; + let ch = s.get_or_create_channel(&name); + (ch, name) }; - if is_source { - eprintln!("[relay] Source connected: {}", addr); - handle_source(ws_stream, addr, state, keepalive_interval).await; - } else { - eprintln!("[relay] Receiver connected: {}", addr); - handle_receiver(ws_stream, addr, state, keepalive_interval).await; + // Legacy fallback: if path was `/` and no source exists, treat as source + let role = match role { + ConnectionRole::Receiver(ref name) if name == "default" => { + let cs = channel.read().await; + if !cs.stats.source_connected { + ConnectionRole::Source("default".to_string()) + } else { + role + } + } + _ => role, + }; + + match role { + ConnectionRole::Source(ref _name) => { + eprintln!("[relay:{channel_name}] Source connected: {addr}"); + handle_source(ws_stream, addr, channel, &channel_name, keepalive_interval).await; + } + ConnectionRole::Receiver(ref _name) => { + eprintln!("[relay:{channel_name}] Receiver connected: {addr}"); + handle_receiver(ws_stream, addr, channel, &channel_name).await; + } } } async fn handle_source( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, - state: Arc>, + channel: Arc>, + channel_name: &str, keepalive_interval: u64, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Mark source as connected and take the input_rx let input_rx = { - let mut s = state.write().await; - s.stats.source_connected = true; - s.input_rx.take() + let mut cs = channel.write().await; + cs.stats.source_connected = true; + cs.input_rx.take() }; let frame_tx = { - let s = state.read().await; - s.frame_tx.clone() + let cs = channel.read().await; + cs.frame_tx.clone() }; // Forward input events from receivers → source if let Some(mut input_rx) = input_rx { - let state_clone = state.clone(); + let channel_clone = channel.clone(); tokio::spawn(async move { while let Some(input_bytes) = input_rx.recv().await { let msg = Message::Binary(input_bytes.into()); if ws_sink.send(msg).await.is_err() { break; } - let mut s = state_clone.write().await; - s.stats.inputs_relayed += 1; + let mut cs = channel_clone.write().await; + cs.stats.inputs_relayed += 1; } }); } // Keepalive: periodic pings - let state_ping = state.clone(); + let channel_ping = channel.clone(); let ping_task = tokio::spawn(async move { let mut tick = interval(Duration::from_secs(keepalive_interval)); let mut seq = 0u16; loop { tick.tick().await; - let s = state_ping.read().await; - if !s.stats.source_connected { + let cs = channel_ping.read().await; + if !cs.stats.source_connected { break; } - drop(s); - let ping_msg = crate::codec::ping(seq, 0); - let _ = state_ping.read().await.frame_tx.send(ping_msg); + let _ = cs.frame_tx.send(crate::codec::ping(seq, 0)); + drop(cs); seq = seq.wrapping_add(1); } }); // Receive frames from source → broadcast to receivers + let channel_name_owned = channel_name.to_string(); while let Some(Ok(msg)) = ws_source.next().await { if let Message::Binary(data) = msg { let data_vec: Vec = data.into(); { - let mut s = state.write().await; - s.stats.frames_relayed += 1; - s.stats.bytes_relayed += data_vec.len() as u64; + let mut cs = channel.write().await; + cs.stats.frames_relayed += 1; + cs.stats.bytes_relayed += data_vec.len() as u64; // Update state cache - s.cache.process_frame(&data_vec); + cs.cache.process_frame(&data_vec); // Track frame-type-specific stats if data_vec.len() >= HEADER_SIZE { match FrameType::from_u8(data_vec[0]) { - Some(FrameType::SignalDiff) => s.stats.signal_diffs_sent += 1, + Some(FrameType::SignalDiff) => cs.stats.signal_diffs_sent += 1, Some(FrameType::Pixels) | Some(FrameType::Keyframe) | - Some(FrameType::SignalSync) => s.stats.keyframes_sent += 1, + Some(FrameType::SignalSync) => cs.stats.keyframes_sent += 1, _ => {} } let ts = u32::from_le_bytes([ data_vec[4], data_vec[5], data_vec[6], data_vec[7], ]); - s.stats.last_frame_timestamp = ts; + cs.stats.last_frame_timestamp = ts; } } - // Broadcast to all receivers + // Broadcast to all receivers on this channel let _ = frame_tx.send(data_vec); } } // Source disconnected ping_task.abort(); - eprintln!("[relay] Source disconnected: {}", addr); - let mut s = state.write().await; - s.stats.source_connected = false; + eprintln!("[relay:{channel_name_owned}] Source disconnected: {addr}"); + let mut cs = channel.write().await; + cs.stats.source_connected = false; } async fn handle_receiver( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, - state: Arc>, - _keepalive_interval: u64, + channel: Arc>, + channel_name: &str, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Subscribe to frame broadcast and get catchup messages let (mut frame_rx, catchup_msgs) = { - let mut s = state.write().await; - s.stats.connected_receivers += 1; - let rx = s.frame_tx.subscribe(); - let catchup = s.cache.catchup_messages(); + let mut cs = channel.write().await; + cs.stats.connected_receivers += 1; + let rx = cs.frame_tx.subscribe(); + let catchup = cs.cache.catchup_messages(); (rx, catchup) }; let input_tx = { - let s = state.read().await; - s.input_tx.clone() + let cs = channel.read().await; + cs.input_tx.clone() }; // Send cached state to late-joining receiver + let channel_name_owned = channel_name.to_string(); if !catchup_msgs.is_empty() { eprintln!( - "[relay] Sending {} catchup messages to {}", + "[relay:{channel_name_owned}] Sending {} catchup messages to {addr}", catchup_msgs.len(), - addr ); for msg_bytes in catchup_msgs { let msg = Message::Binary(msg_bytes.into()); if ws_sink.send(msg).await.is_err() { - eprintln!("[relay] Failed to send catchup to {}", addr); - let mut s = state.write().await; - s.stats.connected_receivers = s.stats.connected_receivers.saturating_sub(1); + eprintln!("[relay:{channel_name_owned}] Failed to send catchup to {addr}"); + let mut cs = channel.write().await; + cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1); return; } } } // Forward frames from broadcast → this receiver - let addr_clone = addr; let send_task = tokio::spawn(async move { loop { match frame_rx.recv().await { @@ -384,7 +470,7 @@ async fn handle_receiver( } } Err(broadcast::error::RecvError::Lagged(n)) => { - eprintln!("[relay] Receiver {} lagged by {} frames", addr_clone, n); + eprintln!("[relay] Receiver {addr} lagged by {n} frames"); } Err(_) => break, } @@ -400,9 +486,9 @@ async fn handle_receiver( // Receiver disconnected send_task.abort(); - eprintln!("[relay] Receiver disconnected: {}", addr); - let mut s = state.write().await; - s.stats.connected_receivers = s.stats.connected_receivers.saturating_sub(1); + eprintln!("[relay:{channel_name_owned}] Receiver disconnected: {addr}"); + let mut cs = channel.write().await; + cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1); } // ─── Tests ─── @@ -499,4 +585,58 @@ mod tests { // Should have been trimmed: 1100 - 500 = 600 remaining after first trim at 1001 assert!(cache.pending_signal_diffs.len() <= 600); } + + // ─── Path Routing Tests ─── + + #[test] + fn parse_path_source_default() { + match parse_path("/source") { + ConnectionRole::Source(name) => assert_eq!(name, "default"), + _ => panic!("Expected Source"), + } + } + + #[test] + fn parse_path_source_named() { + match parse_path("/source/main") { + ConnectionRole::Source(name) => assert_eq!(name, "main"), + _ => panic!("Expected Source"), + } + } + + #[test] + fn parse_path_stream_default() { + match parse_path("/stream") { + ConnectionRole::Receiver(name) => assert_eq!(name, "default"), + _ => panic!("Expected Receiver"), + } + } + + #[test] + fn parse_path_stream_named() { + match parse_path("/stream/player1") { + ConnectionRole::Receiver(name) => assert_eq!(name, "player1"), + _ => panic!("Expected Receiver"), + } + } + + #[test] + fn parse_path_legacy_root() { + match parse_path("/") { + ConnectionRole::Receiver(name) => assert_eq!(name, "default"), + _ => panic!("Expected Receiver for legacy path"), + } + } + + #[test] + fn channel_state_creation() { + let mut state = RelayState::new(16); + let ch1 = state.get_or_create_channel("main"); + let ch2 = state.get_or_create_channel("player1"); + let ch1_again = state.get_or_create_channel("main"); + + assert_eq!(state.channels.len(), 2); + assert!(Arc::ptr_eq(&ch1, &ch1_again)); + assert!(!Arc::ptr_eq(&ch1, &ch2)); + } }