//! WebSocket Relay Server //! //! Routes frames from source→receivers and inputs from receivers→source. //! //! ## Routing //! //! Connections are routed by WebSocket URI path: //! - `/source` — default source (channel: "default") //! - `/source/{name}` — named source (channel: {name}) //! - `/stream` — default receiver (channel: "default") //! - `/stream/{name}` — named receiver (channel: {name}) //! - `/signal` — WebRTC signaling (channel: "default") //! - `/signal/{name}` — WebRTC signaling (channel: {name}) //! - `/` — legacy: first connection = source, rest = receivers //! //! Each channel has its own broadcast/input channels and state cache, //! allowing multiple independent streams through a single relay. //! //! ## Production Features //! - Multi-source: multiple views stream independently //! - Keyframe caching: late-joining receivers get current state instantly //! - Signal state store: caches SignalSync/Diff frames for reconstruction //! - Ping/pong keepalive: detects dead connections //! - Stats tracking: frames, bytes, latency metrics //! - Max receiver limit per channel (configurable) //! - Channel GC: empty channels are cleaned up periodically //! - Source reconnection: cache preserved for seamless reconnect //! - Graceful shutdown: drain connections on SIGTERM use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Instant; use futures_util::{SinkExt, StreamExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{broadcast, mpsc, RwLock}; use tokio::time::{interval, Duration}; use tokio_tungstenite::tungstenite::Message; use crate::protocol::*; /// Relay server configuration. pub struct RelayConfig { /// Address to bind to. pub addr: SocketAddr, /// Maximum number of receivers per channel. pub max_receivers: usize, /// Frame broadcast channel capacity. pub frame_buffer_size: usize, /// Keepalive interval in seconds. pub keepalive_interval_secs: u64, /// Keepalive timeout in seconds — disconnect after this many seconds without a pong. pub keepalive_timeout_secs: u64, /// Maximum number of channels. pub max_channels: usize, /// Channel GC interval in seconds — how often to scan for empty channels. pub channel_gc_interval_secs: u64, /// Source reconnect grace period in seconds — keep cache alive after source disconnect. pub source_reconnect_grace_secs: u64, /// Replay depth: number of frames to keep in ring buffer for time-travel replay. /// Set to 0 to disable replay (default: 0, catchup-only). Set >0 for full replay. pub replay_depth: usize, /// Upstream relay URLs for federation — frames are forwarded to these relays. pub federation_upstreams: Vec, /// Recording directory — if set, incoming frames are written to disk. pub recording_dir: Option, } impl Default for RelayConfig { fn default() -> Self { Self { addr: "0.0.0.0:9100".parse().unwrap(), max_receivers: 64, frame_buffer_size: 16, keepalive_interval_secs: 10, keepalive_timeout_secs: 30, max_channels: 256, channel_gc_interval_secs: 60, source_reconnect_grace_secs: 30, replay_depth: 0, federation_upstreams: Vec::new(), recording_dir: None, } } } /// Stats tracked by the relay. #[derive(Debug, Default, Clone)] pub struct RelayStats { pub frames_relayed: u64, pub bytes_relayed: u64, pub inputs_relayed: u64, pub connected_receivers: usize, pub source_connected: bool, /// Timestamp of last frame from source (milliseconds since start). pub last_frame_timestamp: u32, /// Total keyframes sent. pub keyframes_sent: u64, /// Total signal diffs sent. pub signal_diffs_sent: u64, /// Uptime in seconds. pub uptime_secs: u64, /// Peak receiver count. pub peak_receivers: usize, /// Total connections served. pub total_connections: u64, /// Rejected connections (over limit). pub rejected_connections: u64, } /// Cached state for late-joining receivers. #[derive(Debug, Default, Clone)] pub struct StateCache { /// Last pixel keyframe (full RGBA frame). pub last_keyframe: Option>, /// Last signal sync frame (full JSON state). pub last_signal_sync: Option>, /// Accumulated signal diffs since last sync. /// Late-joining receivers get: last_signal_sync + all diffs. pub pending_signal_diffs: Vec>, /// Replay ring buffer — stores the last N frames for time-travel replay. /// When replay_depth > 0, receivers can request historical frames. pub replay_buffer: Vec>, /// Maximum replay buffer depth (0 = disabled). pub replay_depth: usize, } impl StateCache { /// Process an incoming frame and update cache. fn process_frame(&mut self, msg: &[u8]) { if msg.len() < HEADER_SIZE { return; } let frame_type = msg[0]; let flags = msg[1]; // Add to replay ring buffer if enabled if self.replay_depth > 0 { self.replay_buffer.push(msg.to_vec()); if self.replay_buffer.len() > self.replay_depth { self.replay_buffer.remove(0); } } match FrameType::from_u8(frame_type) { // Cache keyframes (pixel or signal sync) Some(FrameType::Pixels) if flags & FLAG_KEYFRAME != 0 => { self.last_keyframe = Some(msg.to_vec()); } Some(FrameType::SignalSync) => { self.last_signal_sync = Some(msg.to_vec()); self.pending_signal_diffs.clear(); // sync resets diffs } Some(FrameType::SignalDiff) => { self.pending_signal_diffs.push(msg.to_vec()); // Cap accumulated diffs to prevent unbounded memory if self.pending_signal_diffs.len() > 1000 { self.pending_signal_diffs.drain(..500); } } Some(FrameType::Keyframe) => { self.last_keyframe = Some(msg.to_vec()); } _ => {} } } /// Get all messages a late-joining receiver needs to reconstruct current state. /// /// Instead of replaying hundreds of individual diffs, merges all accumulated /// diffs into the last sync frame to produce ONE consolidated state message. /// This dramatically reduces catchup time and bandwidth. fn catchup_messages(&self) -> Vec> { let mut msgs = Vec::new(); // Strategy: merge all signal diffs into last sync to build a consolidated frame if let Some(ref sync_frame) = self.last_signal_sync { if !self.pending_signal_diffs.is_empty() { // Parse the sync frame's JSON payload if sync_frame.len() >= HEADER_SIZE { let payload_len = u32::from_le_bytes([ sync_frame[12], sync_frame[13], sync_frame[14], sync_frame[15], ]) as usize; let sync_end = HEADER_SIZE + payload_len.min(sync_frame.len() - HEADER_SIZE); let sync_payload = &sync_frame[HEADER_SIZE..sync_end]; if let Ok(mut merged) = serde_json::from_slice::(sync_payload) { // Apply each diff on top for diff_frame in &self.pending_signal_diffs { if diff_frame.len() < HEADER_SIZE { continue; } let diff_pl_len = u32::from_le_bytes([ diff_frame[12], diff_frame[13], diff_frame[14], diff_frame[15], ]) as usize; let diff_end = HEADER_SIZE + diff_pl_len.min(diff_frame.len() - HEADER_SIZE); let diff_payload = &diff_frame[HEADER_SIZE..diff_end]; if let Ok(diff_obj) = serde_json::from_slice::(diff_payload) { if let (Some(merged_map), Some(diff_map)) = (merged.as_object_mut(), diff_obj.as_object()) { for (k, v) in diff_map { if k != "_pid" && k != "_v" { merged_map.insert(k.clone(), v.clone()); } } // Merge version counters if let (Some(merged_v), Some(diff_v)) = ( merged_map.get_mut("_v").and_then(|v| v.as_object_mut()), diff_map.get("_v").and_then(|v| v.as_object()), ) { for (k, v) in diff_v { merged_v.insert(k.clone(), v.clone()); } } } } } // Build a consolidated SignalSync (0x30) frame let merged_json = serde_json::to_vec(&merged).unwrap_or_default(); let mut frame = Vec::with_capacity(HEADER_SIZE + merged_json.len()); // Copy sync frame header but set type to SignalSync and update length frame.extend_from_slice(&sync_frame[..HEADER_SIZE]); frame[0] = 0x30; // SignalSync frame[1] = 0x02; // FLAG_KEYFRAME let len_bytes = (merged_json.len() as u32).to_le_bytes(); frame[12] = len_bytes[0]; frame[13] = len_bytes[1]; frame[14] = len_bytes[2]; frame[15] = len_bytes[3]; frame.truncate(HEADER_SIZE); frame.extend_from_slice(&merged_json); msgs.push(frame); } else { // Fallback: can't parse, send raw sync + all diffs msgs.push(sync_frame.clone()); for diff in &self.pending_signal_diffs { msgs.push(diff.clone()); } } } else { msgs.push(sync_frame.clone()); } } else { // No diffs accumulated — just send the sync msgs.push(sync_frame.clone()); } } else if !self.pending_signal_diffs.is_empty() { // No sync frame but we have diffs — merge diffs into one frame let mut merged = serde_json::Map::new(); let mut merged_v = serde_json::Map::new(); for diff_frame in &self.pending_signal_diffs { if diff_frame.len() < HEADER_SIZE { continue; } let diff_pl_len = u32::from_le_bytes([ diff_frame[12], diff_frame[13], diff_frame[14], diff_frame[15], ]) as usize; let diff_end = HEADER_SIZE + diff_pl_len.min(diff_frame.len() - HEADER_SIZE); let diff_payload = &diff_frame[HEADER_SIZE..diff_end]; if let Ok(diff_obj) = serde_json::from_slice::(diff_payload) { if let Some(diff_map) = diff_obj.as_object() { for (k, v) in diff_map { if k == "_v" { if let Some(v_map) = v.as_object() { for (vk, vv) in v_map { merged_v.insert(vk.clone(), vv.clone()); } } } else if k != "_pid" { merged.insert(k.clone(), v.clone()); } } } } } if !merged_v.is_empty() { merged.insert("_v".to_string(), serde_json::Value::Object(merged_v)); } let merged_json = serde_json::to_vec(&serde_json::Value::Object(merged)).unwrap_or_default(); let frame = crate::codec::signal_sync_frame(0, 0, &merged_json); msgs.push(frame); } // Pixel keyframe (if available) if let Some(ref kf) = self.last_keyframe { msgs.push(kf.clone()); } msgs } /// Clear all cached state. #[allow(dead_code)] fn clear(&mut self) { self.last_keyframe = None; self.last_signal_sync = None; self.pending_signal_diffs.clear(); self.replay_buffer.clear(); } /// Returns true if this cache has any state. fn has_state(&self) -> bool { self.last_keyframe.is_some() || self.last_signal_sync.is_some() || !self.pending_signal_diffs.is_empty() || !self.replay_buffer.is_empty() } /// Get the replay buffer frames for time-travel playback. /// Returns frames from `start_index` onwards. pub fn replay_frames(&self, start_index: usize) -> &[Vec] { if start_index < self.replay_buffer.len() { &self.replay_buffer[start_index..] } else { &[] } } /// Total number of frames in the replay buffer. pub fn replay_len(&self) -> usize { self.replay_buffer.len() } } /// Per-channel state — each named channel is an independent stream. struct ChannelState { /// Broadcast channel: source → all receivers (frames) frame_tx: broadcast::Sender>, /// Channel: receivers → source (input events) input_tx: mpsc::Sender>, input_rx: Option>>, /// Broadcast channel for WebRTC signaling (SDP/ICE as text) signaling_tx: broadcast::Sender, /// Live stats for this channel stats: RelayStats, /// Cached state for late-joining receivers cache: StateCache, /// When the source last disconnected (for reconnect grace period) source_disconnect_time: Option, /// Max receivers for this channel max_receivers: usize, /// Cached schema announcement (0x32 payload) from source schema: Option>, } impl ChannelState { fn new(frame_buffer_size: usize, max_receivers: usize, replay_depth: usize) -> Self { let (frame_tx, _) = broadcast::channel(frame_buffer_size); let (input_tx, input_rx) = mpsc::channel(256); let (signaling_tx, _) = broadcast::channel(64); Self { frame_tx, input_tx, input_rx: Some(input_rx), signaling_tx, stats: RelayStats::default(), cache: StateCache { replay_depth, ..StateCache::default() }, source_disconnect_time: None, max_receivers, schema: None, } } /// Returns true if this channel is idle (no source, no receivers, no cache). fn is_idle(&self) -> bool { !self.stats.source_connected && self.stats.connected_receivers == 0 && !self.cache.has_state() } /// Returns true if the source reconnect grace period has expired. fn grace_period_expired(&self, grace_secs: u64) -> bool { if let Some(disconnect_time) = self.source_disconnect_time { disconnect_time.elapsed() > Duration::from_secs(grace_secs) } else { false } } } // ─── v1.2: Auth-Gated Channel ─── /// Per-channel authentication state. /// When `required` is true, sources must send a valid `Auth` frame before /// the relay will forward their frames. #[derive(Debug, Clone)] pub struct ChannelAuth { required: bool, /// Pre-shared key for this channel (empty = open) key: Vec, /// Authenticated source addresses (by string identifier) authenticated: Vec, } impl ChannelAuth { pub fn open() -> Self { ChannelAuth { required: false, key: Vec::new(), authenticated: Vec::new() } } pub fn with_key(key: &[u8]) -> Self { ChannelAuth { required: true, key: key.to_vec(), authenticated: Vec::new() } } /// Check if a source is authenticated. pub fn is_authenticated(&self, source_id: &str) -> bool { if !self.required { return true; } self.authenticated.iter().any(|s| s == source_id) } /// Attempt to authenticate a source with a token. /// Returns true if authentication succeeded. pub fn authenticate(&mut self, source_id: &str, token: &[u8]) -> bool { if !self.required { return true; } if token == self.key.as_slice() { if !self.authenticated.iter().any(|s| s == source_id) { self.authenticated.push(source_id.to_string()); } true } else { false } } /// Revoke authentication for a source. pub fn revoke(&mut self, source_id: &str) { self.authenticated.retain(|s| s != source_id); } pub fn is_required(&self) -> bool { self.required } pub fn authenticated_count(&self) -> usize { self.authenticated.len() } } /// Shared relay state — holds all channels. struct RelayState { /// Named channels: "default", "main", "player1", etc. channels: HashMap>>, /// Frame buffer size for new channels frame_buffer_size: usize, /// Max receivers per channel max_receivers: usize, /// Max channels max_channels: usize, /// Server start time start_time: Instant, /// Replay depth for new channels replay_depth: usize, /// Recording directory (None = disabled) recording_dir: Option, } impl RelayState { fn new(frame_buffer_size: usize, max_receivers: usize, max_channels: usize, replay_depth: usize, recording_dir: Option) -> Self { Self { channels: HashMap::new(), frame_buffer_size, max_receivers, max_channels, start_time: Instant::now(), replay_depth, recording_dir, } } /// Get or create a channel by name. Returns None if at max channels. fn get_or_create_channel(&mut self, name: &str) -> Option>> { if self.channels.contains_key(name) { return Some(self.channels[name].clone()); } if self.channels.len() >= self.max_channels { eprintln!("[relay] Max channels ({}) reached, rejecting: {}", self.max_channels, name); return None; } let channel = Arc::new(RwLock::new(ChannelState::new( self.frame_buffer_size, self.max_receivers, self.replay_depth, ))); self.channels.insert(name.to_string(), channel.clone()); Some(channel) } } /// Parsed connection role from the WebSocket URI path. #[derive(Debug, Clone)] enum ConnectionRole { Source(String), // channel name Receiver(String), // channel name Signaling(String), // channel name — WebRTC signaling Peer(String), // channel name — bidirectional sync Meta(String), // channel name — HTTP introspection } /// Parse the WebSocket URI path to determine connection role and channel. fn parse_path(path: &str) -> ConnectionRole { let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect(); match parts.as_slice() { ["source"] => ConnectionRole::Source("default".to_string()), ["source", name] => ConnectionRole::Source(name.to_string()), ["stream"] => ConnectionRole::Receiver("default".to_string()), ["stream", name] => ConnectionRole::Receiver(name.to_string()), ["signal"] => ConnectionRole::Signaling("default".to_string()), ["signal", name] => ConnectionRole::Signaling(name.to_string()), ["peer"] => ConnectionRole::Peer("default".to_string()), ["peer", name] => ConnectionRole::Peer(name.to_string()), ["meta"] => ConnectionRole::Meta("default".to_string()), ["meta", name] => ConnectionRole::Meta(name.to_string()), _ => ConnectionRole::Receiver("default".to_string()), } } /// Check if a channel name matches a wildcard pattern (channel groups). /// Supports `*` at the end for prefix matching: `games/*` matches `games/chess`. pub fn channel_matches(pattern: &str, channel: &str) -> bool { if pattern == channel { return true; } if let Some(prefix) = pattern.strip_suffix("/*") { return channel.starts_with(prefix) && channel.len() > prefix.len(); } if pattern == "*" { return true; } false } /// Find all channels matching a wildcard pattern. #[cfg(test)] fn find_matching_channels(state: &RelayState, pattern: &str) -> Vec { state.channels.keys() .filter(|name| channel_matches(pattern, name)) .cloned() .collect() } /// Run the WebSocket relay server. pub async fn run_relay(config: RelayConfig) -> Result<(), Box> { let listener = TcpListener::bind(&config.addr).await?; eprintln!("╔══════════════════════════════════════════════════╗"); eprintln!("║ DreamStack Bitstream Relay v0.3.0 ║"); eprintln!("║ ║"); eprintln!("║ Source: ws://{}/source/{{name}} ║", config.addr); eprintln!("║ Receiver: ws://{}/stream/{{name}} ║", config.addr); eprintln!("║ Signal: ws://{}/signal/{{name}} ║", config.addr); eprintln!("║ ║"); eprintln!("║ Max receivers/ch: {:>4} ║", config.max_receivers); eprintln!("║ Max channels: {:>4} ║", config.max_channels); eprintln!("║ Reconnect grace: {:>4}s ║", config.source_reconnect_grace_secs); eprintln!("╚══════════════════════════════════════════════════╝"); let grace_secs = config.source_reconnect_grace_secs; let replay_depth = config.replay_depth; let recording_dir = config.recording_dir.clone(); let federation_upstreams = config.federation_upstreams.clone(); let state = Arc::new(RwLock::new(RelayState::new( config.frame_buffer_size, config.max_receivers, config.max_channels, config.replay_depth, config.recording_dir.clone(), ))); // Log the feature status if replay_depth > 0 { eprintln!("║ Replay depth: {:>4} frames ║", replay_depth); } if let Some(ref dir) = recording_dir { eprintln!("║ Recording to: {} ║", dir); // Ensure recording directory exists std::fs::create_dir_all(dir).unwrap_or_else(|e| { eprintln!("[relay] Warning: could not create recording dir {}: {}", dir, e); }); } if !federation_upstreams.is_empty() { eprintln!("║ Federation: {} upstream(s) ║", federation_upstreams.len()); } // Background: periodic stats + channel GC { let state = state.clone(); let gc_interval = config.channel_gc_interval_secs; tokio::spawn(async move { let mut tick = interval(Duration::from_secs(gc_interval.min(30))); let mut gc_counter: u64 = 0; loop { tick.tick().await; gc_counter += 1; let s = state.read().await; let uptime = s.start_time.elapsed().as_secs(); // Stats logging for (name, channel) in &s.channels { let cs = channel.read().await; if cs.stats.source_connected || cs.stats.connected_receivers > 0 { eprintln!( "[relay:{name}] up={uptime}s frames={} bytes={} receivers={}/{} peak={} signal_diffs={} cached={}", cs.stats.frames_relayed, cs.stats.bytes_relayed, cs.stats.connected_receivers, cs.max_receivers, cs.stats.peak_receivers, cs.stats.signal_diffs_sent, cs.cache.has_state(), ); } } drop(s); // Channel GC — every gc_interval ticks if gc_counter % (gc_interval / gc_interval.min(30)).max(1) == 0 { let mut s = state.write().await; let before = s.channels.len(); let mut to_remove = Vec::new(); for (name, channel) in &s.channels { let cs = channel.read().await; if cs.is_idle() || cs.grace_period_expired(grace_secs) { to_remove.push(name.clone()); } } for name in &to_remove { s.channels.remove(name); } if !to_remove.is_empty() { eprintln!( "[relay] GC: removed {} idle channel(s) ({} → {}): {:?}", to_remove.len(), before, s.channels.len(), to_remove ); } } } }); } // Background: federation forwarding to upstream relays for upstream_url in &federation_upstreams { let url = upstream_url.clone(); let state = state.clone(); tokio::spawn(async move { let mut backoff = Duration::from_secs(1); loop { eprintln!("[relay:federation] Connecting to upstream: {url}"); match tokio_tungstenite::connect_async(&url).await { Ok((mut ws, _)) => { eprintln!("[relay:federation] Connected to {url}"); backoff = Duration::from_secs(1); // Subscribe to the default channel and forward frames let frame_rx = { let s = state.read().await; if let Some(ch) = s.channels.get("default") { let cs = ch.read().await; Some(cs.frame_tx.subscribe()) } else { None } }; if let Some(mut rx) = frame_rx { while let Ok(frame) = rx.recv().await { use futures_util::SinkExt; if ws.send(Message::Binary(frame.into())).await.is_err() { break; } } } eprintln!("[relay:federation] Disconnected from {url}"); } Err(e) => { eprintln!("[relay:federation] Failed to connect to {url}: {e}"); } } // Exponential backoff on reconnect tokio::time::sleep(backoff).await; backoff = (backoff * 2).min(Duration::from_secs(30)); } }); } while let Ok((stream, addr)) = listener.accept().await { let state = state.clone(); let keepalive_interval = config.keepalive_interval_secs; let keepalive_timeout = config.keepalive_timeout_secs; tokio::spawn(async move { // Peek at the HTTP request to check for /meta requests // Meta requests get a raw HTTP JSON response, not a WS upgrade let mut peek_buf = [0u8; 512]; let n = match stream.peek(&mut peek_buf).await { Ok(n) => n, Err(_) => return, }; let request_line = String::from_utf8_lossy(&peek_buf[..n]); // Extract path from "GET /meta/name HTTP/1.1" if let Some(path) = request_line.lines().next().and_then(|line| { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.len() >= 2 { Some(parts[1]) } else { None } }) { if path.starts_with("/meta") { handle_meta_http(stream, addr, &state, path).await; return; } } handle_connection(stream, addr, state, keepalive_interval, keepalive_timeout).await; }); } Ok(()) } /// Handle an HTTP /meta/{channel} request — return channel stats as JSON. /// /// This is NOT a WebSocket connection. We read the full HTTP request, /// then respond with an HTTP 200 + JSON body. async fn handle_meta_http( mut stream: TcpStream, addr: SocketAddr, state: &Arc>, path: &str, ) { use tokio::io::AsyncWriteExt; use tokio::io::AsyncReadExt; // Drain the HTTP request from the socket let mut buf = vec![0u8; 4096]; let _ = stream.read(&mut buf).await; // Parse channel name from path let channel_name = path.trim_start_matches("/meta/").trim_matches('/'); let channel_name = if channel_name.is_empty() { "__all__" } else { channel_name }; let s = state.read().await; let uptime = s.start_time.elapsed().as_secs(); let body = if channel_name == "__all__" { // Return stats for ALL channels let mut channels = serde_json::Map::new(); for (name, channel) in &s.channels { let cs = channel.read().await; channels.insert(name.clone(), serde_json::json!({ "source_connected": cs.stats.source_connected, "receivers": cs.stats.connected_receivers, "peak_receivers": cs.stats.peak_receivers, "frames_relayed": cs.stats.frames_relayed, "bytes_relayed": cs.stats.bytes_relayed, "signal_diffs": cs.stats.signal_diffs_sent, "keyframes": cs.stats.keyframes_sent, "total_connections": cs.stats.total_connections, "rejected_connections": cs.stats.rejected_connections, "cached": cs.cache.has_state(), "pending_diffs": cs.cache.pending_signal_diffs.len(), })); } serde_json::json!({ "uptime_secs": uptime, "channel_count": s.channels.len(), "channels": serde_json::Value::Object(channels), }).to_string() } else if let Some(channel) = s.channels.get(channel_name) { // Return stats for a specific channel let cs = channel.read().await; // Extract current signal state if available let mut current_state = serde_json::Value::Null; if let Some(ref sync_frame) = cs.cache.last_signal_sync { if sync_frame.len() >= HEADER_SIZE { let pl_len = u32::from_le_bytes([ sync_frame[12], sync_frame[13], sync_frame[14], sync_frame[15], ]) as usize; let end = HEADER_SIZE + pl_len.min(sync_frame.len() - HEADER_SIZE); if let Ok(val) = serde_json::from_slice::(&sync_frame[HEADER_SIZE..end]) { current_state = val; } } } serde_json::json!({ "channel": channel_name, "uptime_secs": uptime, "source_connected": cs.stats.source_connected, "receivers": cs.stats.connected_receivers, "max_receivers": cs.max_receivers, "peak_receivers": cs.stats.peak_receivers, "frames_relayed": cs.stats.frames_relayed, "bytes_relayed": cs.stats.bytes_relayed, "signal_diffs": cs.stats.signal_diffs_sent, "keyframes": cs.stats.keyframes_sent, "total_connections": cs.stats.total_connections, "rejected_connections": cs.stats.rejected_connections, "cached": cs.cache.has_state(), "pending_diffs": cs.cache.pending_signal_diffs.len(), "schema": cs.schema.as_ref().and_then(|s| { if s.len() >= HEADER_SIZE { let pl_len = u32::from_le_bytes([s[12], s[13], s[14], s[15]]) as usize; let end = HEADER_SIZE + pl_len.min(s.len() - HEADER_SIZE); serde_json::from_slice::(&s[HEADER_SIZE..end]).ok() } else { None } }), "current_state": current_state, }).to_string() } else { serde_json::json!({ "error": "channel not found", "channel": channel_name, }).to_string() }; // Write HTTP response let response = format!( "HTTP/1.1 200 OK\r\n\ Content-Type: application/json\r\n\ Access-Control-Allow-Origin: *\r\n\ Content-Length: {}\r\n\ Connection: close\r\n\ \r\n\ {}", body.len(), body ); let _ = stream.write_all(response.as_bytes()).await; eprintln!("[relay:meta] {addr} → {path}"); } async fn handle_connection( stream: TcpStream, addr: SocketAddr, state: Arc>, keepalive_interval: u64, _keepalive_timeout: u64, ) { // Extract the URI path during handshake to determine role let mut role = ConnectionRole::Receiver("default".to_string()); // Peek at the path to check for meta requests (handle via raw HTTP, not WS) let ws_stream = match tokio_tungstenite::accept_hdr_async( stream, |req: &tokio_tungstenite::tungstenite::handshake::server::Request, res: tokio_tungstenite::tungstenite::handshake::server::Response| { role = parse_path(req.uri().path()); Ok(res) }, ) .await { Ok(ws) => ws, Err(e) => { // Meta requests won't have an Upgrade header, so they fail the handshake // That's fine — they were handled inline if detected. if !matches!(role, ConnectionRole::Meta(_)) { eprintln!("[relay] WebSocket handshake failed from {}: {}", addr, e); } return; } }; // Handle meta requests separately (they don't use WebSocket) if let ConnectionRole::Meta(ref _channel_name) = role { // Meta requests fail the WS handshake, so we'll never reach here // But just in case, close the connection return; } // Get or create the channel let channel_name = match &role { ConnectionRole::Source(n) | ConnectionRole::Receiver(n) | ConnectionRole::Signaling(n) | ConnectionRole::Peer(n) | ConnectionRole::Meta(n) => n.clone(), }; let channel = { let mut s = state.write().await; match s.get_or_create_channel(&channel_name) { Some(ch) => ch, None => { eprintln!("[relay] Rejected connection from {addr}: max channels reached"); return; } } }; // Check receiver limit if let ConnectionRole::Receiver(_) = &role { let mut cs = channel.write().await; if cs.stats.connected_receivers >= cs.max_receivers { eprintln!( "[relay:{channel_name}] Rejected receiver {addr}: at capacity ({}/{})", cs.stats.connected_receivers, cs.max_receivers ); cs.stats.rejected_connections += 1; return; } } // Legacy fallback: if path was `/` and no source exists, treat as source let role = match role { ConnectionRole::Receiver(ref name) if name == "default" => { let cs = channel.read().await; if !cs.stats.source_connected { ConnectionRole::Source("default".to_string()) } else { role } } _ => role, }; match role { ConnectionRole::Source(ref _name) => { eprintln!("[relay:{channel_name}] Source connected: {addr}"); // Get recording_dir from relay state let recording_dir = { let s = state.read().await; s.recording_dir.clone() }; handle_source(ws_stream, addr, channel, &channel_name, keepalive_interval, recording_dir).await; } ConnectionRole::Receiver(ref _name) => { eprintln!("[relay:{channel_name}] Receiver connected: {addr}"); handle_receiver(ws_stream, addr, channel, &channel_name).await; } ConnectionRole::Signaling(ref _name) => { eprintln!("[relay:{channel_name}] Signaling peer connected: {addr}"); handle_signaling(ws_stream, addr, channel, &channel_name).await; } ConnectionRole::Peer(ref _name) => { eprintln!("[relay:{channel_name}] Peer connected: {addr}"); handle_peer(ws_stream, addr, channel, &channel_name).await; } ConnectionRole::Meta(_) => { // Handled before WS handshake — should never reach here } } } async fn handle_source( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, channel: Arc>, channel_name: &str, keepalive_interval: u64, recording_dir: Option, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Mark source as connected; if reconnecting, preserve cache let input_rx = { let mut cs = channel.write().await; if cs.source_disconnect_time.is_some() { eprintln!("[relay:{channel_name}] Source reconnected — cache preserved ({} diffs, sync={})", cs.cache.pending_signal_diffs.len(), cs.cache.last_signal_sync.is_some(), ); } cs.stats.source_connected = true; cs.stats.total_connections += 1; cs.source_disconnect_time = None; cs.input_rx.take() }; let frame_tx = { let cs = channel.read().await; cs.frame_tx.clone() }; // Forward input events from receivers → source if let Some(mut input_rx) = input_rx { let channel_clone = channel.clone(); tokio::spawn(async move { while let Some(input_bytes) = input_rx.recv().await { let msg = Message::Binary(input_bytes.into()); if ws_sink.send(msg).await.is_err() { break; } let mut cs = channel_clone.write().await; cs.stats.inputs_relayed += 1; } }); } // Keepalive: periodic pings let channel_ping = channel.clone(); let ping_task = tokio::spawn(async move { let mut tick = interval(Duration::from_secs(keepalive_interval)); let mut seq = 0u16; loop { tick.tick().await; let cs = channel_ping.read().await; if !cs.stats.source_connected { break; } let _ = cs.frame_tx.send(crate::codec::ping(seq, 0)); drop(cs); seq = seq.wrapping_add(1); } }); // Open recording file if configured let mut recording_file = if let Some(ref dir) = recording_dir { let path = format!("{}/{}.dsrec", dir, channel_name.replace('/', "_")); match tokio::fs::OpenOptions::new() .create(true) .append(true) .open(&path) .await { Ok(f) => { eprintln!("[relay:{channel_name}] Recording to {path}"); Some(f) } Err(e) => { eprintln!("[relay:{channel_name}] Warning: could not open recording file {path}: {e}"); None } } } else { None }; // Receive frames from source → broadcast to receivers let channel_name_owned = channel_name.to_string(); while let Some(Ok(msg)) = ws_source.next().await { if let Message::Binary(data) = msg { let data_vec: Vec = data.into(); { let mut cs = channel.write().await; cs.stats.frames_relayed += 1; cs.stats.bytes_relayed += data_vec.len() as u64; // Update state cache cs.cache.process_frame(&data_vec); // Track frame-type-specific stats if data_vec.len() >= HEADER_SIZE { match FrameType::from_u8(data_vec[0]) { Some(FrameType::SignalDiff) => cs.stats.signal_diffs_sent += 1, Some(FrameType::Pixels) | Some(FrameType::Keyframe) | Some(FrameType::SignalSync) => cs.stats.keyframes_sent += 1, _ => {} } let ts = u32::from_le_bytes([ data_vec[4], data_vec[5], data_vec[6], data_vec[7], ]); cs.stats.last_frame_timestamp = ts; } } // Record frame to disk (length-delimited: [u32 len][frame bytes]) if let Some(ref mut file) = recording_file { use tokio::io::AsyncWriteExt; let len_bytes = (data_vec.len() as u32).to_le_bytes(); let _ = file.write_all(&len_bytes).await; let _ = file.write_all(&data_vec).await; } // Broadcast to all receivers on this channel let _ = frame_tx.send(data_vec); } } // Source disconnected — start grace period, keep cache ping_task.abort(); eprintln!("[relay:{channel_name_owned}] Source disconnected: {addr} (cache preserved for reconnect)"); let mut cs = channel.write().await; cs.stats.source_connected = false; cs.source_disconnect_time = Some(Instant::now()); // Recreate input channel for next source connection let (new_input_tx, new_input_rx) = mpsc::channel(256); cs.input_tx = new_input_tx; cs.input_rx = Some(new_input_rx); } async fn handle_receiver( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, channel: Arc>, channel_name: &str, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Subscribe to frame broadcast and get catchup + schema let (mut frame_rx, catchup_msgs, schema_frame) = { let mut cs = channel.write().await; cs.stats.connected_receivers += 1; cs.stats.total_connections += 1; if cs.stats.connected_receivers > cs.stats.peak_receivers { cs.stats.peak_receivers = cs.stats.connected_receivers; } let rx = cs.frame_tx.subscribe(); let catchup = cs.cache.catchup_messages(); let schema = cs.schema.clone(); (rx, catchup, schema) }; let input_tx = { let cs = channel.read().await; cs.input_tx.clone() }; // Send cached schema (0x32) first so receiver knows what signals are available let channel_name_owned = channel_name.to_string(); if let Some(schema_bytes) = schema_frame { let _ = ws_sink.send(Message::Binary(schema_bytes.into())).await; } // Send cached state to late-joining receiver if !catchup_msgs.is_empty() { eprintln!( "[relay:{channel_name_owned}] Sending {} catchup messages to {addr}", catchup_msgs.len(), ); for msg_bytes in catchup_msgs { let msg = Message::Binary(msg_bytes.into()); if ws_sink.send(msg).await.is_err() { eprintln!("[relay:{channel_name_owned}] Failed to send catchup to {addr}"); let mut cs = channel.write().await; cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1); return; } } } // Per-receiver signal filter (set via 0x33 SubscribeFilter) let filter: Arc>>> = Arc::new(tokio::sync::RwLock::new(None)); let filter_for_send = filter.clone(); // Forward frames from broadcast → this receiver (with optional filtering) let send_task = tokio::spawn(async move { loop { match frame_rx.recv().await { Ok(frame_bytes) => { // Check if we need to filter signal frames let filter_guard = filter_for_send.read().await; let should_filter = filter_guard.is_some() && frame_bytes.len() >= 16 && (frame_bytes[0] == 0x30 || frame_bytes[0] == 0x31); if should_filter { if let Some(ref wanted) = *filter_guard { // Parse JSON payload, keep only wanted keys let payload_len = u32::from_le_bytes([ frame_bytes[12], frame_bytes[13], frame_bytes[14], frame_bytes[15], ]) as usize; if frame_bytes.len() >= 16 + payload_len { let payload = &frame_bytes[16..16 + payload_len]; if let Ok(mut obj) = serde_json::from_slice::(payload) { if let Some(map) = obj.as_object_mut() { let keys: Vec = map.keys().cloned().collect(); for k in keys { // Keep _pid, _v (internal), and wanted fields if k != "_pid" && k != "_v" && !wanted.contains(&k) { map.remove(&k); } } // Re-encode let new_payload = serde_json::to_vec(&obj).unwrap_or_default(); let mut new_frame = Vec::with_capacity(16 + new_payload.len()); new_frame.extend_from_slice(&frame_bytes[..12]); new_frame.extend_from_slice(&(new_payload.len() as u32).to_le_bytes()); new_frame.extend_from_slice(&new_payload); drop(filter_guard); if ws_sink.send(Message::Binary(new_frame.into())).await.is_err() { break; } continue; } } } } } drop(filter_guard); let msg = Message::Binary(frame_bytes.into()); if ws_sink.send(msg).await.is_err() { break; } } Err(broadcast::error::RecvError::Lagged(n)) => { eprintln!("[relay] Receiver {addr} lagged by {n} frames"); } Err(_) => break, } } }); // Handle incoming messages from receiver (input events + 0x33 filter) while let Some(Ok(msg)) = ws_source.next().await { if let Message::Binary(data) = msg { let data_vec: Vec = data.into(); // Parse SubscribeFilter (0x33) if data_vec.len() >= 16 && data_vec[0] == 0x33 { let payload_len = u32::from_le_bytes([ data_vec[12], data_vec[13], data_vec[14], data_vec[15], ]) as usize; if data_vec.len() >= 16 + payload_len { let payload = &data_vec[16..16 + payload_len]; if let Ok(obj) = serde_json::from_slice::(payload) { if let Some(select_arr) = obj.get("select").and_then(|v| v.as_array()) { let wanted: std::collections::HashSet = select_arr .iter() .filter_map(|v| v.as_str().map(|s| s.to_string())) .collect(); eprintln!( "[relay:{channel_name_owned}] Receiver {addr} subscribed to: {:?}", wanted ); let mut f = filter.write().await; *f = Some(wanted); } } } continue; // Don't forward filter frames to source } // Forward input events to source let _ = input_tx.send(data_vec).await; } } // Receiver disconnected send_task.abort(); eprintln!("[relay:{channel_name_owned}] Receiver disconnected: {addr}"); let mut cs = channel.write().await; cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1); } /// Handle a WebRTC signaling connection. /// /// Signaling peers exchange JSON messages (SDP offers/answers, ICE candidates) /// over WebSocket. The relay broadcasts all text messages to all other peers /// on the same channel, enabling peer-to-peer WebRTC setup. async fn handle_signaling( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, channel: Arc>, channel_name: &str, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Subscribe to signaling broadcast let mut sig_rx = { let cs = channel.read().await; cs.signaling_tx.subscribe() }; let sig_tx = { let cs = channel.read().await; cs.signaling_tx.clone() }; // Forward signaling messages from broadcast → this peer let channel_name_owned = channel_name.to_string(); let send_task = tokio::spawn(async move { loop { match sig_rx.recv().await { Ok(msg_text) => { let msg = Message::Text(msg_text.into()); if ws_sink.send(msg).await.is_err() { break; } } Err(broadcast::error::RecvError::Lagged(n)) => { eprintln!("[relay:{channel_name_owned}] Signaling peer lagged by {n} messages"); } Err(_) => break, } } }); // Forward signaling messages from this peer → broadcast to others while let Some(Ok(msg)) = ws_source.next().await { if let Message::Text(text) = msg { let text_str: String = text.into(); let _ = sig_tx.send(text_str); } } send_task.abort(); eprintln!("[relay:{channel_name}] Signaling peer disconnected: {addr}"); } /// Handle a peer connection for bidirectional sync. /// /// Peers are equal — every binary frame sent by any peer is broadcast to all /// other peers on the same channel. No source/receiver distinction. async fn handle_peer( ws_stream: tokio_tungstenite::WebSocketStream, addr: SocketAddr, channel: Arc>, channel_name: &str, ) { let (mut ws_sink, mut ws_source) = ws_stream.split(); // Subscribe to the frame broadcast channel (same one sources use) let mut frame_rx = { let cs = channel.read().await; cs.frame_tx.subscribe() }; let frame_tx = { let cs = channel.read().await; cs.frame_tx.clone() }; // Track peer count { let mut cs = channel.write().await; cs.stats.connected_receivers += 1; cs.stats.peak_receivers = cs.stats.peak_receivers.max(cs.stats.connected_receivers); } // Send catchup state to late-joining peers { let cs = channel.read().await; if cs.cache.has_state() { let msgs = cs.cache.catchup_messages(); eprintln!("[relay:{channel_name}] Sending {} catchup messages to peer {addr}", msgs.len()); for msg in msgs { let _ = ws_sink.send(Message::Binary(msg.into())).await; } } } // Forward frames from broadcast → this peer let channel_name_owned = channel_name.to_string(); let send_task = tokio::spawn(async move { loop { match frame_rx.recv().await { Ok(frame_data) => { if ws_sink.send(Message::Binary(frame_data.into())).await.is_err() { break; } } Err(broadcast::error::RecvError::Lagged(n)) => { eprintln!("[relay:{channel_name_owned}] Peer lagged by {n} frames"); } Err(_) => break, } } }); // Forward frames from this peer → broadcast to all others let channel_for_cache = channel.clone(); let _channel_name_cache = channel_name.to_string(); while let Some(Ok(msg)) = ws_source.next().await { if let Message::Binary(data) = msg { let data_vec: Vec = data.into(); // Intercept schema announcement (0x32) — cache, don't broadcast if data_vec.len() >= 16 && data_vec[0] == 0x32 { let payload_len = u32::from_le_bytes([data_vec[12], data_vec[13], data_vec[14], data_vec[15]]) as usize; if data_vec.len() >= 16 + payload_len { let mut cs = channel_for_cache.write().await; cs.schema = Some(data_vec.clone()); } continue; // Don't broadcast schema frames } // Cache for late joiners { let mut cs = channel_for_cache.write().await; cs.cache.process_frame(&data_vec); cs.stats.frames_relayed += 1; cs.stats.bytes_relayed += data_vec.len() as u64; } // Broadcast to all peers (they'll filter out their own msgs by content) let _ = frame_tx.send(data_vec); } } send_task.abort(); { let mut cs = channel.write().await; cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1); } eprintln!("[relay:{channel_name}] Peer disconnected: {addr}"); } // ─── Tests ─── #[cfg(test)] mod tests { use super::*; #[test] fn default_config() { let config = RelayConfig::default(); assert_eq!(config.addr, "0.0.0.0:9100".parse::().unwrap()); assert_eq!(config.max_receivers, 64); assert_eq!(config.frame_buffer_size, 16); assert_eq!(config.keepalive_interval_secs, 10); assert_eq!(config.max_channels, 256); assert_eq!(config.channel_gc_interval_secs, 60); assert_eq!(config.source_reconnect_grace_secs, 30); } #[test] fn stats_default() { let stats = RelayStats::default(); assert_eq!(stats.frames_relayed, 0); assert_eq!(stats.bytes_relayed, 0); assert!(!stats.source_connected); assert_eq!(stats.keyframes_sent, 0); assert_eq!(stats.signal_diffs_sent, 0); assert_eq!(stats.peak_receivers, 0); assert_eq!(stats.total_connections, 0); assert_eq!(stats.rejected_connections, 0); } #[test] fn state_cache_signal_sync() { let mut cache = StateCache::default(); // Simulate a SignalSync frame let sync_json = br#"{"count":0}"#; let sync_msg = crate::codec::signal_sync_frame(0, 100, sync_json); cache.process_frame(&sync_msg); assert!(cache.last_signal_sync.is_some()); assert!(cache.has_state()); assert_eq!(cache.pending_signal_diffs.len(), 0); // Simulate a SignalDiff let diff_json = br#"{"count":1}"#; let diff_msg = crate::codec::signal_diff_frame(1, 200, diff_json); cache.process_frame(&diff_msg); assert_eq!(cache.pending_signal_diffs.len(), 1); // Catchup should contain ONE merged frame (sync + diff consolidated) let catchup = cache.catchup_messages(); assert_eq!(catchup.len(), 1); // Verify the merged payload contains the updated count let frame = &catchup[0]; assert!(frame.len() >= HEADER_SIZE); let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize; let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len]; let merged: serde_json::Value = serde_json::from_slice(payload).unwrap(); assert_eq!(merged["count"], 1); // diff applied on top of sync } #[test] fn state_cache_signal_sync_resets_diffs() { let mut cache = StateCache::default(); // Add some diffs for i in 0..5 { let json = format!(r#"{{"count":{}}}"#, i); let msg = crate::codec::signal_diff_frame(i, i as u32 * 100, json.as_bytes()); cache.process_frame(&msg); } assert_eq!(cache.pending_signal_diffs.len(), 5); // A new sync should clear diffs let sync = crate::codec::signal_sync_frame(10, 1000, b"{}"); cache.process_frame(&sync); assert_eq!(cache.pending_signal_diffs.len(), 0); assert!(cache.last_signal_sync.is_some()); } #[test] fn state_cache_keyframe() { let mut cache = StateCache::default(); // Simulate a keyframe pixel frame let kf = crate::codec::pixel_frame(0, 100, 10, 10, &vec![0xFF; 400]); cache.process_frame(&kf); assert!(cache.last_keyframe.is_some()); assert!(cache.has_state()); let catchup = cache.catchup_messages(); assert_eq!(catchup.len(), 1); } #[test] fn state_cache_diff_cap() { let mut cache = StateCache::default(); // Add more than 1000 diffs — should auto-trim for i in 0..1100u16 { let json = format!(r#"{{"n":{}}}"#, i); let msg = crate::codec::signal_diff_frame(i, i as u32, json.as_bytes()); cache.process_frame(&msg); } // Should have been trimmed: 1100 - 500 = 600 remaining after first trim at 1001 assert!(cache.pending_signal_diffs.len() <= 600); } #[test] fn state_cache_clear() { let mut cache = StateCache::default(); let sync = crate::codec::signal_sync_frame(0, 0, b"{}"); cache.process_frame(&sync); assert!(cache.has_state()); cache.clear(); assert!(!cache.has_state()); assert!(cache.last_signal_sync.is_none()); } // ─── Path Routing Tests ─── #[test] fn parse_path_source_default() { match parse_path("/source") { ConnectionRole::Source(name) => assert_eq!(name, "default"), _ => panic!("Expected Source"), } } #[test] fn parse_path_source_named() { match parse_path("/source/main") { ConnectionRole::Source(name) => assert_eq!(name, "main"), _ => panic!("Expected Source"), } } #[test] fn parse_path_stream_default() { match parse_path("/stream") { ConnectionRole::Receiver(name) => assert_eq!(name, "default"), _ => panic!("Expected Receiver"), } } #[test] fn parse_path_stream_named() { match parse_path("/stream/player1") { ConnectionRole::Receiver(name) => assert_eq!(name, "player1"), _ => panic!("Expected Receiver"), } } #[test] fn parse_path_signal_default() { match parse_path("/signal") { ConnectionRole::Signaling(name) => assert_eq!(name, "default"), _ => panic!("Expected Signaling"), } } #[test] fn parse_path_signal_named() { match parse_path("/signal/main") { ConnectionRole::Signaling(name) => assert_eq!(name, "main"), _ => panic!("Expected Signaling"), } } #[test] fn parse_path_legacy_root() { match parse_path("/") { ConnectionRole::Receiver(name) => assert_eq!(name, "default"), _ => panic!("Expected Receiver for legacy path"), } } // ─── v0.14: Peer + Meta Path Tests ─── #[test] fn parse_path_peer_default() { match parse_path("/peer") { ConnectionRole::Peer(name) => assert_eq!(name, "default"), _ => panic!("Expected Peer"), } } #[test] fn parse_path_peer_named() { match parse_path("/peer/room42") { ConnectionRole::Peer(name) => assert_eq!(name, "room42"), _ => panic!("Expected Peer"), } } #[test] fn parse_path_meta_default() { match parse_path("/meta") { ConnectionRole::Meta(name) => assert_eq!(name, "default"), _ => panic!("Expected Meta"), } } #[test] fn parse_path_meta_named() { match parse_path("/meta/stats") { ConnectionRole::Meta(name) => assert_eq!(name, "stats"), _ => panic!("Expected Meta"), } } // ─── Channel State Tests ─── #[test] fn channel_state_creation() { let mut state = RelayState::new(16, 64, 256, 0, None); let ch1 = state.get_or_create_channel("main").unwrap(); let ch2 = state.get_or_create_channel("player1").unwrap(); let ch1_again = state.get_or_create_channel("main").unwrap(); assert_eq!(state.channels.len(), 2); assert!(Arc::ptr_eq(&ch1, &ch1_again)); assert!(!Arc::ptr_eq(&ch1, &ch2)); } #[test] fn channel_max_limit() { let mut state = RelayState::new(16, 64, 2, 0, None); assert!(state.get_or_create_channel("a").is_some()); assert!(state.get_or_create_channel("b").is_some()); assert!(state.get_or_create_channel("c").is_none()); // max reached assert!(state.get_or_create_channel("a").is_some()); // existing OK } #[test] fn channel_idle_detection() { let cs = ChannelState::new(16, 64, 0); assert!(cs.is_idle()); // no source, no receivers, no cache } #[test] fn channel_not_idle_with_cache() { let mut cs = ChannelState::new(16, 64, 0); let sync = crate::codec::signal_sync_frame(0, 0, b"{}"); cs.cache.process_frame(&sync); assert!(!cs.is_idle()); // has cached state } #[test] fn channel_not_idle_with_source() { let mut cs = ChannelState::new(16, 64, 0); cs.stats.source_connected = true; assert!(!cs.is_idle()); } #[test] fn channel_not_idle_with_receivers() { let mut cs = ChannelState::new(16, 64, 0); cs.stats.connected_receivers = 1; assert!(!cs.is_idle()); } #[test] fn grace_period_not_expired_initially() { let cs = ChannelState::new(16, 64, 0); assert!(!cs.grace_period_expired(30)); } #[test] fn grace_period_expired_after_disconnect() { let mut cs = ChannelState::new(16, 64, 0); cs.source_disconnect_time = Some(Instant::now() - Duration::from_secs(60)); assert!(cs.grace_period_expired(30)); } // ─── Diff Merging Tests ─── #[test] fn catchup_merges_multiple_diffs() { let mut cache = StateCache::default(); // Start with sync: { count: 0, name: "test" } let sync = crate::codec::signal_sync_frame(0, 0, br#"{"count":0,"name":"test"}"#); cache.process_frame(&sync); // Apply 3 diffs let d1 = crate::codec::signal_diff_frame(1, 100, br#"{"count":1}"#); let d2 = crate::codec::signal_diff_frame(2, 200, br#"{"count":2}"#); let d3 = crate::codec::signal_diff_frame(3, 300, br#"{"count":3,"name":"updated"}"#); cache.process_frame(&d1); cache.process_frame(&d2); cache.process_frame(&d3); // Should produce 1 merged frame, not 4 let catchup = cache.catchup_messages(); assert_eq!(catchup.len(), 1); // Verify merged state let frame = &catchup[0]; let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize; let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len]; let merged: serde_json::Value = serde_json::from_slice(payload).unwrap(); assert_eq!(merged["count"], 3); assert_eq!(merged["name"], "updated"); } #[test] fn catchup_diffs_only_no_sync() { let mut cache = StateCache::default(); // Only diffs, no initial sync (first connection scenario) let d1 = crate::codec::signal_diff_frame(0, 0, br#"{"mood":"happy"}"#); let d2 = crate::codec::signal_diff_frame(1, 100, br#"{"energy":75}"#); cache.process_frame(&d1); cache.process_frame(&d2); // Should produce 1 merged frame let catchup = cache.catchup_messages(); assert_eq!(catchup.len(), 1); let frame = &catchup[0]; let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize; let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len]; let merged: serde_json::Value = serde_json::from_slice(payload).unwrap(); assert_eq!(merged["mood"], "happy"); assert_eq!(merged["energy"], 75); } #[test] fn catchup_preserves_version_counters() { let mut cache = StateCache::default(); let sync = crate::codec::signal_sync_frame(0, 0, br#"{"count":0,"_v":{"count":0}}"#); cache.process_frame(&sync); let d1 = crate::codec::signal_diff_frame(1, 100, br#"{"count":5,"_v":{"count":3}}"#); cache.process_frame(&d1); let catchup = cache.catchup_messages(); assert_eq!(catchup.len(), 1); let frame = &catchup[0]; let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize; let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len]; let merged: serde_json::Value = serde_json::from_slice(payload).unwrap(); assert_eq!(merged["count"], 5); assert_eq!(merged["_v"]["count"], 3); // version preserved from diff } // ─── Channel Wildcard Tests ─── #[test] fn channel_matches_exact() { assert!(channel_matches("games/chess", "games/chess")); assert!(!channel_matches("games/chess", "games/go")); } #[test] fn channel_matches_wildcard() { assert!(channel_matches("games/*", "games/chess")); assert!(channel_matches("games/*", "games/go")); assert!(!channel_matches("games/*", "other/chess")); assert!(!channel_matches("games/*", "games")); // no trailing segment } #[test] fn channel_matches_star_all() { assert!(channel_matches("*", "anything")); assert!(channel_matches("*", "games/chess")); } #[test] fn find_matching_channels_works() { let mut state = RelayState::new(16, 64, 256, 0, None); state.get_or_create_channel("games/chess").unwrap(); state.get_or_create_channel("games/go").unwrap(); state.get_or_create_channel("chat/main").unwrap(); let mut matches = find_matching_channels(&state, "games/*"); matches.sort(); assert_eq!(matches, vec!["games/chess", "games/go"]); } // ─── Replay Buffer Tests ─── #[test] fn replay_buffer_stores_frames() { let mut cache = StateCache { replay_depth: 3, ..Default::default() }; let f1 = crate::codec::signal_diff_frame(0, 0, b"{\"a\":1}"); let f2 = crate::codec::signal_diff_frame(1, 100, b"{\"a\":2}"); let f3 = crate::codec::signal_diff_frame(2, 200, b"{\"a\":3}"); cache.process_frame(&f1); cache.process_frame(&f2); cache.process_frame(&f3); assert_eq!(cache.replay_len(), 3); assert_eq!(cache.replay_frames(0).len(), 3); assert_eq!(cache.replay_frames(1).len(), 2); assert_eq!(cache.replay_frames(3).len(), 0); } #[test] fn replay_buffer_evicts_oldest() { let mut cache = StateCache { replay_depth: 2, ..Default::default() }; let f1 = crate::codec::signal_diff_frame(0, 0, b"{\"a\":1}"); let f2 = crate::codec::signal_diff_frame(1, 100, b"{\"a\":2}"); let f3 = crate::codec::signal_diff_frame(2, 200, b"{\"a\":3}"); cache.process_frame(&f1); cache.process_frame(&f2); cache.process_frame(&f3); assert_eq!(cache.replay_len(), 2); // First frame should be evicted assert_eq!(cache.replay_buffer[0], f2); assert_eq!(cache.replay_buffer[1], f3); } #[test] fn replay_depth_propagates_to_channel() { let mut state = RelayState::new(16, 64, 256, 100, None); let ch = state.get_or_create_channel("test").unwrap(); let cs = ch.try_read().unwrap(); assert_eq!(cs.cache.replay_depth, 100); } #[test] fn replay_disabled_when_zero() { let mut cache = StateCache::default(); assert_eq!(cache.replay_depth, 0); let f1 = crate::codec::signal_diff_frame(0, 0, b"{\"a\":1}"); cache.process_frame(&f1); assert_eq!(cache.replay_len(), 0); // Nothing stored when depth=0 } // ─── v1.2: Channel Auth Tests ─── #[test] fn channel_auth_open_allows_all() { let auth = ChannelAuth::open(); assert!(auth.is_authenticated("any-source")); assert!(!auth.is_required()); } #[test] fn channel_auth_rejects_unauthed() { let auth = ChannelAuth::with_key(b"secret-123"); assert!(auth.is_required()); assert!(!auth.is_authenticated("source-1")); } #[test] fn channel_auth_allows_authed() { let mut auth = ChannelAuth::with_key(b"secret-123"); assert!(!auth.authenticate("source-1", b"wrong-key")); assert!(!auth.is_authenticated("source-1")); assert!(auth.authenticate("source-1", b"secret-123")); assert!(auth.is_authenticated("source-1")); assert_eq!(auth.authenticated_count(), 1); } #[test] fn channel_auth_revoke() { let mut auth = ChannelAuth::with_key(b"key"); auth.authenticate("src", b"key"); assert!(auth.is_authenticated("src")); auth.revoke("src"); assert!(!auth.is_authenticated("src")); assert_eq!(auth.authenticated_count(), 0); } }