v2.0 — Pipeline Architecture - Frame, CodecResult, Codec trait, Pipeline builder - 6 adapters: Passthrough, Dedup, Compress, Pacer, Slicer, Stats v2.1 — Multi-frame & new codecs - CodecOutput::Many fan-out, EncryptCodec, FilterCodec - Codec::reset(), encode_all/decode_all, real SlicerCodec chunking v2.2 — Observability & reassembly - PipelineResult (frames+errors+consumed), StageMetric - ReassemblyCodec, ConditionalCodec, Pipeline presets & metrics v2.3 — Integrity & rate control - ChecksumCodec (CRC32), RateLimitCodec (token bucket), TagCodec - Pipeline::chain(), Pipeline::describe() 13 codec adapters, 474 tests (all green, 0 regressions)
1848 lines
70 KiB
Rust
1848 lines
70 KiB
Rust
//! WebSocket Relay Server
|
|
//!
|
|
//! Routes frames from source→receivers and inputs from receivers→source.
|
|
//!
|
|
//! ## Routing
|
|
//!
|
|
//! Connections are routed by WebSocket URI path:
|
|
//! - `/source` — default source (channel: "default")
|
|
//! - `/source/{name}` — named source (channel: {name})
|
|
//! - `/stream` — default receiver (channel: "default")
|
|
//! - `/stream/{name}` — named receiver (channel: {name})
|
|
//! - `/signal` — WebRTC signaling (channel: "default")
|
|
//! - `/signal/{name}` — WebRTC signaling (channel: {name})
|
|
//! - `/` — legacy: first connection = source, rest = receivers
|
|
//!
|
|
//! Each channel has its own broadcast/input channels and state cache,
|
|
//! allowing multiple independent streams through a single relay.
|
|
//!
|
|
//! ## Production Features
|
|
//! - Multi-source: multiple views stream independently
|
|
//! - Keyframe caching: late-joining receivers get current state instantly
|
|
//! - Signal state store: caches SignalSync/Diff frames for reconstruction
|
|
//! - Ping/pong keepalive: detects dead connections
|
|
//! - Stats tracking: frames, bytes, latency metrics
|
|
//! - Max receiver limit per channel (configurable)
|
|
//! - Channel GC: empty channels are cleaned up periodically
|
|
//! - Source reconnection: cache preserved for seamless reconnect
|
|
//! - Graceful shutdown: drain connections on SIGTERM
|
|
|
|
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
|
use tokio::time::{interval, Duration};
|
|
use tokio_tungstenite::tungstenite::Message;
|
|
|
|
use crate::protocol::*;
|
|
|
|
/// Relay server configuration.
|
|
pub struct RelayConfig {
|
|
/// Address to bind to.
|
|
pub addr: SocketAddr,
|
|
/// Maximum number of receivers per channel.
|
|
pub max_receivers: usize,
|
|
/// Frame broadcast channel capacity.
|
|
pub frame_buffer_size: usize,
|
|
/// Keepalive interval in seconds.
|
|
pub keepalive_interval_secs: u64,
|
|
/// Keepalive timeout in seconds — disconnect after this many seconds without a pong.
|
|
pub keepalive_timeout_secs: u64,
|
|
/// Maximum number of channels.
|
|
pub max_channels: usize,
|
|
/// Channel GC interval in seconds — how often to scan for empty channels.
|
|
pub channel_gc_interval_secs: u64,
|
|
/// Source reconnect grace period in seconds — keep cache alive after source disconnect.
|
|
pub source_reconnect_grace_secs: u64,
|
|
/// Replay depth: number of frames to keep in ring buffer for time-travel replay.
|
|
/// Set to 0 to disable replay (default: 0, catchup-only). Set >0 for full replay.
|
|
pub replay_depth: usize,
|
|
/// Upstream relay URLs for federation — frames are forwarded to these relays.
|
|
pub federation_upstreams: Vec<String>,
|
|
/// Recording directory — if set, incoming frames are written to disk.
|
|
pub recording_dir: Option<String>,
|
|
}
|
|
|
|
impl Default for RelayConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
addr: "0.0.0.0:9100".parse().unwrap(),
|
|
max_receivers: 64,
|
|
frame_buffer_size: 16,
|
|
keepalive_interval_secs: 10,
|
|
keepalive_timeout_secs: 30,
|
|
max_channels: 256,
|
|
channel_gc_interval_secs: 60,
|
|
source_reconnect_grace_secs: 30,
|
|
replay_depth: 0,
|
|
federation_upstreams: Vec::new(),
|
|
recording_dir: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Stats tracked by the relay.
|
|
#[derive(Debug, Default, Clone)]
|
|
pub struct RelayStats {
|
|
pub frames_relayed: u64,
|
|
pub bytes_relayed: u64,
|
|
pub inputs_relayed: u64,
|
|
pub connected_receivers: usize,
|
|
pub source_connected: bool,
|
|
/// Timestamp of last frame from source (milliseconds since start).
|
|
pub last_frame_timestamp: u32,
|
|
/// Total keyframes sent.
|
|
pub keyframes_sent: u64,
|
|
/// Total signal diffs sent.
|
|
pub signal_diffs_sent: u64,
|
|
/// Uptime in seconds.
|
|
pub uptime_secs: u64,
|
|
/// Peak receiver count.
|
|
pub peak_receivers: usize,
|
|
/// Total connections served.
|
|
pub total_connections: u64,
|
|
/// Rejected connections (over limit).
|
|
pub rejected_connections: u64,
|
|
}
|
|
|
|
/// Cached state for late-joining receivers.
|
|
#[derive(Debug, Default, Clone)]
|
|
pub struct StateCache {
|
|
/// Last pixel keyframe (full RGBA frame).
|
|
pub last_keyframe: Option<Vec<u8>>,
|
|
/// Last signal sync frame (full JSON state).
|
|
pub last_signal_sync: Option<Vec<u8>>,
|
|
/// Accumulated signal diffs since last sync.
|
|
/// Late-joining receivers get: last_signal_sync + all diffs.
|
|
pub pending_signal_diffs: Vec<Vec<u8>>,
|
|
/// Replay ring buffer — stores the last N frames for time-travel replay.
|
|
/// When replay_depth > 0, receivers can request historical frames.
|
|
pub replay_buffer: Vec<Vec<u8>>,
|
|
/// Maximum replay buffer depth (0 = disabled).
|
|
pub replay_depth: usize,
|
|
}
|
|
|
|
impl StateCache {
|
|
/// Process an incoming frame and update cache.
|
|
fn process_frame(&mut self, msg: &[u8]) {
|
|
if msg.len() < HEADER_SIZE {
|
|
return;
|
|
}
|
|
let frame_type = msg[0];
|
|
let flags = msg[1];
|
|
|
|
// Add to replay ring buffer if enabled
|
|
if self.replay_depth > 0 {
|
|
self.replay_buffer.push(msg.to_vec());
|
|
if self.replay_buffer.len() > self.replay_depth {
|
|
self.replay_buffer.remove(0);
|
|
}
|
|
}
|
|
|
|
match FrameType::from_u8(frame_type) {
|
|
// Cache keyframes (pixel or signal sync)
|
|
Some(FrameType::Pixels) if flags & FLAG_KEYFRAME != 0 => {
|
|
self.last_keyframe = Some(msg.to_vec());
|
|
}
|
|
Some(FrameType::SignalSync) => {
|
|
self.last_signal_sync = Some(msg.to_vec());
|
|
self.pending_signal_diffs.clear(); // sync resets diffs
|
|
}
|
|
Some(FrameType::SignalDiff) => {
|
|
self.pending_signal_diffs.push(msg.to_vec());
|
|
// Cap accumulated diffs to prevent unbounded memory
|
|
if self.pending_signal_diffs.len() > 1000 {
|
|
self.pending_signal_diffs.drain(..500);
|
|
}
|
|
}
|
|
Some(FrameType::Keyframe) => {
|
|
self.last_keyframe = Some(msg.to_vec());
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
/// Get all messages a late-joining receiver needs to reconstruct current state.
|
|
///
|
|
/// Instead of replaying hundreds of individual diffs, merges all accumulated
|
|
/// diffs into the last sync frame to produce ONE consolidated state message.
|
|
/// This dramatically reduces catchup time and bandwidth.
|
|
fn catchup_messages(&self) -> Vec<Vec<u8>> {
|
|
let mut msgs = Vec::new();
|
|
|
|
// Strategy: merge all signal diffs into last sync to build a consolidated frame
|
|
if let Some(ref sync_frame) = self.last_signal_sync {
|
|
if !self.pending_signal_diffs.is_empty() {
|
|
// Parse the sync frame's JSON payload
|
|
if sync_frame.len() >= HEADER_SIZE {
|
|
let payload_len = u32::from_le_bytes([
|
|
sync_frame[12], sync_frame[13], sync_frame[14], sync_frame[15],
|
|
]) as usize;
|
|
let sync_end = HEADER_SIZE + payload_len.min(sync_frame.len() - HEADER_SIZE);
|
|
let sync_payload = &sync_frame[HEADER_SIZE..sync_end];
|
|
|
|
if let Ok(mut merged) = serde_json::from_slice::<serde_json::Value>(sync_payload) {
|
|
// Apply each diff on top
|
|
for diff_frame in &self.pending_signal_diffs {
|
|
if diff_frame.len() < HEADER_SIZE { continue; }
|
|
let diff_pl_len = u32::from_le_bytes([
|
|
diff_frame[12], diff_frame[13], diff_frame[14], diff_frame[15],
|
|
]) as usize;
|
|
let diff_end = HEADER_SIZE + diff_pl_len.min(diff_frame.len() - HEADER_SIZE);
|
|
let diff_payload = &diff_frame[HEADER_SIZE..diff_end];
|
|
if let Ok(diff_obj) = serde_json::from_slice::<serde_json::Value>(diff_payload) {
|
|
if let (Some(merged_map), Some(diff_map)) = (merged.as_object_mut(), diff_obj.as_object()) {
|
|
for (k, v) in diff_map {
|
|
if k != "_pid" && k != "_v" {
|
|
merged_map.insert(k.clone(), v.clone());
|
|
}
|
|
}
|
|
// Merge version counters
|
|
if let (Some(merged_v), Some(diff_v)) = (
|
|
merged_map.get_mut("_v").and_then(|v| v.as_object_mut()),
|
|
diff_map.get("_v").and_then(|v| v.as_object()),
|
|
) {
|
|
for (k, v) in diff_v {
|
|
merged_v.insert(k.clone(), v.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Build a consolidated SignalSync (0x30) frame
|
|
let merged_json = serde_json::to_vec(&merged).unwrap_or_default();
|
|
let mut frame = Vec::with_capacity(HEADER_SIZE + merged_json.len());
|
|
// Copy sync frame header but set type to SignalSync and update length
|
|
frame.extend_from_slice(&sync_frame[..HEADER_SIZE]);
|
|
frame[0] = 0x30; // SignalSync
|
|
frame[1] = 0x02; // FLAG_KEYFRAME
|
|
let len_bytes = (merged_json.len() as u32).to_le_bytes();
|
|
frame[12] = len_bytes[0];
|
|
frame[13] = len_bytes[1];
|
|
frame[14] = len_bytes[2];
|
|
frame[15] = len_bytes[3];
|
|
frame.truncate(HEADER_SIZE);
|
|
frame.extend_from_slice(&merged_json);
|
|
msgs.push(frame);
|
|
} else {
|
|
// Fallback: can't parse, send raw sync + all diffs
|
|
msgs.push(sync_frame.clone());
|
|
for diff in &self.pending_signal_diffs {
|
|
msgs.push(diff.clone());
|
|
}
|
|
}
|
|
} else {
|
|
msgs.push(sync_frame.clone());
|
|
}
|
|
} else {
|
|
// No diffs accumulated — just send the sync
|
|
msgs.push(sync_frame.clone());
|
|
}
|
|
} else if !self.pending_signal_diffs.is_empty() {
|
|
// No sync frame but we have diffs — merge diffs into one frame
|
|
let mut merged = serde_json::Map::new();
|
|
let mut merged_v = serde_json::Map::new();
|
|
for diff_frame in &self.pending_signal_diffs {
|
|
if diff_frame.len() < HEADER_SIZE { continue; }
|
|
let diff_pl_len = u32::from_le_bytes([
|
|
diff_frame[12], diff_frame[13], diff_frame[14], diff_frame[15],
|
|
]) as usize;
|
|
let diff_end = HEADER_SIZE + diff_pl_len.min(diff_frame.len() - HEADER_SIZE);
|
|
let diff_payload = &diff_frame[HEADER_SIZE..diff_end];
|
|
if let Ok(diff_obj) = serde_json::from_slice::<serde_json::Value>(diff_payload) {
|
|
if let Some(diff_map) = diff_obj.as_object() {
|
|
for (k, v) in diff_map {
|
|
if k == "_v" {
|
|
if let Some(v_map) = v.as_object() {
|
|
for (vk, vv) in v_map {
|
|
merged_v.insert(vk.clone(), vv.clone());
|
|
}
|
|
}
|
|
} else if k != "_pid" {
|
|
merged.insert(k.clone(), v.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if !merged_v.is_empty() {
|
|
merged.insert("_v".to_string(), serde_json::Value::Object(merged_v));
|
|
}
|
|
let merged_json = serde_json::to_vec(&serde_json::Value::Object(merged)).unwrap_or_default();
|
|
let frame = crate::codec::signal_sync_frame(0, 0, &merged_json);
|
|
msgs.push(frame);
|
|
}
|
|
|
|
// Pixel keyframe (if available)
|
|
if let Some(ref kf) = self.last_keyframe {
|
|
msgs.push(kf.clone());
|
|
}
|
|
msgs
|
|
}
|
|
|
|
/// Clear all cached state.
|
|
#[allow(dead_code)]
|
|
fn clear(&mut self) {
|
|
self.last_keyframe = None;
|
|
self.last_signal_sync = None;
|
|
self.pending_signal_diffs.clear();
|
|
self.replay_buffer.clear();
|
|
}
|
|
|
|
/// Returns true if this cache has any state.
|
|
fn has_state(&self) -> bool {
|
|
self.last_keyframe.is_some()
|
|
|| self.last_signal_sync.is_some()
|
|
|| !self.pending_signal_diffs.is_empty()
|
|
|| !self.replay_buffer.is_empty()
|
|
}
|
|
|
|
/// Get the replay buffer frames for time-travel playback.
|
|
/// Returns frames from `start_index` onwards.
|
|
pub fn replay_frames(&self, start_index: usize) -> &[Vec<u8>] {
|
|
if start_index < self.replay_buffer.len() {
|
|
&self.replay_buffer[start_index..]
|
|
} else {
|
|
&[]
|
|
}
|
|
}
|
|
|
|
/// Total number of frames in the replay buffer.
|
|
pub fn replay_len(&self) -> usize {
|
|
self.replay_buffer.len()
|
|
}
|
|
}
|
|
|
|
/// Per-channel state — each named channel is an independent stream.
|
|
struct ChannelState {
|
|
/// Broadcast channel: source → all receivers (frames)
|
|
frame_tx: broadcast::Sender<Vec<u8>>,
|
|
/// Channel: receivers → source (input events)
|
|
input_tx: mpsc::Sender<Vec<u8>>,
|
|
input_rx: Option<mpsc::Receiver<Vec<u8>>>,
|
|
/// Broadcast channel for WebRTC signaling (SDP/ICE as text)
|
|
signaling_tx: broadcast::Sender<String>,
|
|
/// Live stats for this channel
|
|
stats: RelayStats,
|
|
/// Cached state for late-joining receivers
|
|
cache: StateCache,
|
|
/// When the source last disconnected (for reconnect grace period)
|
|
source_disconnect_time: Option<Instant>,
|
|
/// Max receivers for this channel
|
|
max_receivers: usize,
|
|
/// Cached schema announcement (0x32 payload) from source
|
|
schema: Option<Vec<u8>>,
|
|
}
|
|
|
|
impl ChannelState {
|
|
fn new(frame_buffer_size: usize, max_receivers: usize, replay_depth: usize) -> Self {
|
|
let (frame_tx, _) = broadcast::channel(frame_buffer_size);
|
|
let (input_tx, input_rx) = mpsc::channel(256);
|
|
let (signaling_tx, _) = broadcast::channel(64);
|
|
Self {
|
|
frame_tx,
|
|
input_tx,
|
|
input_rx: Some(input_rx),
|
|
signaling_tx,
|
|
stats: RelayStats::default(),
|
|
cache: StateCache {
|
|
replay_depth,
|
|
..StateCache::default()
|
|
},
|
|
source_disconnect_time: None,
|
|
max_receivers,
|
|
schema: None,
|
|
}
|
|
}
|
|
|
|
/// Returns true if this channel is idle (no source, no receivers, no cache).
|
|
fn is_idle(&self) -> bool {
|
|
!self.stats.source_connected
|
|
&& self.stats.connected_receivers == 0
|
|
&& !self.cache.has_state()
|
|
}
|
|
|
|
/// Returns true if the source reconnect grace period has expired.
|
|
fn grace_period_expired(&self, grace_secs: u64) -> bool {
|
|
if let Some(disconnect_time) = self.source_disconnect_time {
|
|
disconnect_time.elapsed() > Duration::from_secs(grace_secs)
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
}
|
|
|
|
// ─── v1.2: Auth-Gated Channel ───
|
|
|
|
/// Per-channel authentication state.
|
|
/// When `required` is true, sources must send a valid `Auth` frame before
|
|
/// the relay will forward their frames.
|
|
#[derive(Debug, Clone)]
|
|
pub struct ChannelAuth {
|
|
required: bool,
|
|
/// Pre-shared key for this channel (empty = open)
|
|
key: Vec<u8>,
|
|
/// Authenticated source addresses (by string identifier)
|
|
authenticated: Vec<String>,
|
|
}
|
|
|
|
impl ChannelAuth {
|
|
pub fn open() -> Self {
|
|
ChannelAuth { required: false, key: Vec::new(), authenticated: Vec::new() }
|
|
}
|
|
|
|
pub fn with_key(key: &[u8]) -> Self {
|
|
ChannelAuth { required: true, key: key.to_vec(), authenticated: Vec::new() }
|
|
}
|
|
|
|
/// Check if a source is authenticated.
|
|
pub fn is_authenticated(&self, source_id: &str) -> bool {
|
|
if !self.required { return true; }
|
|
self.authenticated.iter().any(|s| s == source_id)
|
|
}
|
|
|
|
/// Attempt to authenticate a source with a token.
|
|
/// Returns true if authentication succeeded.
|
|
pub fn authenticate(&mut self, source_id: &str, token: &[u8]) -> bool {
|
|
if !self.required { return true; }
|
|
if token == self.key.as_slice() {
|
|
if !self.authenticated.iter().any(|s| s == source_id) {
|
|
self.authenticated.push(source_id.to_string());
|
|
}
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Revoke authentication for a source.
|
|
pub fn revoke(&mut self, source_id: &str) {
|
|
self.authenticated.retain(|s| s != source_id);
|
|
}
|
|
|
|
pub fn is_required(&self) -> bool { self.required }
|
|
pub fn authenticated_count(&self) -> usize { self.authenticated.len() }
|
|
}
|
|
|
|
/// Shared relay state — holds all channels.
|
|
struct RelayState {
|
|
/// Named channels: "default", "main", "player1", etc.
|
|
channels: HashMap<String, Arc<RwLock<ChannelState>>>,
|
|
/// Frame buffer size for new channels
|
|
frame_buffer_size: usize,
|
|
/// Max receivers per channel
|
|
max_receivers: usize,
|
|
/// Max channels
|
|
max_channels: usize,
|
|
/// Server start time
|
|
start_time: Instant,
|
|
/// Replay depth for new channels
|
|
replay_depth: usize,
|
|
/// Recording directory (None = disabled)
|
|
recording_dir: Option<String>,
|
|
}
|
|
|
|
impl RelayState {
|
|
fn new(frame_buffer_size: usize, max_receivers: usize, max_channels: usize, replay_depth: usize, recording_dir: Option<String>) -> Self {
|
|
Self {
|
|
channels: HashMap::new(),
|
|
frame_buffer_size,
|
|
max_receivers,
|
|
max_channels,
|
|
start_time: Instant::now(),
|
|
replay_depth,
|
|
recording_dir,
|
|
}
|
|
}
|
|
|
|
/// Get or create a channel by name. Returns None if at max channels.
|
|
fn get_or_create_channel(&mut self, name: &str) -> Option<Arc<RwLock<ChannelState>>> {
|
|
if self.channels.contains_key(name) {
|
|
return Some(self.channels[name].clone());
|
|
}
|
|
if self.channels.len() >= self.max_channels {
|
|
eprintln!("[relay] Max channels ({}) reached, rejecting: {}", self.max_channels, name);
|
|
return None;
|
|
}
|
|
let channel = Arc::new(RwLock::new(ChannelState::new(
|
|
self.frame_buffer_size,
|
|
self.max_receivers,
|
|
self.replay_depth,
|
|
)));
|
|
self.channels.insert(name.to_string(), channel.clone());
|
|
Some(channel)
|
|
}
|
|
}
|
|
|
|
/// Parsed connection role from the WebSocket URI path.
|
|
#[derive(Debug, Clone)]
|
|
enum ConnectionRole {
|
|
Source(String), // channel name
|
|
Receiver(String), // channel name
|
|
Signaling(String), // channel name — WebRTC signaling
|
|
Peer(String), // channel name — bidirectional sync
|
|
Meta(String), // channel name — HTTP introspection
|
|
}
|
|
|
|
/// Parse the WebSocket URI path to determine connection role and channel.
|
|
fn parse_path(path: &str) -> ConnectionRole {
|
|
let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
|
|
match parts.as_slice() {
|
|
["source"] => ConnectionRole::Source("default".to_string()),
|
|
["source", name] => ConnectionRole::Source(name.to_string()),
|
|
["stream"] => ConnectionRole::Receiver("default".to_string()),
|
|
["stream", name] => ConnectionRole::Receiver(name.to_string()),
|
|
["signal"] => ConnectionRole::Signaling("default".to_string()),
|
|
["signal", name] => ConnectionRole::Signaling(name.to_string()),
|
|
["peer"] => ConnectionRole::Peer("default".to_string()),
|
|
["peer", name] => ConnectionRole::Peer(name.to_string()),
|
|
["meta"] => ConnectionRole::Meta("default".to_string()),
|
|
["meta", name] => ConnectionRole::Meta(name.to_string()),
|
|
_ => ConnectionRole::Receiver("default".to_string()),
|
|
}
|
|
}
|
|
|
|
/// Check if a channel name matches a wildcard pattern (channel groups).
|
|
/// Supports `*` at the end for prefix matching: `games/*` matches `games/chess`.
|
|
pub fn channel_matches(pattern: &str, channel: &str) -> bool {
|
|
if pattern == channel {
|
|
return true;
|
|
}
|
|
if let Some(prefix) = pattern.strip_suffix("/*") {
|
|
return channel.starts_with(prefix) && channel.len() > prefix.len();
|
|
}
|
|
if pattern == "*" {
|
|
return true;
|
|
}
|
|
false
|
|
}
|
|
|
|
/// Find all channels matching a wildcard pattern.
|
|
#[cfg(test)]
|
|
fn find_matching_channels(state: &RelayState, pattern: &str) -> Vec<String> {
|
|
state.channels.keys()
|
|
.filter(|name| channel_matches(pattern, name))
|
|
.cloned()
|
|
.collect()
|
|
}
|
|
|
|
/// Run the WebSocket relay server.
|
|
pub async fn run_relay(config: RelayConfig) -> Result<(), Box<dyn std::error::Error>> {
|
|
let listener = TcpListener::bind(&config.addr).await?;
|
|
eprintln!("╔══════════════════════════════════════════════════╗");
|
|
eprintln!("║ DreamStack Bitstream Relay v0.3.0 ║");
|
|
eprintln!("║ ║");
|
|
eprintln!("║ Source: ws://{}/source/{{name}} ║", config.addr);
|
|
eprintln!("║ Receiver: ws://{}/stream/{{name}} ║", config.addr);
|
|
eprintln!("║ Signal: ws://{}/signal/{{name}} ║", config.addr);
|
|
eprintln!("║ ║");
|
|
eprintln!("║ Max receivers/ch: {:>4} ║", config.max_receivers);
|
|
eprintln!("║ Max channels: {:>4} ║", config.max_channels);
|
|
eprintln!("║ Reconnect grace: {:>4}s ║", config.source_reconnect_grace_secs);
|
|
eprintln!("╚══════════════════════════════════════════════════╝");
|
|
|
|
let grace_secs = config.source_reconnect_grace_secs;
|
|
let replay_depth = config.replay_depth;
|
|
let recording_dir = config.recording_dir.clone();
|
|
let federation_upstreams = config.federation_upstreams.clone();
|
|
let state = Arc::new(RwLock::new(RelayState::new(
|
|
config.frame_buffer_size,
|
|
config.max_receivers,
|
|
config.max_channels,
|
|
config.replay_depth,
|
|
config.recording_dir.clone(),
|
|
)));
|
|
|
|
// Log the feature status
|
|
if replay_depth > 0 {
|
|
eprintln!("║ Replay depth: {:>4} frames ║", replay_depth);
|
|
}
|
|
if let Some(ref dir) = recording_dir {
|
|
eprintln!("║ Recording to: {} ║", dir);
|
|
// Ensure recording directory exists
|
|
std::fs::create_dir_all(dir).unwrap_or_else(|e| {
|
|
eprintln!("[relay] Warning: could not create recording dir {}: {}", dir, e);
|
|
});
|
|
}
|
|
if !federation_upstreams.is_empty() {
|
|
eprintln!("║ Federation: {} upstream(s) ║", federation_upstreams.len());
|
|
}
|
|
|
|
// Background: periodic stats + channel GC
|
|
{
|
|
let state = state.clone();
|
|
let gc_interval = config.channel_gc_interval_secs;
|
|
tokio::spawn(async move {
|
|
let mut tick = interval(Duration::from_secs(gc_interval.min(30)));
|
|
let mut gc_counter: u64 = 0;
|
|
loop {
|
|
tick.tick().await;
|
|
gc_counter += 1;
|
|
|
|
let s = state.read().await;
|
|
let uptime = s.start_time.elapsed().as_secs();
|
|
|
|
// Stats logging
|
|
for (name, channel) in &s.channels {
|
|
let cs = channel.read().await;
|
|
if cs.stats.source_connected || cs.stats.connected_receivers > 0 {
|
|
eprintln!(
|
|
"[relay:{name}] up={uptime}s frames={} bytes={} receivers={}/{} peak={} signal_diffs={} cached={}",
|
|
cs.stats.frames_relayed,
|
|
cs.stats.bytes_relayed,
|
|
cs.stats.connected_receivers,
|
|
cs.max_receivers,
|
|
cs.stats.peak_receivers,
|
|
cs.stats.signal_diffs_sent,
|
|
cs.cache.has_state(),
|
|
);
|
|
}
|
|
}
|
|
drop(s);
|
|
|
|
// Channel GC — every gc_interval ticks
|
|
if gc_counter % (gc_interval / gc_interval.min(30)).max(1) == 0 {
|
|
let mut s = state.write().await;
|
|
let before = s.channels.len();
|
|
let mut to_remove = Vec::new();
|
|
|
|
for (name, channel) in &s.channels {
|
|
let cs = channel.read().await;
|
|
if cs.is_idle() || cs.grace_period_expired(grace_secs) {
|
|
to_remove.push(name.clone());
|
|
}
|
|
}
|
|
|
|
for name in &to_remove {
|
|
s.channels.remove(name);
|
|
}
|
|
|
|
if !to_remove.is_empty() {
|
|
eprintln!(
|
|
"[relay] GC: removed {} idle channel(s) ({} → {}): {:?}",
|
|
to_remove.len(),
|
|
before,
|
|
s.channels.len(),
|
|
to_remove
|
|
);
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
// Background: federation forwarding to upstream relays
|
|
for upstream_url in &federation_upstreams {
|
|
let url = upstream_url.clone();
|
|
let state = state.clone();
|
|
tokio::spawn(async move {
|
|
let mut backoff = Duration::from_secs(1);
|
|
loop {
|
|
eprintln!("[relay:federation] Connecting to upstream: {url}");
|
|
match tokio_tungstenite::connect_async(&url).await {
|
|
Ok((mut ws, _)) => {
|
|
eprintln!("[relay:federation] Connected to {url}");
|
|
backoff = Duration::from_secs(1);
|
|
|
|
// Subscribe to the default channel and forward frames
|
|
let frame_rx = {
|
|
let s = state.read().await;
|
|
if let Some(ch) = s.channels.get("default") {
|
|
let cs = ch.read().await;
|
|
Some(cs.frame_tx.subscribe())
|
|
} else {
|
|
None
|
|
}
|
|
};
|
|
|
|
if let Some(mut rx) = frame_rx {
|
|
while let Ok(frame) = rx.recv().await {
|
|
use futures_util::SinkExt;
|
|
if ws.send(Message::Binary(frame.into())).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
eprintln!("[relay:federation] Disconnected from {url}");
|
|
}
|
|
Err(e) => {
|
|
eprintln!("[relay:federation] Failed to connect to {url}: {e}");
|
|
}
|
|
}
|
|
// Exponential backoff on reconnect
|
|
tokio::time::sleep(backoff).await;
|
|
backoff = (backoff * 2).min(Duration::from_secs(30));
|
|
}
|
|
});
|
|
}
|
|
|
|
while let Ok((stream, addr)) = listener.accept().await {
|
|
let state = state.clone();
|
|
let keepalive_interval = config.keepalive_interval_secs;
|
|
let keepalive_timeout = config.keepalive_timeout_secs;
|
|
|
|
tokio::spawn(async move {
|
|
// Peek at the HTTP request to check for /meta requests
|
|
// Meta requests get a raw HTTP JSON response, not a WS upgrade
|
|
let mut peek_buf = [0u8; 512];
|
|
let n = match stream.peek(&mut peek_buf).await {
|
|
Ok(n) => n,
|
|
Err(_) => return,
|
|
};
|
|
let request_line = String::from_utf8_lossy(&peek_buf[..n]);
|
|
|
|
// Extract path from "GET /meta/name HTTP/1.1"
|
|
if let Some(path) = request_line.lines().next().and_then(|line| {
|
|
let parts: Vec<&str> = line.split_whitespace().collect();
|
|
if parts.len() >= 2 { Some(parts[1]) } else { None }
|
|
}) {
|
|
if path.starts_with("/meta") {
|
|
handle_meta_http(stream, addr, &state, path).await;
|
|
return;
|
|
}
|
|
}
|
|
|
|
handle_connection(stream, addr, state, keepalive_interval, keepalive_timeout).await;
|
|
});
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Handle an HTTP /meta/{channel} request — return channel stats as JSON.
|
|
///
|
|
/// This is NOT a WebSocket connection. We read the full HTTP request,
|
|
/// then respond with an HTTP 200 + JSON body.
|
|
async fn handle_meta_http(
|
|
mut stream: TcpStream,
|
|
addr: SocketAddr,
|
|
state: &Arc<RwLock<RelayState>>,
|
|
path: &str,
|
|
) {
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio::io::AsyncReadExt;
|
|
|
|
// Drain the HTTP request from the socket
|
|
let mut buf = vec![0u8; 4096];
|
|
let _ = stream.read(&mut buf).await;
|
|
|
|
// Parse channel name from path
|
|
let channel_name = path.trim_start_matches("/meta/").trim_matches('/');
|
|
let channel_name = if channel_name.is_empty() { "__all__" } else { channel_name };
|
|
|
|
let s = state.read().await;
|
|
let uptime = s.start_time.elapsed().as_secs();
|
|
|
|
let body = if channel_name == "__all__" {
|
|
// Return stats for ALL channels
|
|
let mut channels = serde_json::Map::new();
|
|
for (name, channel) in &s.channels {
|
|
let cs = channel.read().await;
|
|
channels.insert(name.clone(), serde_json::json!({
|
|
"source_connected": cs.stats.source_connected,
|
|
"receivers": cs.stats.connected_receivers,
|
|
"peak_receivers": cs.stats.peak_receivers,
|
|
"frames_relayed": cs.stats.frames_relayed,
|
|
"bytes_relayed": cs.stats.bytes_relayed,
|
|
"signal_diffs": cs.stats.signal_diffs_sent,
|
|
"keyframes": cs.stats.keyframes_sent,
|
|
"total_connections": cs.stats.total_connections,
|
|
"rejected_connections": cs.stats.rejected_connections,
|
|
"cached": cs.cache.has_state(),
|
|
"pending_diffs": cs.cache.pending_signal_diffs.len(),
|
|
}));
|
|
}
|
|
serde_json::json!({
|
|
"uptime_secs": uptime,
|
|
"channel_count": s.channels.len(),
|
|
"channels": serde_json::Value::Object(channels),
|
|
}).to_string()
|
|
} else if let Some(channel) = s.channels.get(channel_name) {
|
|
// Return stats for a specific channel
|
|
let cs = channel.read().await;
|
|
|
|
// Extract current signal state if available
|
|
let mut current_state = serde_json::Value::Null;
|
|
if let Some(ref sync_frame) = cs.cache.last_signal_sync {
|
|
if sync_frame.len() >= HEADER_SIZE {
|
|
let pl_len = u32::from_le_bytes([
|
|
sync_frame[12], sync_frame[13], sync_frame[14], sync_frame[15],
|
|
]) as usize;
|
|
let end = HEADER_SIZE + pl_len.min(sync_frame.len() - HEADER_SIZE);
|
|
if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&sync_frame[HEADER_SIZE..end]) {
|
|
current_state = val;
|
|
}
|
|
}
|
|
}
|
|
|
|
serde_json::json!({
|
|
"channel": channel_name,
|
|
"uptime_secs": uptime,
|
|
"source_connected": cs.stats.source_connected,
|
|
"receivers": cs.stats.connected_receivers,
|
|
"max_receivers": cs.max_receivers,
|
|
"peak_receivers": cs.stats.peak_receivers,
|
|
"frames_relayed": cs.stats.frames_relayed,
|
|
"bytes_relayed": cs.stats.bytes_relayed,
|
|
"signal_diffs": cs.stats.signal_diffs_sent,
|
|
"keyframes": cs.stats.keyframes_sent,
|
|
"total_connections": cs.stats.total_connections,
|
|
"rejected_connections": cs.stats.rejected_connections,
|
|
"cached": cs.cache.has_state(),
|
|
"pending_diffs": cs.cache.pending_signal_diffs.len(),
|
|
"schema": cs.schema.as_ref().and_then(|s| {
|
|
if s.len() >= HEADER_SIZE {
|
|
let pl_len = u32::from_le_bytes([s[12], s[13], s[14], s[15]]) as usize;
|
|
let end = HEADER_SIZE + pl_len.min(s.len() - HEADER_SIZE);
|
|
serde_json::from_slice::<serde_json::Value>(&s[HEADER_SIZE..end]).ok()
|
|
} else { None }
|
|
}),
|
|
"current_state": current_state,
|
|
}).to_string()
|
|
} else {
|
|
serde_json::json!({
|
|
"error": "channel not found",
|
|
"channel": channel_name,
|
|
}).to_string()
|
|
};
|
|
|
|
// Write HTTP response
|
|
let response = format!(
|
|
"HTTP/1.1 200 OK\r\n\
|
|
Content-Type: application/json\r\n\
|
|
Access-Control-Allow-Origin: *\r\n\
|
|
Content-Length: {}\r\n\
|
|
Connection: close\r\n\
|
|
\r\n\
|
|
{}",
|
|
body.len(),
|
|
body
|
|
);
|
|
let _ = stream.write_all(response.as_bytes()).await;
|
|
eprintln!("[relay:meta] {addr} → {path}");
|
|
}
|
|
|
|
async fn handle_connection(
|
|
stream: TcpStream,
|
|
addr: SocketAddr,
|
|
state: Arc<RwLock<RelayState>>,
|
|
keepalive_interval: u64,
|
|
_keepalive_timeout: u64,
|
|
) {
|
|
// Extract the URI path during handshake to determine role
|
|
let mut role = ConnectionRole::Receiver("default".to_string());
|
|
|
|
// Peek at the path to check for meta requests (handle via raw HTTP, not WS)
|
|
let ws_stream = match tokio_tungstenite::accept_hdr_async(
|
|
stream,
|
|
|req: &tokio_tungstenite::tungstenite::handshake::server::Request,
|
|
res: tokio_tungstenite::tungstenite::handshake::server::Response| {
|
|
role = parse_path(req.uri().path());
|
|
Ok(res)
|
|
},
|
|
)
|
|
.await
|
|
{
|
|
Ok(ws) => ws,
|
|
Err(e) => {
|
|
// Meta requests won't have an Upgrade header, so they fail the handshake
|
|
// That's fine — they were handled inline if detected.
|
|
if !matches!(role, ConnectionRole::Meta(_)) {
|
|
eprintln!("[relay] WebSocket handshake failed from {}: {}", addr, e);
|
|
}
|
|
return;
|
|
}
|
|
};
|
|
|
|
// Handle meta requests separately (they don't use WebSocket)
|
|
if let ConnectionRole::Meta(ref _channel_name) = role {
|
|
// Meta requests fail the WS handshake, so we'll never reach here
|
|
// But just in case, close the connection
|
|
return;
|
|
}
|
|
|
|
// Get or create the channel
|
|
let channel_name = match &role {
|
|
ConnectionRole::Source(n) | ConnectionRole::Receiver(n) | ConnectionRole::Signaling(n) | ConnectionRole::Peer(n) | ConnectionRole::Meta(n) => n.clone(),
|
|
};
|
|
|
|
let channel = {
|
|
let mut s = state.write().await;
|
|
match s.get_or_create_channel(&channel_name) {
|
|
Some(ch) => ch,
|
|
None => {
|
|
eprintln!("[relay] Rejected connection from {addr}: max channels reached");
|
|
return;
|
|
}
|
|
}
|
|
};
|
|
|
|
// Check receiver limit
|
|
if let ConnectionRole::Receiver(_) = &role {
|
|
let mut cs = channel.write().await;
|
|
if cs.stats.connected_receivers >= cs.max_receivers {
|
|
eprintln!(
|
|
"[relay:{channel_name}] Rejected receiver {addr}: at capacity ({}/{})",
|
|
cs.stats.connected_receivers, cs.max_receivers
|
|
);
|
|
cs.stats.rejected_connections += 1;
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Legacy fallback: if path was `/` and no source exists, treat as source
|
|
let role = match role {
|
|
ConnectionRole::Receiver(ref name) if name == "default" => {
|
|
let cs = channel.read().await;
|
|
if !cs.stats.source_connected {
|
|
ConnectionRole::Source("default".to_string())
|
|
} else {
|
|
role
|
|
}
|
|
}
|
|
_ => role,
|
|
};
|
|
|
|
match role {
|
|
ConnectionRole::Source(ref _name) => {
|
|
eprintln!("[relay:{channel_name}] Source connected: {addr}");
|
|
// Get recording_dir from relay state
|
|
let recording_dir = {
|
|
let s = state.read().await;
|
|
s.recording_dir.clone()
|
|
};
|
|
handle_source(ws_stream, addr, channel, &channel_name, keepalive_interval, recording_dir).await;
|
|
}
|
|
ConnectionRole::Receiver(ref _name) => {
|
|
eprintln!("[relay:{channel_name}] Receiver connected: {addr}");
|
|
handle_receiver(ws_stream, addr, channel, &channel_name).await;
|
|
}
|
|
ConnectionRole::Signaling(ref _name) => {
|
|
eprintln!("[relay:{channel_name}] Signaling peer connected: {addr}");
|
|
handle_signaling(ws_stream, addr, channel, &channel_name).await;
|
|
}
|
|
ConnectionRole::Peer(ref _name) => {
|
|
eprintln!("[relay:{channel_name}] Peer connected: {addr}");
|
|
handle_peer(ws_stream, addr, channel, &channel_name).await;
|
|
}
|
|
ConnectionRole::Meta(_) => {
|
|
// Handled before WS handshake — should never reach here
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_source(
|
|
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
|
|
addr: SocketAddr,
|
|
channel: Arc<RwLock<ChannelState>>,
|
|
channel_name: &str,
|
|
keepalive_interval: u64,
|
|
recording_dir: Option<String>,
|
|
) {
|
|
let (mut ws_sink, mut ws_source) = ws_stream.split();
|
|
|
|
// Mark source as connected; if reconnecting, preserve cache
|
|
let input_rx = {
|
|
let mut cs = channel.write().await;
|
|
if cs.source_disconnect_time.is_some() {
|
|
eprintln!("[relay:{channel_name}] Source reconnected — cache preserved ({} diffs, sync={})",
|
|
cs.cache.pending_signal_diffs.len(),
|
|
cs.cache.last_signal_sync.is_some(),
|
|
);
|
|
}
|
|
cs.stats.source_connected = true;
|
|
cs.stats.total_connections += 1;
|
|
cs.source_disconnect_time = None;
|
|
cs.input_rx.take()
|
|
};
|
|
|
|
let frame_tx = {
|
|
let cs = channel.read().await;
|
|
cs.frame_tx.clone()
|
|
};
|
|
|
|
// Forward input events from receivers → source
|
|
if let Some(mut input_rx) = input_rx {
|
|
let channel_clone = channel.clone();
|
|
tokio::spawn(async move {
|
|
while let Some(input_bytes) = input_rx.recv().await {
|
|
let msg = Message::Binary(input_bytes.into());
|
|
if ws_sink.send(msg).await.is_err() {
|
|
break;
|
|
}
|
|
let mut cs = channel_clone.write().await;
|
|
cs.stats.inputs_relayed += 1;
|
|
}
|
|
});
|
|
}
|
|
|
|
// Keepalive: periodic pings
|
|
let channel_ping = channel.clone();
|
|
let ping_task = tokio::spawn(async move {
|
|
let mut tick = interval(Duration::from_secs(keepalive_interval));
|
|
let mut seq = 0u16;
|
|
loop {
|
|
tick.tick().await;
|
|
let cs = channel_ping.read().await;
|
|
if !cs.stats.source_connected {
|
|
break;
|
|
}
|
|
let _ = cs.frame_tx.send(crate::codec::ping(seq, 0));
|
|
drop(cs);
|
|
seq = seq.wrapping_add(1);
|
|
}
|
|
});
|
|
|
|
// Open recording file if configured
|
|
let mut recording_file = if let Some(ref dir) = recording_dir {
|
|
let path = format!("{}/{}.dsrec", dir, channel_name.replace('/', "_"));
|
|
match tokio::fs::OpenOptions::new()
|
|
.create(true)
|
|
.append(true)
|
|
.open(&path)
|
|
.await
|
|
{
|
|
Ok(f) => {
|
|
eprintln!("[relay:{channel_name}] Recording to {path}");
|
|
Some(f)
|
|
}
|
|
Err(e) => {
|
|
eprintln!("[relay:{channel_name}] Warning: could not open recording file {path}: {e}");
|
|
None
|
|
}
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Receive frames from source → broadcast to receivers
|
|
let channel_name_owned = channel_name.to_string();
|
|
while let Some(Ok(msg)) = ws_source.next().await {
|
|
if let Message::Binary(data) = msg {
|
|
let data_vec: Vec<u8> = data.into();
|
|
{
|
|
let mut cs = channel.write().await;
|
|
cs.stats.frames_relayed += 1;
|
|
cs.stats.bytes_relayed += data_vec.len() as u64;
|
|
|
|
// Update state cache
|
|
cs.cache.process_frame(&data_vec);
|
|
|
|
// Track frame-type-specific stats
|
|
if data_vec.len() >= HEADER_SIZE {
|
|
match FrameType::from_u8(data_vec[0]) {
|
|
Some(FrameType::SignalDiff) => cs.stats.signal_diffs_sent += 1,
|
|
Some(FrameType::Pixels) | Some(FrameType::Keyframe) |
|
|
Some(FrameType::SignalSync) => cs.stats.keyframes_sent += 1,
|
|
_ => {}
|
|
}
|
|
let ts = u32::from_le_bytes([
|
|
data_vec[4], data_vec[5], data_vec[6], data_vec[7],
|
|
]);
|
|
cs.stats.last_frame_timestamp = ts;
|
|
}
|
|
}
|
|
|
|
// Record frame to disk (length-delimited: [u32 len][frame bytes])
|
|
if let Some(ref mut file) = recording_file {
|
|
use tokio::io::AsyncWriteExt;
|
|
let len_bytes = (data_vec.len() as u32).to_le_bytes();
|
|
let _ = file.write_all(&len_bytes).await;
|
|
let _ = file.write_all(&data_vec).await;
|
|
}
|
|
|
|
// Broadcast to all receivers on this channel
|
|
let _ = frame_tx.send(data_vec);
|
|
}
|
|
}
|
|
|
|
// Source disconnected — start grace period, keep cache
|
|
ping_task.abort();
|
|
eprintln!("[relay:{channel_name_owned}] Source disconnected: {addr} (cache preserved for reconnect)");
|
|
let mut cs = channel.write().await;
|
|
cs.stats.source_connected = false;
|
|
cs.source_disconnect_time = Some(Instant::now());
|
|
// Recreate input channel for next source connection
|
|
let (new_input_tx, new_input_rx) = mpsc::channel(256);
|
|
cs.input_tx = new_input_tx;
|
|
cs.input_rx = Some(new_input_rx);
|
|
}
|
|
|
|
async fn handle_receiver(
|
|
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
|
|
addr: SocketAddr,
|
|
channel: Arc<RwLock<ChannelState>>,
|
|
channel_name: &str,
|
|
) {
|
|
let (mut ws_sink, mut ws_source) = ws_stream.split();
|
|
|
|
// Subscribe to frame broadcast and get catchup + schema
|
|
let (mut frame_rx, catchup_msgs, schema_frame) = {
|
|
let mut cs = channel.write().await;
|
|
cs.stats.connected_receivers += 1;
|
|
cs.stats.total_connections += 1;
|
|
if cs.stats.connected_receivers > cs.stats.peak_receivers {
|
|
cs.stats.peak_receivers = cs.stats.connected_receivers;
|
|
}
|
|
let rx = cs.frame_tx.subscribe();
|
|
let catchup = cs.cache.catchup_messages();
|
|
let schema = cs.schema.clone();
|
|
(rx, catchup, schema)
|
|
};
|
|
|
|
let input_tx = {
|
|
let cs = channel.read().await;
|
|
cs.input_tx.clone()
|
|
};
|
|
|
|
// Send cached schema (0x32) first so receiver knows what signals are available
|
|
let channel_name_owned = channel_name.to_string();
|
|
if let Some(schema_bytes) = schema_frame {
|
|
let _ = ws_sink.send(Message::Binary(schema_bytes.into())).await;
|
|
}
|
|
|
|
// Send cached state to late-joining receiver
|
|
if !catchup_msgs.is_empty() {
|
|
eprintln!(
|
|
"[relay:{channel_name_owned}] Sending {} catchup messages to {addr}",
|
|
catchup_msgs.len(),
|
|
);
|
|
for msg_bytes in catchup_msgs {
|
|
let msg = Message::Binary(msg_bytes.into());
|
|
if ws_sink.send(msg).await.is_err() {
|
|
eprintln!("[relay:{channel_name_owned}] Failed to send catchup to {addr}");
|
|
let mut cs = channel.write().await;
|
|
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Per-receiver signal filter (set via 0x33 SubscribeFilter)
|
|
let filter: Arc<tokio::sync::RwLock<Option<std::collections::HashSet<String>>>> =
|
|
Arc::new(tokio::sync::RwLock::new(None));
|
|
let filter_for_send = filter.clone();
|
|
|
|
// Forward frames from broadcast → this receiver (with optional filtering)
|
|
let send_task = tokio::spawn(async move {
|
|
loop {
|
|
match frame_rx.recv().await {
|
|
Ok(frame_bytes) => {
|
|
// Check if we need to filter signal frames
|
|
let filter_guard = filter_for_send.read().await;
|
|
let should_filter = filter_guard.is_some()
|
|
&& frame_bytes.len() >= 16
|
|
&& (frame_bytes[0] == 0x30 || frame_bytes[0] == 0x31);
|
|
|
|
if should_filter {
|
|
if let Some(ref wanted) = *filter_guard {
|
|
// Parse JSON payload, keep only wanted keys
|
|
let payload_len = u32::from_le_bytes([
|
|
frame_bytes[12], frame_bytes[13],
|
|
frame_bytes[14], frame_bytes[15],
|
|
]) as usize;
|
|
if frame_bytes.len() >= 16 + payload_len {
|
|
let payload = &frame_bytes[16..16 + payload_len];
|
|
if let Ok(mut obj) = serde_json::from_slice::<serde_json::Value>(payload) {
|
|
if let Some(map) = obj.as_object_mut() {
|
|
let keys: Vec<String> = map.keys().cloned().collect();
|
|
for k in keys {
|
|
// Keep _pid, _v (internal), and wanted fields
|
|
if k != "_pid" && k != "_v" && !wanted.contains(&k) {
|
|
map.remove(&k);
|
|
}
|
|
}
|
|
// Re-encode
|
|
let new_payload = serde_json::to_vec(&obj).unwrap_or_default();
|
|
let mut new_frame = Vec::with_capacity(16 + new_payload.len());
|
|
new_frame.extend_from_slice(&frame_bytes[..12]);
|
|
new_frame.extend_from_slice(&(new_payload.len() as u32).to_le_bytes());
|
|
new_frame.extend_from_slice(&new_payload);
|
|
drop(filter_guard);
|
|
if ws_sink.send(Message::Binary(new_frame.into())).await.is_err() {
|
|
break;
|
|
}
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
drop(filter_guard);
|
|
let msg = Message::Binary(frame_bytes.into());
|
|
if ws_sink.send(msg).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
|
eprintln!("[relay] Receiver {addr} lagged by {n} frames");
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
});
|
|
|
|
// Handle incoming messages from receiver (input events + 0x33 filter)
|
|
while let Some(Ok(msg)) = ws_source.next().await {
|
|
if let Message::Binary(data) = msg {
|
|
let data_vec: Vec<u8> = data.into();
|
|
// Parse SubscribeFilter (0x33)
|
|
if data_vec.len() >= 16 && data_vec[0] == 0x33 {
|
|
let payload_len = u32::from_le_bytes([
|
|
data_vec[12], data_vec[13], data_vec[14], data_vec[15],
|
|
]) as usize;
|
|
if data_vec.len() >= 16 + payload_len {
|
|
let payload = &data_vec[16..16 + payload_len];
|
|
if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(payload) {
|
|
if let Some(select_arr) = obj.get("select").and_then(|v| v.as_array()) {
|
|
let wanted: std::collections::HashSet<String> = select_arr
|
|
.iter()
|
|
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
|
.collect();
|
|
eprintln!(
|
|
"[relay:{channel_name_owned}] Receiver {addr} subscribed to: {:?}",
|
|
wanted
|
|
);
|
|
let mut f = filter.write().await;
|
|
*f = Some(wanted);
|
|
}
|
|
}
|
|
}
|
|
continue; // Don't forward filter frames to source
|
|
}
|
|
// Forward input events to source
|
|
let _ = input_tx.send(data_vec).await;
|
|
}
|
|
}
|
|
|
|
// Receiver disconnected
|
|
send_task.abort();
|
|
eprintln!("[relay:{channel_name_owned}] Receiver disconnected: {addr}");
|
|
let mut cs = channel.write().await;
|
|
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
|
|
}
|
|
|
|
/// Handle a WebRTC signaling connection.
|
|
///
|
|
/// Signaling peers exchange JSON messages (SDP offers/answers, ICE candidates)
|
|
/// over WebSocket. The relay broadcasts all text messages to all other peers
|
|
/// on the same channel, enabling peer-to-peer WebRTC setup.
|
|
async fn handle_signaling(
|
|
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
|
|
addr: SocketAddr,
|
|
channel: Arc<RwLock<ChannelState>>,
|
|
channel_name: &str,
|
|
) {
|
|
let (mut ws_sink, mut ws_source) = ws_stream.split();
|
|
|
|
// Subscribe to signaling broadcast
|
|
let mut sig_rx = {
|
|
let cs = channel.read().await;
|
|
cs.signaling_tx.subscribe()
|
|
};
|
|
|
|
let sig_tx = {
|
|
let cs = channel.read().await;
|
|
cs.signaling_tx.clone()
|
|
};
|
|
|
|
// Forward signaling messages from broadcast → this peer
|
|
let channel_name_owned = channel_name.to_string();
|
|
let send_task = tokio::spawn(async move {
|
|
loop {
|
|
match sig_rx.recv().await {
|
|
Ok(msg_text) => {
|
|
let msg = Message::Text(msg_text.into());
|
|
if ws_sink.send(msg).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
|
eprintln!("[relay:{channel_name_owned}] Signaling peer lagged by {n} messages");
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
});
|
|
|
|
// Forward signaling messages from this peer → broadcast to others
|
|
while let Some(Ok(msg)) = ws_source.next().await {
|
|
if let Message::Text(text) = msg {
|
|
let text_str: String = text.into();
|
|
let _ = sig_tx.send(text_str);
|
|
}
|
|
}
|
|
|
|
send_task.abort();
|
|
eprintln!("[relay:{channel_name}] Signaling peer disconnected: {addr}");
|
|
}
|
|
|
|
/// Handle a peer connection for bidirectional sync.
|
|
///
|
|
/// Peers are equal — every binary frame sent by any peer is broadcast to all
|
|
/// other peers on the same channel. No source/receiver distinction.
|
|
async fn handle_peer(
|
|
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
|
|
addr: SocketAddr,
|
|
channel: Arc<RwLock<ChannelState>>,
|
|
channel_name: &str,
|
|
) {
|
|
let (mut ws_sink, mut ws_source) = ws_stream.split();
|
|
|
|
// Subscribe to the frame broadcast channel (same one sources use)
|
|
let mut frame_rx = {
|
|
let cs = channel.read().await;
|
|
cs.frame_tx.subscribe()
|
|
};
|
|
|
|
let frame_tx = {
|
|
let cs = channel.read().await;
|
|
cs.frame_tx.clone()
|
|
};
|
|
|
|
// Track peer count
|
|
{
|
|
let mut cs = channel.write().await;
|
|
cs.stats.connected_receivers += 1;
|
|
cs.stats.peak_receivers = cs.stats.peak_receivers.max(cs.stats.connected_receivers);
|
|
}
|
|
|
|
// Send catchup state to late-joining peers
|
|
{
|
|
let cs = channel.read().await;
|
|
if cs.cache.has_state() {
|
|
let msgs = cs.cache.catchup_messages();
|
|
eprintln!("[relay:{channel_name}] Sending {} catchup messages to peer {addr}", msgs.len());
|
|
for msg in msgs {
|
|
let _ = ws_sink.send(Message::Binary(msg.into())).await;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Forward frames from broadcast → this peer
|
|
let channel_name_owned = channel_name.to_string();
|
|
let send_task = tokio::spawn(async move {
|
|
loop {
|
|
match frame_rx.recv().await {
|
|
Ok(frame_data) => {
|
|
if ws_sink.send(Message::Binary(frame_data.into())).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
|
eprintln!("[relay:{channel_name_owned}] Peer lagged by {n} frames");
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
});
|
|
|
|
// Forward frames from this peer → broadcast to all others
|
|
let channel_for_cache = channel.clone();
|
|
let _channel_name_cache = channel_name.to_string();
|
|
while let Some(Ok(msg)) = ws_source.next().await {
|
|
if let Message::Binary(data) = msg {
|
|
let data_vec: Vec<u8> = data.into();
|
|
// Intercept schema announcement (0x32) — cache, don't broadcast
|
|
if data_vec.len() >= 16 && data_vec[0] == 0x32 {
|
|
let payload_len = u32::from_le_bytes([data_vec[12], data_vec[13], data_vec[14], data_vec[15]]) as usize;
|
|
if data_vec.len() >= 16 + payload_len {
|
|
let mut cs = channel_for_cache.write().await;
|
|
cs.schema = Some(data_vec.clone());
|
|
}
|
|
continue; // Don't broadcast schema frames
|
|
}
|
|
// Cache for late joiners
|
|
{
|
|
let mut cs = channel_for_cache.write().await;
|
|
cs.cache.process_frame(&data_vec);
|
|
cs.stats.frames_relayed += 1;
|
|
cs.stats.bytes_relayed += data_vec.len() as u64;
|
|
}
|
|
// Broadcast to all peers (they'll filter out their own msgs by content)
|
|
let _ = frame_tx.send(data_vec);
|
|
}
|
|
}
|
|
|
|
send_task.abort();
|
|
{
|
|
let mut cs = channel.write().await;
|
|
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
|
|
}
|
|
eprintln!("[relay:{channel_name}] Peer disconnected: {addr}");
|
|
}
|
|
|
|
// ─── Tests ───
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn default_config() {
|
|
let config = RelayConfig::default();
|
|
assert_eq!(config.addr, "0.0.0.0:9100".parse::<SocketAddr>().unwrap());
|
|
assert_eq!(config.max_receivers, 64);
|
|
assert_eq!(config.frame_buffer_size, 16);
|
|
assert_eq!(config.keepalive_interval_secs, 10);
|
|
assert_eq!(config.max_channels, 256);
|
|
assert_eq!(config.channel_gc_interval_secs, 60);
|
|
assert_eq!(config.source_reconnect_grace_secs, 30);
|
|
}
|
|
|
|
#[test]
|
|
fn stats_default() {
|
|
let stats = RelayStats::default();
|
|
assert_eq!(stats.frames_relayed, 0);
|
|
assert_eq!(stats.bytes_relayed, 0);
|
|
assert!(!stats.source_connected);
|
|
assert_eq!(stats.keyframes_sent, 0);
|
|
assert_eq!(stats.signal_diffs_sent, 0);
|
|
assert_eq!(stats.peak_receivers, 0);
|
|
assert_eq!(stats.total_connections, 0);
|
|
assert_eq!(stats.rejected_connections, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn state_cache_signal_sync() {
|
|
let mut cache = StateCache::default();
|
|
|
|
// Simulate a SignalSync frame
|
|
let sync_json = br#"{"count":0}"#;
|
|
let sync_msg = crate::codec::signal_sync_frame(0, 100, sync_json);
|
|
cache.process_frame(&sync_msg);
|
|
|
|
assert!(cache.last_signal_sync.is_some());
|
|
assert!(cache.has_state());
|
|
assert_eq!(cache.pending_signal_diffs.len(), 0);
|
|
|
|
// Simulate a SignalDiff
|
|
let diff_json = br#"{"count":1}"#;
|
|
let diff_msg = crate::codec::signal_diff_frame(1, 200, diff_json);
|
|
cache.process_frame(&diff_msg);
|
|
|
|
assert_eq!(cache.pending_signal_diffs.len(), 1);
|
|
|
|
// Catchup should contain ONE merged frame (sync + diff consolidated)
|
|
let catchup = cache.catchup_messages();
|
|
assert_eq!(catchup.len(), 1);
|
|
|
|
// Verify the merged payload contains the updated count
|
|
let frame = &catchup[0];
|
|
assert!(frame.len() >= HEADER_SIZE);
|
|
let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize;
|
|
let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len];
|
|
let merged: serde_json::Value = serde_json::from_slice(payload).unwrap();
|
|
assert_eq!(merged["count"], 1); // diff applied on top of sync
|
|
}
|
|
|
|
#[test]
|
|
fn state_cache_signal_sync_resets_diffs() {
|
|
let mut cache = StateCache::default();
|
|
|
|
// Add some diffs
|
|
for i in 0..5 {
|
|
let json = format!(r#"{{"count":{}}}"#, i);
|
|
let msg = crate::codec::signal_diff_frame(i, i as u32 * 100, json.as_bytes());
|
|
cache.process_frame(&msg);
|
|
}
|
|
assert_eq!(cache.pending_signal_diffs.len(), 5);
|
|
|
|
// A new sync should clear diffs
|
|
let sync = crate::codec::signal_sync_frame(10, 1000, b"{}");
|
|
cache.process_frame(&sync);
|
|
assert_eq!(cache.pending_signal_diffs.len(), 0);
|
|
assert!(cache.last_signal_sync.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn state_cache_keyframe() {
|
|
let mut cache = StateCache::default();
|
|
|
|
// Simulate a keyframe pixel frame
|
|
let kf = crate::codec::pixel_frame(0, 100, 10, 10, &vec![0xFF; 400]);
|
|
cache.process_frame(&kf);
|
|
|
|
assert!(cache.last_keyframe.is_some());
|
|
assert!(cache.has_state());
|
|
let catchup = cache.catchup_messages();
|
|
assert_eq!(catchup.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn state_cache_diff_cap() {
|
|
let mut cache = StateCache::default();
|
|
|
|
// Add more than 1000 diffs — should auto-trim
|
|
for i in 0..1100u16 {
|
|
let json = format!(r#"{{"n":{}}}"#, i);
|
|
let msg = crate::codec::signal_diff_frame(i, i as u32, json.as_bytes());
|
|
cache.process_frame(&msg);
|
|
}
|
|
// Should have been trimmed: 1100 - 500 = 600 remaining after first trim at 1001
|
|
assert!(cache.pending_signal_diffs.len() <= 600);
|
|
}
|
|
|
|
#[test]
|
|
fn state_cache_clear() {
|
|
let mut cache = StateCache::default();
|
|
let sync = crate::codec::signal_sync_frame(0, 0, b"{}");
|
|
cache.process_frame(&sync);
|
|
assert!(cache.has_state());
|
|
cache.clear();
|
|
assert!(!cache.has_state());
|
|
assert!(cache.last_signal_sync.is_none());
|
|
}
|
|
|
|
// ─── Path Routing Tests ───
|
|
|
|
#[test]
|
|
fn parse_path_source_default() {
|
|
match parse_path("/source") {
|
|
ConnectionRole::Source(name) => assert_eq!(name, "default"),
|
|
_ => panic!("Expected Source"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_path_source_named() {
|
|
match parse_path("/source/main") {
|
|
ConnectionRole::Source(name) => assert_eq!(name, "main"),
|
|
_ => panic!("Expected Source"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_path_stream_default() {
|
|
match parse_path("/stream") {
|
|
ConnectionRole::Receiver(name) => assert_eq!(name, "default"),
|
|
_ => panic!("Expected Receiver"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_path_stream_named() {
|
|
match parse_path("/stream/player1") {
|
|
ConnectionRole::Receiver(name) => assert_eq!(name, "player1"),
|
|
_ => panic!("Expected Receiver"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_path_signal_default() {
|
|
match parse_path("/signal") {
|
|
ConnectionRole::Signaling(name) => assert_eq!(name, "default"),
|
|
_ => panic!("Expected Signaling"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_path_signal_named() {
|
|
match parse_path("/signal/main") {
|
|
ConnectionRole::Signaling(name) => assert_eq!(name, "main"),
|
|
_ => panic!("Expected Signaling"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_path_legacy_root() {
|
|
match parse_path("/") {
|
|
ConnectionRole::Receiver(name) => assert_eq!(name, "default"),
|
|
_ => panic!("Expected Receiver for legacy path"),
|
|
}
|
|
}
|
|
|
|
// ─── v0.14: Peer + Meta Path Tests ───
|
|
|
|
#[test]
|
|
fn parse_path_peer_default() {
|
|
match parse_path("/peer") {
|
|
ConnectionRole::Peer(name) => assert_eq!(name, "default"),
|
|
_ => panic!("Expected Peer"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_path_peer_named() {
|
|
match parse_path("/peer/room42") {
|
|
ConnectionRole::Peer(name) => assert_eq!(name, "room42"),
|
|
_ => panic!("Expected Peer"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_path_meta_default() {
|
|
match parse_path("/meta") {
|
|
ConnectionRole::Meta(name) => assert_eq!(name, "default"),
|
|
_ => panic!("Expected Meta"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn parse_path_meta_named() {
|
|
match parse_path("/meta/stats") {
|
|
ConnectionRole::Meta(name) => assert_eq!(name, "stats"),
|
|
_ => panic!("Expected Meta"),
|
|
}
|
|
}
|
|
|
|
// ─── Channel State Tests ───
|
|
|
|
#[test]
|
|
fn channel_state_creation() {
|
|
let mut state = RelayState::new(16, 64, 256, 0, None);
|
|
let ch1 = state.get_or_create_channel("main").unwrap();
|
|
let ch2 = state.get_or_create_channel("player1").unwrap();
|
|
let ch1_again = state.get_or_create_channel("main").unwrap();
|
|
|
|
assert_eq!(state.channels.len(), 2);
|
|
assert!(Arc::ptr_eq(&ch1, &ch1_again));
|
|
assert!(!Arc::ptr_eq(&ch1, &ch2));
|
|
}
|
|
|
|
#[test]
|
|
fn channel_max_limit() {
|
|
let mut state = RelayState::new(16, 64, 2, 0, None);
|
|
assert!(state.get_or_create_channel("a").is_some());
|
|
assert!(state.get_or_create_channel("b").is_some());
|
|
assert!(state.get_or_create_channel("c").is_none()); // max reached
|
|
assert!(state.get_or_create_channel("a").is_some()); // existing OK
|
|
}
|
|
|
|
#[test]
|
|
fn channel_idle_detection() {
|
|
let cs = ChannelState::new(16, 64, 0);
|
|
assert!(cs.is_idle()); // no source, no receivers, no cache
|
|
}
|
|
|
|
#[test]
|
|
fn channel_not_idle_with_cache() {
|
|
let mut cs = ChannelState::new(16, 64, 0);
|
|
let sync = crate::codec::signal_sync_frame(0, 0, b"{}");
|
|
cs.cache.process_frame(&sync);
|
|
assert!(!cs.is_idle()); // has cached state
|
|
}
|
|
|
|
#[test]
|
|
fn channel_not_idle_with_source() {
|
|
let mut cs = ChannelState::new(16, 64, 0);
|
|
cs.stats.source_connected = true;
|
|
assert!(!cs.is_idle());
|
|
}
|
|
|
|
#[test]
|
|
fn channel_not_idle_with_receivers() {
|
|
let mut cs = ChannelState::new(16, 64, 0);
|
|
cs.stats.connected_receivers = 1;
|
|
assert!(!cs.is_idle());
|
|
}
|
|
|
|
#[test]
|
|
fn grace_period_not_expired_initially() {
|
|
let cs = ChannelState::new(16, 64, 0);
|
|
assert!(!cs.grace_period_expired(30));
|
|
}
|
|
|
|
#[test]
|
|
fn grace_period_expired_after_disconnect() {
|
|
let mut cs = ChannelState::new(16, 64, 0);
|
|
cs.source_disconnect_time = Some(Instant::now() - Duration::from_secs(60));
|
|
assert!(cs.grace_period_expired(30));
|
|
}
|
|
|
|
// ─── Diff Merging Tests ───
|
|
|
|
#[test]
|
|
fn catchup_merges_multiple_diffs() {
|
|
let mut cache = StateCache::default();
|
|
|
|
// Start with sync: { count: 0, name: "test" }
|
|
let sync = crate::codec::signal_sync_frame(0, 0, br#"{"count":0,"name":"test"}"#);
|
|
cache.process_frame(&sync);
|
|
|
|
// Apply 3 diffs
|
|
let d1 = crate::codec::signal_diff_frame(1, 100, br#"{"count":1}"#);
|
|
let d2 = crate::codec::signal_diff_frame(2, 200, br#"{"count":2}"#);
|
|
let d3 = crate::codec::signal_diff_frame(3, 300, br#"{"count":3,"name":"updated"}"#);
|
|
cache.process_frame(&d1);
|
|
cache.process_frame(&d2);
|
|
cache.process_frame(&d3);
|
|
|
|
// Should produce 1 merged frame, not 4
|
|
let catchup = cache.catchup_messages();
|
|
assert_eq!(catchup.len(), 1);
|
|
|
|
// Verify merged state
|
|
let frame = &catchup[0];
|
|
let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize;
|
|
let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len];
|
|
let merged: serde_json::Value = serde_json::from_slice(payload).unwrap();
|
|
assert_eq!(merged["count"], 3);
|
|
assert_eq!(merged["name"], "updated");
|
|
}
|
|
|
|
#[test]
|
|
fn catchup_diffs_only_no_sync() {
|
|
let mut cache = StateCache::default();
|
|
|
|
// Only diffs, no initial sync (first connection scenario)
|
|
let d1 = crate::codec::signal_diff_frame(0, 0, br#"{"mood":"happy"}"#);
|
|
let d2 = crate::codec::signal_diff_frame(1, 100, br#"{"energy":75}"#);
|
|
cache.process_frame(&d1);
|
|
cache.process_frame(&d2);
|
|
|
|
// Should produce 1 merged frame
|
|
let catchup = cache.catchup_messages();
|
|
assert_eq!(catchup.len(), 1);
|
|
|
|
let frame = &catchup[0];
|
|
let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize;
|
|
let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len];
|
|
let merged: serde_json::Value = serde_json::from_slice(payload).unwrap();
|
|
assert_eq!(merged["mood"], "happy");
|
|
assert_eq!(merged["energy"], 75);
|
|
}
|
|
|
|
#[test]
|
|
fn catchup_preserves_version_counters() {
|
|
let mut cache = StateCache::default();
|
|
|
|
let sync = crate::codec::signal_sync_frame(0, 0, br#"{"count":0,"_v":{"count":0}}"#);
|
|
cache.process_frame(&sync);
|
|
|
|
let d1 = crate::codec::signal_diff_frame(1, 100, br#"{"count":5,"_v":{"count":3}}"#);
|
|
cache.process_frame(&d1);
|
|
|
|
let catchup = cache.catchup_messages();
|
|
assert_eq!(catchup.len(), 1);
|
|
|
|
let frame = &catchup[0];
|
|
let payload_len = u32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) as usize;
|
|
let payload = &frame[HEADER_SIZE..HEADER_SIZE + payload_len];
|
|
let merged: serde_json::Value = serde_json::from_slice(payload).unwrap();
|
|
assert_eq!(merged["count"], 5);
|
|
assert_eq!(merged["_v"]["count"], 3); // version preserved from diff
|
|
}
|
|
|
|
// ─── Channel Wildcard Tests ───
|
|
|
|
#[test]
|
|
fn channel_matches_exact() {
|
|
assert!(channel_matches("games/chess", "games/chess"));
|
|
assert!(!channel_matches("games/chess", "games/go"));
|
|
}
|
|
|
|
#[test]
|
|
fn channel_matches_wildcard() {
|
|
assert!(channel_matches("games/*", "games/chess"));
|
|
assert!(channel_matches("games/*", "games/go"));
|
|
assert!(!channel_matches("games/*", "other/chess"));
|
|
assert!(!channel_matches("games/*", "games")); // no trailing segment
|
|
}
|
|
|
|
#[test]
|
|
fn channel_matches_star_all() {
|
|
assert!(channel_matches("*", "anything"));
|
|
assert!(channel_matches("*", "games/chess"));
|
|
}
|
|
|
|
#[test]
|
|
fn find_matching_channels_works() {
|
|
let mut state = RelayState::new(16, 64, 256, 0, None);
|
|
state.get_or_create_channel("games/chess").unwrap();
|
|
state.get_or_create_channel("games/go").unwrap();
|
|
state.get_or_create_channel("chat/main").unwrap();
|
|
let mut matches = find_matching_channels(&state, "games/*");
|
|
matches.sort();
|
|
assert_eq!(matches, vec!["games/chess", "games/go"]);
|
|
}
|
|
|
|
// ─── Replay Buffer Tests ───
|
|
|
|
#[test]
|
|
fn replay_buffer_stores_frames() {
|
|
let mut cache = StateCache { replay_depth: 3, ..Default::default() };
|
|
let f1 = crate::codec::signal_diff_frame(0, 0, b"{\"a\":1}");
|
|
let f2 = crate::codec::signal_diff_frame(1, 100, b"{\"a\":2}");
|
|
let f3 = crate::codec::signal_diff_frame(2, 200, b"{\"a\":3}");
|
|
cache.process_frame(&f1);
|
|
cache.process_frame(&f2);
|
|
cache.process_frame(&f3);
|
|
assert_eq!(cache.replay_len(), 3);
|
|
assert_eq!(cache.replay_frames(0).len(), 3);
|
|
assert_eq!(cache.replay_frames(1).len(), 2);
|
|
assert_eq!(cache.replay_frames(3).len(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn replay_buffer_evicts_oldest() {
|
|
let mut cache = StateCache { replay_depth: 2, ..Default::default() };
|
|
let f1 = crate::codec::signal_diff_frame(0, 0, b"{\"a\":1}");
|
|
let f2 = crate::codec::signal_diff_frame(1, 100, b"{\"a\":2}");
|
|
let f3 = crate::codec::signal_diff_frame(2, 200, b"{\"a\":3}");
|
|
cache.process_frame(&f1);
|
|
cache.process_frame(&f2);
|
|
cache.process_frame(&f3);
|
|
assert_eq!(cache.replay_len(), 2);
|
|
// First frame should be evicted
|
|
assert_eq!(cache.replay_buffer[0], f2);
|
|
assert_eq!(cache.replay_buffer[1], f3);
|
|
}
|
|
|
|
#[test]
|
|
fn replay_depth_propagates_to_channel() {
|
|
let mut state = RelayState::new(16, 64, 256, 100, None);
|
|
let ch = state.get_or_create_channel("test").unwrap();
|
|
let cs = ch.try_read().unwrap();
|
|
assert_eq!(cs.cache.replay_depth, 100);
|
|
}
|
|
|
|
#[test]
|
|
fn replay_disabled_when_zero() {
|
|
let mut cache = StateCache::default();
|
|
assert_eq!(cache.replay_depth, 0);
|
|
let f1 = crate::codec::signal_diff_frame(0, 0, b"{\"a\":1}");
|
|
cache.process_frame(&f1);
|
|
assert_eq!(cache.replay_len(), 0); // Nothing stored when depth=0
|
|
}
|
|
|
|
// ─── v1.2: Channel Auth Tests ───
|
|
|
|
#[test]
|
|
fn channel_auth_open_allows_all() {
|
|
let auth = ChannelAuth::open();
|
|
assert!(auth.is_authenticated("any-source"));
|
|
assert!(!auth.is_required());
|
|
}
|
|
|
|
#[test]
|
|
fn channel_auth_rejects_unauthed() {
|
|
let auth = ChannelAuth::with_key(b"secret-123");
|
|
assert!(auth.is_required());
|
|
assert!(!auth.is_authenticated("source-1"));
|
|
}
|
|
|
|
#[test]
|
|
fn channel_auth_allows_authed() {
|
|
let mut auth = ChannelAuth::with_key(b"secret-123");
|
|
assert!(!auth.authenticate("source-1", b"wrong-key"));
|
|
assert!(!auth.is_authenticated("source-1"));
|
|
assert!(auth.authenticate("source-1", b"secret-123"));
|
|
assert!(auth.is_authenticated("source-1"));
|
|
assert_eq!(auth.authenticated_count(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn channel_auth_revoke() {
|
|
let mut auth = ChannelAuth::with_key(b"key");
|
|
auth.authenticate("src", b"key");
|
|
assert!(auth.is_authenticated("src"));
|
|
auth.revoke("src");
|
|
assert!(!auth.is_authenticated("src"));
|
|
assert_eq!(auth.authenticated_count(), 0);
|
|
}
|
|
}
|