dreamstack/engine/ds-stream/src/relay.rs

643 lines
22 KiB
Rust
Raw Normal View History

//! WebSocket Relay Server
//!
//! Routes frames from source→receivers and inputs from receivers→source.
//!
//! ## Routing
//!
//! Connections are routed by WebSocket URI path:
//! - `/source` — default source (channel: "default")
//! - `/source/{name}` — named source (channel: {name})
//! - `/stream` — default receiver (channel: "default")
//! - `/stream/{name}` — named receiver (channel: {name})
//! - `/` — legacy: first connection = source, rest = receivers
//!
//! Each channel has its own broadcast/input channels and state cache,
//! allowing multiple independent streams through a single relay.
//!
//! ## Features
//! - Multi-source: multiple views stream independently
//! - Keyframe caching: late-joining receivers get current state instantly
//! - Signal state store: caches SignalSync/Diff frames for reconstruction
//! - Ping/pong keepalive: detects dead connections
//! - Stats tracking: frames, bytes, latency metrics
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use futures_util::{SinkExt, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time::{interval, Duration};
use tokio_tungstenite::tungstenite::Message;
use crate::protocol::*;
/// Relay server configuration.
pub struct RelayConfig {
/// Address to bind to.
pub addr: SocketAddr,
/// Maximum number of receivers per channel.
pub max_receivers: usize,
/// Frame broadcast channel capacity.
pub frame_buffer_size: usize,
/// Keepalive interval in seconds.
pub keepalive_interval_secs: u64,
/// Keepalive timeout in seconds — disconnect after this many seconds without a pong.
pub keepalive_timeout_secs: u64,
}
impl Default for RelayConfig {
fn default() -> Self {
Self {
addr: "0.0.0.0:9100".parse().unwrap(),
max_receivers: 64,
frame_buffer_size: 16,
keepalive_interval_secs: 10,
keepalive_timeout_secs: 30,
}
}
}
/// Stats tracked by the relay.
#[derive(Debug, Default, Clone)]
pub struct RelayStats {
pub frames_relayed: u64,
pub bytes_relayed: u64,
pub inputs_relayed: u64,
pub connected_receivers: usize,
pub source_connected: bool,
/// Timestamp of last frame from source (milliseconds since start).
pub last_frame_timestamp: u32,
/// Total keyframes sent.
pub keyframes_sent: u64,
/// Total signal diffs sent.
pub signal_diffs_sent: u64,
/// Uptime in seconds.
pub uptime_secs: u64,
}
/// Cached state for late-joining receivers.
#[derive(Debug, Default, Clone)]
pub struct StateCache {
/// Last pixel keyframe (full RGBA frame).
pub last_keyframe: Option<Vec<u8>>,
/// Last signal sync frame (full JSON state).
pub last_signal_sync: Option<Vec<u8>>,
/// Accumulated signal diffs since last sync.
/// Late-joining receivers get: last_signal_sync + all diffs.
pub pending_signal_diffs: Vec<Vec<u8>>,
}
impl StateCache {
/// Process an incoming frame and update cache.
fn process_frame(&mut self, msg: &[u8]) {
if msg.len() < HEADER_SIZE {
return;
}
let frame_type = msg[0];
let flags = msg[1];
match FrameType::from_u8(frame_type) {
// Cache keyframes (pixel or signal sync)
Some(FrameType::Pixels) if flags & FLAG_KEYFRAME != 0 => {
self.last_keyframe = Some(msg.to_vec());
}
Some(FrameType::SignalSync) => {
self.last_signal_sync = Some(msg.to_vec());
self.pending_signal_diffs.clear(); // sync resets diffs
}
Some(FrameType::SignalDiff) => {
self.pending_signal_diffs.push(msg.to_vec());
// Cap accumulated diffs to prevent unbounded memory
if self.pending_signal_diffs.len() > 1000 {
self.pending_signal_diffs.drain(..500);
}
}
Some(FrameType::Keyframe) => {
self.last_keyframe = Some(msg.to_vec());
}
_ => {}
}
}
/// Get all messages a late-joining receiver needs to reconstruct current state.
fn catchup_messages(&self) -> Vec<Vec<u8>> {
let mut msgs = Vec::new();
// Send last signal sync first (if available)
if let Some(ref sync) = self.last_signal_sync {
msgs.push(sync.clone());
}
// Then all accumulated diffs
for diff in &self.pending_signal_diffs {
msgs.push(diff.clone());
}
// Then last pixel keyframe (if available)
if let Some(ref kf) = self.last_keyframe {
msgs.push(kf.clone());
}
msgs
}
}
/// Per-channel state — each named channel is an independent stream.
struct ChannelState {
/// Broadcast channel: source → all receivers (frames)
frame_tx: broadcast::Sender<Vec<u8>>,
/// Channel: receivers → source (input events)
input_tx: mpsc::Sender<Vec<u8>>,
input_rx: Option<mpsc::Receiver<Vec<u8>>>,
/// Live stats for this channel
stats: RelayStats,
/// Cached state for late-joining receivers
cache: StateCache,
}
impl ChannelState {
fn new(frame_buffer_size: usize) -> Self {
let (frame_tx, _) = broadcast::channel(frame_buffer_size);
let (input_tx, input_rx) = mpsc::channel(256);
Self {
frame_tx,
input_tx,
input_rx: Some(input_rx),
stats: RelayStats::default(),
cache: StateCache::default(),
}
}
}
/// Shared relay state — holds all channels.
struct RelayState {
/// Named channels: "default", "main", "player1", etc.
channels: HashMap<String, Arc<RwLock<ChannelState>>>,
/// Frame buffer size for new channels
frame_buffer_size: usize,
/// Server start time
start_time: Instant,
}
impl RelayState {
fn new(frame_buffer_size: usize) -> Self {
Self {
channels: HashMap::new(),
frame_buffer_size,
start_time: Instant::now(),
}
}
/// Get or create a channel by name.
fn get_or_create_channel(&mut self, name: &str) -> Arc<RwLock<ChannelState>> {
self.channels
.entry(name.to_string())
.or_insert_with(|| Arc::new(RwLock::new(ChannelState::new(self.frame_buffer_size))))
.clone()
}
}
/// Parsed connection role from the WebSocket URI path.
#[derive(Debug, Clone)]
enum ConnectionRole {
Source(String), // channel name
Receiver(String), // channel name
}
/// Parse the WebSocket URI path to determine connection role and channel.
fn parse_path(path: &str) -> ConnectionRole {
let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
match parts.as_slice() {
["source"] => ConnectionRole::Source("default".to_string()),
["source", name] => ConnectionRole::Source(name.to_string()),
["stream"] => ConnectionRole::Receiver("default".to_string()),
["stream", name] => ConnectionRole::Receiver(name.to_string()),
_ => ConnectionRole::Receiver("default".to_string()), // legacy: `/` = receiver
}
}
/// Run the WebSocket relay server.
pub async fn run_relay(config: RelayConfig) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(&config.addr).await?;
eprintln!("╔══════════════════════════════════════════════════╗");
eprintln!("║ DreamStack Bitstream Relay v0.3.0 ║");
eprintln!("║ ║");
eprintln!("║ Source: ws://{}/source/{{name}}", config.addr);
eprintln!("║ Receiver: ws://{}/stream/{{name}}", config.addr);
eprintln!("║ ║");
eprintln!("║ Multi-source, keyframe cache, keepalive, RLE ║");
eprintln!("╚══════════════════════════════════════════════════╝");
let state = Arc::new(RwLock::new(RelayState::new(config.frame_buffer_size)));
// Background: periodic stats logging
{
let state = state.clone();
tokio::spawn(async move {
let mut tick = interval(Duration::from_secs(30));
loop {
tick.tick().await;
let s = state.read().await;
let uptime = s.start_time.elapsed().as_secs();
for (name, channel) in &s.channels {
let cs = channel.read().await;
if cs.stats.source_connected || cs.stats.connected_receivers > 0 {
eprintln!(
"[relay:{name}] up={uptime}s frames={} bytes={} inputs={} receivers={} signal_diffs={} cached={}",
cs.stats.frames_relayed,
cs.stats.bytes_relayed,
cs.stats.inputs_relayed,
cs.stats.connected_receivers,
cs.stats.signal_diffs_sent,
cs.cache.last_signal_sync.is_some(),
);
}
}
}
});
}
while let Ok((stream, addr)) = listener.accept().await {
let state = state.clone();
let keepalive_interval = config.keepalive_interval_secs;
let keepalive_timeout = config.keepalive_timeout_secs;
tokio::spawn(handle_connection(stream, addr, state, keepalive_interval, keepalive_timeout));
}
Ok(())
}
async fn handle_connection(
stream: TcpStream,
addr: SocketAddr,
state: Arc<RwLock<RelayState>>,
keepalive_interval: u64,
_keepalive_timeout: u64,
) {
// Extract the URI path during handshake to determine role
let mut role = ConnectionRole::Receiver("default".to_string());
let ws_stream = match tokio_tungstenite::accept_hdr_async(
stream,
|req: &tokio_tungstenite::tungstenite::handshake::server::Request,
res: tokio_tungstenite::tungstenite::handshake::server::Response| {
role = parse_path(req.uri().path());
Ok(res)
},
)
.await
{
Ok(ws) => ws,
Err(e) => {
eprintln!("[relay] WebSocket handshake failed from {}: {}", addr, e);
return;
}
};
// Get or create the channel
let (channel, channel_name) = {
let mut s = state.write().await;
let name = match &role {
ConnectionRole::Source(n) | ConnectionRole::Receiver(n) => n.clone(),
};
let ch = s.get_or_create_channel(&name);
(ch, name)
};
// Legacy fallback: if path was `/` and no source exists, treat as source
let role = match role {
ConnectionRole::Receiver(ref name) if name == "default" => {
let cs = channel.read().await;
if !cs.stats.source_connected {
ConnectionRole::Source("default".to_string())
} else {
role
}
}
_ => role,
};
match role {
ConnectionRole::Source(ref _name) => {
eprintln!("[relay:{channel_name}] Source connected: {addr}");
handle_source(ws_stream, addr, channel, &channel_name, keepalive_interval).await;
}
ConnectionRole::Receiver(ref _name) => {
eprintln!("[relay:{channel_name}] Receiver connected: {addr}");
handle_receiver(ws_stream, addr, channel, &channel_name).await;
}
}
}
async fn handle_source(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
keepalive_interval: u64,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Mark source as connected and take the input_rx
let input_rx = {
let mut cs = channel.write().await;
cs.stats.source_connected = true;
cs.input_rx.take()
};
let frame_tx = {
let cs = channel.read().await;
cs.frame_tx.clone()
};
// Forward input events from receivers → source
if let Some(mut input_rx) = input_rx {
let channel_clone = channel.clone();
tokio::spawn(async move {
while let Some(input_bytes) = input_rx.recv().await {
let msg = Message::Binary(input_bytes.into());
if ws_sink.send(msg).await.is_err() {
break;
}
let mut cs = channel_clone.write().await;
cs.stats.inputs_relayed += 1;
}
});
}
// Keepalive: periodic pings
let channel_ping = channel.clone();
let ping_task = tokio::spawn(async move {
let mut tick = interval(Duration::from_secs(keepalive_interval));
let mut seq = 0u16;
loop {
tick.tick().await;
let cs = channel_ping.read().await;
if !cs.stats.source_connected {
break;
}
let _ = cs.frame_tx.send(crate::codec::ping(seq, 0));
drop(cs);
seq = seq.wrapping_add(1);
}
});
// Receive frames from source → broadcast to receivers
let channel_name_owned = channel_name.to_string();
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let data_vec: Vec<u8> = data.into();
{
let mut cs = channel.write().await;
cs.stats.frames_relayed += 1;
cs.stats.bytes_relayed += data_vec.len() as u64;
// Update state cache
cs.cache.process_frame(&data_vec);
// Track frame-type-specific stats
if data_vec.len() >= HEADER_SIZE {
match FrameType::from_u8(data_vec[0]) {
Some(FrameType::SignalDiff) => cs.stats.signal_diffs_sent += 1,
Some(FrameType::Pixels) | Some(FrameType::Keyframe) |
Some(FrameType::SignalSync) => cs.stats.keyframes_sent += 1,
_ => {}
}
let ts = u32::from_le_bytes([
data_vec[4], data_vec[5], data_vec[6], data_vec[7],
]);
cs.stats.last_frame_timestamp = ts;
}
}
// Broadcast to all receivers on this channel
let _ = frame_tx.send(data_vec);
}
}
// Source disconnected
ping_task.abort();
eprintln!("[relay:{channel_name_owned}] Source disconnected: {addr}");
let mut cs = channel.write().await;
cs.stats.source_connected = false;
}
async fn handle_receiver(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Subscribe to frame broadcast and get catchup messages
let (mut frame_rx, catchup_msgs) = {
let mut cs = channel.write().await;
cs.stats.connected_receivers += 1;
let rx = cs.frame_tx.subscribe();
let catchup = cs.cache.catchup_messages();
(rx, catchup)
};
let input_tx = {
let cs = channel.read().await;
cs.input_tx.clone()
};
// Send cached state to late-joining receiver
let channel_name_owned = channel_name.to_string();
if !catchup_msgs.is_empty() {
eprintln!(
"[relay:{channel_name_owned}] Sending {} catchup messages to {addr}",
catchup_msgs.len(),
);
for msg_bytes in catchup_msgs {
let msg = Message::Binary(msg_bytes.into());
if ws_sink.send(msg).await.is_err() {
eprintln!("[relay:{channel_name_owned}] Failed to send catchup to {addr}");
let mut cs = channel.write().await;
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
return;
}
}
}
// Forward frames from broadcast → this receiver
let send_task = tokio::spawn(async move {
loop {
match frame_rx.recv().await {
Ok(frame_bytes) => {
let msg = Message::Binary(frame_bytes.into());
if ws_sink.send(msg).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("[relay] Receiver {addr} lagged by {n} frames");
}
Err(_) => break,
}
}
});
// Forward input events from this receiver → source
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let _ = input_tx.send(data.into()).await;
}
}
// Receiver disconnected
send_task.abort();
eprintln!("[relay:{channel_name_owned}] Receiver disconnected: {addr}");
let mut cs = channel.write().await;
cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
}
// ─── Tests ───
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config() {
let config = RelayConfig::default();
assert_eq!(config.addr, "0.0.0.0:9100".parse::<SocketAddr>().unwrap());
assert_eq!(config.max_receivers, 64);
assert_eq!(config.frame_buffer_size, 16);
assert_eq!(config.keepalive_interval_secs, 10);
}
#[test]
fn stats_default() {
let stats = RelayStats::default();
assert_eq!(stats.frames_relayed, 0);
assert_eq!(stats.bytes_relayed, 0);
assert!(!stats.source_connected);
assert_eq!(stats.keyframes_sent, 0);
assert_eq!(stats.signal_diffs_sent, 0);
}
#[test]
fn state_cache_signal_sync() {
let mut cache = StateCache::default();
// Simulate a SignalSync frame
let sync_json = br#"{"count":0}"#;
let sync_msg = crate::codec::signal_sync_frame(0, 100, sync_json);
cache.process_frame(&sync_msg);
assert!(cache.last_signal_sync.is_some());
assert_eq!(cache.pending_signal_diffs.len(), 0);
// Simulate a SignalDiff
let diff_json = br#"{"count":1}"#;
let diff_msg = crate::codec::signal_diff_frame(1, 200, diff_json);
cache.process_frame(&diff_msg);
assert_eq!(cache.pending_signal_diffs.len(), 1);
// Catchup should contain sync + diff
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 2);
}
#[test]
fn state_cache_signal_sync_resets_diffs() {
let mut cache = StateCache::default();
// Add some diffs
for i in 0..5 {
let json = format!(r#"{{"count":{}}}"#, i);
let msg = crate::codec::signal_diff_frame(i, i as u32 * 100, json.as_bytes());
cache.process_frame(&msg);
}
assert_eq!(cache.pending_signal_diffs.len(), 5);
// A new sync should clear diffs
let sync = crate::codec::signal_sync_frame(10, 1000, b"{}");
cache.process_frame(&sync);
assert_eq!(cache.pending_signal_diffs.len(), 0);
assert!(cache.last_signal_sync.is_some());
}
#[test]
fn state_cache_keyframe() {
let mut cache = StateCache::default();
// Simulate a keyframe pixel frame
let kf = crate::codec::pixel_frame(0, 100, 10, 10, &vec![0xFF; 400]);
cache.process_frame(&kf);
assert!(cache.last_keyframe.is_some());
let catchup = cache.catchup_messages();
assert_eq!(catchup.len(), 1);
}
#[test]
fn state_cache_diff_cap() {
let mut cache = StateCache::default();
// Add more than 1000 diffs — should auto-trim
for i in 0..1100u16 {
let json = format!(r#"{{"n":{}}}"#, i);
let msg = crate::codec::signal_diff_frame(i, i as u32, json.as_bytes());
cache.process_frame(&msg);
}
// Should have been trimmed: 1100 - 500 = 600 remaining after first trim at 1001
assert!(cache.pending_signal_diffs.len() <= 600);
}
// ─── Path Routing Tests ───
#[test]
fn parse_path_source_default() {
match parse_path("/source") {
ConnectionRole::Source(name) => assert_eq!(name, "default"),
_ => panic!("Expected Source"),
}
}
#[test]
fn parse_path_source_named() {
match parse_path("/source/main") {
ConnectionRole::Source(name) => assert_eq!(name, "main"),
_ => panic!("Expected Source"),
}
}
#[test]
fn parse_path_stream_default() {
match parse_path("/stream") {
ConnectionRole::Receiver(name) => assert_eq!(name, "default"),
_ => panic!("Expected Receiver"),
}
}
#[test]
fn parse_path_stream_named() {
match parse_path("/stream/player1") {
ConnectionRole::Receiver(name) => assert_eq!(name, "player1"),
_ => panic!("Expected Receiver"),
}
}
#[test]
fn parse_path_legacy_root() {
match parse_path("/") {
ConnectionRole::Receiver(name) => assert_eq!(name, "default"),
_ => panic!("Expected Receiver for legacy path"),
}
}
#[test]
fn channel_state_creation() {
let mut state = RelayState::new(16);
let ch1 = state.get_or_create_channel("main");
let ch2 = state.get_or_create_channel("player1");
let ch1_again = state.get_or_create_channel("main");
assert_eq!(state.channels.len(), 2);
assert!(Arc::ptr_eq(&ch1, &ch1_again));
assert!(!Arc::ptr_eq(&ch1, &ch2));
}
}