dreamstack/compiler/ds-analyzer/src/signal_graph.rs

536 lines
18 KiB
Rust
Raw Normal View History

/// Signal graph extraction — the core of DreamStack's compile-time reactivity.
///
/// Walks the AST and builds a directed acyclic graph (DAG) of signals:
/// - Source signals: `let count = 0` (mutable, user-controlled)
/// - Derived signals: `let doubled = count * 2` (computed, auto-tracked)
/// - Effects: DOM bindings that update when their dependencies change
use ds_parser::{Program, Declaration, Expr, BinOp, Container, Element, LetDecl, ViewDecl};
use std::collections::{HashMap, HashSet};
/// The complete signal dependency graph for a program.
#[derive(Debug)]
pub struct SignalGraph {
pub nodes: Vec<SignalNode>,
pub name_to_id: HashMap<String, usize>,
}
/// A node in the signal graph.
#[derive(Debug, Clone)]
pub struct SignalNode {
pub id: usize,
pub name: String,
pub kind: SignalKind,
pub dependencies: Vec<Dependency>,
pub initial_value: Option<InitialValue>,
pub streamable: bool,
}
#[derive(Debug, Clone)]
pub enum SignalKind {
/// Mutable source signal: `let count = 0`
Source,
/// Computed derived signal: `let doubled = count * 2`
Derived,
/// An event handler that mutates signals
Handler { event: String, mutations: Vec<Mutation> },
}
/// What a handler does to a signal.
#[derive(Debug, Clone)]
pub struct Mutation {
pub target: String,
pub op: MutationOp,
}
#[derive(Debug, Clone)]
pub enum MutationOp {
Set(String), // expression source
AddAssign(String),
SubAssign(String),
}
/// A dependency edge in the signal graph.
#[derive(Debug, Clone)]
pub struct Dependency {
pub signal_name: String,
pub signal_id: Option<usize>,
}
/// Inferred initial value for source signals.
#[derive(Debug, Clone)]
pub enum InitialValue {
Int(i64),
Float(f64),
Bool(bool),
String(String),
}
/// Analyzed view information.
#[derive(Debug)]
pub struct AnalyzedView {
pub name: String,
pub bindings: Vec<DomBinding>,
}
/// A reactive DOM binding extracted from a view.
#[derive(Debug, Clone)]
pub struct DomBinding {
pub kind: BindingKind,
pub dependencies: Vec<String>,
}
#[derive(Debug, Clone)]
pub enum BindingKind {
/// `text label` — text content bound to a signal
TextContent { signal: String },
/// `button "+" { click: count += 1 }` — event handler on an element
EventHandler { element_tag: String, event: String, action: String },
/// `when cond -> body` — conditional mount/unmount
Conditional { condition_signals: Vec<String> },
/// `column [ ... ]` — static container
StaticContainer { kind: String, child_count: usize },
/// Static text with no binding
StaticText { text: String },
}
/// Static description of all signals for receiver reconstruction.
#[derive(Debug, Clone)]
pub struct SignalManifest {
pub signals: Vec<ManifestEntry>,
}
#[derive(Debug, Clone)]
pub struct ManifestEntry {
pub name: String,
pub kind: SignalKind,
pub initial: Option<InitialValue>,
pub is_spring: bool,
}
impl SignalGraph {
/// Build a signal graph from a parsed program.
pub fn from_program(program: &Program) -> Self {
let mut graph = SignalGraph {
nodes: Vec::new(),
name_to_id: HashMap::new(),
};
// First pass: register all let declarations as signals
for decl in &program.declarations {
if let Declaration::Let(let_decl) = decl {
let deps = extract_dependencies(&let_decl.value);
let kind = if deps.is_empty() {
SignalKind::Source
} else {
SignalKind::Derived
};
let initial = match &let_decl.value {
Expr::IntLit(n) => Some(InitialValue::Int(*n)),
Expr::FloatLit(n) => Some(InitialValue::Float(*n)),
Expr::BoolLit(b) => Some(InitialValue::Bool(*b)),
Expr::StringLit(s) => {
if s.segments.len() == 1 {
if let ds_parser::StringSegment::Literal(text) = &s.segments[0] {
Some(InitialValue::String(text.clone()))
} else {
None
}
} else {
None
}
}
_ => None,
};
let id = graph.nodes.len();
let dependencies: Vec<Dependency> = deps.into_iter()
.map(|name| Dependency { signal_name: name, signal_id: None })
.collect();
graph.name_to_id.insert(let_decl.name.clone(), id);
graph.nodes.push(SignalNode {
id,
name: let_decl.name.clone(),
kind,
dependencies,
initial_value: initial,
streamable: false,
});
}
}
// Detect stream declarations and mark source signals as streamable
let has_stream = program.declarations.iter()
.any(|d| matches!(d, Declaration::Stream(_)));
if has_stream {
for node in &mut graph.nodes {
if matches!(node.kind, SignalKind::Source) {
node.streamable = true;
}
}
}
// Second pass: register event handlers
for decl in &program.declarations {
if let Declaration::OnHandler(handler) = decl {
let mutations = extract_mutations(&handler.body);
let deps: Vec<String> = mutations.iter().map(|m| m.target.clone()).collect();
let id = graph.nodes.len();
graph.nodes.push(SignalNode {
id,
name: format!("handler_{}", handler.event),
kind: SignalKind::Handler {
event: handler.event.clone(),
mutations,
},
dependencies: deps.into_iter()
.map(|name| Dependency { signal_name: name, signal_id: None })
.collect(),
initial_value: None,
streamable: false,
});
}
}
// Third pass: resolve dependency IDs
let name_map = graph.name_to_id.clone();
for node in &mut graph.nodes {
for dep in &mut node.dependencies {
dep.signal_id = name_map.get(&dep.signal_name).copied();
}
}
graph
}
/// Generate a manifest for receivers to know how to reconstruct the signal state.
pub fn signal_manifest(&self) -> SignalManifest {
SignalManifest {
signals: self.nodes.iter()
.filter(|n| n.streamable)
.map(|n| ManifestEntry {
name: n.name.clone(),
kind: n.kind.clone(),
initial: n.initial_value.clone(),
is_spring: false,
})
.collect()
}
}
/// Analyze views and extract DOM bindings.
pub fn analyze_views(program: &Program) -> Vec<AnalyzedView> {
let mut views = Vec::new();
for decl in &program.declarations {
if let Declaration::View(view) = decl {
let bindings = extract_bindings(&view.body);
views.push(AnalyzedView {
name: view.name.clone(),
bindings,
});
}
}
views
}
/// Get topological order for signal propagation.
pub fn topological_order(&self) -> Vec<usize> {
let mut visited = HashSet::new();
let mut order = Vec::new();
for node in &self.nodes {
if !visited.contains(&node.id) {
self.topo_visit(node.id, &mut visited, &mut order);
}
}
order
}
fn topo_visit(&self, id: usize, visited: &mut HashSet<usize>, order: &mut Vec<usize>) {
if visited.contains(&id) {
return;
}
visited.insert(id);
for dep in &self.nodes[id].dependencies {
if let Some(dep_id) = dep.signal_id {
self.topo_visit(dep_id, visited, order);
}
}
order.push(id);
}
}
/// Extract all signal names referenced in an expression.
fn extract_dependencies(expr: &Expr) -> Vec<String> {
let mut deps = Vec::new();
collect_deps(expr, &mut deps);
deps.sort();
deps.dedup();
deps
}
fn collect_deps(expr: &Expr, deps: &mut Vec<String>) {
match expr {
Expr::Ident(name) => deps.push(name.clone()),
Expr::DotAccess(base, _) => collect_deps(base, deps),
Expr::BinOp(left, _, right) => {
collect_deps(left, deps);
collect_deps(right, deps);
}
Expr::UnaryOp(_, inner) => collect_deps(inner, deps),
Expr::Call(_, args) => {
for arg in args {
collect_deps(arg, deps);
}
}
Expr::If(cond, then_b, else_b) => {
collect_deps(cond, deps);
collect_deps(then_b, deps);
collect_deps(else_b, deps);
}
Expr::Pipe(left, right) => {
collect_deps(left, deps);
collect_deps(right, deps);
}
Expr::Container(c) => {
for child in &c.children {
collect_deps(child, deps);
}
}
Expr::Element(el) => {
for arg in &el.args {
collect_deps(arg, deps);
}
for (_, val) in &el.props {
collect_deps(val, deps);
}
}
Expr::Record(fields) => {
for (_, val) in fields {
collect_deps(val, deps);
}
}
Expr::List(items) => {
for item in items {
collect_deps(item, deps);
}
}
Expr::When(cond, body, else_body) => {
collect_deps(cond, deps);
collect_deps(body, deps);
if let Some(eb) = else_body {
collect_deps(eb, deps);
}
}
Expr::Match(scrutinee, arms) => {
collect_deps(scrutinee, deps);
for arm in arms {
collect_deps(&arm.body, deps);
}
}
Expr::Assign(target, _, value) => {
collect_deps(target, deps);
collect_deps(value, deps);
}
Expr::Lambda(_, body) => collect_deps(body, deps),
Expr::StringLit(s) => {
for seg in &s.segments {
if let ds_parser::StringSegment::Interpolation(expr) = seg {
collect_deps(expr, deps);
}
}
}
_ => {}
}
}
/// Extract mutations from a handler body (e.g., `count += 1`).
fn extract_mutations(expr: &Expr) -> Vec<Mutation> {
let mut mutations = Vec::new();
match expr {
Expr::Assign(target, op, value) => {
if let Expr::Ident(name) = target.as_ref() {
let mutation_op = match op {
ds_parser::AssignOp::Set => MutationOp::Set(format!("{value:?}")),
ds_parser::AssignOp::AddAssign => MutationOp::AddAssign(format!("{value:?}")),
ds_parser::AssignOp::SubAssign => MutationOp::SubAssign(format!("{value:?}")),
};
mutations.push(Mutation { target: name.clone(), op: mutation_op });
}
}
Expr::Block(exprs) => {
for e in exprs {
mutations.extend(extract_mutations(e));
}
}
_ => {}
}
mutations
}
/// Extract DOM bindings from a view body.
fn extract_bindings(expr: &Expr) -> Vec<DomBinding> {
let mut bindings = Vec::new();
collect_bindings(expr, &mut bindings);
bindings
}
fn collect_bindings(expr: &Expr, bindings: &mut Vec<DomBinding>) {
match expr {
Expr::Container(c) => {
let kind_str = match &c.kind {
ds_parser::ContainerKind::Column => "column",
ds_parser::ContainerKind::Row => "row",
ds_parser::ContainerKind::Stack => "stack",
ds_parser::ContainerKind::Panel => "panel",
ds_parser::ContainerKind::List => "list",
ds_parser::ContainerKind::Form => "form",
ds_parser::ContainerKind::Scene => "scene",
ds_parser::ContainerKind::Custom(s) => s,
};
bindings.push(DomBinding {
kind: BindingKind::StaticContainer {
kind: kind_str.to_string(),
child_count: c.children.len(),
},
dependencies: Vec::new(),
});
for child in &c.children {
collect_bindings(child, bindings);
}
}
Expr::Element(el) => {
// Check if any arg is an identifier (signal binding)
for arg in &el.args {
match arg {
Expr::Ident(name) => {
bindings.push(DomBinding {
kind: BindingKind::TextContent { signal: name.clone() },
dependencies: vec![name.clone()],
});
}
Expr::StringLit(s) => {
if let Some(ds_parser::StringSegment::Literal(text)) = s.segments.first() {
bindings.push(DomBinding {
kind: BindingKind::StaticText { text: text.clone() },
dependencies: Vec::new(),
});
}
}
_ => {}
}
}
// Check props for event handlers
for (key, val) in &el.props {
if matches!(key.as_str(), "click" | "input" | "change" | "submit" | "keydown" | "keyup") {
let action = format!("{val:?}");
let deps = extract_dependencies(val);
bindings.push(DomBinding {
kind: BindingKind::EventHandler {
element_tag: el.tag.clone(),
event: key.clone(),
action,
},
dependencies: deps,
});
}
}
}
Expr::When(cond, body, else_body) => {
let deps = extract_dependencies(cond);
bindings.push(DomBinding {
kind: BindingKind::Conditional { condition_signals: deps.clone() },
dependencies: deps,
});
collect_bindings(body, bindings);
if let Some(eb) = else_body {
collect_bindings(eb, bindings);
}
}
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ds_parser::{Lexer, Parser};
fn analyze(src: &str) -> (SignalGraph, Vec<AnalyzedView>) {
let mut lexer = Lexer::new(src);
let tokens = lexer.tokenize();
let mut parser = Parser::new(tokens);
let program = parser.parse_program().expect("parse failed");
let graph = SignalGraph::from_program(&program);
let views = SignalGraph::analyze_views(&program);
(graph, views)
}
#[test]
fn test_source_signal() {
let (graph, _) = analyze("let count = 0");
assert_eq!(graph.nodes.len(), 1);
assert!(matches!(graph.nodes[0].kind, SignalKind::Source));
assert_eq!(graph.nodes[0].name, "count");
}
#[test]
fn test_derived_signal() {
let (graph, _) = analyze("let count = 0\nlet doubled = count * 2");
assert_eq!(graph.nodes.len(), 2);
assert!(matches!(graph.nodes[0].kind, SignalKind::Source));
assert!(matches!(graph.nodes[1].kind, SignalKind::Derived));
assert_eq!(graph.nodes[1].dependencies[0].signal_name, "count");
assert_eq!(graph.nodes[1].dependencies[0].signal_id, Some(0));
}
#[test]
fn test_topological_order() {
let (graph, _) = analyze("let count = 0\nlet doubled = count * 2");
let order = graph.topological_order();
// count (id=0) should come before doubled (id=1)
let pos_count = order.iter().position(|&id| id == 0).unwrap();
let pos_doubled = order.iter().position(|&id| id == 1).unwrap();
assert!(pos_count < pos_doubled);
}
#[test]
fn test_view_bindings() {
let (_, views) = analyze(
r#"let label = "hi"
view counter =
column [
text label
button "+" { click: count += 1 }
]"#
);
assert_eq!(views.len(), 1);
assert_eq!(views[0].name, "counter");
// Should have: container, text binding, static text, event handler
assert!(views[0].bindings.len() >= 3);
}
#[test]
fn test_streamable_signals() {
let (graph, _) = analyze(
"stream main on \"ws://localhost:9100\"\nlet count = 0\nview main = column [ text \"hello\" ]"
);
let count_node = graph.nodes.iter().find(|n| n.name == "count").unwrap();
assert!(count_node.streamable, "source signal should be streamable when stream decl present");
}
#[test]
fn test_not_streamable_without_decl() {
let (graph, _) = analyze("let count = 0\nview main = column [ text \"hi\" ]");
let count_node = graph.nodes.iter().find(|n| n.name == "count").unwrap();
assert!(!count_node.streamable, "signals should not be streamable without stream decl");
}
}