feat(relay): multi-source routing — /source/{name} and /stream/{name}

- ChannelState: per-channel broadcast, input mpsc, cache
- RelayState: HashMap<String, Arc<RwLock<ChannelState>>>
- parse_path: /source/{n}, /stream/{n}, legacy / fallback
- Per-channel stats logging with channel name prefix
- Backward compatible: plain /source and /stream use 'default' channel
- 44 ds-stream tests, 95 total workspace (0 failures)
This commit is contained in:
enzotar 2026-02-25 14:50:39 -08:00
parent 2b2b4ffaec
commit 0ed76513a6

View file

@ -1,16 +1,27 @@
//! WebSocket Relay Server //! WebSocket Relay Server
//! //!
//! Routes frames from source→receivers and inputs from receivers→source. //! Routes frames from source→receivers and inputs from receivers→source.
//! Roles are determined by connection order:
//! - First connection = source (frame producer)
//! - Subsequent connections = receivers (frame consumers)
//! //!
//! Features: //! ## Routing
//!
//! Connections are routed by WebSocket URI path:
//! - `/source` — default source (channel: "default")
//! - `/source/{name}` — named source (channel: {name})
//! - `/stream` — default receiver (channel: "default")
//! - `/stream/{name}` — named receiver (channel: {name})
//! - `/` — legacy: first connection = source, rest = receivers
//!
//! Each channel has its own broadcast/input channels and state cache,
//! allowing multiple independent streams through a single relay.
//!
//! ## Features
//! - Multi-source: multiple views stream independently
//! - Keyframe caching: late-joining receivers get current state instantly //! - Keyframe caching: late-joining receivers get current state instantly
//! - Signal state store: caches SignalSync/Diff frames for reconstruction //! - Signal state store: caches SignalSync/Diff frames for reconstruction
//! - Ping/pong keepalive: detects dead connections //! - Ping/pong keepalive: detects dead connections
//! - Stats tracking: frames, bytes, latency metrics //! - Stats tracking: frames, bytes, latency metrics
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@ -27,7 +38,7 @@ use crate::protocol::*;
pub struct RelayConfig { pub struct RelayConfig {
/// Address to bind to. /// Address to bind to.
pub addr: SocketAddr, pub addr: SocketAddr,
/// Maximum number of receivers. /// Maximum number of receivers per channel.
pub max_receivers: usize, pub max_receivers: usize,
/// Frame broadcast channel capacity. /// Frame broadcast channel capacity.
pub frame_buffer_size: usize, pub frame_buffer_size: usize,
@ -130,44 +141,93 @@ impl StateCache {
} }
} }
/// Shared relay state. /// Per-channel state — each named channel is an independent stream.
struct RelayState { struct ChannelState {
/// Broadcast channel: source → all receivers (frames) /// Broadcast channel: source → all receivers (frames)
frame_tx: broadcast::Sender<Vec<u8>>, frame_tx: broadcast::Sender<Vec<u8>>,
/// Channel: receivers → source (input events) /// Channel: receivers → source (input events)
input_tx: mpsc::Sender<Vec<u8>>, input_tx: mpsc::Sender<Vec<u8>>,
input_rx: Option<mpsc::Receiver<Vec<u8>>>, input_rx: Option<mpsc::Receiver<Vec<u8>>>,
/// Live stats /// Live stats for this channel
stats: RelayStats, stats: RelayStats,
/// Cached state for late-joining receivers /// Cached state for late-joining receivers
cache: StateCache, cache: StateCache,
}
impl ChannelState {
fn new(frame_buffer_size: usize) -> Self {
let (frame_tx, _) = broadcast::channel(frame_buffer_size);
let (input_tx, input_rx) = mpsc::channel(256);
Self {
frame_tx,
input_tx,
input_rx: Some(input_rx),
stats: RelayStats::default(),
cache: StateCache::default(),
}
}
}
/// Shared relay state — holds all channels.
struct RelayState {
/// Named channels: "default", "main", "player1", etc.
channels: HashMap<String, Arc<RwLock<ChannelState>>>,
/// Frame buffer size for new channels
frame_buffer_size: usize,
/// Server start time /// Server start time
start_time: Instant, start_time: Instant,
} }
impl RelayState {
fn new(frame_buffer_size: usize) -> Self {
Self {
channels: HashMap::new(),
frame_buffer_size,
start_time: Instant::now(),
}
}
/// Get or create a channel by name.
fn get_or_create_channel(&mut self, name: &str) -> Arc<RwLock<ChannelState>> {
self.channels
.entry(name.to_string())
.or_insert_with(|| Arc::new(RwLock::new(ChannelState::new(self.frame_buffer_size))))
.clone()
}
}
/// Parsed connection role from the WebSocket URI path.
#[derive(Debug, Clone)]
enum ConnectionRole {
Source(String), // channel name
Receiver(String), // channel name
}
/// Parse the WebSocket URI path to determine connection role and channel.
fn parse_path(path: &str) -> ConnectionRole {
let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
match parts.as_slice() {
["source"] => ConnectionRole::Source("default".to_string()),
["source", name] => ConnectionRole::Source(name.to_string()),
["stream"] => ConnectionRole::Receiver("default".to_string()),
["stream", name] => ConnectionRole::Receiver(name.to_string()),
_ => ConnectionRole::Receiver("default".to_string()), // legacy: `/` = receiver
}
}
/// Run the WebSocket relay server. /// Run the WebSocket relay server.
pub async fn run_relay(config: RelayConfig) -> Result<(), Box<dyn std::error::Error>> { pub async fn run_relay(config: RelayConfig) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(&config.addr).await?; let listener = TcpListener::bind(&config.addr).await?;
eprintln!("╔══════════════════════════════════════════════════╗"); eprintln!("╔══════════════════════════════════════════════════╗");
eprintln!("║ DreamStack Bitstream Relay v0.2.0 ║"); eprintln!("║ DreamStack Bitstream Relay v0.3.0 ║");
eprintln!("║ ║"); eprintln!("║ ║");
eprintln!("║ Source: ws://{}/source ║", config.addr); eprintln!("║ Source: ws://{}/source/{{name}}", config.addr);
eprintln!("║ Receiver: ws://{}/stream ║", config.addr); eprintln!("║ Receiver: ws://{}/stream/{{name}}", config.addr);
eprintln!("║ ║"); eprintln!("║ ║");
eprintln!("║ Features: keyframe cache, keepalive, RLE ║"); eprintln!("Multi-source, keyframe cache, keepalive, RLE");
eprintln!("╚══════════════════════════════════════════════════╝"); eprintln!("╚══════════════════════════════════════════════════╝");
let (frame_tx, _) = broadcast::channel(config.frame_buffer_size); let state = Arc::new(RwLock::new(RelayState::new(config.frame_buffer_size)));
let (input_tx, input_rx) = mpsc::channel(256);
let state = Arc::new(RwLock::new(RelayState {
frame_tx,
input_tx,
input_rx: Some(input_rx),
stats: RelayStats::default(),
cache: StateCache::default(),
start_time: Instant::now(),
}));
// Background: periodic stats logging // Background: periodic stats logging
{ {
@ -178,19 +238,21 @@ pub async fn run_relay(config: RelayConfig) -> Result<(), Box<dyn std::error::Er
tick.tick().await; tick.tick().await;
let s = state.read().await; let s = state.read().await;
let uptime = s.start_time.elapsed().as_secs(); let uptime = s.start_time.elapsed().as_secs();
if s.stats.source_connected || s.stats.connected_receivers > 0 { for (name, channel) in &s.channels {
let cs = channel.read().await;
if cs.stats.source_connected || cs.stats.connected_receivers > 0 {
eprintln!( eprintln!(
"[relay] up={}s frames={} bytes={} inputs={} receivers={} signal_diffs={} cached={}", "[relay:{name}] up={uptime}s frames={} bytes={} inputs={} receivers={} signal_diffs={} cached={}",
uptime, cs.stats.frames_relayed,
s.stats.frames_relayed, cs.stats.bytes_relayed,
s.stats.bytes_relayed, cs.stats.inputs_relayed,
s.stats.inputs_relayed, cs.stats.connected_receivers,
s.stats.connected_receivers, cs.stats.signal_diffs_sent,
s.stats.signal_diffs_sent, cs.cache.last_signal_sync.is_some(),
s.cache.last_signal_sync.is_some(),
); );
} }
} }
}
}); });
} }
@ -210,10 +272,14 @@ async fn handle_connection(
keepalive_interval: u64, keepalive_interval: u64,
_keepalive_timeout: u64, _keepalive_timeout: u64,
) { ) {
// Extract the URI path during handshake to determine role
let mut role = ConnectionRole::Receiver("default".to_string());
let ws_stream = match tokio_tungstenite::accept_hdr_async( let ws_stream = match tokio_tungstenite::accept_hdr_async(
stream, stream,
|_req: &tokio_tungstenite::tungstenite::handshake::server::Request, |req: &tokio_tungstenite::tungstenite::handshake::server::Request,
res: tokio_tungstenite::tungstenite::handshake::server::Response| { res: tokio_tungstenite::tungstenite::handshake::server::Response| {
role = parse_path(req.uri().path());
Ok(res) Ok(res)
}, },
) )
@ -226,154 +292,174 @@ async fn handle_connection(
} }
}; };
// First connection = source, subsequent = receivers // Get or create the channel
let is_source = { let (channel, channel_name) = {
let s = state.read().await; let mut s = state.write().await;
!s.stats.source_connected let name = match &role {
ConnectionRole::Source(n) | ConnectionRole::Receiver(n) => n.clone(),
};
let ch = s.get_or_create_channel(&name);
(ch, name)
}; };
if is_source { // Legacy fallback: if path was `/` and no source exists, treat as source
eprintln!("[relay] Source connected: {}", addr); let role = match role {
handle_source(ws_stream, addr, state, keepalive_interval).await; ConnectionRole::Receiver(ref name) if name == "default" => {
let cs = channel.read().await;
if !cs.stats.source_connected {
ConnectionRole::Source("default".to_string())
} else { } else {
eprintln!("[relay] Receiver connected: {}", addr); role
handle_receiver(ws_stream, addr, state, keepalive_interval).await; }
}
_ => role,
};
match role {
ConnectionRole::Source(ref _name) => {
eprintln!("[relay:{channel_name}] Source connected: {addr}");
handle_source(ws_stream, addr, channel, &channel_name, keepalive_interval).await;
}
ConnectionRole::Receiver(ref _name) => {
eprintln!("[relay:{channel_name}] Receiver connected: {addr}");
handle_receiver(ws_stream, addr, channel, &channel_name).await;
}
} }
} }
async fn handle_source( async fn handle_source(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>, ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr, addr: SocketAddr,
state: Arc<RwLock<RelayState>>, channel: Arc<RwLock<ChannelState>>,
channel_name: &str,
keepalive_interval: u64, keepalive_interval: u64,
) { ) {
let (mut ws_sink, mut ws_source) = ws_stream.split(); let (mut ws_sink, mut ws_source) = ws_stream.split();
// Mark source as connected and take the input_rx // Mark source as connected and take the input_rx
let input_rx = { let input_rx = {
let mut s = state.write().await; let mut cs = channel.write().await;
s.stats.source_connected = true; cs.stats.source_connected = true;
s.input_rx.take() cs.input_rx.take()
}; };
let frame_tx = { let frame_tx = {
let s = state.read().await; let cs = channel.read().await;
s.frame_tx.clone() cs.frame_tx.clone()
}; };
// Forward input events from receivers → source // Forward input events from receivers → source
if let Some(mut input_rx) = input_rx { if let Some(mut input_rx) = input_rx {
let state_clone = state.clone(); let channel_clone = channel.clone();
tokio::spawn(async move { tokio::spawn(async move {
while let Some(input_bytes) = input_rx.recv().await { while let Some(input_bytes) = input_rx.recv().await {
let msg = Message::Binary(input_bytes.into()); let msg = Message::Binary(input_bytes.into());
if ws_sink.send(msg).await.is_err() { if ws_sink.send(msg).await.is_err() {
break; break;
} }
let mut s = state_clone.write().await; let mut cs = channel_clone.write().await;
s.stats.inputs_relayed += 1; cs.stats.inputs_relayed += 1;
} }
}); });
} }
// Keepalive: periodic pings // Keepalive: periodic pings
let state_ping = state.clone(); let channel_ping = channel.clone();
let ping_task = tokio::spawn(async move { let ping_task = tokio::spawn(async move {
let mut tick = interval(Duration::from_secs(keepalive_interval)); let mut tick = interval(Duration::from_secs(keepalive_interval));
let mut seq = 0u16; let mut seq = 0u16;
loop { loop {
tick.tick().await; tick.tick().await;
let s = state_ping.read().await; let cs = channel_ping.read().await;
if !s.stats.source_connected { if !cs.stats.source_connected {
break; break;
} }
drop(s); let _ = cs.frame_tx.send(crate::codec::ping(seq, 0));
let ping_msg = crate::codec::ping(seq, 0); drop(cs);
let _ = state_ping.read().await.frame_tx.send(ping_msg);
seq = seq.wrapping_add(1); seq = seq.wrapping_add(1);
} }
}); });
// Receive frames from source → broadcast to receivers // Receive frames from source → broadcast to receivers
let channel_name_owned = channel_name.to_string();
while let Some(Ok(msg)) = ws_source.next().await { while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg { if let Message::Binary(data) = msg {
let data_vec: Vec<u8> = data.into(); let data_vec: Vec<u8> = data.into();
{ {
let mut s = state.write().await; let mut cs = channel.write().await;
s.stats.frames_relayed += 1; cs.stats.frames_relayed += 1;
s.stats.bytes_relayed += data_vec.len() as u64; cs.stats.bytes_relayed += data_vec.len() as u64;
// Update state cache // Update state cache
s.cache.process_frame(&data_vec); cs.cache.process_frame(&data_vec);
// Track frame-type-specific stats // Track frame-type-specific stats
if data_vec.len() >= HEADER_SIZE { if data_vec.len() >= HEADER_SIZE {
match FrameType::from_u8(data_vec[0]) { match FrameType::from_u8(data_vec[0]) {
Some(FrameType::SignalDiff) => s.stats.signal_diffs_sent += 1, Some(FrameType::SignalDiff) => cs.stats.signal_diffs_sent += 1,
Some(FrameType::Pixels) | Some(FrameType::Keyframe) | Some(FrameType::Pixels) | Some(FrameType::Keyframe) |
Some(FrameType::SignalSync) => s.stats.keyframes_sent += 1, Some(FrameType::SignalSync) => cs.stats.keyframes_sent += 1,
_ => {} _ => {}
} }
let ts = u32::from_le_bytes([ let ts = u32::from_le_bytes([
data_vec[4], data_vec[5], data_vec[6], data_vec[7], data_vec[4], data_vec[5], data_vec[6], data_vec[7],
]); ]);
s.stats.last_frame_timestamp = ts; cs.stats.last_frame_timestamp = ts;
} }
} }
// Broadcast to all receivers // Broadcast to all receivers on this channel
let _ = frame_tx.send(data_vec); let _ = frame_tx.send(data_vec);
} }
} }
// Source disconnected // Source disconnected
ping_task.abort(); ping_task.abort();
eprintln!("[relay] Source disconnected: {}", addr); eprintln!("[relay:{channel_name_owned}] Source disconnected: {addr}");
let mut s = state.write().await; let mut cs = channel.write().await;
s.stats.source_connected = false; cs.stats.source_connected = false;
} }
async fn handle_receiver( async fn handle_receiver(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>, ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr, addr: SocketAddr,
state: Arc<RwLock<RelayState>>, channel: Arc<RwLock<ChannelState>>,
_keepalive_interval: u64, channel_name: &str,
) { ) {
let (mut ws_sink, mut ws_source) = ws_stream.split(); let (mut ws_sink, mut ws_source) = ws_stream.split();
// Subscribe to frame broadcast and get catchup messages // Subscribe to frame broadcast and get catchup messages
let (mut frame_rx, catchup_msgs) = { let (mut frame_rx, catchup_msgs) = {
let mut s = state.write().await; let mut cs = channel.write().await;
s.stats.connected_receivers += 1; cs.stats.connected_receivers += 1;
let rx = s.frame_tx.subscribe(); let rx = cs.frame_tx.subscribe();
let catchup = s.cache.catchup_messages(); let catchup = cs.cache.catchup_messages();
(rx, catchup) (rx, catchup)
}; };
let input_tx = { let input_tx = {
let s = state.read().await; let cs = channel.read().await;
s.input_tx.clone() cs.input_tx.clone()
}; };
// Send cached state to late-joining receiver // Send cached state to late-joining receiver
let channel_name_owned = channel_name.to_string();
if !catchup_msgs.is_empty() { if !catchup_msgs.is_empty() {
eprintln!( eprintln!(
"[relay] Sending {} catchup messages to {}", "[relay:{channel_name_owned}] Sending {} catchup messages to {addr}",
catchup_msgs.len(), catchup_msgs.len(),
addr
); );
for msg_bytes in catchup_msgs { for msg_bytes in catchup_msgs {
let msg = Message::Binary(msg_bytes.into()); let msg = Message::Binary(msg_bytes.into());
if ws_sink.send(msg).await.is_err() { if ws_sink.send(msg).await.is_err() {
eprintln!("[relay] Failed to send catchup to {}", addr); eprintln!("[relay:{channel_name_owned}] Failed to send catchup to {addr}");
let mut s = state.write().await; let mut cs = channel.write().await;
s.stats.connected_receivers = s.stats.connected_receivers.saturating_sub(1); cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
return; return;
} }
} }
} }
// Forward frames from broadcast → this receiver // Forward frames from broadcast → this receiver
let addr_clone = addr;
let send_task = tokio::spawn(async move { let send_task = tokio::spawn(async move {
loop { loop {
match frame_rx.recv().await { match frame_rx.recv().await {
@ -384,7 +470,7 @@ async fn handle_receiver(
} }
} }
Err(broadcast::error::RecvError::Lagged(n)) => { Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("[relay] Receiver {} lagged by {} frames", addr_clone, n); eprintln!("[relay] Receiver {addr} lagged by {n} frames");
} }
Err(_) => break, Err(_) => break,
} }
@ -400,9 +486,9 @@ async fn handle_receiver(
// Receiver disconnected // Receiver disconnected
send_task.abort(); send_task.abort();
eprintln!("[relay] Receiver disconnected: {}", addr); eprintln!("[relay:{channel_name_owned}] Receiver disconnected: {addr}");
let mut s = state.write().await; let mut cs = channel.write().await;
s.stats.connected_receivers = s.stats.connected_receivers.saturating_sub(1); cs.stats.connected_receivers = cs.stats.connected_receivers.saturating_sub(1);
} }
// ─── Tests ─── // ─── Tests ───
@ -499,4 +585,58 @@ mod tests {
// Should have been trimmed: 1100 - 500 = 600 remaining after first trim at 1001 // Should have been trimmed: 1100 - 500 = 600 remaining after first trim at 1001
assert!(cache.pending_signal_diffs.len() <= 600); assert!(cache.pending_signal_diffs.len() <= 600);
} }
// ─── Path Routing Tests ───
#[test]
fn parse_path_source_default() {
match parse_path("/source") {
ConnectionRole::Source(name) => assert_eq!(name, "default"),
_ => panic!("Expected Source"),
}
}
#[test]
fn parse_path_source_named() {
match parse_path("/source/main") {
ConnectionRole::Source(name) => assert_eq!(name, "main"),
_ => panic!("Expected Source"),
}
}
#[test]
fn parse_path_stream_default() {
match parse_path("/stream") {
ConnectionRole::Receiver(name) => assert_eq!(name, "default"),
_ => panic!("Expected Receiver"),
}
}
#[test]
fn parse_path_stream_named() {
match parse_path("/stream/player1") {
ConnectionRole::Receiver(name) => assert_eq!(name, "player1"),
_ => panic!("Expected Receiver"),
}
}
#[test]
fn parse_path_legacy_root() {
match parse_path("/") {
ConnectionRole::Receiver(name) => assert_eq!(name, "default"),
_ => panic!("Expected Receiver for legacy path"),
}
}
#[test]
fn channel_state_creation() {
let mut state = RelayState::new(16);
let ch1 = state.get_or_create_channel("main");
let ch2 = state.get_or_create_channel("player1");
let ch1_again = state.get_or_create_channel("main");
assert_eq!(state.channels.len(), 2);
assert!(Arc::ptr_eq(&ch1, &ch1_again));
assert!(!Arc::ptr_eq(&ch1, &ch2));
}
} }