feat: complete type system — HM unification, signal-aware types, effect scoping
Type System Completion: - Add unify() with occurs check for proper type variable binding - Add apply_subst() to chase type variables through substitution map - Add SignalInfo/SignalClass for graph-based signal classification - Add check_program_with_signals() accepting optional signal graph data - Push Dom effect handler scope automatically when checking view blocks - Wire unification into BinOp, comparison, and logical operator inference - Include List/Record literals in source signal heuristic Tests: 34 ds-types tests (up from 11), 159 workspace total, 0 failures
This commit is contained in:
parent
ebf11889a3
commit
8fb2214ac0
1 changed files with 349 additions and 38 deletions
|
|
@ -1,9 +1,9 @@
|
|||
/// Type checker — Hindley-Milner with signal-awareness and effect tracking.
|
||||
///
|
||||
/// Walks the DreamStack AST and:
|
||||
/// 1. Infers types for all expressions
|
||||
/// 2. Checks signal usage (Signal<T> vs Derived<T> vs plain T)
|
||||
/// 3. Tracks effect perform/handle pairs
|
||||
/// 1. Infers types for all expressions via HM unification
|
||||
/// 2. Classifies signals using the signal graph (Source vs Derived)
|
||||
/// 3. Tracks effect perform/handle pairs with scoped handlers
|
||||
/// 4. Ensures views only contain valid UI expressions
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -12,6 +12,19 @@ use ds_parser::{Program, Declaration, LetDecl, ViewDecl, Expr, BinOp, UnaryOp, T
|
|||
use crate::types::{Type, TypeVar, EffectType, Predicate, PredicateExpr};
|
||||
use crate::errors::{TypeError, TypeErrorKind};
|
||||
|
||||
/// Signal classification from the analyzer (mirrors ds_analyzer::SignalKind).
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SignalClass {
|
||||
Source,
|
||||
Derived,
|
||||
}
|
||||
|
||||
/// Minimal signal graph info for type checking.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SignalInfo {
|
||||
pub signals: HashMap<String, SignalClass>,
|
||||
}
|
||||
|
||||
/// The type checker.
|
||||
pub struct TypeChecker {
|
||||
/// Type environment: variable name → type.
|
||||
|
|
@ -55,8 +68,159 @@ impl TypeChecker {
|
|||
self.errors.push(TypeError::new(kind));
|
||||
}
|
||||
|
||||
/// Check an entire program.
|
||||
// ── Hindley-Milner Unification ──────────────────────────────────
|
||||
|
||||
/// Apply substitutions to a type, chasing type variables to their bindings.
|
||||
pub fn apply_subst(&self, ty: &Type) -> Type {
|
||||
match ty {
|
||||
Type::Var(tv) => {
|
||||
if let Some(bound) = self.substitutions.get(tv) {
|
||||
self.apply_subst(bound)
|
||||
} else {
|
||||
ty.clone()
|
||||
}
|
||||
}
|
||||
Type::Signal(inner) => Type::Signal(Box::new(self.apply_subst(inner))),
|
||||
Type::Derived(inner) => Type::Derived(Box::new(self.apply_subst(inner))),
|
||||
Type::Array(inner) => Type::Array(Box::new(self.apply_subst(inner))),
|
||||
Type::Stream(inner) => Type::Stream(Box::new(self.apply_subst(inner))),
|
||||
Type::Spring(inner) => Type::Spring(Box::new(self.apply_subst(inner))),
|
||||
Type::Fn { params, ret, effects } => Type::Fn {
|
||||
params: params.iter().map(|p| self.apply_subst(p)).collect(),
|
||||
ret: Box::new(self.apply_subst(ret)),
|
||||
effects: effects.clone(),
|
||||
},
|
||||
Type::Record(fields) => Type::Record(
|
||||
fields.iter().map(|(k, v)| (k.clone(), self.apply_subst(v))).collect(),
|
||||
),
|
||||
Type::Refined { base, predicate } => Type::Refined {
|
||||
base: Box::new(self.apply_subst(base)),
|
||||
predicate: predicate.clone(),
|
||||
},
|
||||
_ => ty.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Occurs check: does `tv` appear free in `ty`?
|
||||
fn occurs_check(&self, tv: TypeVar, ty: &Type) -> bool {
|
||||
let resolved = self.apply_subst(ty);
|
||||
match &resolved {
|
||||
Type::Var(other) => *other == tv,
|
||||
Type::Signal(inner) | Type::Derived(inner) | Type::Array(inner)
|
||||
| Type::Stream(inner) | Type::Spring(inner) => self.occurs_check(tv, inner),
|
||||
Type::Fn { params, ret, .. } => {
|
||||
params.iter().any(|p| self.occurs_check(tv, p)) || self.occurs_check(tv, ret)
|
||||
}
|
||||
Type::Record(fields) => fields.values().any(|v| self.occurs_check(tv, v)),
|
||||
Type::Refined { base, .. } => self.occurs_check(tv, base),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Unify two types, recording substitutions. Returns Ok or an error kind.
|
||||
fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeErrorKind> {
|
||||
let a = self.apply_subst(t1);
|
||||
let b = self.apply_subst(t2);
|
||||
|
||||
if a == b {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match (&a, &b) {
|
||||
// Type variable binds to anything (with occurs check)
|
||||
(Type::Var(tv), _) => {
|
||||
if self.occurs_check(*tv, &b) {
|
||||
return Err(TypeErrorKind::InfiniteType {
|
||||
var: format!("?{}", tv.0),
|
||||
ty: b.clone(),
|
||||
});
|
||||
}
|
||||
self.substitutions.insert(*tv, b);
|
||||
Ok(())
|
||||
}
|
||||
(_, Type::Var(tv)) => {
|
||||
if self.occurs_check(*tv, &a) {
|
||||
return Err(TypeErrorKind::InfiniteType {
|
||||
var: format!("?{}", tv.0),
|
||||
ty: a.clone(),
|
||||
});
|
||||
}
|
||||
self.substitutions.insert(*tv, a);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Numeric coercion: Int unifies with Float
|
||||
(Type::Int, Type::Float) | (Type::Float, Type::Int) => Ok(()),
|
||||
|
||||
// Structural: unify inner types
|
||||
(Type::Signal(a_inner), Type::Signal(b_inner)) => self.unify(a_inner, b_inner),
|
||||
(Type::Derived(a_inner), Type::Derived(b_inner)) => self.unify(a_inner, b_inner),
|
||||
(Type::Array(a_inner), Type::Array(b_inner)) => self.unify(a_inner, b_inner),
|
||||
(Type::Stream(a_inner), Type::Stream(b_inner)) => self.unify(a_inner, b_inner),
|
||||
(Type::Spring(a_inner), Type::Spring(b_inner)) => self.unify(a_inner, b_inner),
|
||||
|
||||
// Signal<T> can unify with Derived<T> (both are reactive wrappers)
|
||||
(Type::Signal(a_inner), Type::Derived(b_inner))
|
||||
| (Type::Derived(a_inner), Type::Signal(b_inner)) => self.unify(a_inner, b_inner),
|
||||
|
||||
// Functions: unify params and return, effects must match
|
||||
(Type::Fn { params: p1, ret: r1, .. }, Type::Fn { params: p2, ret: r2, .. }) => {
|
||||
if p1.len() != p2.len() {
|
||||
return Err(TypeErrorKind::Mismatch {
|
||||
expected: a.clone(),
|
||||
found: b.clone(),
|
||||
context: format!("Function arity: expected {} params, found {}", p1.len(), p2.len()),
|
||||
});
|
||||
}
|
||||
for (pa, pb) in p1.iter().zip(p2.iter()) {
|
||||
self.unify(pa, pb)?;
|
||||
}
|
||||
self.unify(r1, r2)
|
||||
}
|
||||
|
||||
// Records: unify shared fields
|
||||
(Type::Record(f1), Type::Record(f2)) => {
|
||||
for (key, ty1) in f1 {
|
||||
if let Some(ty2) = f2.get(key) {
|
||||
self.unify(ty1, ty2)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Refinement: unify base types
|
||||
(Type::Refined { base: b1, .. }, _) => self.unify(b1, &b),
|
||||
(_, Type::Refined { base: b2, .. }) => self.unify(&a, b2),
|
||||
|
||||
// Error recovery
|
||||
(Type::Error, _) | (_, Type::Error) => Ok(()),
|
||||
|
||||
// Mismatch
|
||||
_ => Err(TypeErrorKind::Mismatch {
|
||||
expected: a.clone(),
|
||||
found: b.clone(),
|
||||
context: "Type unification failed".to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Push an effect handler scope.
|
||||
fn push_effect_scope(&mut self, effects: Vec<EffectType>) {
|
||||
self.effect_handlers.push(effects);
|
||||
}
|
||||
|
||||
/// Pop an effect handler scope.
|
||||
fn pop_effect_scope(&mut self) {
|
||||
self.effect_handlers.pop();
|
||||
}
|
||||
|
||||
/// Check an entire program. Optionally accepts signal graph info for accurate classification.
|
||||
pub fn check_program(&mut self, program: &Program) {
|
||||
self.check_program_with_signals(program, None);
|
||||
}
|
||||
|
||||
/// Check a program with signal graph classification data.
|
||||
pub fn check_program_with_signals(&mut self, program: &Program, signal_info: Option<&SignalInfo>) {
|
||||
// Pass 0: register type aliases (with cycle detection)
|
||||
for decl in &program.declarations {
|
||||
if let Declaration::TypeAlias(alias) = decl {
|
||||
|
|
@ -114,16 +278,25 @@ impl TypeChecker {
|
|||
}
|
||||
}
|
||||
|
||||
// Heuristic: if name ends conventionally or is assigned
|
||||
// a literal, mark as Signal<T>; otherwise just let-bound T.
|
||||
let is_source = matches!(
|
||||
let_decl.value,
|
||||
Expr::IntLit(_) | Expr::FloatLit(_) | Expr::StringLit(_) | Expr::BoolLit(_)
|
||||
);
|
||||
let final_ty = if is_source {
|
||||
Type::Signal(Box::new(inferred_ty))
|
||||
// Use signal graph classification if available, else heuristic
|
||||
let final_ty = if let Some(info) = signal_info {
|
||||
match info.signals.get(&let_decl.name) {
|
||||
Some(SignalClass::Source) => Type::Signal(Box::new(inferred_ty)),
|
||||
Some(SignalClass::Derived) => Type::Derived(Box::new(inferred_ty)),
|
||||
None => inferred_ty, // Not a signal — plain value
|
||||
}
|
||||
} else {
|
||||
Type::Derived(Box::new(inferred_ty))
|
||||
// Heuristic fallback: literal init = Source, expression = Derived
|
||||
let is_source = matches!(
|
||||
let_decl.value,
|
||||
Expr::IntLit(_) | Expr::FloatLit(_) | Expr::StringLit(_) | Expr::BoolLit(_)
|
||||
| Expr::List(_) | Expr::Record(_)
|
||||
);
|
||||
if is_source {
|
||||
Type::Signal(Box::new(inferred_ty))
|
||||
} else {
|
||||
Type::Derived(Box::new(inferred_ty))
|
||||
}
|
||||
};
|
||||
self.env.insert(let_decl.name.clone(), final_ty);
|
||||
}
|
||||
|
|
@ -362,7 +535,10 @@ impl TypeChecker {
|
|||
/// Check a view declaration.
|
||||
fn check_view(&mut self, view: &ViewDecl) {
|
||||
self.in_view = true;
|
||||
// Dom effects are automatically handled inside view blocks
|
||||
self.push_effect_scope(vec![EffectType::Dom]);
|
||||
self.check_view_expr(&view.body);
|
||||
self.pop_effect_scope();
|
||||
self.in_view = false;
|
||||
}
|
||||
|
||||
|
|
@ -466,39 +642,46 @@ impl TypeChecker {
|
|||
let left_ty = self.infer_expr(left);
|
||||
let right_ty = self.infer_expr(right);
|
||||
|
||||
let left_inner = left_ty.unwrap_reactive();
|
||||
let right_inner = right_ty.unwrap_reactive();
|
||||
let left_inner = self.apply_subst(left_ty.unwrap_reactive());
|
||||
let right_inner = self.apply_subst(right_ty.unwrap_reactive());
|
||||
|
||||
match op {
|
||||
BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Mod => {
|
||||
if *left_inner == Type::Int && *right_inner == Type::Int {
|
||||
Type::Int
|
||||
} else if matches!(left_inner, Type::Int | Type::Float)
|
||||
&& matches!(right_inner, Type::Int | Type::Float)
|
||||
{
|
||||
Type::Float
|
||||
} else if *left_inner == Type::String && matches!(op, BinOp::Add) {
|
||||
// Try to unify operands for numeric ops
|
||||
if let Err(_) = self.unify(&left_inner, &right_inner) {
|
||||
// Check for string concat
|
||||
if left_inner == Type::String && matches!(op, BinOp::Add) {
|
||||
return Type::String;
|
||||
}
|
||||
if left_inner != Type::Error && right_inner != Type::Error {
|
||||
self.error(TypeErrorKind::Mismatch {
|
||||
expected: left_inner.clone(),
|
||||
found: right_inner.clone(),
|
||||
context: format!("Both sides of {:?} must be the same numeric type", op),
|
||||
});
|
||||
}
|
||||
return Type::Error;
|
||||
}
|
||||
// Result type: Int*Int=Int, anything with Float=Float
|
||||
let resolved = self.apply_subst(&left_inner);
|
||||
if resolved == Type::String && matches!(op, BinOp::Add) {
|
||||
Type::String
|
||||
} else if *left_inner == Type::Error || *right_inner == Type::Error {
|
||||
Type::Error
|
||||
} else if matches!(resolved, Type::Float) || matches!(self.apply_subst(&right_inner), Type::Float) {
|
||||
Type::Float
|
||||
} else if matches!(resolved, Type::Int) {
|
||||
Type::Int
|
||||
} else {
|
||||
self.error(TypeErrorKind::Mismatch {
|
||||
expected: Type::Int,
|
||||
found: right_ty.clone(),
|
||||
context: format!("Both sides of {:?} must be numeric", op),
|
||||
});
|
||||
Type::Error
|
||||
resolved
|
||||
}
|
||||
}
|
||||
BinOp::Eq | BinOp::Neq | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte => Type::Bool,
|
||||
BinOp::Eq | BinOp::Neq | BinOp::Lt | BinOp::Gt | BinOp::Lte | BinOp::Gte => {
|
||||
// Unify operands (comparisons require same type)
|
||||
let _ = self.unify(&left_inner, &right_inner);
|
||||
Type::Bool
|
||||
}
|
||||
BinOp::And | BinOp::Or => {
|
||||
if *left_inner != Type::Bool && !matches!(left_inner, Type::Var(_)) {
|
||||
self.error(TypeErrorKind::Mismatch {
|
||||
expected: Type::Bool,
|
||||
found: left_ty.clone(),
|
||||
context: format!("Left side of {:?} must be Bool", op),
|
||||
});
|
||||
}
|
||||
let _ = self.unify(&left_inner, &Type::Bool);
|
||||
let _ = self.unify(&right_inner, &Type::Bool);
|
||||
Type::Bool
|
||||
}
|
||||
}
|
||||
|
|
@ -1026,4 +1209,132 @@ mod tests {
|
|||
assert!(msg.contains("-42"), "Error should show the literal value, got: {}", msg);
|
||||
assert!(msg.contains("REFINEMENT VIOLATED"));
|
||||
}
|
||||
|
||||
// ── Hindley-Milner Unification tests ──────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_unify_same_types() {
|
||||
let mut checker = TypeChecker::new();
|
||||
assert!(checker.unify(&Type::Int, &Type::Int).is_ok());
|
||||
assert!(checker.unify(&Type::String, &Type::String).is_ok());
|
||||
assert!(checker.unify(&Type::Bool, &Type::Bool).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unify_type_variable() {
|
||||
let mut checker = TypeChecker::new();
|
||||
let tv = checker.fresh_tv();
|
||||
assert!(checker.unify(&tv, &Type::Int).is_ok());
|
||||
assert_eq!(checker.apply_subst(&tv), Type::Int);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unify_mismatch() {
|
||||
let mut checker = TypeChecker::new();
|
||||
assert!(checker.unify(&Type::Int, &Type::String).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unify_numeric_coercion() {
|
||||
let mut checker = TypeChecker::new();
|
||||
assert!(checker.unify(&Type::Int, &Type::Float).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unify_signal_types() {
|
||||
let mut checker = TypeChecker::new();
|
||||
assert!(checker.unify(
|
||||
&Type::Signal(Box::new(Type::Int)),
|
||||
&Type::Signal(Box::new(Type::Int)),
|
||||
).is_ok());
|
||||
// Signal<Int> vs Signal<String> should fail
|
||||
assert!(checker.unify(
|
||||
&Type::Signal(Box::new(Type::Int)),
|
||||
&Type::Signal(Box::new(Type::String)),
|
||||
).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_occurs_check_infinite_type() {
|
||||
let mut checker = TypeChecker::new();
|
||||
let tv = checker.fresh_tv();
|
||||
let array_of_tv = Type::Array(Box::new(tv.clone()));
|
||||
let result = checker.unify(&tv, &array_of_tv);
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
TypeErrorKind::InfiniteType { .. } => { /* expected */ }
|
||||
other => panic!("Expected InfiniteType, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Signal Graph Classification tests ─────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_signal_graph_classification() {
|
||||
let mut checker = TypeChecker::new();
|
||||
let mut info = SignalInfo::default();
|
||||
info.signals.insert("count".to_string(), SignalClass::Source);
|
||||
info.signals.insert("doubled".to_string(), SignalClass::Derived);
|
||||
|
||||
let program = make_program(vec![
|
||||
Declaration::Let(LetDecl {
|
||||
name: "count".to_string(),
|
||||
type_annotation: None,
|
||||
value: Expr::IntLit(0),
|
||||
span: span(),
|
||||
}),
|
||||
Declaration::Let(LetDecl {
|
||||
name: "doubled".to_string(),
|
||||
type_annotation: None,
|
||||
value: Expr::BinOp(
|
||||
Box::new(Expr::Ident("count".to_string())),
|
||||
BinOp::Mul,
|
||||
Box::new(Expr::IntLit(2)),
|
||||
),
|
||||
span: span(),
|
||||
}),
|
||||
]);
|
||||
checker.check_program_with_signals(&program, Some(&info));
|
||||
assert!(!checker.has_errors(), "Errors: {}", checker.display_errors());
|
||||
|
||||
let count_ty = checker.type_env().get("count").unwrap();
|
||||
assert_eq!(count_ty.display(), "Signal<Int>");
|
||||
|
||||
let doubled_ty = checker.type_env().get("doubled").unwrap();
|
||||
assert_eq!(doubled_ty.display(), "Derived<Int>");
|
||||
}
|
||||
|
||||
// ── Effect Scope tests ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_effect_handler_in_view() {
|
||||
let mut checker = TypeChecker::new();
|
||||
// Dom effect should be auto-handled inside view
|
||||
checker.push_effect_scope(vec![EffectType::Dom]);
|
||||
assert!(checker.is_effect_handled(&EffectType::Dom));
|
||||
checker.pop_effect_scope();
|
||||
assert!(!checker.is_effect_handled(&EffectType::Dom));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_type_unification() {
|
||||
let mut checker = TypeChecker::new();
|
||||
let program = make_program(vec![
|
||||
Declaration::Let(LetDecl {
|
||||
name: "items".to_string(),
|
||||
type_annotation: None,
|
||||
value: Expr::List(vec![
|
||||
Expr::IntLit(1),
|
||||
Expr::IntLit(2),
|
||||
Expr::IntLit(3),
|
||||
]),
|
||||
span: span(),
|
||||
}),
|
||||
]);
|
||||
checker.check_program(&program);
|
||||
assert!(!checker.has_errors(), "Errors: {}", checker.display_errors());
|
||||
|
||||
let items_ty = checker.type_env().get("items").unwrap();
|
||||
assert_eq!(items_ty.display(), "Signal<[Int]>");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue