dreamstack/engine/ds-stream/src/relay.rs

503 lines
17 KiB
Rust
Raw Normal View History

//! WebSocket Relay Server
//!
//! Routes frames from source→receivers and inputs from receivers→source.
//! Roles are determined by connection order:
//! - First connection = source (frame producer)
//! - Subsequent connections = receivers (frame consumers)
//!
//! Features:
//! - Keyframe caching: late-joining receivers get current state instantly
//! - Signal state store: caches SignalSync/Diff frames for reconstruction
//! - Ping/pong keepalive: detects dead connections
//! - Stats tracking: frames, bytes, latency metrics
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use futures_util::{SinkExt, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time::{interval, Duration};
use tokio_tungstenite::tungstenite::Message;
use crate::protocol::*;
/// Relay server configuration.
pub struct RelayConfig {
/// Address to bind to.
pub addr: SocketAddr,
/// Maximum number of receivers.
pub max_receivers: usize,
/// Frame broadcast channel capacity.
pub frame_buffer_size: usize,
/// Keepalive interval in seconds.
pub keepalive_interval_secs: u64,
/// Keepalive timeout in seconds — disconnect after this many seconds without a pong.
pub keepalive_timeout_secs: u64,
}
impl Default for RelayConfig {
fn default() -> Self {
Self {
addr: "0.0.0.0:9100".parse().unwrap(),
max_receivers: 64,
frame_buffer_size: 16,
keepalive_interval_secs: 10,
keepalive_timeout_secs: 30,
}
}
}
/// Stats tracked by the relay.
#[derive(Debug, Default, Clone)]
pub struct RelayStats {
pub frames_relayed: u64,
pub bytes_relayed: u64,
pub inputs_relayed: u64,
pub connected_receivers: usize,
pub source_connected: bool,
/// Timestamp of last frame from source (milliseconds since start).
pub last_frame_timestamp: u32,
/// Total keyframes sent.
pub keyframes_sent: u64,
/// Total signal diffs sent.
pub signal_diffs_sent: u64,
/// Uptime in seconds.
pub uptime_secs: u64,
}
/// Cached state for late-joining receivers.
#[derive(Debug, Default, Clone)]
pub struct StateCache {
/// Last pixel keyframe (full RGBA frame).
pub last_keyframe: Option<Vec<u8>>,
/// Last signal sync frame (full JSON state).
pub last_signal_sync: Option<Vec<u8>>,
/// Accumulated signal diffs since last sync.
/// Late-joining receivers get: last_signal_sync + all diffs.
pub pending_signal_diffs: Vec<Vec<u8>>,
}
impl StateCache {
/// Process an incoming frame and update cache.
fn process_frame(&mut self, msg: &[u8]) {
if msg.len() < HEADER_SIZE {
return;
}
let frame_type = msg[0];
let flags = msg[1];
match FrameType::from_u8(frame_type) {
// Cache keyframes (pixel or signal sync)
Some(FrameType::Pixels) if flags & FLAG_KEYFRAME != 0 => {
self.last_keyframe = Some(msg.to_vec());
}
Some(FrameType::SignalSync) => {
self.last_signal_sync = Some(msg.to_vec());
self.pending_signal_diffs.clear(); // sync resets diffs
}
Some(FrameType::SignalDiff) => {
self.pending_signal_diffs.push(msg.to_vec());
// Cap accumulated diffs to prevent unbounded memory
if self.pending_signal_diffs.len() > 1000 {
self.pending_signal_diffs.drain(..500);
}
}
Some(FrameType::Keyframe) => {
self.last_keyframe = Some(msg.to_vec());
}
_ => {}
}
}
/// Get all messages a late-joining receiver needs to reconstruct current state.
fn catchup_messages(&self) -> Vec<Vec<u8>> {
let mut msgs = Vec::new();
// Send last signal sync first (if available)
if let Some(ref sync) = self.last_signal_sync {
msgs.push(sync.clone());
}
// Then all accumulated diffs
for diff in &self.pending_signal_diffs {
msgs.push(diff.clone());
}
// Then last pixel keyframe (if available)
if let Some(ref kf) = self.last_keyframe {
msgs.push(kf.clone());
}
msgs
}
}
/// Shared relay state.
struct RelayState {
/// Broadcast channel: source → all receivers (frames)
frame_tx: broadcast::Sender<Vec<u8>>,
/// Channel: receivers → source (input events)
input_tx: mpsc::Sender<Vec<u8>>,
input_rx: Option<mpsc::Receiver<Vec<u8>>>,
/// Live stats
stats: RelayStats,
/// Cached state for late-joining receivers
cache: StateCache,
/// Server start time
start_time: Instant,
}
/// Run the WebSocket relay server.
pub async fn run_relay(config: RelayConfig) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(&config.addr).await?;
eprintln!("╔══════════════════════════════════════════════════╗");
eprintln!("║ DreamStack Bitstream Relay v0.2.0 ║");
eprintln!("║ ║");
eprintln!("║ Source: ws://{}/source ║", config.addr);
eprintln!("║ Receiver: ws://{}/stream ║", config.addr);
eprintln!("║ ║");
eprintln!("║ Features: keyframe cache, keepalive, RLE ║");
eprintln!("╚══════════════════════════════════════════════════╝");
let (frame_tx, _) = broadcast::channel(config.frame_buffer_size);
let (input_tx, input_rx) = mpsc::channel(256);
let state = Arc::new(RwLock::new(RelayState {
frame_tx,
input_tx,
input_rx: Some(input_rx),
stats: RelayStats::default(),
cache: StateCache::default(),
start_time: Instant::now(),
}));
// Background: periodic stats logging
{
let state = state.clone();
tokio::spawn(async move {
let mut tick = interval(Duration::from_secs(30));
loop {
tick.tick().await;
let s = state.read().await;
let uptime = s.start_time.elapsed().as_secs();
if s.stats.source_connected || s.stats.connected_receivers > 0 {
eprintln!(
"[relay] up={}s frames={} bytes={} inputs={} receivers={} signal_diffs={} cached={}",
uptime,
s.stats.frames_relayed,
s.stats.bytes_relayed,
s.stats.inputs_relayed,
s.stats.connected_receivers,
s.stats.signal_diffs_sent,
s.cache.last_signal_sync.is_some(),
);
}
}
});
}
while let Ok((stream, addr)) = listener.accept().await {
let state = state.clone();
let keepalive_interval = config.keepalive_interval_secs;
let keepalive_timeout = config.keepalive_timeout_secs;
tokio::spawn(handle_connection(stream, addr, state, keepalive_interval, keepalive_timeout));
}
Ok(())
}
async fn handle_connection(
stream: TcpStream,
addr: SocketAddr,
state: Arc<RwLock<RelayState>>,
keepalive_interval: u64,
_keepalive_timeout: u64,
) {
let ws_stream = match tokio_tungstenite::accept_hdr_async(
stream,
|_req: &tokio_tungstenite::tungstenite::handshake::server::Request,
res: tokio_tungstenite::tungstenite::handshake::server::Response| {
Ok(res)
},
)
.await
{
Ok(ws) => ws,
Err(e) => {
eprintln!("[relay] WebSocket handshake failed from {}: {}", addr, e);
return;
}
};
// First connection = source, subsequent = receivers
let is_source = {
let s = state.read().await;
!s.stats.source_connected
};
if is_source {
eprintln!("[relay] Source connected: {}", addr);
handle_source(ws_stream, addr, state, keepalive_interval).await;
} else {
eprintln!("[relay] Receiver connected: {}", addr);
handle_receiver(ws_stream, addr, state, keepalive_interval).await;
}
}
async fn handle_source(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
state: Arc<RwLock<RelayState>>,
keepalive_interval: u64,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Mark source as connected and take the input_rx
let input_rx = {
let mut s = state.write().await;
s.stats.source_connected = true;
s.input_rx.take()
};
let frame_tx = {
let s = state.read().await;
s.frame_tx.clone()
};
// Forward input events from receivers → source
if let Some(mut input_rx) = input_rx {
let state_clone = state.clone();
tokio::spawn(async move {
while let Some(input_bytes) = input_rx.recv().await {
let msg = Message::Binary(input_bytes.into());
if ws_sink.send(msg).await.is_err() {
break;
}
let mut s = state_clone.write().await;
s.stats.inputs_relayed += 1;
}
});
}
// Keepalive: periodic pings
let state_ping = state.clone();
let ping_task = tokio::spawn(async move {
let mut tick = interval(Duration::from_secs(keepalive_interval));
let mut seq = 0u16;
loop {
tick.tick().await;
let s = state_ping.read().await;
if !s.stats.source_connected {
break;
}
drop(s);
let ping_msg = crate::codec::ping(seq, 0);
let _ = state_ping.read().await.frame_tx.send(ping_msg);
seq = seq.wrapping_add(1);
}
});
// Receive frames from source → broadcast to receivers
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let data_vec: Vec<u8> = data.into();
{
let mut s = state.write().await;
s.stats.frames_relayed += 1;
s.stats.bytes_relayed += data_vec.len() as u64;
// Update state cache
s.cache.process_frame(&data_vec);
// Track frame-type-specific stats
if data_vec.len() >= HEADER_SIZE {
match FrameType::from_u8(data_vec[0]) {
Some(FrameType::SignalDiff) => s.stats.signal_diffs_sent += 1,
Some(FrameType::Pixels) | Some(FrameType::Keyframe) |
Some(FrameType::SignalSync) => s.stats.keyframes_sent += 1,
_ => {}
}
let ts = u32::from_le_bytes([
data_vec[4], data_vec[5], data_vec[6], data_vec[7],
]);
s.stats.last_frame_timestamp = ts;
}
}
// Broadcast to all receivers
let _ = frame_tx.send(data_vec);
}
}
// Source disconnected
ping_task.abort();
eprintln!("[relay] Source disconnected: {}", addr);
let mut s = state.write().await;
s.stats.source_connected = false;
}
async fn handle_receiver(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
state: Arc<RwLock<RelayState>>,
_keepalive_interval: u64,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Subscribe to frame broadcast and get catchup messages
let (mut frame_rx, catchup_msgs) = {
let mut s = state.write().await;
s.stats.connected_receivers += 1;
let rx = s.frame_tx.subscribe();
let catchup = s.cache.catchup_messages();
(rx, catchup)
};
let input_tx = {
let s = state.read().await;
s.input_tx.clone()
};
// Send cached state to late-joining receiver
if !catchup_msgs.is_empty() {
eprintln!(
"[relay] Sending {} catchup messages to {}",
catchup_msgs.len(),
addr
);
for msg_bytes in catchup_msgs {
let msg = Message::Binary(msg_bytes.into());
if ws_sink.send(msg).await.is_err() {
eprintln!("[relay] Failed to send catchup to {}", addr);
let mut s = state.write().await;
s.stats.connected_receivers = s.stats.connected_receivers.saturating_sub(1);
return;
}
}
}
// Forward frames from broadcast → this receiver
let addr_clone = addr;
let send_task = tokio::spawn(async move {
loop {
match frame_rx.recv().await {
Ok(frame_bytes) => {
let msg = Message::Binary(frame_bytes.into());
if ws_sink.send(msg).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("[relay] Receiver {} lagged by {} frames", addr_clone, n);
}
Err(_) => break,
}
}
});
// Forward input events from this receiver → source
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let _ = input_tx.send(data.into()).await;
}
}
// Receiver disconnected
send_task.abort();
eprintln!("[relay] Receiver disconnected: {}", addr);
let mut s = state.write().await;
s.stats.connected_receivers = s.stats.connected_receivers.saturating_sub(1);
}
// ─── Tests ───
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config() {
let config = RelayConfig::default();
assert_eq!(config.addr, "0.0.0.0:9100".parse::<SocketAddr>().unwrap());
assert_eq!(config.max_receivers, 64);
assert_eq!(config.frame_buffer_size, 16);
assert_eq!(config.keepalive_interval_secs, 10);
}
#[test]
fn stats_default() {
let stats = RelayStats::default();
assert_eq!(stats.frames_relayed, 0);
assert_eq!(stats.bytes_relayed, 0);
assert!(!stats.source_connected);
assert_eq!(stats.keyframes_sent, 0);
assert_eq!(stats.signal_diffs_sent, 0);
}
#[test]
fn state_cache_signal_sync() {
let mut cache = StateCache::default();
// Simulate a SignalSync frame
let sync_json = br#"{"count":0}"#;
let sync_msg = crate::codec::signal_sync_frame(0, 100, sync_json);
cache.process_frame(&sync_msg);
assert!(cache.last_signal_sync.is_some());
assert_eq!(cache.pending_signal_diffs.len(), 0);
// Simulate a SignalDiff
let diff_json = br#"{"count":1}"#;
let diff_msg = crate::codec::signal_diff_frame(1, 200, diff_json);
cache.process_frame(&diff_msg);
assert_eq!(cache.pending_signal_diffs.len(), 1);
// Catchup should contain sync + diff
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 2);
}
#[test]
fn state_cache_signal_sync_resets_diffs() {
let mut cache = StateCache::default();
// Add some diffs
for i in 0..5 {
let json = format!(r#"{{"count":{}}}"#, i);
let msg = crate::codec::signal_diff_frame(i, i as u32 * 100, json.as_bytes());
cache.process_frame(&msg);
}
assert_eq!(cache.pending_signal_diffs.len(), 5);
// A new sync should clear diffs
let sync = crate::codec::signal_sync_frame(10, 1000, b"{}");
cache.process_frame(&sync);
assert_eq!(cache.pending_signal_diffs.len(), 0);
assert!(cache.last_signal_sync.is_some());
}
#[test]
fn state_cache_keyframe() {
let mut cache = StateCache::default();
// Simulate a keyframe pixel frame
let kf = crate::codec::pixel_frame(0, 100, 10, 10, &vec![0xFF; 400]);
cache.process_frame(&kf);
assert!(cache.last_keyframe.is_some());
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 1);
}
#[test]
fn state_cache_diff_cap() {
let mut cache = StateCache::default();
// Add more than 1000 diffs — should auto-trim
for i in 0..1100u16 {
let json = format!(r#"{{"n":{}}}"#, i);
let msg = crate::codec::signal_diff_frame(i, i as u32, json.as_bytes());
cache.process_frame(&msg);
}
// Should have been trimmed: 1100 - 500 = 600 remaining after first trim at 1001
assert!(cache.pending_signal_diffs.len() <= 600);
}
}