diff --git a/Cargo.toml b/Cargo.toml index 7bcccff..bcda22f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,8 @@ members = [ "compiler/ds-parser", "compiler/ds-analyzer", "compiler/ds-codegen", + "compiler/ds-layout", + "compiler/ds-types", "compiler/ds-cli", ] @@ -16,3 +18,5 @@ license = "MIT" ds-parser = { path = "compiler/ds-parser" } ds-analyzer = { path = "compiler/ds-analyzer" } ds-codegen = { path = "compiler/ds-codegen" } +ds-layout = { path = "compiler/ds-layout" } +ds-types = { path = "compiler/ds-types" } diff --git a/compiler/ds-layout/Cargo.toml b/compiler/ds-layout/Cargo.toml new file mode 100644 index 0000000..e399a43 --- /dev/null +++ b/compiler/ds-layout/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "ds-layout" +version = "0.1.0" +edition = "2021" + +[dependencies] diff --git a/compiler/ds-layout/src/lib.rs b/compiler/ds-layout/src/lib.rs new file mode 100644 index 0000000..e41dbe1 --- /dev/null +++ b/compiler/ds-layout/src/lib.rs @@ -0,0 +1,7 @@ +/// DreamStack Layout — Cassowary-inspired constraint solver for UI layout. +pub mod solver; + +pub use solver::{ + LayoutSolver, Variable, Constraint, ConstraintKind, + Strength, LayoutRect, Term, +}; diff --git a/compiler/ds-layout/src/solver.rs b/compiler/ds-layout/src/solver.rs new file mode 100644 index 0000000..4b429b5 --- /dev/null +++ b/compiler/ds-layout/src/solver.rs @@ -0,0 +1,452 @@ +/// Cassowary-inspired constraint solver for DreamStack layout. +/// +/// Solves systems of linear constraints using Gaussian elimination +/// with strength-based priority (Required > Strong > Medium > Weak). +/// +/// Each variable represents a layout dimension (x, y, width, height). +/// Constraints express relationships between these dimensions. +/// +/// Performance target: solve 500 constraints in <2ms. + +use std::collections::HashMap; + +// ─── Core types ───────────────────────────────────────────── + +/// A layout variable representing a single dimension of an element. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Variable(pub u32); + +static NEXT_VAR_ID: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(1); + +impl Variable { + pub fn new() -> Self { + Variable(NEXT_VAR_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed)) + } +} + +impl Default for Variable { + fn default() -> Self { + Self::new() + } +} + +/// A term in a linear expression: coefficient * variable. +#[derive(Debug, Clone)] +pub struct Term { + pub var: Variable, + pub coefficient: f64, +} + +/// Constraint strength — determines priority when constraints conflict. +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +pub enum Strength { + Weak = 1, + Medium = 100, + Strong = 1000, + Required = 1_000_000, +} + +/// What kind of constraint. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ConstraintKind { + Eq, // expression == 0 + Lte, // expression <= 0 + Gte, // expression >= 0 +} + +/// A linear constraint: sum_i(coeff_i * var_i) + constant 0 +#[derive(Debug, Clone)] +pub struct Constraint { + pub terms: Vec, + pub constant: f64, + pub kind: ConstraintKind, + pub strength: Strength, +} + +impl Constraint { + /// var == val + pub fn eq_const(var: Variable, val: f64, strength: Strength) -> Self { + // var - val == 0 + Constraint { + terms: vec![Term { var, coefficient: 1.0 }], + constant: -val, + kind: ConstraintKind::Eq, + strength, + } + } + + /// a == b + pub fn eq(a: Variable, b: Variable, strength: Strength) -> Self { + // a - b == 0 + Constraint { + terms: vec![ + Term { var: a, coefficient: 1.0 }, + Term { var: b, coefficient: -1.0 }, + ], + constant: 0.0, + kind: ConstraintKind::Eq, + strength, + } + } + + /// a + b == val + pub fn sum_eq(a: Variable, b: Variable, val: f64, strength: Strength) -> Self { + Constraint { + terms: vec![ + Term { var: a, coefficient: 1.0 }, + Term { var: b, coefficient: 1.0 }, + ], + constant: -val, + kind: ConstraintKind::Eq, + strength, + } + } + + /// var >= val + pub fn gte_const(var: Variable, val: f64, strength: Strength) -> Self { + // var - val >= 0 + Constraint { + terms: vec![Term { var, coefficient: 1.0 }], + constant: -val, + kind: ConstraintKind::Gte, + strength, + } + } + + /// var <= val + pub fn lte_const(var: Variable, val: f64, strength: Strength) -> Self { + // var - val <= 0 + Constraint { + terms: vec![Term { var, coefficient: 1.0 }], + constant: -val, + kind: ConstraintKind::Lte, + strength, + } + } + + /// a * ratio == b (i.e., ratio*a - b == 0) + pub fn ratio(a: Variable, b: Variable, ratio: f64, strength: Strength) -> Self { + Constraint { + terms: vec![ + Term { var: a, coefficient: ratio }, + Term { var: b, coefficient: -1.0 }, + ], + constant: 0.0, + kind: ConstraintKind::Eq, + strength, + } + } +} + +/// Absolute layout rect for a resolved element. +#[derive(Debug, Clone, Copy, Default)] +pub struct LayoutRect { + pub x: f64, + pub y: f64, + pub width: f64, + pub height: f64, +} + +// ─── Solver ───────────────────────────────────────────────── + +/// The constraint solver. +/// +/// Uses Gaussian elimination: each equality constraint is reduced to +/// `subject = expression_of_other_vars + constant`. Existing definitions +/// are substituted in, so the system converges to concrete values. +/// Inequality constraints are handled by clamping. + +pub struct LayoutSolver { + constraints: Vec, + /// Solved variable definitions: var → (coefficients_of_other_vars, constant). + /// When fully reduced, the HashMap of coefficients is empty, and constant is the value. + definitions: HashMap, f64)>, +} + +impl LayoutSolver { + pub fn new() -> Self { + LayoutSolver { + constraints: Vec::new(), + definitions: HashMap::new(), + } + } + + pub fn add_constraint(&mut self, constraint: Constraint) { + self.constraints.push(constraint); + } + + /// Solve all constraints. + pub fn solve(&mut self) { + self.definitions.clear(); + + // Sort constraints: required first (highest strength first) + let mut constraints = self.constraints.clone(); + constraints.sort_by(|a, b| (b.strength as i32).cmp(&(a.strength as i32))); + + for c in constraints { + match c.kind { + ConstraintKind::Eq => self.process_eq(&c), + ConstraintKind::Gte => self.process_gte(&c), + ConstraintKind::Lte => self.process_lte(&c), + } + } + + // Iteratively resolve until all definitions are concrete + for _ in 0..20 { + let mut changed = false; + let vars: Vec = self.definitions.keys().cloned().collect(); + for var in vars { + let (coeffs, constant) = self.definitions[&var].clone(); + let mut new_coeffs: HashMap = HashMap::new(); + let mut new_constant = constant; + + for (&dep_var, &coeff) in &coeffs { + if let Some((dep_coeffs, dep_const)) = self.definitions.get(&dep_var) { + // Substitute: dep_var = dep_coeffs + dep_const + new_constant += coeff * dep_const; + for (&v, &c) in dep_coeffs { + *new_coeffs.entry(v).or_insert(0.0) += coeff * c; + } + changed = true; + } else { + *new_coeffs.entry(dep_var).or_insert(0.0) += coeff; + } + } + + // Clean near-zero coefficients + new_coeffs.retain(|_, c| c.abs() > 1e-10); + + self.definitions.insert(var, (new_coeffs, new_constant)); + } + if !changed { + break; + } + } + } + + fn process_eq(&mut self, c: &Constraint) { + // Build: sum_i(coeff_i * var_i) + constant == 0 + // Substitute known definitions + let mut coeffs: HashMap = HashMap::new(); + let mut constant = c.constant; + + for term in &c.terms { + if let Some((def_coeffs, def_const)) = self.definitions.get(&term.var) { + // Variable already defined — substitute + constant += term.coefficient * def_const; + for (&v, &c) in def_coeffs { + *coeffs.entry(v).or_insert(0.0) += term.coefficient * c; + } + } else { + *coeffs.entry(term.var).or_insert(0.0) += term.coefficient; + } + } + + // Clean near-zero + coeffs.retain(|_, c| c.abs() > 1e-10); + + // Pick a subject variable to solve for + if let Some((&subject, &subject_coeff)) = coeffs.iter().find(|(v, c)| { + c.abs() > 1e-10 && !self.definitions.contains_key(v) + }).or_else(|| coeffs.iter().find(|(_, c)| c.abs() > 1e-10)) { + // Solve: subject_coeff * subject + rest == 0 + // subject = -rest / subject_coeff - constant / subject_coeff + let inv = -1.0 / subject_coeff; + let mut result_coeffs: HashMap = HashMap::new(); + for (&v, &c) in &coeffs { + if v != subject { + result_coeffs.insert(v, c * inv); + } + } + let result_constant = constant * inv; + + self.definitions.insert(subject, (result_coeffs, result_constant)); + } + } + + fn process_gte(&mut self, c: &Constraint) { + // var + constant >= 0 → var >= -constant + // If var is unsolved, set its minimum + if c.terms.len() == 1 { + let term = &c.terms[0]; + let min_val = -c.constant / term.coefficient; + + if let Some((_coeffs, const_val)) = self.definitions.get_mut(&term.var) { + if *const_val < min_val { + *const_val = min_val; + } + } else { + // Not yet defined — set to minimum + self.definitions.insert(term.var, (HashMap::new(), min_val)); + } + } else { + // Multi-variable inequality — convert to equality at boundary + self.process_eq(c); + } + } + + fn process_lte(&mut self, c: &Constraint) { + if c.terms.len() == 1 { + let term = &c.terms[0]; + let max_val = -c.constant / term.coefficient; + + if let Some((_coeffs, const_val)) = self.definitions.get_mut(&term.var) { + if *const_val > max_val { + *const_val = max_val; + } + } else { + self.definitions.insert(term.var, (HashMap::new(), max_val)); + } + } else { + self.process_eq(c); + } + } + + /// Get the resolved value of a variable. + pub fn get_value(&self, var: Variable) -> f64 { + if let Some((coeffs, constant)) = self.definitions.get(&var) { + if coeffs.is_empty() { + return *constant; + } + // Still has dependencies — try to resolve them + let mut val = *constant; + for (&dep, &coeff) in coeffs { + val += coeff * self.get_value(dep); + } + val + } else { + 0.0 + } + } + + /// Resolve a layout rect from four variables. + pub fn get_rect(&self, x: Variable, y: Variable, w: Variable, h: Variable) -> LayoutRect { + LayoutRect { + x: self.get_value(x), + y: self.get_value(y), + width: self.get_value(w), + height: self.get_value(h), + } + } + + /// Clear all constraints and solutions. + pub fn reset(&mut self) { + self.constraints.clear(); + self.definitions.clear(); + } +} + +impl Default for LayoutSolver { + fn default() -> Self { + Self::new() + } +} + +// ─── Tests ────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simple_eq() { + let mut solver = LayoutSolver::new(); + let x = Variable::new(); + solver.add_constraint(Constraint::eq_const(x, 100.0, Strength::Required)); + solver.solve(); + assert!((solver.get_value(x) - 100.0).abs() < 0.01, "x = {}", solver.get_value(x)); + } + + #[test] + fn test_two_vars_eq() { + let mut solver = LayoutSolver::new(); + let x = Variable::new(); + let w = Variable::new(); + solver.add_constraint(Constraint::eq_const(x, 50.0, Strength::Required)); + solver.add_constraint(Constraint::eq(w, x, Strength::Required)); + solver.solve(); + assert!((solver.get_value(x) - 50.0).abs() < 0.01, "x = {}", solver.get_value(x)); + assert!((solver.get_value(w) - 50.0).abs() < 0.01, "w = {}", solver.get_value(w)); + } + + #[test] + fn test_sum_constraint() { + let mut solver = LayoutSolver::new(); + let x = Variable::new(); + let w = Variable::new(); + solver.add_constraint(Constraint::eq_const(x, 100.0, Strength::Required)); + solver.add_constraint(Constraint::sum_eq(x, w, 500.0, Strength::Required)); + solver.solve(); + assert!((solver.get_value(x) - 100.0).abs() < 0.01, "x = {}", solver.get_value(x)); + assert!((solver.get_value(w) - 400.0).abs() < 0.01, "w = {}", solver.get_value(w)); + } + + #[test] + fn test_layout_rect() { + let mut solver = LayoutSolver::new(); + let x = Variable::new(); + let y = Variable::new(); + let w = Variable::new(); + let h = Variable::new(); + + solver.add_constraint(Constraint::eq_const(x, 10.0, Strength::Required)); + solver.add_constraint(Constraint::eq_const(y, 20.0, Strength::Required)); + solver.add_constraint(Constraint::eq_const(w, 300.0, Strength::Required)); + solver.add_constraint(Constraint::eq_const(h, 400.0, Strength::Required)); + solver.solve(); + + let rect = solver.get_rect(x, y, w, h); + assert!((rect.x - 10.0).abs() < 0.01); + assert!((rect.y - 20.0).abs() < 0.01); + assert!((rect.width - 300.0).abs() < 0.01); + assert!((rect.height - 400.0).abs() < 0.01); + } + + #[test] + fn test_gte_constraint() { + let mut solver = LayoutSolver::new(); + let w = Variable::new(); + solver.add_constraint(Constraint::gte_const(w, 50.0, Strength::Required)); + solver.solve(); + let val = solver.get_value(w); + assert!(val >= 49.99, "w should be >= 50, got {}", val); + } + + #[test] + fn test_ratio_constraint() { + let mut solver = LayoutSolver::new(); + let w = Variable::new(); + let h = Variable::new(); + solver.add_constraint(Constraint::eq_const(w, 200.0, Strength::Required)); + solver.add_constraint(Constraint::ratio(w, h, 0.5, Strength::Required)); + solver.solve(); + assert!((solver.get_value(w) - 200.0).abs() < 0.01, "w = {}", solver.get_value(w)); + assert!((solver.get_value(h) - 100.0).abs() < 0.01, "h = {}", solver.get_value(h)); + } + + #[test] + fn test_three_panel_layout() { + let mut solver = LayoutSolver::new(); + + let sidebar_x = Variable::new(); + let sidebar_w = Variable::new(); + let main_x = Variable::new(); + let main_w = Variable::new(); + + // sidebar starts at 0, width 200 + solver.add_constraint(Constraint::eq_const(sidebar_x, 0.0, Strength::Required)); + solver.add_constraint(Constraint::eq_const(sidebar_w, 200.0, Strength::Required)); + // main starts where sidebar ends: main_x = sidebar_x + sidebar_w = 200 + solver.add_constraint(Constraint::eq_const(main_x, 200.0, Strength::Required)); + // main_x + main_w = 1000 (total width) + solver.add_constraint(Constraint::sum_eq(main_x, main_w, 1000.0, Strength::Required)); + + solver.solve(); + + assert!((solver.get_value(sidebar_x) - 0.0).abs() < 0.01); + assert!((solver.get_value(sidebar_w) - 200.0).abs() < 0.01); + assert!((solver.get_value(main_x) - 200.0).abs() < 0.01); + assert!((solver.get_value(main_w) - 800.0).abs() < 0.01, + "main_w = {}", solver.get_value(main_w)); + } +} diff --git a/compiler/ds-types/Cargo.toml b/compiler/ds-types/Cargo.toml new file mode 100644 index 0000000..91153c6 --- /dev/null +++ b/compiler/ds-types/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "ds-types" +version = "0.1.0" +edition = "2021" + +[dependencies] +ds-parser = { path = "../ds-parser" } diff --git a/compiler/ds-types/src/checker.rs b/compiler/ds-types/src/checker.rs new file mode 100644 index 0000000..1d69206 --- /dev/null +++ b/compiler/ds-types/src/checker.rs @@ -0,0 +1,575 @@ +/// Type checker — Hindley-Milner with signal-awareness and effect tracking. +/// +/// Walks the DreamStack AST and: +/// 1. Infers types for all expressions +/// 2. Checks signal usage (Signal vs Derived vs plain T) +/// 3. Tracks effect perform/handle pairs +/// 4. Ensures views only contain valid UI expressions + +use std::collections::HashMap; + +use ds_parser::{Program, Declaration, LetDecl, ViewDecl, Expr, BinOp, UnaryOp}; +use crate::types::{Type, TypeVar, EffectType}; +use crate::errors::{TypeError, TypeErrorKind}; + +/// The type checker. +pub struct TypeChecker { + /// Type environment: variable name → type. + env: HashMap, + /// Accumulated errors (error-recovery: keep going after first error). + errors: Vec, + /// Effect stack: currently handled effects. + effect_handlers: Vec>, + /// Next type variable ID. + next_tv: u32, + /// Type substitutions from unification. + substitutions: HashMap, + /// Currently inside a view block? + in_view: bool, +} + +impl TypeChecker { + pub fn new() -> Self { + TypeChecker { + env: HashMap::new(), + errors: Vec::new(), + effect_handlers: Vec::new(), + next_tv: 0, + substitutions: HashMap::new(), + in_view: false, + } + } + + /// Generate a fresh type variable. + fn fresh_tv(&mut self) -> Type { + let tv = TypeVar(self.next_tv); + self.next_tv += 1; + Type::Var(tv) + } + + /// Record a type error. + fn error(&mut self, kind: TypeErrorKind) { + self.errors.push(TypeError::new(kind)); + } + + /// Check an entire program. + pub fn check_program(&mut self, program: &Program) { + // First pass: register all let declarations + for decl in &program.declarations { + if let Declaration::Let(let_decl) = decl { + let ty = self.infer_expr(&let_decl.value); + // Heuristic: if name ends conventionally or is assigned + // a literal, mark as Signal; otherwise just let-bound T. + // For now: any `let` with a literal is a source signal; + // derivations (expressions involving other identifiers) become Derived. + let is_source = matches!( + let_decl.value, + Expr::IntLit(_) | Expr::FloatLit(_) | Expr::StringLit(_) | Expr::BoolLit(_) + ); + let final_ty = if is_source { + Type::Signal(Box::new(ty)) + } else { + Type::Derived(Box::new(ty)) + }; + self.env.insert(let_decl.name.clone(), final_ty); + } + } + + // Register effect declarations + for decl in &program.declarations { + if let Declaration::Effect(eff) = decl { + let fn_type = Type::Fn { + params: eff.params.iter().map(|_| self.fresh_tv()).collect(), + ret: Box::new(self.fresh_tv()), + effects: vec![EffectType::Custom(eff.name.clone())], + }; + self.env.insert(eff.name.clone(), fn_type); + } + } + + // Register handlers + for decl in &program.declarations { + if let Declaration::OnHandler(handler) = decl { + let fn_type = Type::Fn { + params: vec![], + ret: Box::new(Type::Unit), + effects: vec![EffectType::Dom], + }; + self.env.insert(handler.event.clone(), fn_type); + } + } + + // Second pass: check views + for decl in &program.declarations { + if let Declaration::View(view) = decl { + self.check_view(view); + } + } + } + + /// Check a view declaration. + fn check_view(&mut self, view: &ViewDecl) { + self.in_view = true; + self.check_view_expr(&view.body); + self.in_view = false; + } + + /// Check that an expression is valid inside a view. + fn check_view_expr(&mut self, expr: &Expr) { + match expr { + Expr::Element(el) => { + // Check child arguments and props + for arg in &el.args { + self.check_view_expr(arg); + } + for (_, val) in &el.props { + self.check_view_expr(val); + } + } + Expr::Container(container) => { + for child in &container.children { + self.check_view_expr(child); + } + } + Expr::StringLit(_) => { /* always valid */ } + Expr::Ident(name) => { + if !self.env.contains_key(name) { + self.error(TypeErrorKind::UnboundVariable { + name: name.clone(), + }); + } + } + Expr::When(condition, body) => { + let cond_type = self.infer_expr(condition); + let inner = cond_type.unwrap_reactive().clone(); + if inner != Type::Bool && !matches!(inner, Type::Var(_)) && inner != Type::Error { + self.error(TypeErrorKind::Mismatch { + expected: Type::Bool, + found: cond_type, + context: "Condition in `when` must be a boolean".to_string(), + }); + } + self.check_view_expr(body); + } + Expr::If(condition, then_branch, else_branch) => { + let cond_type = self.infer_expr(condition); + let inner = cond_type.unwrap_reactive().clone(); + if inner != Type::Bool && !matches!(inner, Type::Var(_)) && inner != Type::Error { + self.error(TypeErrorKind::Mismatch { + expected: Type::Bool, + found: cond_type, + context: "Condition in if must be a boolean".to_string(), + }); + } + self.check_view_expr(then_branch); + self.check_view_expr(else_branch); + } + Expr::Block(exprs) => { + for e in exprs { + self.check_view_expr(e); + } + } + _ => { + let _ = self.infer_expr(expr); + } + } + } + + /// Infer the type of an expression. + fn infer_expr(&mut self, expr: &Expr) -> Type { + match expr { + Expr::IntLit(_) => Type::Int, + Expr::FloatLit(_) => Type::Float, + Expr::StringLit(_) => Type::String, + Expr::BoolLit(_) => Type::Bool, + + Expr::Ident(name) => { + if let Some(ty) = self.env.get(name) { + ty.clone() + } else { + self.error(TypeErrorKind::UnboundVariable { name: name.clone() }); + Type::Error + } + } + + Expr::DotAccess(obj, field) => { + let obj_ty = self.infer_expr(obj); + match &obj_ty { + Type::Record(fields) => { + if let Some(ty) = fields.get(field) { + ty.clone() + } else { + self.error(TypeErrorKind::MissingField { + field: field.clone(), + record_type: obj_ty.clone(), + }); + Type::Error + } + } + _ => self.fresh_tv(), // could be a dot-access on a signal + } + } + + Expr::BinOp(left, op, right) => { + let left_ty = self.infer_expr(left); + let right_ty = self.infer_expr(right); + + let left_inner = left_ty.unwrap_reactive(); + let right_inner = right_ty.unwrap_reactive(); + + match op { + BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Mod => { + if *left_inner == Type::Int && *right_inner == Type::Int { + Type::Int + } else if matches!(left_inner, Type::Int | Type::Float) + && matches!(right_inner, Type::Int | Type::Float) + { + Type::Float + } else if *left_inner == Type::String && matches!(op, BinOp::Add) { + Type::String + } else if *left_inner == Type::Error || *right_inner == Type::Error { + Type::Error + } else { + self.error(TypeErrorKind::Mismatch { + expected: Type::Int, + found: right_ty.clone(), + context: format!("Both sides of {:?} must be numeric", op), + }); + Type::Error + } + } + BinOp::Eq | BinOp::Neq | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte => Type::Bool, + BinOp::And | BinOp::Or => { + if *left_inner != Type::Bool && !matches!(left_inner, Type::Var(_)) { + self.error(TypeErrorKind::Mismatch { + expected: Type::Bool, + found: left_ty.clone(), + context: format!("Left side of {:?} must be Bool", op), + }); + } + Type::Bool + } + } + } + + Expr::UnaryOp(op, operand) => { + let ty = self.infer_expr(operand); + let inner = ty.unwrap_reactive(); + match op { + UnaryOp::Neg => { + if matches!(inner, Type::Int | Type::Float) { + inner.clone() + } else { + self.error(TypeErrorKind::Mismatch { + expected: Type::Int, + found: ty.clone(), + context: "Unary `-` requires a numeric type".to_string(), + }); + Type::Error + } + } + UnaryOp::Not => { + if *inner == Type::Bool || matches!(inner, Type::Var(_)) { + Type::Bool + } else { + self.error(TypeErrorKind::Mismatch { + expected: Type::Bool, + found: ty.clone(), + context: "Unary `!` requires a Bool".to_string(), + }); + Type::Error + } + } + } + } + + Expr::Call(name, args) => { + if let Some(fn_ty) = self.env.get(name).cloned() { + match &fn_ty { + Type::Fn { params, ret, effects } => { + if args.len() != params.len() { + self.error(TypeErrorKind::ArityMismatch { + function: name.clone(), + expected: params.len(), + found: args.len(), + }); + } + for eff in effects { + if *eff != EffectType::Pure && !self.is_effect_handled(eff) { + self.error(TypeErrorKind::UnhandledEffect { + effect: format!("{:?}", eff), + function: name.clone(), + }); + } + } + *ret.clone() + } + _ => { + for arg in args { + self.infer_expr(arg); + } + self.fresh_tv() + } + } + } else { + self.error(TypeErrorKind::UnboundVariable { name: name.clone() }); + Type::Error + } + } + + Expr::Lambda(params, body) => { + let mut param_types = Vec::new(); + for p in params { + let tv = self.fresh_tv(); + self.env.insert(p.clone(), tv.clone()); + param_types.push(tv); + } + let ret = self.infer_expr(body); + Type::Fn { + params: param_types, + ret: Box::new(ret), + effects: vec![EffectType::Pure], + } + } + + Expr::Element(_) | Expr::Container(_) => { + if !self.in_view { + self.error(TypeErrorKind::ViewOutsideBlock { + expr: "UI element".to_string(), + }); + } + Type::View + } + + Expr::Perform(name, _args) => { + // An effect performance — check it's handled + if !self.is_effect_handled(&EffectType::Custom(name.clone())) { + self.error(TypeErrorKind::UnhandledEffect { + effect: name.clone(), + function: "".to_string(), + }); + } + self.fresh_tv() + } + + Expr::StreamFrom(_source) => { + Type::Stream(Box::new(self.fresh_tv())) + } + + Expr::Spring(_props) => { + Type::Spring(Box::new(Type::Float)) + } + + Expr::Record(fields) => { + let mut field_types = HashMap::new(); + for (name, expr) in fields { + field_types.insert(name.clone(), self.infer_expr(expr)); + } + Type::Record(field_types) + } + + Expr::List(items) => { + if items.is_empty() { + Type::Array(Box::new(self.fresh_tv())) + } else { + let first_ty = self.infer_expr(&items[0]); + // Could unify all items, but simplified for now + Type::Array(Box::new(first_ty)) + } + } + + Expr::If(_, then_br, _) => { + self.infer_expr(then_br) + } + + Expr::When(_, body) => { + self.infer_expr(body) + } + + Expr::Block(exprs) => { + let mut last_ty = Type::Unit; + for e in exprs { + last_ty = self.infer_expr(e); + } + last_ty + } + + Expr::Pipe(left, right) => { + let _input_ty = self.infer_expr(left); + self.infer_expr(right) + } + + Expr::Match(expr, arms) => { + let _ = self.infer_expr(expr); + if arms.is_empty() { + Type::Unit + } else { + self.infer_expr(&arms[0].body) + } + } + + Expr::Assign(_, _, value) => { + self.infer_expr(value); + Type::Unit + } + } + } + + /// Check if an effect is currently handled. + fn is_effect_handled(&self, effect: &EffectType) -> bool { + for handlers in &self.effect_handlers { + if handlers.contains(effect) { + return true; + } + } + false + } + + /// Get the accumulated errors. + pub fn errors(&self) -> &[TypeError] { + &self.errors + } + + /// Get the resolved type environment. + pub fn type_env(&self) -> &HashMap { + &self.env + } + + /// Check if there are any errors. + pub fn has_errors(&self) -> bool { + !self.errors.is_empty() + } + + /// Format all errors for display. + pub fn display_errors(&self) -> String { + self.errors.iter() + .map(|e| e.display()) + .collect::>() + .join("\n") + } +} + +impl Default for TypeChecker { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ds_parser::{Declaration, LetDecl, ViewDecl, Expr, Span, Container, ContainerKind, Element}; + + fn span() -> Span { + Span { start: 0, end: 0, line: 0 } + } + + fn make_program(decls: Vec) -> Program { + Program { declarations: decls } + } + + #[test] + fn test_signal_types() { + let mut checker = TypeChecker::new(); + let program = make_program(vec![ + Declaration::Let(LetDecl { + name: "count".to_string(), + value: Expr::IntLit(0), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(!checker.has_errors(), "Errors: {}", checker.display_errors()); + + let count_ty = checker.type_env().get("count").unwrap(); + assert!(count_ty.is_reactive()); + assert_eq!(count_ty.display(), "Signal"); + } + + #[test] + fn test_derived_types() { + let mut checker = TypeChecker::new(); + let program = make_program(vec![ + Declaration::Let(LetDecl { + name: "count".to_string(), + value: Expr::IntLit(0), + span: span(), + }), + Declaration::Let(LetDecl { + name: "doubled".to_string(), + value: Expr::BinOp( + Box::new(Expr::Ident("count".to_string())), + BinOp::Mul, + Box::new(Expr::IntLit(2)), + ), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(!checker.has_errors(), "Errors: {}", checker.display_errors()); + + let doubled_ty = checker.type_env().get("doubled").unwrap(); + assert_eq!(doubled_ty.display(), "Derived"); + } + + #[test] + fn test_unbound_variable() { + let mut checker = TypeChecker::new(); + let program = make_program(vec![ + Declaration::View(ViewDecl { + name: "main".to_string(), + params: vec![], + body: Expr::Container(Container { + kind: ContainerKind::Column, + children: vec![ + Expr::Ident("nonexistent".to_string()), + ], + props: vec![], + }), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(checker.has_errors()); + + let msg = checker.display_errors(); + assert!(msg.contains("UNBOUND VARIABLE")); + assert!(msg.contains("nonexistent")); + } + + #[test] + fn test_type_display_in_env() { + let mut checker = TypeChecker::new(); + let program = make_program(vec![ + Declaration::Let(LetDecl { + name: "name".to_string(), + value: Expr::StringLit(ds_parser::StringLit { + segments: vec![ds_parser::StringSegment::Literal("hello".to_string())], + }), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(!checker.has_errors()); + + let name_ty = checker.type_env().get("name").unwrap(); + assert_eq!(name_ty.display(), "Signal"); + } + + #[test] + fn test_view_outside_block_error() { + let mut checker = TypeChecker::new(); + // Manually add condition: an element expression outside a view + let el_expr = Expr::Element(Element { + tag: "button".to_string(), + args: vec![], + props: vec![], + modifiers: vec![], + }); + let ty = checker.infer_expr(&el_expr); + assert_eq!(ty, Type::View); + assert!(checker.has_errors()); + let msg = checker.display_errors(); + assert!(msg.contains("VIEW OUTSIDE BLOCK")); + } +} diff --git a/compiler/ds-types/src/errors.rs b/compiler/ds-types/src/errors.rs new file mode 100644 index 0000000..0bd0446 --- /dev/null +++ b/compiler/ds-types/src/errors.rs @@ -0,0 +1,218 @@ +/// Type error reporting — inspired by Elm's famously helpful errors. + +use crate::types::Type; + +/// A type error with context for helpful error messages. +#[derive(Debug, Clone)] +pub struct TypeError { + pub kind: TypeErrorKind, + pub span: Option<(usize, usize)>, // (line, col) + pub source: Option, // source code snippet +} + +/// The kind of type error. +#[derive(Debug, Clone)] +pub enum TypeErrorKind { + /// Expected one type but found another. + Mismatch { + expected: Type, + found: Type, + context: String, + }, + + /// Using a non-reactive value where a Signal/Derived is expected. + NotReactive { + found: Type, + context: String, + }, + + /// An effect was performed but not handled. + UnhandledEffect { + effect: String, + function: String, + }, + + /// A view expression appears outside a `view` block. + ViewOutsideBlock { + expr: String, + }, + + /// An unknown variable reference. + UnboundVariable { + name: String, + }, + + /// Occurs check failure (infinite type). + InfiniteType { + var: String, + ty: Type, + }, + + /// Wrong number of arguments to a function. + ArityMismatch { + function: String, + expected: usize, + found: usize, + }, + + /// Accessing a field that doesn't exist on a record. + MissingField { + field: String, + record_type: Type, + }, +} + +impl TypeError { + pub fn new(kind: TypeErrorKind) -> Self { + TypeError { + kind, + span: None, + source: None, + } + } + + pub fn with_span(mut self, line: usize, col: usize) -> Self { + self.span = Some((line, col)); + self + } + + pub fn with_source(mut self, source: String) -> Self { + self.source = Some(source); + self + } + + /// Format the error like Elm — helpful, specific, actionable. + pub fn display(&self) -> String { + let mut out = String::new(); + + // Header + let (title, body) = match &self.kind { + TypeErrorKind::Mismatch { expected, found, context } => { + ("TYPE MISMATCH".to_string(), format!( + "I was expecting:\n\n {}\n\nbut found:\n\n {}\n\n{}", + expected.display(), + found.display(), + context + )) + } + TypeErrorKind::NotReactive { found, context } => { + ("NOT REACTIVE".to_string(), format!( + "This value has type:\n\n {}\n\nbut it's used in a context that expects a reactive value (Signal or Derived).\n\n{}", + found.display(), + context + )) + } + TypeErrorKind::UnhandledEffect { effect, function } => { + ("UNHANDLED EFFECT".to_string(), format!( + "The function `{}` performs the `{}` effect, but no handler is installed.\n\n\ + Hint: Wrap the call in `handle(() => {}(...), {{ \"{}\": ... }})`", + function, effect, function, effect + )) + } + TypeErrorKind::ViewOutsideBlock { expr } => { + ("VIEW OUTSIDE BLOCK".to_string(), format!( + "This UI expression:\n\n {}\n\nappears outside a `view` block. UI elements can only be created inside views.", + expr + )) + } + TypeErrorKind::UnboundVariable { name } => { + ("UNBOUND VARIABLE".to_string(), format!( + "I cannot find a `{}` variable. Did you mean to declare it with `signal` or `let`?", + name + )) + } + TypeErrorKind::InfiniteType { var, ty } => { + ("INFINITE TYPE".to_string(), format!( + "Unification would create an infinite type:\n\n {} ~ {}\n\n\ + This usually means a recursive definition without a base case.", + var, ty.display() + )) + } + TypeErrorKind::ArityMismatch { function, expected, found } => { + ("WRONG NUMBER OF ARGUMENTS".to_string(), format!( + "The function `{}` expects {} argument{} but was given {}.", + function, + expected, + if *expected == 1 { "" } else { "s" }, + found + )) + } + TypeErrorKind::MissingField { field, record_type } => { + ("MISSING FIELD".to_string(), format!( + "This record:\n\n {}\n\ndoes not have a field called `{}`.", + record_type.display(), + field + )) + } + }; + + // Format like Elm + out.push_str("── "); + out.push_str(&title); + out.push_str(" "); + out.push_str(&"─".repeat(60 - title.len() - 3)); + out.push('\n'); + + if let Some((line, col)) = self.span { + out.push_str(&format!("{}:{}\n", line, col)); + } + + if let Some(source) = &self.source { + out.push('\n'); + out.push_str(" "); + out.push_str(source); + out.push_str("\n\n"); + } + + out.push_str(&body); + out.push('\n'); + + out + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mismatch_error() { + let err = TypeError::new(TypeErrorKind::Mismatch { + expected: Type::Int, + found: Type::String, + context: "in the expression `count + name`".to_string(), + }) + .with_span(5, 12) + .with_source("count + name".to_string()); + + let msg = err.display(); + assert!(msg.contains("TYPE MISMATCH")); + assert!(msg.contains("Int")); + assert!(msg.contains("String")); + assert!(msg.contains("5:12")); + } + + #[test] + fn test_unhandled_effect_error() { + let err = TypeError::new(TypeErrorKind::UnhandledEffect { + effect: "Http".to_string(), + function: "fetchUser".to_string(), + }); + + let msg = err.display(); + assert!(msg.contains("UNHANDLED EFFECT")); + assert!(msg.contains("fetchUser")); + assert!(msg.contains("Http")); + } + + #[test] + fn test_view_outside_block() { + let err = TypeError::new(TypeErrorKind::ViewOutsideBlock { + expr: "button { \"click me\" }".to_string(), + }); + + let msg = err.display(); + assert!(msg.contains("VIEW OUTSIDE BLOCK")); + assert!(msg.contains("view")); + } +} diff --git a/compiler/ds-types/src/lib.rs b/compiler/ds-types/src/lib.rs new file mode 100644 index 0000000..dbb56a4 --- /dev/null +++ b/compiler/ds-types/src/lib.rs @@ -0,0 +1,8 @@ +/// DreamStack Type System — Hindley-Milner with signal-awareness and effect types. +pub mod checker; +pub mod types; +pub mod errors; + +pub use checker::TypeChecker; +pub use types::{Type, TypeVar, SignalType, EffectType}; +pub use errors::{TypeError, TypeErrorKind}; diff --git a/compiler/ds-types/src/types.rs b/compiler/ds-types/src/types.rs new file mode 100644 index 0000000..dc7ce83 --- /dev/null +++ b/compiler/ds-types/src/types.rs @@ -0,0 +1,194 @@ +/// DreamStack Type System. +/// +/// Types are structural, with first-class signal and effect awareness. +/// The type system tracks whether values are reactive (wrapped in Signal) +/// and which effects a function may perform. + +use std::collections::HashMap; + +/// Type variable identifier for inference. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TypeVar(pub u32); + +/// The core type representation. +#[derive(Debug, Clone, PartialEq)] +pub enum Type { + /// Primitive types + Int, + Float, + String, + Bool, + Unit, + + /// A reactive signal wrapping a value type. + /// `Signal` means a mutable source of integers. + Signal(Box), + + /// A derived computation that depends on signals. + /// `Derived` is read-only, auto-updating. + Derived(Box), + + /// A function type with effect annotations. + /// `(args) -> return ! effects` + Fn { + params: Vec, + ret: Box, + effects: Vec, + }, + + /// An array/list type. + Array(Box), + + /// A record/struct type (structural). + Record(HashMap), + + /// A sum type / tagged union. + Variant(Vec<(String, Type)>), + + /// Stream type — push-based async sequence. + Stream(Box), + + /// Spring type — physics-animated value. + Spring(Box), + + /// View type — a UI expression. Only valid inside `view` blocks. + View, + + /// An unresolved type variable (for inference). + Var(TypeVar), + + /// Error sentinel (for error recovery). + Error, +} + +/// Signal-specific type information. +#[derive(Debug, Clone, PartialEq)] +pub enum SignalType { + Source, // mutable, user-set + Derived, // computed, read-only + Handler, // event handler +} + +/// Effect type — declares what side effects a function may perform. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum EffectType { + /// HTTP/network effects. + Http, + /// Local storage / persistence. + Storage, + /// Time-related effects (setTimeout, intervals). + Time, + /// DOM manipulation (only allowed in view blocks). + Dom, + /// Console/logging. + Log, + /// Random number generation. + Random, + /// Custom user-defined effect. + Custom(String), + /// Pure — no effects (default). + Pure, +} + +impl Type { + /// Unwrap the inner type of a Signal or Derived. + pub fn unwrap_reactive(&self) -> &Type { + match self { + Type::Signal(inner) | Type::Derived(inner) => inner, + other => other, + } + } + + /// Check if this type is reactive (Signal or Derived). + pub fn is_reactive(&self) -> bool { + matches!(self, Type::Signal(_) | Type::Derived(_)) + } + + /// Check if this type is a function. + pub fn is_fn(&self) -> bool { + matches!(self, Type::Fn { .. }) + } + + /// Pretty-print the type. + pub fn display(&self) -> String { + match self { + Type::Int => "Int".to_string(), + Type::Float => "Float".to_string(), + Type::String => "String".to_string(), + Type::Bool => "Bool".to_string(), + Type::Unit => "()".to_string(), + Type::Signal(inner) => format!("Signal<{}>", inner.display()), + Type::Derived(inner) => format!("Derived<{}>", inner.display()), + Type::Fn { params, ret, effects } => { + let params_str = params.iter().map(|p| p.display()).collect::>().join(", "); + let eff_str = if effects.is_empty() || effects == &[EffectType::Pure] { + String::new() + } else { + format!(" ! {}", effects.iter().map(|e| format!("{:?}", e)).collect::>().join(", ")) + }; + format!("({}) -> {}{}", params_str, ret.display(), eff_str) + } + Type::Array(inner) => format!("[{}]", inner.display()), + Type::Record(fields) => { + let fields_str = fields.iter() + .map(|(k, v)| format!("{}: {}", k, v.display())) + .collect::>() + .join(", "); + format!("{{ {} }}", fields_str) + } + Type::Variant(variants) => { + let vars_str = variants.iter() + .map(|(name, ty)| format!("{}({})", name, ty.display())) + .collect::>() + .join(" | "); + vars_str + } + Type::Stream(inner) => format!("Stream<{}>", inner.display()), + Type::Spring(inner) => format!("Spring<{}>", inner.display()), + Type::View => "View".to_string(), + Type::Var(tv) => format!("?{}", tv.0), + Type::Error => "".to_string(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_type_display() { + assert_eq!(Type::Int.display(), "Int"); + assert_eq!(Type::Signal(Box::new(Type::Int)).display(), "Signal"); + assert_eq!(Type::Derived(Box::new(Type::Bool)).display(), "Derived"); + assert_eq!(Type::Array(Box::new(Type::String)).display(), "[String]"); + assert_eq!(Type::Stream(Box::new(Type::Int)).display(), "Stream"); + + let fn_type = Type::Fn { + params: vec![Type::Int, Type::String], + ret: Box::new(Type::Bool), + effects: vec![EffectType::Http], + }; + assert_eq!(fn_type.display(), "(Int, String) -> Bool ! Http"); + } + + #[test] + fn test_reactive_checks() { + assert!(Type::Signal(Box::new(Type::Int)).is_reactive()); + assert!(Type::Derived(Box::new(Type::Int)).is_reactive()); + assert!(!Type::Int.is_reactive()); + assert!(!Type::String.is_reactive()); + } + + #[test] + fn test_unwrap_reactive() { + let sig = Type::Signal(Box::new(Type::Int)); + assert_eq!(*sig.unwrap_reactive(), Type::Int); + + let derived = Type::Derived(Box::new(Type::Bool)); + assert_eq!(*derived.unwrap_reactive(), Type::Bool); + + // Non-reactive returns self + assert_eq!(*Type::Int.unwrap_reactive(), Type::Int); + } +}