dreamstack/engine/ds-stream/src/relay.rs

256 lines
7.9 KiB
Rust
Raw Normal View History

//! WebSocket Relay Server
//!
//! Routes frames from source→receivers and inputs from receivers→source.
//! Roles are determined by the connection path:
//! - `/source` — the frame producer (DreamStack renderer)
//! - `/stream` — frame consumers (thin receivers)
//!
//! The relay is intentionally dumb — it just forwards bytes.
//! The protocol semantics live in the source and receiver.
use std::net::SocketAddr;
use std::sync::Arc;
use futures_util::{SinkExt, StreamExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio_tungstenite::tungstenite::Message;
/// Relay server configuration.
pub struct RelayConfig {
/// Address to bind to.
pub addr: SocketAddr,
/// Maximum number of receivers.
pub max_receivers: usize,
/// Frame broadcast channel capacity.
pub frame_buffer_size: usize,
}
impl Default for RelayConfig {
fn default() -> Self {
Self {
addr: "0.0.0.0:9100".parse().unwrap(),
max_receivers: 64,
frame_buffer_size: 16,
}
}
}
/// Stats tracked by the relay.
#[derive(Debug, Default, Clone)]
pub struct RelayStats {
pub frames_relayed: u64,
pub bytes_relayed: u64,
pub inputs_relayed: u64,
pub connected_receivers: usize,
pub source_connected: bool,
}
/// Shared relay state.
struct RelayState {
/// Broadcast channel: source → all receivers (frames)
frame_tx: broadcast::Sender<Vec<u8>>,
/// Channel: receivers → source (input events)
input_tx: mpsc::Sender<Vec<u8>>,
input_rx: Option<mpsc::Receiver<Vec<u8>>>,
/// Live stats
stats: RelayStats,
}
/// Run the WebSocket relay server.
pub async fn run_relay(config: RelayConfig) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(&config.addr).await?;
eprintln!("╔══════════════════════════════════════════════════╗");
eprintln!("║ DreamStack Bitstream Relay v0.1.0 ║");
eprintln!("║ ║");
eprintln!("║ Source: ws://{}/source ║", config.addr);
eprintln!("║ Receiver: ws://{}/stream ║", config.addr);
eprintln!("╚══════════════════════════════════════════════════╝");
let (frame_tx, _) = broadcast::channel(config.frame_buffer_size);
let (input_tx, input_rx) = mpsc::channel(256);
let state = Arc::new(RwLock::new(RelayState {
frame_tx,
input_tx,
input_rx: Some(input_rx),
stats: RelayStats::default(),
}));
while let Ok((stream, addr)) = listener.accept().await {
let state = state.clone();
tokio::spawn(handle_connection(stream, addr, state));
}
Ok(())
}
async fn handle_connection(
stream: TcpStream,
addr: SocketAddr,
state: Arc<RwLock<RelayState>>,
) {
// Peek at the HTTP upgrade request to determine role
let ws_stream = match tokio_tungstenite::accept_hdr_async(
stream,
|_req: &tokio_tungstenite::tungstenite::handshake::server::Request,
res: tokio_tungstenite::tungstenite::handshake::server::Response| {
// We'll extract the path from the URI later via a different mechanism
Ok(res)
},
)
.await
{
Ok(ws) => ws,
Err(e) => {
eprintln!("[relay] WebSocket handshake failed from {}: {}", addr, e);
return;
}
};
// For simplicity in the PoC, first connection = source, subsequent = receivers.
// A production version would parse the URI path.
let is_source = {
let s = state.read().await;
!s.stats.source_connected
};
if is_source {
eprintln!("[relay] Source connected: {}", addr);
handle_source(ws_stream, addr, state).await;
} else {
eprintln!("[relay] Receiver connected: {}", addr);
handle_receiver(ws_stream, addr, state).await;
}
}
async fn handle_source(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
state: Arc<RwLock<RelayState>>,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Mark source as connected and take the input_rx
let input_rx = {
let mut s = state.write().await;
s.stats.source_connected = true;
s.input_rx.take()
};
let frame_tx = {
let s = state.read().await;
s.frame_tx.clone()
};
// Forward input events from receivers → source
if let Some(mut input_rx) = input_rx {
let state_clone = state.clone();
tokio::spawn(async move {
while let Some(input_bytes) = input_rx.recv().await {
let msg = Message::Binary(input_bytes.into());
if ws_sink.send(msg).await.is_err() {
break;
}
let mut s = state_clone.write().await;
s.stats.inputs_relayed += 1;
}
});
}
// Receive frames from source → broadcast to receivers
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let data_vec: Vec<u8> = data.into();
{
let mut s = state.write().await;
s.stats.frames_relayed += 1;
s.stats.bytes_relayed += data_vec.len() as u64;
}
// Broadcast to all receivers (ignore send errors = no receivers)
let _ = frame_tx.send(data_vec);
}
}
// Source disconnected
eprintln!("[relay] Source disconnected: {}", addr);
let mut s = state.write().await;
s.stats.source_connected = false;
}
async fn handle_receiver(
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
addr: SocketAddr,
state: Arc<RwLock<RelayState>>,
) {
let (mut ws_sink, mut ws_source) = ws_stream.split();
// Subscribe to frame broadcast
let mut frame_rx = {
let mut s = state.write().await;
s.stats.connected_receivers += 1;
s.frame_tx.subscribe()
};
let input_tx = {
let s = state.read().await;
s.input_tx.clone()
};
// Forward frames from broadcast → this receiver
let _state_clone = state.clone();
let addr_clone = addr;
let send_task = tokio::spawn(async move {
loop {
match frame_rx.recv().await {
Ok(frame_bytes) => {
let msg = Message::Binary(frame_bytes.into());
if ws_sink.send(msg).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(n)) => {
eprintln!("[relay] Receiver {} lagged by {} frames", addr_clone, n);
}
Err(_) => break,
}
}
});
// Forward input events from this receiver → source
while let Some(Ok(msg)) = ws_source.next().await {
if let Message::Binary(data) = msg {
let _ = input_tx.send(data.into()).await;
}
}
// Receiver disconnected
send_task.abort();
eprintln!("[relay] Receiver disconnected: {}", addr);
let mut s = state.write().await;
s.stats.connected_receivers = s.stats.connected_receivers.saturating_sub(1);
}
// ─── Tests ───
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config() {
let config = RelayConfig::default();
assert_eq!(config.addr, "0.0.0.0:9100".parse::<SocketAddr>().unwrap());
assert_eq!(config.max_receivers, 64);
assert_eq!(config.frame_buffer_size, 16);
}
#[test]
fn stats_default() {
let stats = RelayStats::default();
assert_eq!(stats.frames_relayed, 0);
assert_eq!(stats.bytes_relayed, 0);
assert!(!stats.source_connected);
}
}