diff --git a/compiler/ds-types/src/checker.rs b/compiler/ds-types/src/checker.rs index 9749b95..0fd859c 100644 --- a/compiler/ds-types/src/checker.rs +++ b/compiler/ds-types/src/checker.rs @@ -57,9 +57,16 @@ impl TypeChecker { /// Check an entire program. pub fn check_program(&mut self, program: &Program) { - // Pass 0: register type aliases + // Pass 0: register type aliases (with cycle detection) for decl in &program.declarations { if let Declaration::TypeAlias(alias) = decl { + // Check for direct self-reference cycle + if Self::type_expr_references(&alias.definition, &alias.name) { + self.error(TypeErrorKind::TypeAliasCycle { + name: alias.name.clone(), + }); + continue; + } let resolved = self.resolve_type_expr(&alias.definition); self.type_aliases.insert(alias.name.clone(), resolved); } @@ -253,6 +260,12 @@ impl TypeChecker { fn ast_to_pred_expr(expr: &Expr) -> PredicateExpr { match expr { Expr::Ident(name) if name == "value" => PredicateExpr::Value, + Expr::Ident(name) => { + // Non-value ident in a predicate — treat as an unresolvable reference. + // This allows predicates like `value > min_threshold` where + // min_threshold won't be statically evaluated. + PredicateExpr::Call(name.clone(), vec![]) // model as zero-arg "call" + } Expr::IntLit(n) => PredicateExpr::IntLit(*n), Expr::FloatLit(f) => PredicateExpr::FloatLit(*f), Expr::StringLit(s) => { @@ -269,7 +282,11 @@ impl TypeChecker { .collect(); PredicateExpr::Call(name.clone(), pred_args) } - _ => PredicateExpr::Value, // fallback + _ => { + // Unrecognized expression — model as opaque. This will prevent + // static evaluation (returns None) rather than silently succeeding. + PredicateExpr::Call("".to_string(), vec![]) + } } } @@ -280,18 +297,26 @@ impl TypeChecker { &mut self, predicate: &Predicate, value_expr: &Expr, - var_name: &str, + _var_name: &str, type_ann: &TypeExpr, ) { // Try to get a static value from the expression let static_val = match value_expr { Expr::IntLit(n) => Some(PredicateExpr::IntLit(*n)), Expr::FloatLit(f) => Some(PredicateExpr::FloatLit(*f)), + Expr::BoolLit(b) => Some(PredicateExpr::BoolLit(*b)), + Expr::StringLit(s) => { + let text: String = s.segments.iter().map(|seg| match seg { + ds_parser::StringSegment::Literal(l) => l.clone(), + _ => String::new(), + }).collect(); + Some(PredicateExpr::StringLit(text)) + } _ => None, }; - if let Some(val) = static_val { - if let Some(result) = predicate.evaluate_static(&val) { + if let Some(ref val) = static_val { + if let Some(result) = predicate.evaluate_static(val) { if !result { // Static violation — compile-time error let type_name = match type_ann { @@ -302,10 +327,20 @@ impl TypeChecker { }, _ => "".to_string(), }; + let value_display = match value_expr { + Expr::IntLit(n) => n.to_string(), + Expr::FloatLit(f) => f.to_string(), + Expr::BoolLit(b) => b.to_string(), + Expr::StringLit(_) => match val { + PredicateExpr::StringLit(s) => format!("\"{}\"", s), + _ => format!("{:?}", value_expr), + }, + _ => format!("{:?}", value_expr), + }; self.error(TypeErrorKind::RefinementViolation { type_name, predicate: predicate.display(), - value: format!("{:?}", value_expr), + value: value_display, }); } } @@ -313,6 +348,17 @@ impl TypeChecker { // Dynamic values: accepted (codegen emits runtime guard) } + /// Check if a TypeExpr references a given name (for cycle detection). + fn type_expr_references(type_expr: &TypeExpr, name: &str) -> bool { + match type_expr { + TypeExpr::Named(n) => n == name, + TypeExpr::Generic(n, params) => { + n == name || params.iter().any(|p| Self::type_expr_references(p, name)) + } + TypeExpr::Refined { base, .. } => Self::type_expr_references(base, name), + } + } + /// Check a view declaration. fn check_view(&mut self, view: &ViewDecl) { self.in_view = true; @@ -917,4 +963,48 @@ mod tests { checker.check_program(&program); assert!(!checker.has_errors(), "Errors: {}", checker.display_errors()); } + + #[test] + fn test_type_alias_cycle_detection() { + let mut checker = TypeChecker::new(); + // type Foo = Foo (self-referential cycle) + let program = make_program(vec![ + Declaration::TypeAlias(ds_parser::TypeAliasDecl { + name: "Foo".to_string(), + definition: ds_parser::TypeExpr::Named("Foo".to_string()), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(checker.has_errors()); + let msg = checker.display_errors(); + assert!(msg.contains("TYPE ALIAS CYCLE"), "Expected cycle error, got: {}", msg); + assert!(msg.contains("Foo")); + } + + #[test] + fn test_refinement_violation_shows_readable_value() { + let mut checker = TypeChecker::new(); + // let x: Int where value > 0 = -42 + let program = make_program(vec![ + Declaration::Let(LetDecl { + name: "x".to_string(), + type_annotation: Some(ds_parser::TypeExpr::Refined { + base: Box::new(ds_parser::TypeExpr::Named("Int".to_string())), + predicate: Box::new(Expr::BinOp( + Box::new(Expr::Ident("value".to_string())), + BinOp::Gt, + Box::new(Expr::IntLit(0)), + )), + }), + value: Expr::IntLit(-42), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(checker.has_errors()); + let msg = checker.display_errors(); + assert!(msg.contains("-42"), "Error should show the literal value, got: {}", msg); + assert!(msg.contains("REFINEMENT VIOLATED")); + } } diff --git a/compiler/ds-types/src/types.rs b/compiler/ds-types/src/types.rs index 0ae9a58..d3d39c6 100644 --- a/compiler/ds-types/src/types.rs +++ b/compiler/ds-types/src/types.rs @@ -176,11 +176,18 @@ impl Predicate { Some(lv <= rv) } Predicate::Eq(l, r) => { + // Prefer exact integer comparison to avoid f64 precision issues + if let Some((lv, rv)) = Self::try_resolve_ints(l, r, value) { + return Some(lv == rv); + } let lv = Self::resolve_expr(l, value)?; let rv = Self::resolve_expr(r, value)?; Some(lv == rv) } Predicate::Neq(l, r) => { + if let Some((lv, rv)) = Self::try_resolve_ints(l, r, value) { + return Some(lv != rv); + } let lv = Self::resolve_expr(l, value)?; let rv = Self::resolve_expr(r, value)?; Some(lv != rv) @@ -215,6 +222,27 @@ impl Predicate { _ => None, } } + + /// Try to resolve both as integers for exact comparison (avoids f64 precision loss). + fn try_resolve_ints(l: &PredicateExpr, r: &PredicateExpr, value: &PredicateExpr) -> Option<(i64, i64)> { + let lv = match l { + PredicateExpr::Value => match value { + PredicateExpr::IntLit(n) => Some(*n), + _ => None, + }, + PredicateExpr::IntLit(n) => Some(*n), + _ => None, + }?; + let rv = match r { + PredicateExpr::Value => match value { + PredicateExpr::IntLit(n) => Some(*n), + _ => None, + }, + PredicateExpr::IntLit(n) => Some(*n), + _ => None, + }?; + Some((lv, rv)) + } } impl PredicateExpr { @@ -238,6 +266,15 @@ impl Type { pub fn unwrap_reactive(&self) -> &Type { match self { Type::Signal(inner) | Type::Derived(inner) => inner, + Type::Refined { base, .. } => base.unwrap_reactive(), + other => other, + } + } + + /// Strip refinement wrapper, returning the base type. + pub fn unwrap_refined(&self) -> &Type { + match self { + Type::Refined { base, .. } => base.unwrap_refined(), other => other, } } @@ -338,4 +375,106 @@ mod tests { // Non-reactive returns self assert_eq!(*Type::Int.unwrap_reactive(), Type::Int); } + + #[test] + fn test_predicate_display() { + let pred = Predicate::Gt( + Box::new(PredicateExpr::Value), + Box::new(PredicateExpr::IntLit(0)), + ); + assert_eq!(pred.display(), "value > 0"); + + let compound = Predicate::And( + Box::new(Predicate::Gte( + Box::new(PredicateExpr::Value), + Box::new(PredicateExpr::IntLit(0)), + )), + Box::new(Predicate::Lte( + Box::new(PredicateExpr::Value), + Box::new(PredicateExpr::IntLit(100)), + )), + ); + assert_eq!(compound.display(), "value >= 0 and value <= 100"); + } + + #[test] + fn test_evaluate_static_int() { + let pred = Predicate::Gt( + Box::new(PredicateExpr::Value), + Box::new(PredicateExpr::IntLit(0)), + ); + assert_eq!(pred.evaluate_static(&PredicateExpr::IntLit(5)), Some(true)); + assert_eq!(pred.evaluate_static(&PredicateExpr::IntLit(0)), Some(false)); + assert_eq!(pred.evaluate_static(&PredicateExpr::IntLit(-1)), Some(false)); + } + + #[test] + fn test_evaluate_static_eq_integer_precision() { + // Test that integer equality uses exact comparison, not f64 + let pred = Predicate::Eq( + Box::new(PredicateExpr::Value), + Box::new(PredicateExpr::IntLit(42)), + ); + assert_eq!(pred.evaluate_static(&PredicateExpr::IntLit(42)), Some(true)); + assert_eq!(pred.evaluate_static(&PredicateExpr::IntLit(43)), Some(false)); + } + + #[test] + fn test_evaluate_static_compound() { + // value >= 0 and value <= 100 + let pred = Predicate::And( + Box::new(Predicate::Gte( + Box::new(PredicateExpr::Value), + Box::new(PredicateExpr::IntLit(0)), + )), + Box::new(Predicate::Lte( + Box::new(PredicateExpr::Value), + Box::new(PredicateExpr::IntLit(100)), + )), + ); + assert_eq!(pred.evaluate_static(&PredicateExpr::IntLit(50)), Some(true)); + assert_eq!(pred.evaluate_static(&PredicateExpr::IntLit(0)), Some(true)); + assert_eq!(pred.evaluate_static(&PredicateExpr::IntLit(100)), Some(true)); + assert_eq!(pred.evaluate_static(&PredicateExpr::IntLit(-1)), Some(false)); + assert_eq!(pred.evaluate_static(&PredicateExpr::IntLit(101)), Some(false)); + } + + #[test] + fn test_unwrap_refined() { + let refined = Type::Refined { + base: Box::new(Type::Int), + predicate: Predicate::Gt( + Box::new(PredicateExpr::Value), + Box::new(PredicateExpr::IntLit(0)), + ), + }; + assert_eq!(*refined.unwrap_refined(), Type::Int); + // Non-refined returns self + assert_eq!(*Type::Int.unwrap_refined(), Type::Int); + } + + #[test] + fn test_refined_type_display() { + let refined = Type::Refined { + base: Box::new(Type::Int), + predicate: Predicate::Gt( + Box::new(PredicateExpr::Value), + Box::new(PredicateExpr::IntLit(0)), + ), + }; + assert_eq!(refined.display(), "Int where value > 0"); + } + + #[test] + fn test_refined_unwrap_reactive() { + // Refined wrapping a non-reactive type + let refined = Type::Refined { + base: Box::new(Type::Int), + predicate: Predicate::Gt( + Box::new(PredicateExpr::Value), + Box::new(PredicateExpr::IntLit(0)), + ), + }; + assert_eq!(*refined.unwrap_reactive(), Type::Int); + } }