256 lines
7.9 KiB
Rust
256 lines
7.9 KiB
Rust
|
|
//! WebSocket Relay Server
|
||
|
|
//!
|
||
|
|
//! Routes frames from source→receivers and inputs from receivers→source.
|
||
|
|
//! Roles are determined by the connection path:
|
||
|
|
//! - `/source` — the frame producer (DreamStack renderer)
|
||
|
|
//! - `/stream` — frame consumers (thin receivers)
|
||
|
|
//!
|
||
|
|
//! The relay is intentionally dumb — it just forwards bytes.
|
||
|
|
//! The protocol semantics live in the source and receiver.
|
||
|
|
|
||
|
|
|
||
|
|
use std::net::SocketAddr;
|
||
|
|
use std::sync::Arc;
|
||
|
|
|
||
|
|
use futures_util::{SinkExt, StreamExt};
|
||
|
|
use tokio::net::{TcpListener, TcpStream};
|
||
|
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
||
|
|
use tokio_tungstenite::tungstenite::Message;
|
||
|
|
|
||
|
|
/// Relay server configuration.
|
||
|
|
pub struct RelayConfig {
|
||
|
|
/// Address to bind to.
|
||
|
|
pub addr: SocketAddr,
|
||
|
|
/// Maximum number of receivers.
|
||
|
|
pub max_receivers: usize,
|
||
|
|
/// Frame broadcast channel capacity.
|
||
|
|
pub frame_buffer_size: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Default for RelayConfig {
|
||
|
|
fn default() -> Self {
|
||
|
|
Self {
|
||
|
|
addr: "0.0.0.0:9100".parse().unwrap(),
|
||
|
|
max_receivers: 64,
|
||
|
|
frame_buffer_size: 16,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Stats tracked by the relay.
|
||
|
|
#[derive(Debug, Default, Clone)]
|
||
|
|
pub struct RelayStats {
|
||
|
|
pub frames_relayed: u64,
|
||
|
|
pub bytes_relayed: u64,
|
||
|
|
pub inputs_relayed: u64,
|
||
|
|
pub connected_receivers: usize,
|
||
|
|
pub source_connected: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Shared relay state.
|
||
|
|
struct RelayState {
|
||
|
|
/// Broadcast channel: source → all receivers (frames)
|
||
|
|
frame_tx: broadcast::Sender<Vec<u8>>,
|
||
|
|
/// Channel: receivers → source (input events)
|
||
|
|
input_tx: mpsc::Sender<Vec<u8>>,
|
||
|
|
input_rx: Option<mpsc::Receiver<Vec<u8>>>,
|
||
|
|
/// Live stats
|
||
|
|
stats: RelayStats,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Run the WebSocket relay server.
|
||
|
|
pub async fn run_relay(config: RelayConfig) -> Result<(), Box<dyn std::error::Error>> {
|
||
|
|
let listener = TcpListener::bind(&config.addr).await?;
|
||
|
|
eprintln!("╔══════════════════════════════════════════════════╗");
|
||
|
|
eprintln!("║ DreamStack Bitstream Relay v0.1.0 ║");
|
||
|
|
eprintln!("║ ║");
|
||
|
|
eprintln!("║ Source: ws://{}/source ║", config.addr);
|
||
|
|
eprintln!("║ Receiver: ws://{}/stream ║", config.addr);
|
||
|
|
eprintln!("╚══════════════════════════════════════════════════╝");
|
||
|
|
|
||
|
|
let (frame_tx, _) = broadcast::channel(config.frame_buffer_size);
|
||
|
|
let (input_tx, input_rx) = mpsc::channel(256);
|
||
|
|
|
||
|
|
let state = Arc::new(RwLock::new(RelayState {
|
||
|
|
frame_tx,
|
||
|
|
input_tx,
|
||
|
|
input_rx: Some(input_rx),
|
||
|
|
stats: RelayStats::default(),
|
||
|
|
}));
|
||
|
|
|
||
|
|
while let Ok((stream, addr)) = listener.accept().await {
|
||
|
|
let state = state.clone();
|
||
|
|
tokio::spawn(handle_connection(stream, addr, state));
|
||
|
|
}
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn handle_connection(
|
||
|
|
stream: TcpStream,
|
||
|
|
addr: SocketAddr,
|
||
|
|
state: Arc<RwLock<RelayState>>,
|
||
|
|
) {
|
||
|
|
// Peek at the HTTP upgrade request to determine role
|
||
|
|
let ws_stream = match tokio_tungstenite::accept_hdr_async(
|
||
|
|
stream,
|
||
|
|
|_req: &tokio_tungstenite::tungstenite::handshake::server::Request,
|
||
|
|
res: tokio_tungstenite::tungstenite::handshake::server::Response| {
|
||
|
|
// We'll extract the path from the URI later via a different mechanism
|
||
|
|
Ok(res)
|
||
|
|
},
|
||
|
|
)
|
||
|
|
.await
|
||
|
|
{
|
||
|
|
Ok(ws) => ws,
|
||
|
|
Err(e) => {
|
||
|
|
eprintln!("[relay] WebSocket handshake failed from {}: {}", addr, e);
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
// For simplicity in the PoC, first connection = source, subsequent = receivers.
|
||
|
|
// A production version would parse the URI path.
|
||
|
|
let is_source = {
|
||
|
|
let s = state.read().await;
|
||
|
|
!s.stats.source_connected
|
||
|
|
};
|
||
|
|
|
||
|
|
if is_source {
|
||
|
|
eprintln!("[relay] Source connected: {}", addr);
|
||
|
|
handle_source(ws_stream, addr, state).await;
|
||
|
|
} else {
|
||
|
|
eprintln!("[relay] Receiver connected: {}", addr);
|
||
|
|
handle_receiver(ws_stream, addr, state).await;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn handle_source(
|
||
|
|
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
|
||
|
|
addr: SocketAddr,
|
||
|
|
state: Arc<RwLock<RelayState>>,
|
||
|
|
) {
|
||
|
|
let (mut ws_sink, mut ws_source) = ws_stream.split();
|
||
|
|
|
||
|
|
// Mark source as connected and take the input_rx
|
||
|
|
let input_rx = {
|
||
|
|
let mut s = state.write().await;
|
||
|
|
s.stats.source_connected = true;
|
||
|
|
s.input_rx.take()
|
||
|
|
};
|
||
|
|
|
||
|
|
let frame_tx = {
|
||
|
|
let s = state.read().await;
|
||
|
|
s.frame_tx.clone()
|
||
|
|
};
|
||
|
|
|
||
|
|
// Forward input events from receivers → source
|
||
|
|
if let Some(mut input_rx) = input_rx {
|
||
|
|
let state_clone = state.clone();
|
||
|
|
tokio::spawn(async move {
|
||
|
|
while let Some(input_bytes) = input_rx.recv().await {
|
||
|
|
let msg = Message::Binary(input_bytes.into());
|
||
|
|
if ws_sink.send(msg).await.is_err() {
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
let mut s = state_clone.write().await;
|
||
|
|
s.stats.inputs_relayed += 1;
|
||
|
|
}
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
// Receive frames from source → broadcast to receivers
|
||
|
|
while let Some(Ok(msg)) = ws_source.next().await {
|
||
|
|
if let Message::Binary(data) = msg {
|
||
|
|
let data_vec: Vec<u8> = data.into();
|
||
|
|
{
|
||
|
|
let mut s = state.write().await;
|
||
|
|
s.stats.frames_relayed += 1;
|
||
|
|
s.stats.bytes_relayed += data_vec.len() as u64;
|
||
|
|
}
|
||
|
|
// Broadcast to all receivers (ignore send errors = no receivers)
|
||
|
|
let _ = frame_tx.send(data_vec);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Source disconnected
|
||
|
|
eprintln!("[relay] Source disconnected: {}", addr);
|
||
|
|
let mut s = state.write().await;
|
||
|
|
s.stats.source_connected = false;
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn handle_receiver(
|
||
|
|
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
|
||
|
|
addr: SocketAddr,
|
||
|
|
state: Arc<RwLock<RelayState>>,
|
||
|
|
) {
|
||
|
|
let (mut ws_sink, mut ws_source) = ws_stream.split();
|
||
|
|
|
||
|
|
// Subscribe to frame broadcast
|
||
|
|
let mut frame_rx = {
|
||
|
|
let mut s = state.write().await;
|
||
|
|
s.stats.connected_receivers += 1;
|
||
|
|
s.frame_tx.subscribe()
|
||
|
|
};
|
||
|
|
|
||
|
|
let input_tx = {
|
||
|
|
let s = state.read().await;
|
||
|
|
s.input_tx.clone()
|
||
|
|
};
|
||
|
|
|
||
|
|
// Forward frames from broadcast → this receiver
|
||
|
|
let _state_clone = state.clone();
|
||
|
|
let addr_clone = addr;
|
||
|
|
let send_task = tokio::spawn(async move {
|
||
|
|
loop {
|
||
|
|
match frame_rx.recv().await {
|
||
|
|
Ok(frame_bytes) => {
|
||
|
|
let msg = Message::Binary(frame_bytes.into());
|
||
|
|
if ws_sink.send(msg).await.is_err() {
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||
|
|
eprintln!("[relay] Receiver {} lagged by {} frames", addr_clone, n);
|
||
|
|
}
|
||
|
|
Err(_) => break,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
// Forward input events from this receiver → source
|
||
|
|
while let Some(Ok(msg)) = ws_source.next().await {
|
||
|
|
if let Message::Binary(data) = msg {
|
||
|
|
let _ = input_tx.send(data.into()).await;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Receiver disconnected
|
||
|
|
send_task.abort();
|
||
|
|
eprintln!("[relay] Receiver disconnected: {}", addr);
|
||
|
|
let mut s = state.write().await;
|
||
|
|
s.stats.connected_receivers = s.stats.connected_receivers.saturating_sub(1);
|
||
|
|
}
|
||
|
|
|
||
|
|
// ─── Tests ───
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn default_config() {
|
||
|
|
let config = RelayConfig::default();
|
||
|
|
assert_eq!(config.addr, "0.0.0.0:9100".parse::<SocketAddr>().unwrap());
|
||
|
|
assert_eq!(config.max_receivers, 64);
|
||
|
|
assert_eq!(config.frame_buffer_size, 16);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn stats_default() {
|
||
|
|
let stats = RelayStats::default();
|
||
|
|
assert_eq!(stats.frames_relayed, 0);
|
||
|
|
assert_eq!(stats.bytes_relayed, 0);
|
||
|
|
assert!(!stats.source_connected);
|
||
|
|
}
|
||
|
|
}
|