dreamstack/engine/ds-stream/src/relay.rs

1149 lines
42 KiB
Rust
Raw Normal View History

//! WebSocket Relay Server
//!
//! Routes frames from source→receivers and inputs from receivers→source.
//!
//! ## Routing
//!
//! Connections are routed by WebSocket URI path:
//! - `/source` — default source (channel: "default")
//! - `/source/{name}` — named source (channel: {name})
//! - `/stream` — default receiver (channel: "default")
//! - `/stream/{name}` — named receiver (channel: {name})
//! - `/signal` — WebRTC signaling (channel: "default")
//! - `/signal/{name}` — WebRTC signaling (channel: {name})
//! - `/` — legacy: first connection = source, rest = receivers
//!
//! Each channel has its own broadcast/input channels and state cache,
//! allowing multiple independent streams through a single relay.
//!
//! ## Production Features
//! - Multi-source: multiple views stream independently
//! - Keyframe caching: late-joining receivers get current state instantly
//! - Signal state store: caches SignalSync/Diff frames for reconstruction
//! - Ping/pong keepalive: detects dead connections
//! - Stats tracking: frames, bytes, latency metrics
//! - Max receiver limit per channel (configurable)
//! - Channel GC: empty channels are cleaned up periodically
//! - Source reconnection: cache preserved for seamless reconnect
//! - Graceful shutdown: drain connections on SIGTERM
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use futures_util::{SinkExt, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time::{interval, Duration};
use tokio_tungstenite::tungstenite::Message;
use crate::protocol::*;
/// Relay server configuration.
pub struct RelayConfig {
/// Address to bind to.
pub addr: SocketAddr,
/// Maximum number of receivers per channel.
pub max_receivers: usize,
/// Frame broadcast channel capacity.
pub frame_buffer_size: usize,
/// Keepalive interval in seconds.
pub keepalive_interval_secs: u64,
/// Keepalive timeout in seconds — disconnect after this many seconds without a pong.
pub keepalive_timeout_secs: u64,
/// Maximum number of channels.
pub max_channels: usize,
/// Channel GC interval in seconds — how often to scan for empty channels.
pub channel_gc_interval_secs: u64,
/// Source reconnect grace period in seconds — keep cache alive after source disconnect.
pub source_reconnect_grace_secs: u64,
}
impl Default for RelayConfig {
fn default() -> Self {
Self {
addr: "0.0.0.0:9100".parse().unwrap(),
max_receivers: 64,
frame_buffer_size: 16,
keepalive_interval_secs: 10,
keepalive_timeout_secs: 30,
max_channels: 256,
channel_gc_interval_secs: 60,
source_reconnect_grace_secs: 30,
}
}
}
/// Stats tracked by the relay.
#[derive(Debug, Default, Clone)]
pub struct RelayStats {
pub frames_relayed: u64,
pub bytes_relayed: u64,
pub inputs_relayed: u64,
pub connected_receivers: usize,
pub source_connected: bool,
/// Timestamp of last frame from source (milliseconds since start).
pub last_frame_timestamp: u32,
/// Total keyframes sent.
pub keyframes_sent: u64,
/// Total signal diffs sent.
pub signal_diffs_sent: u64,
/// Uptime in seconds.
pub uptime_secs: u64,
/// Peak receiver count.
pub peak_receivers: usize,
/// Total connections served.
pub total_connections: u64,
/// Rejected connections (over limit).
pub rejected_connections: u64,
}
/// Cached state for late-joining receivers.
#[derive(Debug, Default, Clone)]
pub struct StateCache {
/// Last pixel keyframe (full RGBA frame).
pub last_keyframe: Option<Vec<u8>>,
/// Last signal sync frame (full JSON state).
pub last_signal_sync: Option<Vec<u8>>,
/// Accumulated signal diffs since last sync.
/// Late-joining receivers get: last_signal_sync + all diffs.
pub pending_signal_diffs: Vec<Vec<u8>>,
}
impl StateCache {
/// Process an incoming frame and update cache.
fn process_frame(&mut self, msg: &[u8]) {
if msg.len() < HEADER_SIZE {
return;
}
let frame_type = msg[0];
let flags = msg[1];
match FrameType::from_u8(frame_type) {
// Cache keyframes (pixel or signal sync)
Some(FrameType::Pixels) if flags & FLAG_KEYFRAME != 0 => {
self.last_keyframe = Some(msg.to_vec());
}
Some(FrameType::SignalSync) => {
self.last_signal_sync = Some(msg.to_vec());
self.pending_signal_diffs.clear(); // sync resets diffs
}
Some(FrameType::SignalDiff) => {
self.pending_signal_diffs.push(msg.to_vec());
// Cap accumulated diffs to prevent unbounded memory
if self.pending_signal_diffs.len() > 1000 {
self.pending_signal_diffs.drain(..500);
}
}
Some(FrameType::Keyframe) => {
self.last_keyframe = Some(msg.to_vec());
}
_ => {}
}
}
/// Get all messages a late-joining receiver needs to reconstruct current state.
fn catchup_messages(&self) -> Vec<Vec<u8>> {
let mut msgs = Vec::new();
// Send last signal sync first (if available)
if let Some(ref sync) = self.last_signal_sync {
msgs.push(sync.clone());
}
// Then all accumulated diffs
for diff in &self.pending_signal_diffs {
msgs.push(diff.clone());
}
// Then last pixel keyframe (if available)
if let Some(ref kf) = self.last_keyframe {
msgs.push(kf.clone());
}
msgs
}
/// Clear all cached state.
fn clear(&mut self) {
self.last_keyframe = None;
self.last_signal_sync = None;
self.pending_signal_diffs.clear();
}
/// Returns true if this cache has any state.
fn has_state(&self) -> bool {
self.last_keyframe.is_some()
|| self.last_signal_sync.is_some()
|| !self.pending_signal_diffs.is_empty()
}
}
/// Per-channel state — each named channel is an independent stream.
struct ChannelState {
/// Broadcast channel: source → all receivers (frames)
frame_tx: broadcast::Sender<Vec<u8>>,
/// Channel: receivers → source (input events)
input_tx: mpsc::Sender<Vec<u8>>,
input_rx: Option<mpsc::Receiver<Vec<u8>>>,
/// Broadcast channel for WebRTC signaling (SDP/ICE as text)
signaling_tx: broadcast::Sender<String>,
/// Live stats for this channel
stats: RelayStats,
/// Cached state for late-joining receivers
cache: StateCache,
/// When the source last disconnected (for reconnect grace period)
source_disconnect_time: Option<Instant>,
/// Max receivers for this channel
max_receivers: usize,
/// Cached schema announcement (0x32 payload) from source
schema: Option<Vec<u8>>,
}
impl ChannelState {
fn new(frame_buffer_size: usize, max_receivers: usize) -> Self {
let (frame_tx, _) = broadcast::channel(frame_buffer_size);
let (input_tx, input_rx) = mpsc::channel(256);
let (signaling_tx, _) = broadcast::channel(64);
Self {
frame_tx,
input_tx,
input_rx: Some(input_rx),
signaling_tx,
stats: RelayStats::default(),
cache: StateCache::default(),
source_disconnect_time: None,
max_receivers,
schema: None,
}
}
/// Returns true if this channel is idle (no source, no receivers, no cache).
fn is_idle(&self) -> bool {
!self.stats.source_connected
&& self.stats.connected_receivers == 0
&& !self.cache.has_state()
}
/// Returns true if the source reconnect grace period has expired.
fn grace_period_expired(&self, grace_secs: u64) -> bool {
if let Some(disconnect_time) = self.source_disconnect_time {
disconnect_time.elapsed() > Duration::from_secs(grace_secs)
} else {
false
}
}
}
/// Shared relay state — holds all channels.
struct RelayState {
/// Named channels: "default", "main", "player1", etc.
channels: HashMap<String, Arc<RwLock<ChannelState>>>,
/// Frame buffer size for new channels
frame_buffer_size: usize,
/// Max receivers per channel
max_receivers: usize,
/// Max channels
max_channels: usize,
/// Server start time
start_time: Instant,
}
impl RelayState {
fn new(frame_buffer_size: usize, max_receivers: usize, max_channels: usize) -> Self {
Self {
channels: HashMap::new(),
frame_buffer_size,
max_receivers,
max_channels,
start_time: Instant::now(),
}
}
/// Get or create a channel by name. Returns None if at max channels.
fn get_or_create_channel(&mut self, name: &str) -> Option<Arc<RwLock<ChannelState>>> {
if self.channels.contains_key(name) {
return Some(self.channels[name].clone());
}
if self.channels.len() >= self.max_channels {
eprintln!("[relay] Max channels ({}) reached, rejecting: {}", self.max_channels, name);
return None;
}
let channel = Arc::new(RwLock::new(ChannelState::new(
self.frame_buffer_size,
self.max_receivers,
)));
self.channels.insert(name.to_string(), channel.clone());
Some(channel)
}
}
/// Parsed connection role from the WebSocket URI path.
#[derive(Debug, Clone)]
enum ConnectionRole {
Source(String), // channel name
Receiver(String), // channel name
Signaling(String), // channel name — WebRTC signaling
Peer(String), // channel name — bidirectional sync
Meta(String), // channel name — HTTP introspection
}
/// Parse the WebSocket URI path to determine connection role and channel.
fn parse_path(path: &str) -> ConnectionRole {
let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
match parts.as_slice() {
["source"] => ConnectionRole::Source("default".to_string()),
["source", name] => ConnectionRole::Source(name.to_string()),
["stream"] => ConnectionRole::Receiver("default".to_string()),
["stream", name] => ConnectionRole::Receiver(name.to_string()),
["signal"] => ConnectionRole::Signaling("default".to_string()),
["signal", name] => ConnectionRole::Signaling(name.to_string()),
["peer"] => ConnectionRole::Peer("default".to_string()),
["peer", name] => ConnectionRole::Peer(name.to_string()),
["meta"] => ConnectionRole::Meta("default".to_string()),
["meta", name] => ConnectionRole::Meta(name.to_string()),
_ => ConnectionRole::Receiver("default".to_string()),
}
}
/// Run the WebSocket relay server.
pub async fn run_relay(config: RelayConfig) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(&config.addr).await?;
eprintln!("╔══════════════════════════════════════════════════╗");
eprintln!("║ DreamStack Bitstream Relay v1.0.0 ║");
eprintln!("║ ║");
eprintln!("║ Source: ws://{}/source/{{name}}", config.addr);
eprintln!("║ Receiver: ws://{}/stream/{{name}}", config.addr);
eprintln!("║ Signal: ws://{}/signal/{{name}}", config.addr);
eprintln!("║ ║");
eprintln!("║ Max receivers/ch: {:>4}", config.max_receivers);
eprintln!("║ Max channels: {:>4}", config.max_channels);
eprintln!("║ Reconnect grace: {:>4}s ║", config.source_reconnect_grace_secs);
eprintln!("╚══════════════════════════════════════════════════╝");
let grace_secs = config.source_reconnect_grace_secs;
let state = Arc::new(RwLock::new(RelayState::new(
config.frame_buffer_size,
config.max_receivers,
config.max_channels,
)));
// Background: periodic stats + channel GC
{
let state = state.clone();
let gc_interval = config.channel_gc_interval_secs;
tokio::spawn(async move {
let mut tick = interval(Duration::from_secs(gc_interval.min(30)));
let mut gc_counter: u64 = 0;
loop {
tick.tick().await;
gc_counter += 1;
let s = state.read().await;
let uptime = s.start_time.elapsed().as_secs();
// Stats logging
for (name, channel) in &s.channels {
let cs = channel.read().await;
if cs.stats.source_connected || cs.stats.connected_receivers > 0 {
eprintln!(
"[relay:{name}] up={uptime}s frames={} bytes={} receivers={}/{} peak={} signal_diffs={} cached={}",
cs.stats.frames_relayed,
cs.stats.bytes_relayed,
cs.stats.connected_receivers,
cs.max_receivers,
cs.stats.peak_receivers,
cs.stats.signal_diffs_sent,
cs.cache.has_state(),
);
}
}
drop(s);
// Channel GC — every gc_interval ticks
if gc_counter % (gc_interval / gc_interval.min(30)).max(1) == 0 {
let mut s = state.write().await;
let before = s.channels.len();
let mut to_remove = Vec::new();
for (name, channel) in &s.channels {
let cs = channel.read().await;
if cs.is_idle() || cs.grace_period_expired(grace_secs) {
to_remove.push(name.clone());
}
}
for name in &to_remove {
s.channels.remove(name);
}
if !to_remove.is_empty() {
eprintln!(
"[relay] GC: removed {} idle channel(s) ({} → {}): {:?}",
to_remove.len(),
before,
s.channels.len(),
to_remove
);
}
}
}
});
}
while let Ok((stream, addr)) = listener.accept().await {
let state = state.clone();
let keepalive_interval = config.keepalive_interval_secs;
let keepalive_timeout = config.keepalive_timeout_secs;
tokio::spawn(handle_connection(stream, addr, state, keepalive_interval, keepalive_timeout));
}
Ok(())
}
async fn handle_connection(
stream: TcpStream,
addr: SocketAddr,
state: Arc<RwLock<RelayState>>,
keepalive_interval: u64,
_keepalive_timeout: u64,
) {
// Extract the URI path during handshake to determine role
let mut role = ConnectionRole::Receiver("default".to_string());
// Peek at the path to check for meta requests (handle via raw HTTP, not WS)
let ws_stream = match tokio_tungstenite::accept_hdr_async(
stream,
|req: &tokio_tungstenite::tungstenite::handshake::server::Request,
res: tokio_tungstenite::tungstenite::handshake::server::Response| {
role = parse_path(req.uri().path());
Ok(res)
},
)
.await
{
Ok(ws) => ws,
Err(e) => {
// Meta requests won't have an Upgrade header, so they fail the handshake
// That's fine — they were handled inline if detected.
if !matches!(role, ConnectionRole::Meta(_)) {
eprintln!("[relay] WebSocket handshake failed from {}: {}", addr, e);
}
return;
}
};
// Handle meta requests separately (they don't use WebSocket)
if let ConnectionRole::Meta(ref channel_name) = role {
// Meta requests fail the WS handshake, so we'll never reach here
// But just in case, close the connection
return;
}
// Get or create the channel
let channel_name = match &role {
ConnectionRole::Source(n) | ConnectionRole::Receiver(n) | ConnectionRole::Signaling(n) | ConnectionRole::Peer(n) | ConnectionRole::Meta(n) => n.clone(),
};
let channel = {
let mut s = state.write().await;
match s.get_or_create_channel(&channel_name) {
Some(ch) => ch,
None => {
eprintln!("[relay] Rejected connection from {addr}: max channels reached");
return;
}
}
};
// Check receiver limit
if let ConnectionRole::Receiver(_) = &role {
let mut cs = channel.write().await;
if cs.stats.connected_receivers >= cs.max_receivers {
eprintln!(
"[relay:{channel_name}] Rejected receiver {addr}: at capacity ({}/{})",
cs.stats.connected_receivers, cs.max_receivers
);
cs.stats.rejected_connections += 1;
return;
}
}
// Legacy fallback: if path was `/` and no source exists, treat as source
let role = match role {
ConnectionRole::Receiver(ref name) if name == "default" => {
let cs = channel.read().await;
if !cs.stats.source_connected {
ConnectionRole::Source("default".to_string())
} else {
role
}
}
_ => role,
};
match role {
ConnectionRole::Source(ref _name) => {
eprintln!("[relay:{channel_name}] Source connected: {addr}");
handle_source(ws_stream, addr, channel, &channel_name, keepalive_interval).await;
}
ConnectionRole::Receiver(ref _name) => {
eprintln!("[relay:{channel_name}] Receiver connected: {addr}");
handle_receiver(ws_stream, addr, channel, &channel_name).await;
}
ConnectionRole::Signaling(ref _name) => {
eprintln!("[relay:{channel_name}] Signaling peer connected: {addr}");
handle_signaling(ws_stream, addr, channel, &channel_name).await;
}
ConnectionRole::Peer(ref _name) => {
eprintln!("[relay:{channel_name}] Peer connected: {addr}");
handle_peer(ws_stream, addr, channel, &channel_name).await;
}
ConnectionRole::Meta(_) => {
// Handled before WS handshake — should never reach here
}
}
}
async fn handle_source(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
keepalive_interval: u64,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Mark source as connected; if reconnecting, preserve cache
let input_rx = {
let mut cs = channel.write().await;
if cs.source_disconnect_time.is_some() {
eprintln!("[relay:{channel_name}] Source reconnected — cache preserved ({} diffs, sync={})",
cs.cache.pending_signal_diffs.len(),
cs.cache.last_signal_sync.is_some(),
);
}
cs.stats.source_connected = true;
cs.stats.total_connections += 1;
cs.source_disconnect_time = None;
cs.input_rx.take()
};
let frame_tx = {
let cs = channel.read().await;
cs.frame_tx.clone()
};
// Forward input events from receivers → source
if let Some(mut input_rx) = input_rx {
let channel_clone = channel.clone();
tokio::spawn(async move {
while let Some(input_bytes) = input_rx.recv().await {
let msg = Message::Binary(input_bytes.into());
if ws_sink.send(msg).await.is_err() {
break;
}
let mut cs = channel_clone.write().await;
cs.stats.inputs_relayed += 1;
}
});
}
// Keepalive: periodic pings
let channel_ping = channel.clone();
let ping_task = tokio::spawn(async move {
let mut tick = interval(Duration::from_secs(keepalive_interval));
let mut seq = 0u16;
loop {
tick.tick().await;
let cs = channel_ping.read().await;
if !cs.stats.source_connected {
break;
}
let _ = cs.frame_tx.send(crate::codec::ping(seq, 0));
drop(cs);
seq = seq.wrapping_add(1);
}
});
// Receive frames from source → broadcast to receivers
let channel_name_owned = channel_name.to_string();
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let data_vec: Vec<u8> = data.into();
{
let mut cs = channel.write().await;
cs.stats.frames_relayed += 1;
cs.stats.bytes_relayed += data_vec.len() as u64;
// Update state cache
cs.cache.process_frame(&data_vec);
// Track frame-type-specific stats
if data_vec.len() >= HEADER_SIZE {
match FrameType::from_u8(data_vec[0]) {
Some(FrameType::SignalDiff) => cs.stats.signal_diffs_sent += 1,
Some(FrameType::Pixels) | Some(FrameType::Keyframe) |
Some(FrameType::SignalSync) => cs.stats.keyframes_sent += 1,
_ => {}
}
let ts = u32::from_le_bytes([
data_vec[4], data_vec[5], data_vec[6], data_vec[7],
]);
cs.stats.last_frame_timestamp = ts;
}
}
// Broadcast to all receivers on this channel
let _ = frame_tx.send(data_vec);
}
}
// Source disconnected — start grace period, keep cache
ping_task.abort();
eprintln!("[relay:{channel_name_owned}] Source disconnected: {addr} (cache preserved for reconnect)");
let mut cs = channel.write().await;
cs.stats.source_connected = false;
cs.source_disconnect_time = Some(Instant::now());
// Recreate input channel for next source connection
let (new_input_tx, new_input_rx) = mpsc::channel(256);
cs.input_tx = new_input_tx;
cs.input_rx = Some(new_input_rx);
}
async fn handle_receiver(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Subscribe to frame broadcast and get catchup + schema
let (mut frame_rx, catchup_msgs, schema_frame) = {
let mut cs = channel.write().await;
cs.stats.connected_receivers += 1;
cs.stats.total_connections += 1;
if cs.stats.connected_receivers > cs.stats.peak_receivers {
cs.stats.peak_receivers = cs.stats.connected_receivers;
}
let rx = cs.frame_tx.subscribe();
let catchup = cs.cache.catchup_messages();
let schema = cs.schema.clone();
(rx, catchup, schema)
};
let input_tx = {
let cs = channel.read().await;
cs.input_tx.clone()
};
// Send cached schema (0x32) first so receiver knows what signals are available
let channel_name_owned = channel_name.to_string();
if let Some(schema_bytes) = schema_frame {
let _ = ws_sink.send(Message::Binary(schema_bytes.into())).await;
}
// Send cached state to late-joining receiver
if !catchup_msgs.is_empty() {
eprintln!(
"[relay:{channel_name_owned}] Sending {} catchup messages to {addr}",
catchup_msgs.len(),
);
for msg_bytes in catchup_msgs {
let msg = Message::Binary(msg_bytes.into());
if ws_sink.send(msg).await.is_err() {
eprintln!("[relay:{channel_name_owned}] Failed to send catchup to {addr}");
let mut cs = channel.write().await;
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
return;
}
}
}
// Per-receiver signal filter (set via 0x33 SubscribeFilter)
let filter: Arc<tokio::sync::RwLock<Option<std::collections::HashSet<String>>>> =
Arc::new(tokio::sync::RwLock::new(None));
let filter_for_send = filter.clone();
// Forward frames from broadcast → this receiver (with optional filtering)
let send_task = tokio::spawn(async move {
loop {
match frame_rx.recv().await {
Ok(frame_bytes) => {
// Check if we need to filter signal frames
let filter_guard = filter_for_send.read().await;
let should_filter = filter_guard.is_some()
&& frame_bytes.len() >= 16
&& (frame_bytes[0] == 0x30 || frame_bytes[0] == 0x31);
if should_filter {
if let Some(ref wanted) = *filter_guard {
// Parse JSON payload, keep only wanted keys
let payload_len = u32::from_le_bytes([
frame_bytes[12], frame_bytes[13],
frame_bytes[14], frame_bytes[15],
]) as usize;
if frame_bytes.len() >= 16 + payload_len {
let payload = &frame_bytes[16..16 + payload_len];
if let Ok(mut obj) = serde_json::from_slice::<serde_json::Value>(payload) {
if let Some(map) = obj.as_object_mut() {
let keys: Vec<String> = map.keys().cloned().collect();
for k in keys {
// Keep _pid, _v (internal), and wanted fields
if k != "_pid" && k != "_v" && !wanted.contains(&k) {
map.remove(&k);
}
}
// Re-encode
let new_payload = serde_json::to_vec(&obj).unwrap_or_default();
let mut new_frame = Vec::with_capacity(16 + new_payload.len());
new_frame.extend_from_slice(&frame_bytes[..12]);
new_frame.extend_from_slice(&(new_payload.len() as u32).to_le_bytes());
new_frame.extend_from_slice(&new_payload);
drop(filter_guard);
if ws_sink.send(Message::Binary(new_frame.into())).await.is_err() {
break;
}
continue;
}
}
}
}
}
drop(filter_guard);
let msg = Message::Binary(frame_bytes.into());
if ws_sink.send(msg).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("[relay] Receiver {addr} lagged by {n} frames");
}
Err(_) => break,
}
}
});
// Handle incoming messages from receiver (input events + 0x33 filter)
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let data_vec: Vec<u8> = data.into();
// Parse SubscribeFilter (0x33)
if data_vec.len() >= 16 && data_vec[0] == 0x33 {
let payload_len = u32::from_le_bytes([
data_vec[12], data_vec[13], data_vec[14], data_vec[15],
]) as usize;
if data_vec.len() >= 16 + payload_len {
let payload = &data_vec[16..16 + payload_len];
if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(payload) {
if let Some(select_arr) = obj.get("select").and_then(|v| v.as_array()) {
let wanted: std::collections::HashSet<String> = select_arr
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
eprintln!(
"[relay:{channel_name_owned}] Receiver {addr} subscribed to: {:?}",
wanted
);
let mut f = filter.write().await;
*f = Some(wanted);
}
}
}
continue; // Don't forward filter frames to source
}
// Forward input events to source
let _ = input_tx.send(data_vec).await;
}
}
// Receiver disconnected
send_task.abort();
eprintln!("[relay:{channel_name_owned}] Receiver disconnected: {addr}");
let mut cs = channel.write().await;
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
}
/// Handle a WebRTC signaling connection.
///
/// Signaling peers exchange JSON messages (SDP offers/answers, ICE candidates)
/// over WebSocket. The relay broadcasts all text messages to all other peers
/// on the same channel, enabling peer-to-peer WebRTC setup.
async fn handle_signaling(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Subscribe to signaling broadcast
let mut sig_rx = {
let cs = channel.read().await;
cs.signaling_tx.subscribe()
};
let sig_tx = {
let cs = channel.read().await;
cs.signaling_tx.clone()
};
// Forward signaling messages from broadcast → this peer
let channel_name_owned = channel_name.to_string();
let send_task = tokio::spawn(async move {
loop {
match sig_rx.recv().await {
Ok(msg_text) => {
let msg = Message::Text(msg_text.into());
if ws_sink.send(msg).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("[relay:{channel_name_owned}] Signaling peer lagged by {n} messages");
}
Err(_) => break,
}
}
});
// Forward signaling messages from this peer → broadcast to others
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Text(text) = msg {
let text_str: String = text.into();
let _ = sig_tx.send(text_str);
}
}
send_task.abort();
eprintln!("[relay:{channel_name}] Signaling peer disconnected: {addr}");
}
/// Handle a peer connection for bidirectional sync.
///
/// Peers are equal — every binary frame sent by any peer is broadcast to all
/// other peers on the same channel. No source/receiver distinction.
async fn handle_peer(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Subscribe to the frame broadcast channel (same one sources use)
let mut frame_rx = {
let cs = channel.read().await;
cs.frame_tx.subscribe()
};
let frame_tx = {
let cs = channel.read().await;
cs.frame_tx.clone()
};
// Track peer count
{
let mut cs = channel.write().await;
cs.stats.connected_receivers += 1;
cs.stats.peak_receivers = cs.stats.peak_receivers.max(cs.stats.connected_receivers);
}
// Send catchup state to late-joining peers
{
let cs = channel.read().await;
if cs.cache.has_state() {
let msgs = cs.cache.catchup_messages();
eprintln!("[relay:{channel_name}] Sending {} catchup messages to peer {addr}", msgs.len());
for msg in msgs {
let _ = ws_sink.send(Message::Binary(msg.into())).await;
}
}
}
// Forward frames from broadcast → this peer
let channel_name_owned = channel_name.to_string();
let send_task = tokio::spawn(async move {
loop {
match frame_rx.recv().await {
Ok(frame_data) => {
if ws_sink.send(Message::Binary(frame_data.into())).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("[relay:{channel_name_owned}] Peer lagged by {n} frames");
}
Err(_) => break,
}
}
});
// Forward frames from this peer → broadcast to all others
let channel_for_cache = channel.clone();
let channel_name_cache = channel_name.to_string();
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let data_vec: Vec<u8> = data.into();
// Intercept schema announcement (0x32) — cache, don't broadcast
if data_vec.len() >= 16 && data_vec[0] == 0x32 {
let payload_len = u32::from_le_bytes([data_vec[12], data_vec[13], data_vec[14], data_vec[15]]) as usize;
if data_vec.len() >= 16 + payload_len {
let mut cs = channel_for_cache.write().await;
cs.schema = Some(data_vec.clone());
}
continue; // Don't broadcast schema frames
}
// Cache for late joiners
{
let mut cs = channel_for_cache.write().await;
cs.cache.process_frame(&data_vec);
cs.stats.frames_relayed += 1;
cs.stats.bytes_relayed += data_vec.len() as u64;
}
// Broadcast to all peers (they'll filter out their own msgs by content)
let _ = frame_tx.send(data_vec);
}
}
send_task.abort();
{
let mut cs = channel.write().await;
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
}
eprintln!("[relay:{channel_name}] Peer disconnected: {addr}");
}
// ─── Tests ───
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config() {
let config = RelayConfig::default();
assert_eq!(config.addr, "0.0.0.0:9100".parse::<SocketAddr>().unwrap());
assert_eq!(config.max_receivers, 64);
assert_eq!(config.frame_buffer_size, 16);
assert_eq!(config.keepalive_interval_secs, 10);
assert_eq!(config.max_channels, 256);
assert_eq!(config.channel_gc_interval_secs, 60);
assert_eq!(config.source_reconnect_grace_secs, 30);
}
#[test]
fn stats_default() {
let stats = RelayStats::default();
assert_eq!(stats.frames_relayed, 0);
assert_eq!(stats.bytes_relayed, 0);
assert!(!stats.source_connected);
assert_eq!(stats.keyframes_sent, 0);
assert_eq!(stats.signal_diffs_sent, 0);
assert_eq!(stats.peak_receivers, 0);
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.rejected_connections, 0);
}
#[test]
fn state_cache_signal_sync() {
let mut cache = StateCache::default();
// Simulate a SignalSync frame
let sync_json = br#"{"count":0}"#;
let sync_msg = crate::codec::signal_sync_frame(0, 100, sync_json);
cache.process_frame(&sync_msg);
assert!(cache.last_signal_sync.is_some());
assert!(cache.has_state());
assert_eq!(cache.pending_signal_diffs.len(), 0);
// Simulate a SignalDiff
let diff_json = br#"{"count":1}"#;
let diff_msg = crate::codec::signal_diff_frame(1, 200, diff_json);
cache.process_frame(&diff_msg);
assert_eq!(cache.pending_signal_diffs.len(), 1);
// Catchup should contain sync + diff
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 2);
}
#[test]
fn state_cache_signal_sync_resets_diffs() {
let mut cache = StateCache::default();
// Add some diffs
for i in 0..5 {
let json = format!(r#"{{"count":{}}}"#, i);
let msg = crate::codec::signal_diff_frame(i, i as u32 * 100, json.as_bytes());
cache.process_frame(&msg);
}
assert_eq!(cache.pending_signal_diffs.len(), 5);
// A new sync should clear diffs
let sync = crate::codec::signal_sync_frame(10, 1000, b"{}");
cache.process_frame(&sync);
assert_eq!(cache.pending_signal_diffs.len(), 0);
assert!(cache.last_signal_sync.is_some());
}
#[test]
fn state_cache_keyframe() {
let mut cache = StateCache::default();
// Simulate a keyframe pixel frame
let kf = crate::codec::pixel_frame(0, 100, 10, 10, &vec![0xFF; 400]);
cache.process_frame(&kf);
assert!(cache.last_keyframe.is_some());
assert!(cache.has_state());
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 1);
}
#[test]
fn state_cache_diff_cap() {
let mut cache = StateCache::default();
// Add more than 1000 diffs — should auto-trim
for i in 0..1100u16 {
let json = format!(r#"{{"n":{}}}"#, i);
let msg = crate::codec::signal_diff_frame(i, i as u32, json.as_bytes());
cache.process_frame(&msg);
}
// Should have been trimmed: 1100 - 500 = 600 remaining after first trim at 1001
assert!(cache.pending_signal_diffs.len() <= 600);
}
#[test]
fn state_cache_clear() {
let mut cache = StateCache::default();
let sync = crate::codec::signal_sync_frame(0, 0, b"{}");
cache.process_frame(&sync);
assert!(cache.has_state());
cache.clear();
assert!(!cache.has_state());
assert!(cache.last_signal_sync.is_none());
}
// ─── Path Routing Tests ───
#[test]
fn parse_path_source_default() {
match parse_path("/source") {
ConnectionRole::Source(name) => assert_eq!(name, "default"),
_ => panic!("Expected Source"),
}
}
#[test]
fn parse_path_source_named() {
match parse_path("/source/main") {
ConnectionRole::Source(name) => assert_eq!(name, "main"),
_ => panic!("Expected Source"),
}
}
#[test]
fn parse_path_stream_default() {
match parse_path("/stream") {
ConnectionRole::Receiver(name) => assert_eq!(name, "default"),
_ => panic!("Expected Receiver"),
}
}
#[test]
fn parse_path_stream_named() {
match parse_path("/stream/player1") {
ConnectionRole::Receiver(name) => assert_eq!(name, "player1"),
_ => panic!("Expected Receiver"),
}
}
#[test]
fn parse_path_signal_default() {
match parse_path("/signal") {
ConnectionRole::Signaling(name) => assert_eq!(name, "default"),
_ => panic!("Expected Signaling"),
}
}
#[test]
fn parse_path_signal_named() {
match parse_path("/signal/main") {
ConnectionRole::Signaling(name) => assert_eq!(name, "main"),
_ => panic!("Expected Signaling"),
}
}
#[test]
fn parse_path_legacy_root() {
match parse_path("/") {
ConnectionRole::Receiver(name) => assert_eq!(name, "default"),
_ => panic!("Expected Receiver for legacy path"),
}
}
// ─── Channel State Tests ───
#[test]
fn channel_state_creation() {
let mut state = RelayState::new(16, 64, 256);
let ch1 = state.get_or_create_channel("main").unwrap();
let ch2 = state.get_or_create_channel("player1").unwrap();
let ch1_again = state.get_or_create_channel("main").unwrap();
assert_eq!(state.channels.len(), 2);
assert!(Arc::ptr_eq(&ch1, &ch1_again));
assert!(!Arc::ptr_eq(&ch1, &ch2));
}
#[test]
fn channel_max_limit() {
let mut state = RelayState::new(16, 64, 2);
assert!(state.get_or_create_channel("a").is_some());
assert!(state.get_or_create_channel("b").is_some());
assert!(state.get_or_create_channel("c").is_none()); // max reached
assert!(state.get_or_create_channel("a").is_some()); // existing OK
}
#[test]
fn channel_idle_detection() {
let cs = ChannelState::new(16, 64);
assert!(cs.is_idle()); // no source, no receivers, no cache
}
#[test]
fn channel_not_idle_with_cache() {
let mut cs = ChannelState::new(16, 64);
let sync = crate::codec::signal_sync_frame(0, 0, b"{}");
cs.cache.process_frame(&sync);
assert!(!cs.is_idle()); // has cached state
}
#[test]
fn channel_not_idle_with_source() {
let mut cs = ChannelState::new(16, 64);
cs.stats.source_connected = true;
assert!(!cs.is_idle());
}
#[test]
fn channel_not_idle_with_receivers() {
let mut cs = ChannelState::new(16, 64);
cs.stats.connected_receivers = 1;
assert!(!cs.is_idle());
}
#[test]
fn grace_period_not_expired_initially() {
let cs = ChannelState::new(16, 64);
assert!(!cs.grace_period_expired(30));
}
#[test]
fn grace_period_expired_after_disconnect() {
let mut cs = ChannelState::new(16, 64);
cs.source_disconnect_time = Some(Instant::now() - Duration::from_secs(60));
assert!(cs.grace_period_expired(30));
}
}