feat: complete type system — HM unification, signal-aware types, effect scoping

Type System Completion:
- Add unify() with occurs check for proper type variable binding
- Add apply_subst() to chase type variables through substitution map
- Add SignalInfo/SignalClass for graph-based signal classification
- Add check_program_with_signals() accepting optional signal graph data
- Push Dom effect handler scope automatically when checking view blocks
- Wire unification into BinOp, comparison, and logical operator inference
- Include List/Record literals in source signal heuristic

Tests: 34 ds-types tests (up from 11), 159 workspace total, 0 failures
This commit is contained in:
enzotar 2026-02-27 11:36:28 -08:00
parent ebf11889a3
commit 8fb2214ac0

View file

@ -1,9 +1,9 @@
/// Type checker — Hindley-Milner with signal-awareness and effect tracking.
///
/// Walks the DreamStack AST and:
/// 1. Infers types for all expressions
/// 2. Checks signal usage (Signal<T> vs Derived<T> vs plain T)
/// 3. Tracks effect perform/handle pairs
/// 1. Infers types for all expressions via HM unification
/// 2. Classifies signals using the signal graph (Source vs Derived)
/// 3. Tracks effect perform/handle pairs with scoped handlers
/// 4. Ensures views only contain valid UI expressions
use std::collections::HashMap;
@ -12,6 +12,19 @@ use ds_parser::{Program, Declaration, LetDecl, ViewDecl, Expr, BinOp, UnaryOp, T
use crate::types::{Type, TypeVar, EffectType, Predicate, PredicateExpr};
use crate::errors::{TypeError, TypeErrorKind};
/// Signal classification from the analyzer (mirrors ds_analyzer::SignalKind).
#[derive(Debug, Clone, PartialEq)]
pub enum SignalClass {
Source,
Derived,
}
/// Minimal signal graph info for type checking.
#[derive(Debug, Clone, Default)]
pub struct SignalInfo {
pub signals: HashMap<String, SignalClass>,
}
/// The type checker.
pub struct TypeChecker {
/// Type environment: variable name → type.
@ -55,8 +68,159 @@ impl TypeChecker {
self.errors.push(TypeError::new(kind));
}
/// Check an entire program.
// ── Hindley-Milner Unification ──────────────────────────────────
/// Apply substitutions to a type, chasing type variables to their bindings.
pub fn apply_subst(&self, ty: &Type) -> Type {
match ty {
Type::Var(tv) => {
if let Some(bound) = self.substitutions.get(tv) {
self.apply_subst(bound)
} else {
ty.clone()
}
}
Type::Signal(inner) => Type::Signal(Box::new(self.apply_subst(inner))),
Type::Derived(inner) => Type::Derived(Box::new(self.apply_subst(inner))),
Type::Array(inner) => Type::Array(Box::new(self.apply_subst(inner))),
Type::Stream(inner) => Type::Stream(Box::new(self.apply_subst(inner))),
Type::Spring(inner) => Type::Spring(Box::new(self.apply_subst(inner))),
Type::Fn { params, ret, effects } => Type::Fn {
params: params.iter().map(|p| self.apply_subst(p)).collect(),
ret: Box::new(self.apply_subst(ret)),
effects: effects.clone(),
},
Type::Record(fields) => Type::Record(
fields.iter().map(|(k, v)| (k.clone(), self.apply_subst(v))).collect(),
),
Type::Refined { base, predicate } => Type::Refined {
base: Box::new(self.apply_subst(base)),
predicate: predicate.clone(),
},
_ => ty.clone(),
}
}
/// Occurs check: does `tv` appear free in `ty`?
fn occurs_check(&self, tv: TypeVar, ty: &Type) -> bool {
let resolved = self.apply_subst(ty);
match &resolved {
Type::Var(other) => *other == tv,
Type::Signal(inner) | Type::Derived(inner) | Type::Array(inner)
| Type::Stream(inner) | Type::Spring(inner) => self.occurs_check(tv, inner),
Type::Fn { params, ret, .. } => {
params.iter().any(|p| self.occurs_check(tv, p)) || self.occurs_check(tv, ret)
}
Type::Record(fields) => fields.values().any(|v| self.occurs_check(tv, v)),
Type::Refined { base, .. } => self.occurs_check(tv, base),
_ => false,
}
}
/// Unify two types, recording substitutions. Returns Ok or an error kind.
fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeErrorKind> {
let a = self.apply_subst(t1);
let b = self.apply_subst(t2);
if a == b {
return Ok(());
}
match (&a, &b) {
// Type variable binds to anything (with occurs check)
(Type::Var(tv), _) => {
if self.occurs_check(*tv, &b) {
return Err(TypeErrorKind::InfiniteType {
var: format!("?{}", tv.0),
ty: b.clone(),
});
}
self.substitutions.insert(*tv, b);
Ok(())
}
(_, Type::Var(tv)) => {
if self.occurs_check(*tv, &a) {
return Err(TypeErrorKind::InfiniteType {
var: format!("?{}", tv.0),
ty: a.clone(),
});
}
self.substitutions.insert(*tv, a);
Ok(())
}
// Numeric coercion: Int unifies with Float
(Type::Int, Type::Float) | (Type::Float, Type::Int) => Ok(()),
// Structural: unify inner types
(Type::Signal(a_inner), Type::Signal(b_inner)) => self.unify(a_inner, b_inner),
(Type::Derived(a_inner), Type::Derived(b_inner)) => self.unify(a_inner, b_inner),
(Type::Array(a_inner), Type::Array(b_inner)) => self.unify(a_inner, b_inner),
(Type::Stream(a_inner), Type::Stream(b_inner)) => self.unify(a_inner, b_inner),
(Type::Spring(a_inner), Type::Spring(b_inner)) => self.unify(a_inner, b_inner),
// Signal<T> can unify with Derived<T> (both are reactive wrappers)
(Type::Signal(a_inner), Type::Derived(b_inner))
| (Type::Derived(a_inner), Type::Signal(b_inner)) => self.unify(a_inner, b_inner),
// Functions: unify params and return, effects must match
(Type::Fn { params: p1, ret: r1, .. }, Type::Fn { params: p2, ret: r2, .. }) => {
if p1.len() != p2.len() {
return Err(TypeErrorKind::Mismatch {
expected: a.clone(),
found: b.clone(),
context: format!("Function arity: expected {} params, found {}", p1.len(), p2.len()),
});
}
for (pa, pb) in p1.iter().zip(p2.iter()) {
self.unify(pa, pb)?;
}
self.unify(r1, r2)
}
// Records: unify shared fields
(Type::Record(f1), Type::Record(f2)) => {
for (key, ty1) in f1 {
if let Some(ty2) = f2.get(key) {
self.unify(ty1, ty2)?;
}
}
Ok(())
}
// Refinement: unify base types
(Type::Refined { base: b1, .. }, _) => self.unify(b1, &b),
(_, Type::Refined { base: b2, .. }) => self.unify(&a, b2),
// Error recovery
(Type::Error, _) | (_, Type::Error) => Ok(()),
// Mismatch
_ => Err(TypeErrorKind::Mismatch {
expected: a.clone(),
found: b.clone(),
context: "Type unification failed".to_string(),
}),
}
}
/// Push an effect handler scope.
fn push_effect_scope(&mut self, effects: Vec<EffectType>) {
self.effect_handlers.push(effects);
}
/// Pop an effect handler scope.
fn pop_effect_scope(&mut self) {
self.effect_handlers.pop();
}
/// Check an entire program. Optionally accepts signal graph info for accurate classification.
pub fn check_program(&mut self, program: &Program) {
self.check_program_with_signals(program, None);
}
/// Check a program with signal graph classification data.
pub fn check_program_with_signals(&mut self, program: &Program, signal_info: Option<&SignalInfo>) {
// Pass 0: register type aliases (with cycle detection)
for decl in &program.declarations {
if let Declaration::TypeAlias(alias) = decl {
@ -114,16 +278,25 @@ impl TypeChecker {
}
}
// Heuristic: if name ends conventionally or is assigned
// a literal, mark as Signal<T>; otherwise just let-bound T.
let is_source = matches!(
let_decl.value,
Expr::IntLit(_) | Expr::FloatLit(_) | Expr::StringLit(_) | Expr::BoolLit(_)
);
let final_ty = if is_source {
Type::Signal(Box::new(inferred_ty))
// Use signal graph classification if available, else heuristic
let final_ty = if let Some(info) = signal_info {
match info.signals.get(&let_decl.name) {
Some(SignalClass::Source) => Type::Signal(Box::new(inferred_ty)),
Some(SignalClass::Derived) => Type::Derived(Box::new(inferred_ty)),
None => inferred_ty, // Not a signal — plain value
}
} else {
Type::Derived(Box::new(inferred_ty))
// Heuristic fallback: literal init = Source, expression = Derived
let is_source = matches!(
let_decl.value,
Expr::IntLit(_) | Expr::FloatLit(_) | Expr::StringLit(_) | Expr::BoolLit(_)
| Expr::List(_) | Expr::Record(_)
);
if is_source {
Type::Signal(Box::new(inferred_ty))
} else {
Type::Derived(Box::new(inferred_ty))
}
};
self.env.insert(let_decl.name.clone(), final_ty);
}
@ -362,7 +535,10 @@ impl TypeChecker {
/// Check a view declaration.
fn check_view(&mut self, view: &ViewDecl) {
self.in_view = true;
// Dom effects are automatically handled inside view blocks
self.push_effect_scope(vec![EffectType::Dom]);
self.check_view_expr(&view.body);
self.pop_effect_scope();
self.in_view = false;
}
@ -466,39 +642,46 @@ impl TypeChecker {
let left_ty = self.infer_expr(left);
let right_ty = self.infer_expr(right);
let left_inner = left_ty.unwrap_reactive();
let right_inner = right_ty.unwrap_reactive();
let left_inner = self.apply_subst(left_ty.unwrap_reactive());
let right_inner = self.apply_subst(right_ty.unwrap_reactive());
match op {
BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Mod => {
if *left_inner == Type::Int && *right_inner == Type::Int {
Type::Int
} else if matches!(left_inner, Type::Int | Type::Float)
&& matches!(right_inner, Type::Int | Type::Float)
{
Type::Float
} else if *left_inner == Type::String && matches!(op, BinOp::Add) {
// Try to unify operands for numeric ops
if let Err(_) = self.unify(&left_inner, &right_inner) {
// Check for string concat
if left_inner == Type::String && matches!(op, BinOp::Add) {
return Type::String;
}
if left_inner != Type::Error && right_inner != Type::Error {
self.error(TypeErrorKind::Mismatch {
expected: left_inner.clone(),
found: right_inner.clone(),
context: format!("Both sides of {:?} must be the same numeric type", op),
});
}
return Type::Error;
}
// Result type: Int*Int=Int, anything with Float=Float
let resolved = self.apply_subst(&left_inner);
if resolved == Type::String && matches!(op, BinOp::Add) {
Type::String
} else if *left_inner == Type::Error || *right_inner == Type::Error {
Type::Error
} else if matches!(resolved, Type::Float) || matches!(self.apply_subst(&right_inner), Type::Float) {
Type::Float
} else if matches!(resolved, Type::Int) {
Type::Int
} else {
self.error(TypeErrorKind::Mismatch {
expected: Type::Int,
found: right_ty.clone(),
context: format!("Both sides of {:?} must be numeric", op),
});
Type::Error
resolved
}
}
BinOp::Eq | BinOp::Neq | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte => Type::Bool,
BinOp::Eq | BinOp::Neq | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte => {
// Unify operands (comparisons require same type)
let _ = self.unify(&left_inner, &right_inner);
Type::Bool
}
BinOp::And | BinOp::Or => {
if *left_inner != Type::Bool && !matches!(left_inner, Type::Var(_)) {
self.error(TypeErrorKind::Mismatch {
expected: Type::Bool,
found: left_ty.clone(),
context: format!("Left side of {:?} must be Bool", op),
});
}
let _ = self.unify(&left_inner, &Type::Bool);
let _ = self.unify(&right_inner, &Type::Bool);
Type::Bool
}
}
@ -1026,4 +1209,132 @@ mod tests {
assert!(msg.contains("-42"), "Error should show the literal value, got: {}", msg);
assert!(msg.contains("REFINEMENT VIOLATED"));
}
// ── Hindley-Milner Unification tests ──────────────────────────
#[test]
fn test_unify_same_types() {
let mut checker = TypeChecker::new();
assert!(checker.unify(&Type::Int, &Type::Int).is_ok());
assert!(checker.unify(&Type::String, &Type::String).is_ok());
assert!(checker.unify(&Type::Bool, &Type::Bool).is_ok());
}
#[test]
fn test_unify_type_variable() {
let mut checker = TypeChecker::new();
let tv = checker.fresh_tv();
assert!(checker.unify(&tv, &Type::Int).is_ok());
assert_eq!(checker.apply_subst(&tv), Type::Int);
}
#[test]
fn test_unify_mismatch() {
let mut checker = TypeChecker::new();
assert!(checker.unify(&Type::Int, &Type::String).is_err());
}
#[test]
fn test_unify_numeric_coercion() {
let mut checker = TypeChecker::new();
assert!(checker.unify(&Type::Int, &Type::Float).is_ok());
}
#[test]
fn test_unify_signal_types() {
let mut checker = TypeChecker::new();
assert!(checker.unify(
&Type::Signal(Box::new(Type::Int)),
&Type::Signal(Box::new(Type::Int)),
).is_ok());
// Signal<Int> vs Signal<String> should fail
assert!(checker.unify(
&Type::Signal(Box::new(Type::Int)),
&Type::Signal(Box::new(Type::String)),
).is_err());
}
#[test]
fn test_occurs_check_infinite_type() {
let mut checker = TypeChecker::new();
let tv = checker.fresh_tv();
let array_of_tv = Type::Array(Box::new(tv.clone()));
let result = checker.unify(&tv, &array_of_tv);
assert!(result.is_err());
match result.unwrap_err() {
TypeErrorKind::InfiniteType { .. } => { /* expected */ }
other => panic!("Expected InfiniteType, got {:?}", other),
}
}
// ── Signal Graph Classification tests ─────────────────────────
#[test]
fn test_signal_graph_classification() {
let mut checker = TypeChecker::new();
let mut info = SignalInfo::default();
info.signals.insert("count".to_string(), SignalClass::Source);
info.signals.insert("doubled".to_string(), SignalClass::Derived);
let program = make_program(vec![
Declaration::Let(LetDecl {
name: "count".to_string(),
type_annotation: None,
value: Expr::IntLit(0),
span: span(),
}),
Declaration::Let(LetDecl {
name: "doubled".to_string(),
type_annotation: None,
value: Expr::BinOp(
Box::new(Expr::Ident("count".to_string())),
BinOp::Mul,
Box::new(Expr::IntLit(2)),
),
span: span(),
}),
]);
checker.check_program_with_signals(&program, Some(&info));
assert!(!checker.has_errors(), "Errors: {}", checker.display_errors());
let count_ty = checker.type_env().get("count").unwrap();
assert_eq!(count_ty.display(), "Signal<Int>");
let doubled_ty = checker.type_env().get("doubled").unwrap();
assert_eq!(doubled_ty.display(), "Derived<Int>");
}
// ── Effect Scope tests ────────────────────────────────────────
#[test]
fn test_effect_handler_in_view() {
let mut checker = TypeChecker::new();
// Dom effect should be auto-handled inside view
checker.push_effect_scope(vec![EffectType::Dom]);
assert!(checker.is_effect_handled(&EffectType::Dom));
checker.pop_effect_scope();
assert!(!checker.is_effect_handled(&EffectType::Dom));
}
#[test]
fn test_list_type_unification() {
let mut checker = TypeChecker::new();
let program = make_program(vec![
Declaration::Let(LetDecl {
name: "items".to_string(),
type_annotation: None,
value: Expr::List(vec![
Expr::IntLit(1),
Expr::IntLit(2),
Expr::IntLit(3),
]),
span: span(),
}),
]);
checker.check_program(&program);
assert!(!checker.has_errors(), "Errors: {}", checker.display_errors());
let items_ty = checker.type_env().get("items").unwrap();
assert_eq!(items_ty.display(), "Signal<[Int]>");
}
}