diff --git a/compiler/ds-types/src/checker.rs b/compiler/ds-types/src/checker.rs index 8dda7e1..645da9f 100644 --- a/compiler/ds-types/src/checker.rs +++ b/compiler/ds-types/src/checker.rs @@ -1,9 +1,9 @@ /// Type checker — Hindley-Milner with signal-awareness and effect tracking. /// /// Walks the DreamStack AST and: -/// 1. Infers types for all expressions -/// 2. Checks signal usage (Signal vs Derived vs plain T) -/// 3. Tracks effect perform/handle pairs +/// 1. Infers types for all expressions via HM unification +/// 2. Classifies signals using the signal graph (Source vs Derived) +/// 3. Tracks effect perform/handle pairs with scoped handlers /// 4. Ensures views only contain valid UI expressions use std::collections::HashMap; @@ -12,6 +12,19 @@ use ds_parser::{Program, Declaration, LetDecl, ViewDecl, Expr, BinOp, UnaryOp, T use crate::types::{Type, TypeVar, EffectType, Predicate, PredicateExpr}; use crate::errors::{TypeError, TypeErrorKind}; +/// Signal classification from the analyzer (mirrors ds_analyzer::SignalKind). +#[derive(Debug, Clone, PartialEq)] +pub enum SignalClass { + Source, + Derived, +} + +/// Minimal signal graph info for type checking. +#[derive(Debug, Clone, Default)] +pub struct SignalInfo { + pub signals: HashMap, +} + /// The type checker. pub struct TypeChecker { /// Type environment: variable name → type. @@ -55,8 +68,159 @@ impl TypeChecker { self.errors.push(TypeError::new(kind)); } - /// Check an entire program. + // ── Hindley-Milner Unification ────────────────────────────────── + + /// Apply substitutions to a type, chasing type variables to their bindings. + pub fn apply_subst(&self, ty: &Type) -> Type { + match ty { + Type::Var(tv) => { + if let Some(bound) = self.substitutions.get(tv) { + self.apply_subst(bound) + } else { + ty.clone() + } + } + Type::Signal(inner) => Type::Signal(Box::new(self.apply_subst(inner))), + Type::Derived(inner) => Type::Derived(Box::new(self.apply_subst(inner))), + Type::Array(inner) => Type::Array(Box::new(self.apply_subst(inner))), + Type::Stream(inner) => Type::Stream(Box::new(self.apply_subst(inner))), + Type::Spring(inner) => Type::Spring(Box::new(self.apply_subst(inner))), + Type::Fn { params, ret, effects } => Type::Fn { + params: params.iter().map(|p| self.apply_subst(p)).collect(), + ret: Box::new(self.apply_subst(ret)), + effects: effects.clone(), + }, + Type::Record(fields) => Type::Record( + fields.iter().map(|(k, v)| (k.clone(), self.apply_subst(v))).collect(), + ), + Type::Refined { base, predicate } => Type::Refined { + base: Box::new(self.apply_subst(base)), + predicate: predicate.clone(), + }, + _ => ty.clone(), + } + } + + /// Occurs check: does `tv` appear free in `ty`? + fn occurs_check(&self, tv: TypeVar, ty: &Type) -> bool { + let resolved = self.apply_subst(ty); + match &resolved { + Type::Var(other) => *other == tv, + Type::Signal(inner) | Type::Derived(inner) | Type::Array(inner) + | Type::Stream(inner) | Type::Spring(inner) => self.occurs_check(tv, inner), + Type::Fn { params, ret, .. } => { + params.iter().any(|p| self.occurs_check(tv, p)) || self.occurs_check(tv, ret) + } + Type::Record(fields) => fields.values().any(|v| self.occurs_check(tv, v)), + Type::Refined { base, .. } => self.occurs_check(tv, base), + _ => false, + } + } + + /// Unify two types, recording substitutions. Returns Ok or an error kind. + fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeErrorKind> { + let a = self.apply_subst(t1); + let b = self.apply_subst(t2); + + if a == b { + return Ok(()); + } + + match (&a, &b) { + // Type variable binds to anything (with occurs check) + (Type::Var(tv), _) => { + if self.occurs_check(*tv, &b) { + return Err(TypeErrorKind::InfiniteType { + var: format!("?{}", tv.0), + ty: b.clone(), + }); + } + self.substitutions.insert(*tv, b); + Ok(()) + } + (_, Type::Var(tv)) => { + if self.occurs_check(*tv, &a) { + return Err(TypeErrorKind::InfiniteType { + var: format!("?{}", tv.0), + ty: a.clone(), + }); + } + self.substitutions.insert(*tv, a); + Ok(()) + } + + // Numeric coercion: Int unifies with Float + (Type::Int, Type::Float) | (Type::Float, Type::Int) => Ok(()), + + // Structural: unify inner types + (Type::Signal(a_inner), Type::Signal(b_inner)) => self.unify(a_inner, b_inner), + (Type::Derived(a_inner), Type::Derived(b_inner)) => self.unify(a_inner, b_inner), + (Type::Array(a_inner), Type::Array(b_inner)) => self.unify(a_inner, b_inner), + (Type::Stream(a_inner), Type::Stream(b_inner)) => self.unify(a_inner, b_inner), + (Type::Spring(a_inner), Type::Spring(b_inner)) => self.unify(a_inner, b_inner), + + // Signal can unify with Derived (both are reactive wrappers) + (Type::Signal(a_inner), Type::Derived(b_inner)) + | (Type::Derived(a_inner), Type::Signal(b_inner)) => self.unify(a_inner, b_inner), + + // Functions: unify params and return, effects must match + (Type::Fn { params: p1, ret: r1, .. }, Type::Fn { params: p2, ret: r2, .. }) => { + if p1.len() != p2.len() { + return Err(TypeErrorKind::Mismatch { + expected: a.clone(), + found: b.clone(), + context: format!("Function arity: expected {} params, found {}", p1.len(), p2.len()), + }); + } + for (pa, pb) in p1.iter().zip(p2.iter()) { + self.unify(pa, pb)?; + } + self.unify(r1, r2) + } + + // Records: unify shared fields + (Type::Record(f1), Type::Record(f2)) => { + for (key, ty1) in f1 { + if let Some(ty2) = f2.get(key) { + self.unify(ty1, ty2)?; + } + } + Ok(()) + } + + // Refinement: unify base types + (Type::Refined { base: b1, .. }, _) => self.unify(b1, &b), + (_, Type::Refined { base: b2, .. }) => self.unify(&a, b2), + + // Error recovery + (Type::Error, _) | (_, Type::Error) => Ok(()), + + // Mismatch + _ => Err(TypeErrorKind::Mismatch { + expected: a.clone(), + found: b.clone(), + context: "Type unification failed".to_string(), + }), + } + } + + /// Push an effect handler scope. + fn push_effect_scope(&mut self, effects: Vec) { + self.effect_handlers.push(effects); + } + + /// Pop an effect handler scope. + fn pop_effect_scope(&mut self) { + self.effect_handlers.pop(); + } + + /// Check an entire program. Optionally accepts signal graph info for accurate classification. pub fn check_program(&mut self, program: &Program) { + self.check_program_with_signals(program, None); + } + + /// Check a program with signal graph classification data. + pub fn check_program_with_signals(&mut self, program: &Program, signal_info: Option<&SignalInfo>) { // Pass 0: register type aliases (with cycle detection) for decl in &program.declarations { if let Declaration::TypeAlias(alias) = decl { @@ -114,16 +278,25 @@ impl TypeChecker { } } - // Heuristic: if name ends conventionally or is assigned - // a literal, mark as Signal; otherwise just let-bound T. - let is_source = matches!( - let_decl.value, - Expr::IntLit(_) | Expr::FloatLit(_) | Expr::StringLit(_) | Expr::BoolLit(_) - ); - let final_ty = if is_source { - Type::Signal(Box::new(inferred_ty)) + // Use signal graph classification if available, else heuristic + let final_ty = if let Some(info) = signal_info { + match info.signals.get(&let_decl.name) { + Some(SignalClass::Source) => Type::Signal(Box::new(inferred_ty)), + Some(SignalClass::Derived) => Type::Derived(Box::new(inferred_ty)), + None => inferred_ty, // Not a signal — plain value + } } else { - Type::Derived(Box::new(inferred_ty)) + // Heuristic fallback: literal init = Source, expression = Derived + let is_source = matches!( + let_decl.value, + Expr::IntLit(_) | Expr::FloatLit(_) | Expr::StringLit(_) | Expr::BoolLit(_) + | Expr::List(_) | Expr::Record(_) + ); + if is_source { + Type::Signal(Box::new(inferred_ty)) + } else { + Type::Derived(Box::new(inferred_ty)) + } }; self.env.insert(let_decl.name.clone(), final_ty); } @@ -362,7 +535,10 @@ impl TypeChecker { /// Check a view declaration. fn check_view(&mut self, view: &ViewDecl) { self.in_view = true; + // Dom effects are automatically handled inside view blocks + self.push_effect_scope(vec![EffectType::Dom]); self.check_view_expr(&view.body); + self.pop_effect_scope(); self.in_view = false; } @@ -466,39 +642,46 @@ impl TypeChecker { let left_ty = self.infer_expr(left); let right_ty = self.infer_expr(right); - let left_inner = left_ty.unwrap_reactive(); - let right_inner = right_ty.unwrap_reactive(); + let left_inner = self.apply_subst(left_ty.unwrap_reactive()); + let right_inner = self.apply_subst(right_ty.unwrap_reactive()); match op { BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Mod => { - if *left_inner == Type::Int && *right_inner == Type::Int { - Type::Int - } else if matches!(left_inner, Type::Int | Type::Float) - && matches!(right_inner, Type::Int | Type::Float) - { - Type::Float - } else if *left_inner == Type::String && matches!(op, BinOp::Add) { + // Try to unify operands for numeric ops + if let Err(_) = self.unify(&left_inner, &right_inner) { + // Check for string concat + if left_inner == Type::String && matches!(op, BinOp::Add) { + return Type::String; + } + if left_inner != Type::Error && right_inner != Type::Error { + self.error(TypeErrorKind::Mismatch { + expected: left_inner.clone(), + found: right_inner.clone(), + context: format!("Both sides of {:?} must be the same numeric type", op), + }); + } + return Type::Error; + } + // Result type: Int*Int=Int, anything with Float=Float + let resolved = self.apply_subst(&left_inner); + if resolved == Type::String && matches!(op, BinOp::Add) { Type::String - } else if *left_inner == Type::Error || *right_inner == Type::Error { - Type::Error + } else if matches!(resolved, Type::Float) || matches!(self.apply_subst(&right_inner), Type::Float) { + Type::Float + } else if matches!(resolved, Type::Int) { + Type::Int } else { - self.error(TypeErrorKind::Mismatch { - expected: Type::Int, - found: right_ty.clone(), - context: format!("Both sides of {:?} must be numeric", op), - }); - Type::Error + resolved } } - BinOp::Eq | BinOp::Neq | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte => Type::Bool, + BinOp::Eq | BinOp::Neq | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte => { + // Unify operands (comparisons require same type) + let _ = self.unify(&left_inner, &right_inner); + Type::Bool + } BinOp::And | BinOp::Or => { - if *left_inner != Type::Bool && !matches!(left_inner, Type::Var(_)) { - self.error(TypeErrorKind::Mismatch { - expected: Type::Bool, - found: left_ty.clone(), - context: format!("Left side of {:?} must be Bool", op), - }); - } + let _ = self.unify(&left_inner, &Type::Bool); + let _ = self.unify(&right_inner, &Type::Bool); Type::Bool } } @@ -1026,4 +1209,132 @@ mod tests { assert!(msg.contains("-42"), "Error should show the literal value, got: {}", msg); assert!(msg.contains("REFINEMENT VIOLATED")); } + + // ── Hindley-Milner Unification tests ────────────────────────── + + #[test] + fn test_unify_same_types() { + let mut checker = TypeChecker::new(); + assert!(checker.unify(&Type::Int, &Type::Int).is_ok()); + assert!(checker.unify(&Type::String, &Type::String).is_ok()); + assert!(checker.unify(&Type::Bool, &Type::Bool).is_ok()); + } + + #[test] + fn test_unify_type_variable() { + let mut checker = TypeChecker::new(); + let tv = checker.fresh_tv(); + assert!(checker.unify(&tv, &Type::Int).is_ok()); + assert_eq!(checker.apply_subst(&tv), Type::Int); + } + + #[test] + fn test_unify_mismatch() { + let mut checker = TypeChecker::new(); + assert!(checker.unify(&Type::Int, &Type::String).is_err()); + } + + #[test] + fn test_unify_numeric_coercion() { + let mut checker = TypeChecker::new(); + assert!(checker.unify(&Type::Int, &Type::Float).is_ok()); + } + + #[test] + fn test_unify_signal_types() { + let mut checker = TypeChecker::new(); + assert!(checker.unify( + &Type::Signal(Box::new(Type::Int)), + &Type::Signal(Box::new(Type::Int)), + ).is_ok()); + // Signal vs Signal should fail + assert!(checker.unify( + &Type::Signal(Box::new(Type::Int)), + &Type::Signal(Box::new(Type::String)), + ).is_err()); + } + + #[test] + fn test_occurs_check_infinite_type() { + let mut checker = TypeChecker::new(); + let tv = checker.fresh_tv(); + let array_of_tv = Type::Array(Box::new(tv.clone())); + let result = checker.unify(&tv, &array_of_tv); + assert!(result.is_err()); + match result.unwrap_err() { + TypeErrorKind::InfiniteType { .. } => { /* expected */ } + other => panic!("Expected InfiniteType, got {:?}", other), + } + } + + // ── Signal Graph Classification tests ───────────────────────── + + #[test] + fn test_signal_graph_classification() { + let mut checker = TypeChecker::new(); + let mut info = SignalInfo::default(); + info.signals.insert("count".to_string(), SignalClass::Source); + info.signals.insert("doubled".to_string(), SignalClass::Derived); + + let program = make_program(vec![ + Declaration::Let(LetDecl { + name: "count".to_string(), + type_annotation: None, + value: Expr::IntLit(0), + span: span(), + }), + Declaration::Let(LetDecl { + name: "doubled".to_string(), + type_annotation: None, + value: Expr::BinOp( + Box::new(Expr::Ident("count".to_string())), + BinOp::Mul, + Box::new(Expr::IntLit(2)), + ), + span: span(), + }), + ]); + checker.check_program_with_signals(&program, Some(&info)); + assert!(!checker.has_errors(), "Errors: {}", checker.display_errors()); + + let count_ty = checker.type_env().get("count").unwrap(); + assert_eq!(count_ty.display(), "Signal"); + + let doubled_ty = checker.type_env().get("doubled").unwrap(); + assert_eq!(doubled_ty.display(), "Derived"); + } + + // ── Effect Scope tests ──────────────────────────────────────── + + #[test] + fn test_effect_handler_in_view() { + let mut checker = TypeChecker::new(); + // Dom effect should be auto-handled inside view + checker.push_effect_scope(vec![EffectType::Dom]); + assert!(checker.is_effect_handled(&EffectType::Dom)); + checker.pop_effect_scope(); + assert!(!checker.is_effect_handled(&EffectType::Dom)); + } + + #[test] + fn test_list_type_unification() { + let mut checker = TypeChecker::new(); + let program = make_program(vec![ + Declaration::Let(LetDecl { + name: "items".to_string(), + type_annotation: None, + value: Expr::List(vec![ + Expr::IntLit(1), + Expr::IntLit(2), + Expr::IntLit(3), + ]), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(!checker.has_errors(), "Errors: {}", checker.display_errors()); + + let items_ty = checker.type_env().get("items").unwrap(); + assert_eq!(items_ty.display(), "Signal<[Int]>"); + } }