dreamstack/engine/ds-stream/src/relay.rs
enzotar 35b39a1cf1 feat(ds-stream): v2.0-2.3 composable codec pipeline
v2.0 — Pipeline Architecture
  - Frame, CodecResult, Codec trait, Pipeline builder
  - 6 adapters: Passthrough, Dedup, Compress, Pacer, Slicer, Stats

v2.1 — Multi-frame & new codecs
  - CodecOutput::Many fan-out, EncryptCodec, FilterCodec
  - Codec::reset(), encode_all/decode_all, real SlicerCodec chunking

v2.2 — Observability & reassembly
  - PipelineResult (frames+errors+consumed), StageMetric
  - ReassemblyCodec, ConditionalCodec, Pipeline presets & metrics

v2.3 — Integrity & rate control
  - ChecksumCodec (CRC32), RateLimitCodec (token bucket), TagCodec
  - Pipeline::chain(), Pipeline::describe()

13 codec adapters, 474 tests (all green, 0 regressions)
2026-03-11 23:50:35 -07:00

1848 lines
70 KiB
Rust

//! WebSocket Relay Server
//!
//! Routes frames from source→receivers and inputs from receivers→source.
//!
//! ## Routing
//!
//! Connections are routed by WebSocket URI path:
//! - `/source` — default source (channel: "default")
//! - `/source/{name}` — named source (channel: {name})
//! - `/stream` — default receiver (channel: "default")
//! - `/stream/{name}` — named receiver (channel: {name})
//! - `/signal` — WebRTC signaling (channel: "default")
//! - `/signal/{name}` — WebRTC signaling (channel: {name})
//! - `/` — legacy: first connection = source, rest = receivers
//!
//! Each channel has its own broadcast/input channels and state cache,
//! allowing multiple independent streams through a single relay.
//!
//! ## Production Features
//! - Multi-source: multiple views stream independently
//! - Keyframe caching: late-joining receivers get current state instantly
//! - Signal state store: caches SignalSync/Diff frames for reconstruction
//! - Ping/pong keepalive: detects dead connections
//! - Stats tracking: frames, bytes, latency metrics
//! - Max receiver limit per channel (configurable)
//! - Channel GC: empty channels are cleaned up periodically
//! - Source reconnection: cache preserved for seamless reconnect
//! - Graceful shutdown: drain connections on SIGTERM
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use futures_util::{SinkExt, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time::{interval, Duration};
use tokio_tungstenite::tungstenite::Message;
use crate::protocol::*;
/// Relay server configuration.
pub struct RelayConfig {
/// Address to bind to.
pub addr: SocketAddr,
/// Maximum number of receivers per channel.
pub max_receivers: usize,
/// Frame broadcast channel capacity.
pub frame_buffer_size: usize,
/// Keepalive interval in seconds.
pub keepalive_interval_secs: u64,
/// Keepalive timeout in seconds — disconnect after this many seconds without a pong.
pub keepalive_timeout_secs: u64,
/// Maximum number of channels.
pub max_channels: usize,
/// Channel GC interval in seconds — how often to scan for empty channels.
pub channel_gc_interval_secs: u64,
/// Source reconnect grace period in seconds — keep cache alive after source disconnect.
pub source_reconnect_grace_secs: u64,
/// Replay depth: number of frames to keep in ring buffer for time-travel replay.
/// Set to 0 to disable replay (default: 0, catchup-only). Set >0 for full replay.
pub replay_depth: usize,
/// Upstream relay URLs for federation — frames are forwarded to these relays.
pub federation_upstreams: Vec<String>,
/// Recording directory — if set, incoming frames are written to disk.
pub recording_dir: Option<String>,
}
impl Default for RelayConfig {
fn default() -> Self {
Self {
addr: "0.0.0.0:9100".parse().unwrap(),
max_receivers: 64,
frame_buffer_size: 16,
keepalive_interval_secs: 10,
keepalive_timeout_secs: 30,
max_channels: 256,
channel_gc_interval_secs: 60,
source_reconnect_grace_secs: 30,
replay_depth: 0,
federation_upstreams: Vec::new(),
recording_dir: None,
}
}
}
/// Stats tracked by the relay.
#[derive(Debug, Default, Clone)]
pub struct RelayStats {
pub frames_relayed: u64,
pub bytes_relayed: u64,
pub inputs_relayed: u64,
pub connected_receivers: usize,
pub source_connected: bool,
/// Timestamp of last frame from source (milliseconds since start).
pub last_frame_timestamp: u32,
/// Total keyframes sent.
pub keyframes_sent: u64,
/// Total signal diffs sent.
pub signal_diffs_sent: u64,
/// Uptime in seconds.
pub uptime_secs: u64,
/// Peak receiver count.
pub peak_receivers: usize,
/// Total connections served.
pub total_connections: u64,
/// Rejected connections (over limit).
pub rejected_connections: u64,
}
/// Cached state for late-joining receivers.
#[derive(Debug, Default, Clone)]
pub struct StateCache {
/// Last pixel keyframe (full RGBA frame).
pub last_keyframe: Option<Vec<u8>>,
/// Last signal sync frame (full JSON state).
pub last_signal_sync: Option<Vec<u8>>,
/// Accumulated signal diffs since last sync.
/// Late-joining receivers get: last_signal_sync + all diffs.
pub pending_signal_diffs: Vec<Vec<u8>>,
/// Replay ring buffer — stores the last N frames for time-travel replay.
/// When replay_depth > 0, receivers can request historical frames.
pub replay_buffer: Vec<Vec<u8>>,
/// Maximum replay buffer depth (0 = disabled).
pub replay_depth: usize,
}
impl StateCache {
/// Process an incoming frame and update cache.
fn process_frame(&mut self, msg: &[u8]) {
if msg.len() < HEADER_SIZE {
return;
}
let frame_type = msg[0];
let flags = msg[1];
// Add to replay ring buffer if enabled
if self.replay_depth > 0 {
self.replay_buffer.push(msg.to_vec());
if self.replay_buffer.len() > self.replay_depth {
self.replay_buffer.remove(0);
}
}
match FrameType::from_u8(frame_type) {
// Cache keyframes (pixel or signal sync)
Some(FrameType::Pixels) if flags & FLAG_KEYFRAME != 0 => {
self.last_keyframe = Some(msg.to_vec());
}
Some(FrameType::SignalSync) => {
self.last_signal_sync = Some(msg.to_vec());
self.pending_signal_diffs.clear(); // sync resets diffs
}
Some(FrameType::SignalDiff) => {
self.pending_signal_diffs.push(msg.to_vec());
// Cap accumulated diffs to prevent unbounded memory
if self.pending_signal_diffs.len() > 1000 {
self.pending_signal_diffs.drain(..500);
}
}
Some(FrameType::Keyframe) => {
self.last_keyframe = Some(msg.to_vec());
}
_ => {}
}
}
/// Get all messages a late-joining receiver needs to reconstruct current state.
///
/// Instead of replaying hundreds of individual diffs, merges all accumulated
/// diffs into the last sync frame to produce ONE consolidated state message.
/// This dramatically reduces catchup time and bandwidth.
fn catchup_messages(&self) -> Vec<Vec<u8>> {
let mut msgs = Vec::new();
// Strategy: merge all signal diffs into last sync to build a consolidated frame
if let Some(ref sync_frame) = self.last_signal_sync {
if !self.pending_signal_diffs.is_empty() {
// Parse the sync frame's JSON payload
if sync_frame.len() >= HEADER_SIZE {
let payload_len = u32::from_le_bytes([
sync_frame[12], sync_frame[13], sync_frame[14], sync_frame[15],
]) as usize;
let sync_end = HEADER_SIZE + payload_len.min(sync_frame.len() - HEADER_SIZE);
let sync_payload = &sync_frame[HEADER_SIZE..sync_end];
if let Ok(mut merged) = serde_json::from_slice::<serde_json::Value>(sync_payload) {
// Apply each diff on top
for diff_frame in &self.pending_signal_diffs {
if diff_frame.len() < HEADER_SIZE { continue; }
let diff_pl_len = u32::from_le_bytes([
diff_frame[12], diff_frame[13], diff_frame[14], diff_frame[15],
]) as usize;
let diff_end = HEADER_SIZE + diff_pl_len.min(diff_frame.len() - HEADER_SIZE);
let diff_payload = &diff_frame[HEADER_SIZE..diff_end];
if let Ok(diff_obj) = serde_json::from_slice::<serde_json::Value>(diff_payload) {
if let (Some(merged_map), Some(diff_map)) = (merged.as_object_mut(), diff_obj.as_object()) {
for (k, v) in diff_map {
if k != "_pid" && k != "_v" {
merged_map.insert(k.clone(), v.clone());
}
}
// Merge version counters
if let (Some(merged_v), Some(diff_v)) = (
merged_map.get_mut("_v").and_then(|v| v.as_object_mut()),
diff_map.get("_v").and_then(|v| v.as_object()),
) {
for (k, v) in diff_v {
merged_v.insert(k.clone(), v.clone());
}
}
}
}
}
// Build a consolidated SignalSync (0x30) frame
let merged_json = serde_json::to_vec(&merged).unwrap_or_default();
let mut frame = Vec::with_capacity(HEADER_SIZE + merged_json.len());
// Copy sync frame header but set type to SignalSync and update length
frame.extend_from_slice(&sync_frame[..HEADER_SIZE]);
frame[0] = 0x30; // SignalSync
frame[1] = 0x02; // FLAG_KEYFRAME
let len_bytes = (merged_json.len() as u32).to_le_bytes();
frame[12] = len_bytes[0];
frame[13] = len_bytes[1];
frame[14] = len_bytes[2];
frame[15] = len_bytes[3];
frame.truncate(HEADER_SIZE);
frame.extend_from_slice(&merged_json);
msgs.push(frame);
} else {
// Fallback: can't parse, send raw sync + all diffs
msgs.push(sync_frame.clone());
for diff in &self.pending_signal_diffs {
msgs.push(diff.clone());
}
}
} else {
msgs.push(sync_frame.clone());
}
} else {
// No diffs accumulated — just send the sync
msgs.push(sync_frame.clone());
}
} else if !self.pending_signal_diffs.is_empty() {
// No sync frame but we have diffs — merge diffs into one frame
let mut merged = serde_json::Map::new();
let mut merged_v = serde_json::Map::new();
for diff_frame in &self.pending_signal_diffs {
if diff_frame.len() < HEADER_SIZE { continue; }
let diff_pl_len = u32::from_le_bytes([
diff_frame[12], diff_frame[13], diff_frame[14], diff_frame[15],
]) as usize;
let diff_end = HEADER_SIZE + diff_pl_len.min(diff_frame.len() - HEADER_SIZE);
let diff_payload = &diff_frame[HEADER_SIZE..diff_end];
if let Ok(diff_obj) = serde_json::from_slice::<serde_json::Value>(diff_payload) {
if let Some(diff_map) = diff_obj.as_object() {
for (k, v) in diff_map {
if k == "_v" {
if let Some(v_map) = v.as_object() {
for (vk, vv) in v_map {
merged_v.insert(vk.clone(), vv.clone());
}
}
} else if k != "_pid" {
merged.insert(k.clone(), v.clone());
}
}
}
}
}
if !merged_v.is_empty() {
merged.insert("_v".to_string(), serde_json::Value::Object(merged_v));
}
let merged_json = serde_json::to_vec(&serde_json::Value::Object(merged)).unwrap_or_default();
let frame = crate::codec::signal_sync_frame(0, 0, &merged_json);
msgs.push(frame);
}
// Pixel keyframe (if available)
if let Some(ref kf) = self.last_keyframe {
msgs.push(kf.clone());
}
msgs
}
/// Clear all cached state.
#[allow(dead_code)]
fn clear(&mut self) {
self.last_keyframe = None;
self.last_signal_sync = None;
self.pending_signal_diffs.clear();
self.replay_buffer.clear();
}
/// Returns true if this cache has any state.
fn has_state(&self) -> bool {
self.last_keyframe.is_some()
|| self.last_signal_sync.is_some()
|| !self.pending_signal_diffs.is_empty()
|| !self.replay_buffer.is_empty()
}
/// Get the replay buffer frames for time-travel playback.
/// Returns frames from `start_index` onwards.
pub fn replay_frames(&self, start_index: usize) -> &[Vec<u8>] {
if start_index < self.replay_buffer.len() {
&self.replay_buffer[start_index..]
} else {
&[]
}
}
/// Total number of frames in the replay buffer.
pub fn replay_len(&self) -> usize {
self.replay_buffer.len()
}
}
/// Per-channel state — each named channel is an independent stream.
struct ChannelState {
/// Broadcast channel: source → all receivers (frames)
frame_tx: broadcast::Sender<Vec<u8>>,
/// Channel: receivers → source (input events)
input_tx: mpsc::Sender<Vec<u8>>,
input_rx: Option<mpsc::Receiver<Vec<u8>>>,
/// Broadcast channel for WebRTC signaling (SDP/ICE as text)
signaling_tx: broadcast::Sender<String>,
/// Live stats for this channel
stats: RelayStats,
/// Cached state for late-joining receivers
cache: StateCache,
/// When the source last disconnected (for reconnect grace period)
source_disconnect_time: Option<Instant>,
/// Max receivers for this channel
max_receivers: usize,
/// Cached schema announcement (0x32 payload) from source
schema: Option<Vec<u8>>,
}
impl ChannelState {
fn new(frame_buffer_size: usize, max_receivers: usize, replay_depth: usize) -> Self {
let (frame_tx, _) = broadcast::channel(frame_buffer_size);
let (input_tx, input_rx) = mpsc::channel(256);
let (signaling_tx, _) = broadcast::channel(64);
Self {
frame_tx,
input_tx,
input_rx: Some(input_rx),
signaling_tx,
stats: RelayStats::default(),
cache: StateCache {
replay_depth,
..StateCache::default()
},
source_disconnect_time: None,
max_receivers,
schema: None,
}
}
/// Returns true if this channel is idle (no source, no receivers, no cache).
fn is_idle(&self) -> bool {
!self.stats.source_connected
&& self.stats.connected_receivers == 0
&& !self.cache.has_state()
}
/// Returns true if the source reconnect grace period has expired.
fn grace_period_expired(&self, grace_secs: u64) -> bool {
if let Some(disconnect_time) = self.source_disconnect_time {
disconnect_time.elapsed() > Duration::from_secs(grace_secs)
} else {
false
}
}
}
// ─── v1.2: Auth-Gated Channel ───
/// Per-channel authentication state.
/// When `required` is true, sources must send a valid `Auth` frame before
/// the relay will forward their frames.
#[derive(Debug, Clone)]
pub struct ChannelAuth {
required: bool,
/// Pre-shared key for this channel (empty = open)
key: Vec<u8>,
/// Authenticated source addresses (by string identifier)
authenticated: Vec<String>,
}
impl ChannelAuth {
pub fn open() -> Self {
ChannelAuth { required: false, key: Vec::new(), authenticated: Vec::new() }
}
pub fn with_key(key: &[u8]) -> Self {
ChannelAuth { required: true, key: key.to_vec(), authenticated: Vec::new() }
}
/// Check if a source is authenticated.
pub fn is_authenticated(&self, source_id: &str) -> bool {
if !self.required { return true; }
self.authenticated.iter().any(|s| s == source_id)
}
/// Attempt to authenticate a source with a token.
/// Returns true if authentication succeeded.
pub fn authenticate(&mut self, source_id: &str, token: &[u8]) -> bool {
if !self.required { return true; }
if token == self.key.as_slice() {
if !self.authenticated.iter().any(|s| s == source_id) {
self.authenticated.push(source_id.to_string());
}
true
} else {
false
}
}
/// Revoke authentication for a source.
pub fn revoke(&mut self, source_id: &str) {
self.authenticated.retain(|s| s != source_id);
}
pub fn is_required(&self) -> bool { self.required }
pub fn authenticated_count(&self) -> usize { self.authenticated.len() }
}
/// Shared relay state — holds all channels.
struct RelayState {
/// Named channels: "default", "main", "player1", etc.
channels: HashMap<String, Arc<RwLock<ChannelState>>>,
/// Frame buffer size for new channels
frame_buffer_size: usize,
/// Max receivers per channel
max_receivers: usize,
/// Max channels
max_channels: usize,
/// Server start time
start_time: Instant,
/// Replay depth for new channels
replay_depth: usize,
/// Recording directory (None = disabled)
recording_dir: Option<String>,
}
impl RelayState {
fn new(frame_buffer_size: usize, max_receivers: usize, max_channels: usize, replay_depth: usize, recording_dir: Option<String>) -> Self {
Self {
channels: HashMap::new(),
frame_buffer_size,
max_receivers,
max_channels,
start_time: Instant::now(),
replay_depth,
recording_dir,
}
}
/// Get or create a channel by name. Returns None if at max channels.
fn get_or_create_channel(&mut self, name: &str) -> Option<Arc<RwLock<ChannelState>>> {
if self.channels.contains_key(name) {
return Some(self.channels[name].clone());
}
if self.channels.len() >= self.max_channels {
eprintln!("[relay] Max channels ({}) reached, rejecting: {}", self.max_channels, name);
return None;
}
let channel = Arc::new(RwLock::new(ChannelState::new(
self.frame_buffer_size,
self.max_receivers,
self.replay_depth,
)));
self.channels.insert(name.to_string(), channel.clone());
Some(channel)
}
}
/// Parsed connection role from the WebSocket URI path.
#[derive(Debug, Clone)]
enum ConnectionRole {
Source(String), // channel name
Receiver(String), // channel name
Signaling(String), // channel name — WebRTC signaling
Peer(String), // channel name — bidirectional sync
Meta(String), // channel name — HTTP introspection
}
/// Parse the WebSocket URI path to determine connection role and channel.
fn parse_path(path: &str) -> ConnectionRole {
let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
match parts.as_slice() {
["source"] => ConnectionRole::Source("default".to_string()),
["source", name] => ConnectionRole::Source(name.to_string()),
["stream"] => ConnectionRole::Receiver("default".to_string()),
["stream", name] => ConnectionRole::Receiver(name.to_string()),
["signal"] => ConnectionRole::Signaling("default".to_string()),
["signal", name] => ConnectionRole::Signaling(name.to_string()),
["peer"] => ConnectionRole::Peer("default".to_string()),
["peer", name] => ConnectionRole::Peer(name.to_string()),
["meta"] => ConnectionRole::Meta("default".to_string()),
["meta", name] => ConnectionRole::Meta(name.to_string()),
_ => ConnectionRole::Receiver("default".to_string()),
}
}
/// Check if a channel name matches a wildcard pattern (channel groups).
/// Supports `*` at the end for prefix matching: `games/*` matches `games/chess`.
pub fn channel_matches(pattern: &str, channel: &str) -> bool {
if pattern == channel {
return true;
}
if let Some(prefix) = pattern.strip_suffix("/*") {
return channel.starts_with(prefix) && channel.len() > prefix.len();
}
if pattern == "*" {
return true;
}
false
}
/// Find all channels matching a wildcard pattern.
#[cfg(test)]
fn find_matching_channels(state: &RelayState, pattern: &str) -> Vec<String> {
state.channels.keys()
.filter(|name| channel_matches(pattern, name))
.cloned()
.collect()
}
/// Run the WebSocket relay server.
pub async fn run_relay(config: RelayConfig) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(&config.addr).await?;
eprintln!("╔══════════════════════════════════════════════════╗");
eprintln!("║ DreamStack Bitstream Relay v0.3.0 ║");
eprintln!("║ ║");
eprintln!("║ Source: ws://{}/source/{{name}}", config.addr);
eprintln!("║ Receiver: ws://{}/stream/{{name}}", config.addr);
eprintln!("║ Signal: ws://{}/signal/{{name}}", config.addr);
eprintln!("║ ║");
eprintln!("║ Max receivers/ch: {:>4}", config.max_receivers);
eprintln!("║ Max channels: {:>4}", config.max_channels);
eprintln!("║ Reconnect grace: {:>4}s ║", config.source_reconnect_grace_secs);
eprintln!("╚══════════════════════════════════════════════════╝");
let grace_secs = config.source_reconnect_grace_secs;
let replay_depth = config.replay_depth;
let recording_dir = config.recording_dir.clone();
let federation_upstreams = config.federation_upstreams.clone();
let state = Arc::new(RwLock::new(RelayState::new(
config.frame_buffer_size,
config.max_receivers,
config.max_channels,
config.replay_depth,
config.recording_dir.clone(),
)));
// Log the feature status
if replay_depth > 0 {
eprintln!("║ Replay depth: {:>4} frames ║", replay_depth);
}
if let Some(ref dir) = recording_dir {
eprintln!("║ Recording to: {}", dir);
// Ensure recording directory exists
std::fs::create_dir_all(dir).unwrap_or_else(|e| {
eprintln!("[relay] Warning: could not create recording dir {}: {}", dir, e);
});
}
if !federation_upstreams.is_empty() {
eprintln!("║ Federation: {} upstream(s) ║", federation_upstreams.len());
}
// Background: periodic stats + channel GC
{
let state = state.clone();
let gc_interval = config.channel_gc_interval_secs;
tokio::spawn(async move {
let mut tick = interval(Duration::from_secs(gc_interval.min(30)));
let mut gc_counter: u64 = 0;
loop {
tick.tick().await;
gc_counter += 1;
let s = state.read().await;
let uptime = s.start_time.elapsed().as_secs();
// Stats logging
for (name, channel) in &s.channels {
let cs = channel.read().await;
if cs.stats.source_connected || cs.stats.connected_receivers > 0 {
eprintln!(
"[relay:{name}] up={uptime}s frames={} bytes={} receivers={}/{} peak={} signal_diffs={} cached={}",
cs.stats.frames_relayed,
cs.stats.bytes_relayed,
cs.stats.connected_receivers,
cs.max_receivers,
cs.stats.peak_receivers,
cs.stats.signal_diffs_sent,
cs.cache.has_state(),
);
}
}
drop(s);
// Channel GC — every gc_interval ticks
if gc_counter % (gc_interval / gc_interval.min(30)).max(1) == 0 {
let mut s = state.write().await;
let before = s.channels.len();
let mut to_remove = Vec::new();
for (name, channel) in &s.channels {
let cs = channel.read().await;
if cs.is_idle() || cs.grace_period_expired(grace_secs) {
to_remove.push(name.clone());
}
}
for name in &to_remove {
s.channels.remove(name);
}
if !to_remove.is_empty() {
eprintln!(
"[relay] GC: removed {} idle channel(s) ({}{}): {:?}",
to_remove.len(),
before,
s.channels.len(),
to_remove
);
}
}
}
});
}
// Background: federation forwarding to upstream relays
for upstream_url in &federation_upstreams {
let url = upstream_url.clone();
let state = state.clone();
tokio::spawn(async move {
let mut backoff = Duration::from_secs(1);
loop {
eprintln!("[relay:federation] Connecting to upstream: {url}");
match tokio_tungstenite::connect_async(&url).await {
Ok((mut ws, _)) => {
eprintln!("[relay:federation] Connected to {url}");
backoff = Duration::from_secs(1);
// Subscribe to the default channel and forward frames
let frame_rx = {
let s = state.read().await;
if let Some(ch) = s.channels.get("default") {
let cs = ch.read().await;
Some(cs.frame_tx.subscribe())
} else {
None
}
};
if let Some(mut rx) = frame_rx {
while let Ok(frame) = rx.recv().await {
use futures_util::SinkExt;
if ws.send(Message::Binary(frame.into())).await.is_err() {
break;
}
}
}
eprintln!("[relay:federation] Disconnected from {url}");
}
Err(e) => {
eprintln!("[relay:federation] Failed to connect to {url}: {e}");
}
}
// Exponential backoff on reconnect
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(Duration::from_secs(30));
}
});
}
while let Ok((stream, addr)) = listener.accept().await {
let state = state.clone();
let keepalive_interval = config.keepalive_interval_secs;
let keepalive_timeout = config.keepalive_timeout_secs;
tokio::spawn(async move {
// Peek at the HTTP request to check for /meta requests
// Meta requests get a raw HTTP JSON response, not a WS upgrade
let mut peek_buf = [0u8; 512];
let n = match stream.peek(&mut peek_buf).await {
Ok(n) => n,
Err(_) => return,
};
let request_line = String::from_utf8_lossy(&peek_buf[..n]);
// Extract path from "GET /meta/name HTTP/1.1"
if let Some(path) = request_line.lines().next().and_then(|line| {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 { Some(parts[1]) } else { None }
}) {
if path.starts_with("/meta") {
handle_meta_http(stream, addr, &state, path).await;
return;
}
}
handle_connection(stream, addr, state, keepalive_interval, keepalive_timeout).await;
});
}
Ok(())
}
/// Handle an HTTP /meta/{channel} request — return channel stats as JSON.
///
/// This is NOT a WebSocket connection. We read the full HTTP request,
/// then respond with an HTTP 200 + JSON body.
async fn handle_meta_http(
mut stream: TcpStream,
addr: SocketAddr,
state: &Arc<RwLock<RelayState>>,
path: &str,
) {
use tokio::io::AsyncWriteExt;
use tokio::io::AsyncReadExt;
// Drain the HTTP request from the socket
let mut buf = vec![0u8; 4096];
let _ = stream.read(&mut buf).await;
// Parse channel name from path
let channel_name = path.trim_start_matches("/meta/").trim_matches('/');
let channel_name = if channel_name.is_empty() { "__all__" } else { channel_name };
let s = state.read().await;
let uptime = s.start_time.elapsed().as_secs();
let body = if channel_name == "__all__" {
// Return stats for ALL channels
let mut channels = serde_json::Map::new();
for (name, channel) in &s.channels {
let cs = channel.read().await;
channels.insert(name.clone(), serde_json::json!({
"source_connected": cs.stats.source_connected,
"receivers": cs.stats.connected_receivers,
"peak_receivers": cs.stats.peak_receivers,
"frames_relayed": cs.stats.frames_relayed,
"bytes_relayed": cs.stats.bytes_relayed,
"signal_diffs": cs.stats.signal_diffs_sent,
"keyframes": cs.stats.keyframes_sent,
"total_connections": cs.stats.total_connections,
"rejected_connections": cs.stats.rejected_connections,
"cached": cs.cache.has_state(),
"pending_diffs": cs.cache.pending_signal_diffs.len(),
}));
}
serde_json::json!({
"uptime_secs": uptime,
"channel_count": s.channels.len(),
"channels": serde_json::Value::Object(channels),
}).to_string()
} else if let Some(channel) = s.channels.get(channel_name) {
// Return stats for a specific channel
let cs = channel.read().await;
// Extract current signal state if available
let mut current_state = serde_json::Value::Null;
if let Some(ref sync_frame) = cs.cache.last_signal_sync {
if sync_frame.len() >= HEADER_SIZE {
let pl_len = u32::from_le_bytes([
sync_frame[12], sync_frame[13], sync_frame[14], sync_frame[15],
]) as usize;
let end = HEADER_SIZE + pl_len.min(sync_frame.len() - HEADER_SIZE);
if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&sync_frame[HEADER_SIZE..end]) {
current_state = val;
}
}
}
serde_json::json!({
"channel": channel_name,
"uptime_secs": uptime,
"source_connected": cs.stats.source_connected,
"receivers": cs.stats.connected_receivers,
"max_receivers": cs.max_receivers,
"peak_receivers": cs.stats.peak_receivers,
"frames_relayed": cs.stats.frames_relayed,
"bytes_relayed": cs.stats.bytes_relayed,
"signal_diffs": cs.stats.signal_diffs_sent,
"keyframes": cs.stats.keyframes_sent,
"total_connections": cs.stats.total_connections,
"rejected_connections": cs.stats.rejected_connections,
"cached": cs.cache.has_state(),
"pending_diffs": cs.cache.pending_signal_diffs.len(),
"schema": cs.schema.as_ref().and_then(|s| {
if s.len() >= HEADER_SIZE {
let pl_len = u32::from_le_bytes([s[12], s[13], s[14], s[15]]) as usize;
let end = HEADER_SIZE + pl_len.min(s.len() - HEADER_SIZE);
serde_json::from_slice::<serde_json::Value>(&s[HEADER_SIZE..end]).ok()
} else { None }
}),
"current_state": current_state,
}).to_string()
} else {
serde_json::json!({
"error": "channel not found",
"channel": channel_name,
}).to_string()
};
// Write HTTP response
let response = format!(
"HTTP/1.1 200 OK\r\n\
Content-Type: application/json\r\n\
Access-Control-Allow-Origin: *\r\n\
Content-Length: {}\r\n\
Connection: close\r\n\
\r\n\
{}",
body.len(),
body
);
let _ = stream.write_all(response.as_bytes()).await;
eprintln!("[relay:meta] {addr}{path}");
}
async fn handle_connection(
stream: TcpStream,
addr: SocketAddr,
state: Arc<RwLock<RelayState>>,
keepalive_interval: u64,
_keepalive_timeout: u64,
) {
// Extract the URI path during handshake to determine role
let mut role = ConnectionRole::Receiver("default".to_string());
// Peek at the path to check for meta requests (handle via raw HTTP, not WS)
let ws_stream = match tokio_tungstenite::accept_hdr_async(
stream,
|req: &tokio_tungstenite::tungstenite::handshake::server::Request,
res: tokio_tungstenite::tungstenite::handshake::server::Response| {
role = parse_path(req.uri().path());
Ok(res)
},
)
.await
{
Ok(ws) => ws,
Err(e) => {
// Meta requests won't have an Upgrade header, so they fail the handshake
// That's fine — they were handled inline if detected.
if !matches!(role, ConnectionRole::Meta(_)) {
eprintln!("[relay] WebSocket handshake failed from {}: {}", addr, e);
}
return;
}
};
// Handle meta requests separately (they don't use WebSocket)
if let ConnectionRole::Meta(ref _channel_name) = role {
// Meta requests fail the WS handshake, so we'll never reach here
// But just in case, close the connection
return;
}
// Get or create the channel
let channel_name = match &role {
ConnectionRole::Source(n) | ConnectionRole::Receiver(n) | ConnectionRole::Signaling(n) | ConnectionRole::Peer(n) | ConnectionRole::Meta(n) => n.clone(),
};
let channel = {
let mut s = state.write().await;
match s.get_or_create_channel(&channel_name) {
Some(ch) => ch,
None => {
eprintln!("[relay] Rejected connection from {addr}: max channels reached");
return;
}
}
};
// Check receiver limit
if let ConnectionRole::Receiver(_) = &role {
let mut cs = channel.write().await;
if cs.stats.connected_receivers >= cs.max_receivers {
eprintln!(
"[relay:{channel_name}] Rejected receiver {addr}: at capacity ({}/{})",
cs.stats.connected_receivers, cs.max_receivers
);
cs.stats.rejected_connections += 1;
return;
}
}
// Legacy fallback: if path was `/` and no source exists, treat as source
let role = match role {
ConnectionRole::Receiver(ref name) if name == "default" => {
let cs = channel.read().await;
if !cs.stats.source_connected {
ConnectionRole::Source("default".to_string())
} else {
role
}
}
_ => role,
};
match role {
ConnectionRole::Source(ref _name) => {
eprintln!("[relay:{channel_name}] Source connected: {addr}");
// Get recording_dir from relay state
let recording_dir = {
let s = state.read().await;
s.recording_dir.clone()
};
handle_source(ws_stream, addr, channel, &channel_name, keepalive_interval, recording_dir).await;
}
ConnectionRole::Receiver(ref _name) => {
eprintln!("[relay:{channel_name}] Receiver connected: {addr}");
handle_receiver(ws_stream, addr, channel, &channel_name).await;
}
ConnectionRole::Signaling(ref _name) => {
eprintln!("[relay:{channel_name}] Signaling peer connected: {addr}");
handle_signaling(ws_stream, addr, channel, &channel_name).await;
}
ConnectionRole::Peer(ref _name) => {
eprintln!("[relay:{channel_name}] Peer connected: {addr}");
handle_peer(ws_stream, addr, channel, &channel_name).await;
}
ConnectionRole::Meta(_) => {
// Handled before WS handshake — should never reach here
}
}
}
async fn handle_source(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
keepalive_interval: u64,
recording_dir: Option<String>,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Mark source as connected; if reconnecting, preserve cache
let input_rx = {
let mut cs = channel.write().await;
if cs.source_disconnect_time.is_some() {
eprintln!("[relay:{channel_name}] Source reconnected — cache preserved ({} diffs, sync={})",
cs.cache.pending_signal_diffs.len(),
cs.cache.last_signal_sync.is_some(),
);
}
cs.stats.source_connected = true;
cs.stats.total_connections += 1;
cs.source_disconnect_time = None;
cs.input_rx.take()
};
let frame_tx = {
let cs = channel.read().await;
cs.frame_tx.clone()
};
// Forward input events from receivers → source
if let Some(mut input_rx) = input_rx {
let channel_clone = channel.clone();
tokio::spawn(async move {
while let Some(input_bytes) = input_rx.recv().await {
let msg = Message::Binary(input_bytes.into());
if ws_sink.send(msg).await.is_err() {
break;
}
let mut cs = channel_clone.write().await;
cs.stats.inputs_relayed += 1;
}
});
}
// Keepalive: periodic pings
let channel_ping = channel.clone();
let ping_task = tokio::spawn(async move {
let mut tick = interval(Duration::from_secs(keepalive_interval));
let mut seq = 0u16;
loop {
tick.tick().await;
let cs = channel_ping.read().await;
if !cs.stats.source_connected {
break;
}
let _ = cs.frame_tx.send(crate::codec::ping(seq, 0));
drop(cs);
seq = seq.wrapping_add(1);
}
});
// Open recording file if configured
let mut recording_file = if let Some(ref dir) = recording_dir {
let path = format!("{}/{}.dsrec", dir, channel_name.replace('/', "_"));
match tokio::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.await
{
Ok(f) => {
eprintln!("[relay:{channel_name}] Recording to {path}");
Some(f)
}
Err(e) => {
eprintln!("[relay:{channel_name}] Warning: could not open recording file {path}: {e}");
None
}
}
} else {
None
};
// Receive frames from source → broadcast to receivers
let channel_name_owned = channel_name.to_string();
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let data_vec: Vec<u8> = data.into();
{
let mut cs = channel.write().await;
cs.stats.frames_relayed += 1;
cs.stats.bytes_relayed += data_vec.len() as u64;
// Update state cache
cs.cache.process_frame(&data_vec);
// Track frame-type-specific stats
if data_vec.len() >= HEADER_SIZE {
match FrameType::from_u8(data_vec[0]) {
Some(FrameType::SignalDiff) => cs.stats.signal_diffs_sent += 1,
Some(FrameType::Pixels) | Some(FrameType::Keyframe) |
Some(FrameType::SignalSync) => cs.stats.keyframes_sent += 1,
_ => {}
}
let ts = u32::from_le_bytes([
data_vec[4], data_vec[5], data_vec[6], data_vec[7],
]);
cs.stats.last_frame_timestamp = ts;
}
}
// Record frame to disk (length-delimited: [u32 len][frame bytes])
if let Some(ref mut file) = recording_file {
use tokio::io::AsyncWriteExt;
let len_bytes = (data_vec.len() as u32).to_le_bytes();
let _ = file.write_all(&len_bytes).await;
let _ = file.write_all(&data_vec).await;
}
// Broadcast to all receivers on this channel
let _ = frame_tx.send(data_vec);
}
}
// Source disconnected — start grace period, keep cache
ping_task.abort();
eprintln!("[relay:{channel_name_owned}] Source disconnected: {addr} (cache preserved for reconnect)");
let mut cs = channel.write().await;
cs.stats.source_connected = false;
cs.source_disconnect_time = Some(Instant::now());
// Recreate input channel for next source connection
let (new_input_tx, new_input_rx) = mpsc::channel(256);
cs.input_tx = new_input_tx;
cs.input_rx = Some(new_input_rx);
}
async fn handle_receiver(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Subscribe to frame broadcast and get catchup + schema
let (mut frame_rx, catchup_msgs, schema_frame) = {
let mut cs = channel.write().await;
cs.stats.connected_receivers += 1;
cs.stats.total_connections += 1;
if cs.stats.connected_receivers > cs.stats.peak_receivers {
cs.stats.peak_receivers = cs.stats.connected_receivers;
}
let rx = cs.frame_tx.subscribe();
let catchup = cs.cache.catchup_messages();
let schema = cs.schema.clone();
(rx, catchup, schema)
};
let input_tx = {
let cs = channel.read().await;
cs.input_tx.clone()
};
// Send cached schema (0x32) first so receiver knows what signals are available
let channel_name_owned = channel_name.to_string();
if let Some(schema_bytes) = schema_frame {
let _ = ws_sink.send(Message::Binary(schema_bytes.into())).await;
}
// Send cached state to late-joining receiver
if !catchup_msgs.is_empty() {
eprintln!(
"[relay:{channel_name_owned}] Sending {} catchup messages to {addr}",
catchup_msgs.len(),
);
for msg_bytes in catchup_msgs {
let msg = Message::Binary(msg_bytes.into());
if ws_sink.send(msg).await.is_err() {
eprintln!("[relay:{channel_name_owned}] Failed to send catchup to {addr}");
let mut cs = channel.write().await;
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
return;
}
}
}
// Per-receiver signal filter (set via 0x33 SubscribeFilter)
let filter: Arc<tokio::sync::RwLock<Option<std::collections::HashSet<String>>>> =
Arc::new(tokio::sync::RwLock::new(None));
let filter_for_send = filter.clone();
// Forward frames from broadcast → this receiver (with optional filtering)
let send_task = tokio::spawn(async move {
loop {
match frame_rx.recv().await {
Ok(frame_bytes) => {
// Check if we need to filter signal frames
let filter_guard = filter_for_send.read().await;
let should_filter = filter_guard.is_some()
&& frame_bytes.len() >= 16
&& (frame_bytes[0] == 0x30 || frame_bytes[0] == 0x31);
if should_filter {
if let Some(ref wanted) = *filter_guard {
// Parse JSON payload, keep only wanted keys
let payload_len = u32::from_le_bytes([
frame_bytes[12], frame_bytes[13],
frame_bytes[14], frame_bytes[15],
]) as usize;
if frame_bytes.len() >= 16 + payload_len {
let payload = &frame_bytes[16..16 + payload_len];
if let Ok(mut obj) = serde_json::from_slice::<serde_json::Value>(payload) {
if let Some(map) = obj.as_object_mut() {
let keys: Vec<String> = map.keys().cloned().collect();
for k in keys {
// Keep _pid, _v (internal), and wanted fields
if k != "_pid" && k != "_v" && !wanted.contains(&k) {
map.remove(&k);
}
}
// Re-encode
let new_payload = serde_json::to_vec(&obj).unwrap_or_default();
let mut new_frame = Vec::with_capacity(16 + new_payload.len());
new_frame.extend_from_slice(&frame_bytes[..12]);
new_frame.extend_from_slice(&(new_payload.len() as u32).to_le_bytes());
new_frame.extend_from_slice(&new_payload);
drop(filter_guard);
if ws_sink.send(Message::Binary(new_frame.into())).await.is_err() {
break;
}
continue;
}
}
}
}
}
drop(filter_guard);
let msg = Message::Binary(frame_bytes.into());
if ws_sink.send(msg).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("[relay] Receiver {addr} lagged by {n} frames");
}
Err(_) => break,
}
}
});
// Handle incoming messages from receiver (input events + 0x33 filter)
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let data_vec: Vec<u8> = data.into();
// Parse SubscribeFilter (0x33)
if data_vec.len() >= 16 && data_vec[0] == 0x33 {
let payload_len = u32::from_le_bytes([
data_vec[12], data_vec[13], data_vec[14], data_vec[15],
]) as usize;
if data_vec.len() >= 16 + payload_len {
let payload = &data_vec[16..16 + payload_len];
if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(payload) {
if let Some(select_arr) = obj.get("select").and_then(|v| v.as_array()) {
let wanted: std::collections::HashSet<String> = select_arr
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
eprintln!(
"[relay:{channel_name_owned}] Receiver {addr} subscribed to: {:?}",
wanted
);
let mut f = filter.write().await;
*f = Some(wanted);
}
}
}
continue; // Don't forward filter frames to source
}
// Forward input events to source
let _ = input_tx.send(data_vec).await;
}
}
// Receiver disconnected
send_task.abort();
eprintln!("[relay:{channel_name_owned}] Receiver disconnected: {addr}");
let mut cs = channel.write().await;
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
}
/// Handle a WebRTC signaling connection.
///
/// Signaling peers exchange JSON messages (SDP offers/answers, ICE candidates)
/// over WebSocket. The relay broadcasts all text messages to all other peers
/// on the same channel, enabling peer-to-peer WebRTC setup.
async fn handle_signaling(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Subscribe to signaling broadcast
let mut sig_rx = {
let cs = channel.read().await;
cs.signaling_tx.subscribe()
};
let sig_tx = {
let cs = channel.read().await;
cs.signaling_tx.clone()
};
// Forward signaling messages from broadcast → this peer
let channel_name_owned = channel_name.to_string();
let send_task = tokio::spawn(async move {
loop {
match sig_rx.recv().await {
Ok(msg_text) => {
let msg = Message::Text(msg_text.into());
if ws_sink.send(msg).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("[relay:{channel_name_owned}] Signaling peer lagged by {n} messages");
}
Err(_) => break,
}
}
});
// Forward signaling messages from this peer → broadcast to others
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Text(text) = msg {
let text_str: String = text.into();
let _ = sig_tx.send(text_str);
}
}
send_task.abort();
eprintln!("[relay:{channel_name}] Signaling peer disconnected: {addr}");
}
/// Handle a peer connection for bidirectional sync.
///
/// Peers are equal — every binary frame sent by any peer is broadcast to all
/// other peers on the same channel. No source/receiver distinction.
async fn handle_peer(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Subscribe to the frame broadcast channel (same one sources use)
let mut frame_rx = {
let cs = channel.read().await;
cs.frame_tx.subscribe()
};
let frame_tx = {
let cs = channel.read().await;
cs.frame_tx.clone()
};
// Track peer count
{
let mut cs = channel.write().await;
cs.stats.connected_receivers += 1;
cs.stats.peak_receivers = cs.stats.peak_receivers.max(cs.stats.connected_receivers);
}
// Send catchup state to late-joining peers
{
let cs = channel.read().await;
if cs.cache.has_state() {
let msgs = cs.cache.catchup_messages();
eprintln!("[relay:{channel_name}] Sending {} catchup messages to peer {addr}", msgs.len());
for msg in msgs {
let _ = ws_sink.send(Message::Binary(msg.into())).await;
}
}
}
// Forward frames from broadcast → this peer
let channel_name_owned = channel_name.to_string();
let send_task = tokio::spawn(async move {
loop {
match frame_rx.recv().await {
Ok(frame_data) => {
if ws_sink.send(Message::Binary(frame_data.into())).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("[relay:{channel_name_owned}] Peer lagged by {n} frames");
}
Err(_) => break,
}
}
});
// Forward frames from this peer → broadcast to all others
let channel_for_cache = channel.clone();
let _channel_name_cache = channel_name.to_string();
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let data_vec: Vec<u8> = data.into();
// Intercept schema announcement (0x32) — cache, don't broadcast
if data_vec.len() >= 16 && data_vec[0] == 0x32 {
let payload_len = u32::from_le_bytes([data_vec[12], data_vec[13], data_vec[14], data_vec[15]]) as usize;
if data_vec.len() >= 16 + payload_len {
let mut cs = channel_for_cache.write().await;
cs.schema = Some(data_vec.clone());
}
continue; // Don't broadcast schema frames
}
// Cache for late joiners
{
let mut cs = channel_for_cache.write().await;
cs.cache.process_frame(&data_vec);
cs.stats.frames_relayed += 1;
cs.stats.bytes_relayed += data_vec.len() as u64;
}
// Broadcast to all peers (they'll filter out their own msgs by content)
let _ = frame_tx.send(data_vec);
}
}
send_task.abort();
{
let mut cs = channel.write().await;
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
}
eprintln!("[relay:{channel_name}] Peer disconnected: {addr}");
}
// ─── Tests ───
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config() {
let config = RelayConfig::default();
assert_eq!(config.addr, "0.0.0.0:9100".parse::<SocketAddr>().unwrap());
assert_eq!(config.max_receivers, 64);
assert_eq!(config.frame_buffer_size, 16);
assert_eq!(config.keepalive_interval_secs, 10);
assert_eq!(config.max_channels, 256);
assert_eq!(config.channel_gc_interval_secs, 60);
assert_eq!(config.source_reconnect_grace_secs, 30);
}
#[test]
fn stats_default() {
let stats = RelayStats::default();
assert_eq!(stats.frames_relayed, 0);
assert_eq!(stats.bytes_relayed, 0);
assert!(!stats.source_connected);
assert_eq!(stats.keyframes_sent, 0);
assert_eq!(stats.signal_diffs_sent, 0);
assert_eq!(stats.peak_receivers, 0);
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.rejected_connections, 0);
}
#[test]
fn state_cache_signal_sync() {
let mut cache = StateCache::default();
// Simulate a SignalSync frame
let sync_json = br#"{"count":0}"#;
let sync_msg = crate::codec::signal_sync_frame(0, 100, sync_json);
cache.process_frame(&sync_msg);
assert!(cache.last_signal_sync.is_some());
assert!(cache.has_state());
assert_eq!(cache.pending_signal_diffs.len(), 0);
// Simulate a SignalDiff
let diff_json = br#"{"count":1}"#;
let diff_msg = crate::codec::signal_diff_frame(1, 200, diff_json);
cache.process_frame(&diff_msg);
assert_eq!(cache.pending_signal_diffs.len(), 1);
// Catchup should contain ONE merged frame (sync + diff consolidated)
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 1);
// Verify the merged payload contains the updated count
let frame = &catchup[0];
assert!(frame.len() >= HEADER_SIZE);
let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize;
let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len];
let merged: serde_json::Value = serde_json::from_slice(payload).unwrap();
assert_eq!(merged["count"], 1); // diff applied on top of sync
}
#[test]
fn state_cache_signal_sync_resets_diffs() {
let mut cache = StateCache::default();
// Add some diffs
for i in 0..5 {
let json = format!(r#"{{"count":{}}}"#, i);
let msg = crate::codec::signal_diff_frame(i, i as u32 * 100, json.as_bytes());
cache.process_frame(&msg);
}
assert_eq!(cache.pending_signal_diffs.len(), 5);
// A new sync should clear diffs
let sync = crate::codec::signal_sync_frame(10, 1000, b"{}");
cache.process_frame(&sync);
assert_eq!(cache.pending_signal_diffs.len(), 0);
assert!(cache.last_signal_sync.is_some());
}
#[test]
fn state_cache_keyframe() {
let mut cache = StateCache::default();
// Simulate a keyframe pixel frame
let kf = crate::codec::pixel_frame(0, 100, 10, 10, &vec![0xFF; 400]);
cache.process_frame(&kf);
assert!(cache.last_keyframe.is_some());
assert!(cache.has_state());
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 1);
}
#[test]
fn state_cache_diff_cap() {
let mut cache = StateCache::default();
// Add more than 1000 diffs — should auto-trim
for i in 0..1100u16 {
let json = format!(r#"{{"n":{}}}"#, i);
let msg = crate::codec::signal_diff_frame(i, i as u32, json.as_bytes());
cache.process_frame(&msg);
}
// Should have been trimmed: 1100 - 500 = 600 remaining after first trim at 1001
assert!(cache.pending_signal_diffs.len() <= 600);
}
#[test]
fn state_cache_clear() {
let mut cache = StateCache::default();
let sync = crate::codec::signal_sync_frame(0, 0, b"{}");
cache.process_frame(&sync);
assert!(cache.has_state());
cache.clear();
assert!(!cache.has_state());
assert!(cache.last_signal_sync.is_none());
}
// ─── Path Routing Tests ───
#[test]
fn parse_path_source_default() {
match parse_path("/source") {
ConnectionRole::Source(name) => assert_eq!(name, "default"),
_ => panic!("Expected Source"),
}
}
#[test]
fn parse_path_source_named() {
match parse_path("/source/main") {
ConnectionRole::Source(name) => assert_eq!(name, "main"),
_ => panic!("Expected Source"),
}
}
#[test]
fn parse_path_stream_default() {
match parse_path("/stream") {
ConnectionRole::Receiver(name) => assert_eq!(name, "default"),
_ => panic!("Expected Receiver"),
}
}
#[test]
fn parse_path_stream_named() {
match parse_path("/stream/player1") {
ConnectionRole::Receiver(name) => assert_eq!(name, "player1"),
_ => panic!("Expected Receiver"),
}
}
#[test]
fn parse_path_signal_default() {
match parse_path("/signal") {
ConnectionRole::Signaling(name) => assert_eq!(name, "default"),
_ => panic!("Expected Signaling"),
}
}
#[test]
fn parse_path_signal_named() {
match parse_path("/signal/main") {
ConnectionRole::Signaling(name) => assert_eq!(name, "main"),
_ => panic!("Expected Signaling"),
}
}
#[test]
fn parse_path_legacy_root() {
match parse_path("/") {
ConnectionRole::Receiver(name) => assert_eq!(name, "default"),
_ => panic!("Expected Receiver for legacy path"),
}
}
// ─── v0.14: Peer + Meta Path Tests ───
#[test]
fn parse_path_peer_default() {
match parse_path("/peer") {
ConnectionRole::Peer(name) => assert_eq!(name, "default"),
_ => panic!("Expected Peer"),
}
}
#[test]
fn parse_path_peer_named() {
match parse_path("/peer/room42") {
ConnectionRole::Peer(name) => assert_eq!(name, "room42"),
_ => panic!("Expected Peer"),
}
}
#[test]
fn parse_path_meta_default() {
match parse_path("/meta") {
ConnectionRole::Meta(name) => assert_eq!(name, "default"),
_ => panic!("Expected Meta"),
}
}
#[test]
fn parse_path_meta_named() {
match parse_path("/meta/stats") {
ConnectionRole::Meta(name) => assert_eq!(name, "stats"),
_ => panic!("Expected Meta"),
}
}
// ─── Channel State Tests ───
#[test]
fn channel_state_creation() {
let mut state = RelayState::new(16, 64, 256, 0, None);
let ch1 = state.get_or_create_channel("main").unwrap();
let ch2 = state.get_or_create_channel("player1").unwrap();
let ch1_again = state.get_or_create_channel("main").unwrap();
assert_eq!(state.channels.len(), 2);
assert!(Arc::ptr_eq(&ch1, &ch1_again));
assert!(!Arc::ptr_eq(&ch1, &ch2));
}
#[test]
fn channel_max_limit() {
let mut state = RelayState::new(16, 64, 2, 0, None);
assert!(state.get_or_create_channel("a").is_some());
assert!(state.get_or_create_channel("b").is_some());
assert!(state.get_or_create_channel("c").is_none()); // max reached
assert!(state.get_or_create_channel("a").is_some()); // existing OK
}
#[test]
fn channel_idle_detection() {
let cs = ChannelState::new(16, 64, 0);
assert!(cs.is_idle()); // no source, no receivers, no cache
}
#[test]
fn channel_not_idle_with_cache() {
let mut cs = ChannelState::new(16, 64, 0);
let sync = crate::codec::signal_sync_frame(0, 0, b"{}");
cs.cache.process_frame(&sync);
assert!(!cs.is_idle()); // has cached state
}
#[test]
fn channel_not_idle_with_source() {
let mut cs = ChannelState::new(16, 64, 0);
cs.stats.source_connected = true;
assert!(!cs.is_idle());
}
#[test]
fn channel_not_idle_with_receivers() {
let mut cs = ChannelState::new(16, 64, 0);
cs.stats.connected_receivers = 1;
assert!(!cs.is_idle());
}
#[test]
fn grace_period_not_expired_initially() {
let cs = ChannelState::new(16, 64, 0);
assert!(!cs.grace_period_expired(30));
}
#[test]
fn grace_period_expired_after_disconnect() {
let mut cs = ChannelState::new(16, 64, 0);
cs.source_disconnect_time = Some(Instant::now() - Duration::from_secs(60));
assert!(cs.grace_period_expired(30));
}
// ─── Diff Merging Tests ───
#[test]
fn catchup_merges_multiple_diffs() {
let mut cache = StateCache::default();
// Start with sync: { count: 0, name: "test" }
let sync = crate::codec::signal_sync_frame(0, 0, br#"{"count":0,"name":"test"}"#);
cache.process_frame(&sync);
// Apply 3 diffs
let d1 = crate::codec::signal_diff_frame(1, 100, br#"{"count":1}"#);
let d2 = crate::codec::signal_diff_frame(2, 200, br#"{"count":2}"#);
let d3 = crate::codec::signal_diff_frame(3, 300, br#"{"count":3,"name":"updated"}"#);
cache.process_frame(&d1);
cache.process_frame(&d2);
cache.process_frame(&d3);
// Should produce 1 merged frame, not 4
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 1);
// Verify merged state
let frame = &catchup[0];
let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize;
let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len];
let merged: serde_json::Value = serde_json::from_slice(payload).unwrap();
assert_eq!(merged["count"], 3);
assert_eq!(merged["name"], "updated");
}
#[test]
fn catchup_diffs_only_no_sync() {
let mut cache = StateCache::default();
// Only diffs, no initial sync (first connection scenario)
let d1 = crate::codec::signal_diff_frame(0, 0, br#"{"mood":"happy"}"#);
let d2 = crate::codec::signal_diff_frame(1, 100, br#"{"energy":75}"#);
cache.process_frame(&d1);
cache.process_frame(&d2);
// Should produce 1 merged frame
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 1);
let frame = &catchup[0];
let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize;
let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len];
let merged: serde_json::Value = serde_json::from_slice(payload).unwrap();
assert_eq!(merged["mood"], "happy");
assert_eq!(merged["energy"], 75);
}
#[test]
fn catchup_preserves_version_counters() {
let mut cache = StateCache::default();
let sync = crate::codec::signal_sync_frame(0, 0, br#"{"count":0,"_v":{"count":0}}"#);
cache.process_frame(&sync);
let d1 = crate::codec::signal_diff_frame(1, 100, br#"{"count":5,"_v":{"count":3}}"#);
cache.process_frame(&d1);
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 1);
let frame = &catchup[0];
let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize;
let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len];
let merged: serde_json::Value = serde_json::from_slice(payload).unwrap();
assert_eq!(merged["count"], 5);
assert_eq!(merged["_v"]["count"], 3); // version preserved from diff
}
// ─── Channel Wildcard Tests ───
#[test]
fn channel_matches_exact() {
assert!(channel_matches("games/chess", "games/chess"));
assert!(!channel_matches("games/chess", "games/go"));
}
#[test]
fn channel_matches_wildcard() {
assert!(channel_matches("games/*", "games/chess"));
assert!(channel_matches("games/*", "games/go"));
assert!(!channel_matches("games/*", "other/chess"));
assert!(!channel_matches("games/*", "games")); // no trailing segment
}
#[test]
fn channel_matches_star_all() {
assert!(channel_matches("*", "anything"));
assert!(channel_matches("*", "games/chess"));
}
#[test]
fn find_matching_channels_works() {
let mut state = RelayState::new(16, 64, 256, 0, None);
state.get_or_create_channel("games/chess").unwrap();
state.get_or_create_channel("games/go").unwrap();
state.get_or_create_channel("chat/main").unwrap();
let mut matches = find_matching_channels(&state, "games/*");
matches.sort();
assert_eq!(matches, vec!["games/chess", "games/go"]);
}
// ─── Replay Buffer Tests ───
#[test]
fn replay_buffer_stores_frames() {
let mut cache = StateCache { replay_depth: 3, ..Default::default() };
let f1 = crate::codec::signal_diff_frame(0, 0, b"{\"a\":1}");
let f2 = crate::codec::signal_diff_frame(1, 100, b"{\"a\":2}");
let f3 = crate::codec::signal_diff_frame(2, 200, b"{\"a\":3}");
cache.process_frame(&f1);
cache.process_frame(&f2);
cache.process_frame(&f3);
assert_eq!(cache.replay_len(), 3);
assert_eq!(cache.replay_frames(0).len(), 3);
assert_eq!(cache.replay_frames(1).len(), 2);
assert_eq!(cache.replay_frames(3).len(), 0);
}
#[test]
fn replay_buffer_evicts_oldest() {
let mut cache = StateCache { replay_depth: 2, ..Default::default() };
let f1 = crate::codec::signal_diff_frame(0, 0, b"{\"a\":1}");
let f2 = crate::codec::signal_diff_frame(1, 100, b"{\"a\":2}");
let f3 = crate::codec::signal_diff_frame(2, 200, b"{\"a\":3}");
cache.process_frame(&f1);
cache.process_frame(&f2);
cache.process_frame(&f3);
assert_eq!(cache.replay_len(), 2);
// First frame should be evicted
assert_eq!(cache.replay_buffer[0], f2);
assert_eq!(cache.replay_buffer[1], f3);
}
#[test]
fn replay_depth_propagates_to_channel() {
let mut state = RelayState::new(16, 64, 256, 100, None);
let ch = state.get_or_create_channel("test").unwrap();
let cs = ch.try_read().unwrap();
assert_eq!(cs.cache.replay_depth, 100);
}
#[test]
fn replay_disabled_when_zero() {
let mut cache = StateCache::default();
assert_eq!(cache.replay_depth, 0);
let f1 = crate::codec::signal_diff_frame(0, 0, b"{\"a\":1}");
cache.process_frame(&f1);
assert_eq!(cache.replay_len(), 0); // Nothing stored when depth=0
}
// ─── v1.2: Channel Auth Tests ───
#[test]
fn channel_auth_open_allows_all() {
let auth = ChannelAuth::open();
assert!(auth.is_authenticated("any-source"));
assert!(!auth.is_required());
}
#[test]
fn channel_auth_rejects_unauthed() {
let auth = ChannelAuth::with_key(b"secret-123");
assert!(auth.is_required());
assert!(!auth.is_authenticated("source-1"));
}
#[test]
fn channel_auth_allows_authed() {
let mut auth = ChannelAuth::with_key(b"secret-123");
assert!(!auth.authenticate("source-1", b"wrong-key"));
assert!(!auth.is_authenticated("source-1"));
assert!(auth.authenticate("source-1", b"secret-123"));
assert!(auth.is_authenticated("source-1"));
assert_eq!(auth.authenticated_count(), 1);
}
#[test]
fn channel_auth_revoke() {
let mut auth = ChannelAuth::with_key(b"key");
auth.authenticate("src", b"key");
assert!(auth.is_authenticated("src"));
auth.revoke("src");
assert!(!auth.is_authenticated("src"));
assert_eq!(auth.authenticated_count(), 0);
}
}