diff --git a/compiler/ds-codegen/src/js_emitter.rs b/compiler/ds-codegen/src/js_emitter.rs index 1242bfd..69a0f34 100644 --- a/compiler/ds-codegen/src/js_emitter.rs +++ b/compiler/ds-codegen/src/js_emitter.rs @@ -141,6 +141,55 @@ impl JsEmitter { } } + // Phase 1b: Emit runtime refinement guards + // Collect type aliases from program + let mut type_aliases: std::collections::HashMap = std::collections::HashMap::new(); + for decl in &program.declarations { + if let Declaration::TypeAlias(alias) = decl { + type_aliases.insert(alias.name.clone(), &alias.definition); + } + } + + let mut guards_emitted = false; + for decl in &program.declarations { + if let Declaration::Let(let_decl) = decl { + // Skip literals — they're statically checked by the type checker + if matches!(let_decl.value, + Expr::IntLit(_) | Expr::FloatLit(_) | Expr::StringLit(_) | Expr::BoolLit(_) + ) { + continue; + } + + if let Some(ref type_ann) = let_decl.type_annotation { + // Resolve type annotation to find refinement predicate + let resolved = match type_ann { + TypeExpr::Named(name) => type_aliases.get(name).copied(), + TypeExpr::Refined { .. } => Some(type_ann), + _ => None, + }; + + if let Some(TypeExpr::Refined { predicate, .. }) = resolved { + if !guards_emitted { + self.emit_line(""); + self.emit_line("// ── Refinement Guards ──"); + guards_emitted = true; + } + let type_name = match type_ann { + TypeExpr::Named(n) => n.clone(), + _ => "refined type".to_string(), + }; + let js_pred = Self::predicate_to_js(predicate, &let_decl.name); + self.emit_line(&format!( + "if (!({js_pred})) throw new Error(\"Refinement violated: `{name}` must satisfy {type_name}\");", + js_pred = js_pred, + name = let_decl.name, + type_name = type_name, + )); + } + } + } + } + self.emit_line(""); // Phase 2a: Component functions @@ -981,6 +1030,63 @@ impl JsEmitter { None } + /// Convert a predicate expression from a `where` clause into a JavaScript boolean expression. + /// The `value` identifier is replaced with `signal_name.value` to read the signal's current value. + fn predicate_to_js(expr: &Expr, signal_name: &str) -> String { + match expr { + Expr::Ident(name) if name == "value" => format!("{}.value", signal_name), + Expr::Ident(name) => name.clone(), + Expr::IntLit(n) => format!("{}", n), + Expr::FloatLit(f) => format!("{}", f), + Expr::BoolLit(b) => format!("{}", b), + Expr::StringLit(s) => { + let text: String = s.segments.iter().map(|seg| match seg { + StringSegment::Literal(l) => l.clone(), + _ => String::new(), + }).collect(); + format!("\"{}\"", text) + } + Expr::BinOp(left, op, right) => { + let l = Self::predicate_to_js(left, signal_name); + let r = Self::predicate_to_js(right, signal_name); + let op_str = match op { + BinOp::Gt => ">", + BinOp::Gte => ">=", + BinOp::Lt => "<", + BinOp::Lte => "<=", + BinOp::Eq => "===", + BinOp::Neq => "!==", + BinOp::And => "&&", + BinOp::Or => "||", + BinOp::Add => "+", + BinOp::Sub => "-", + BinOp::Mul => "*", + BinOp::Div => "/", + BinOp::Mod => "%", + }; + format!("({} {} {})", l, op_str, r) + } + Expr::UnaryOp(UnaryOp::Not, inner) => { + format!("!({})", Self::predicate_to_js(inner, signal_name)) + } + Expr::UnaryOp(UnaryOp::Neg, inner) => { + format!("-({})", Self::predicate_to_js(inner, signal_name)) + } + Expr::Call(name, args) => { + let js_args: Vec = args.iter() + .map(|a| Self::predicate_to_js(a, signal_name)) + .collect(); + // Map common predicate functions to JS equivalents + match name.as_str() { + "len" => format!("{}.length", js_args.first().unwrap_or(&"null".to_string())), + "contains" if js_args.len() == 2 => format!("{}.includes({})", js_args[0], js_args[1]), + _ => format!("{}({})", name, js_args.join(", ")), + } + } + _ => format!("{}.value", signal_name), // fallback + } + } + fn is_signal_ref(&self, expr: &str) -> bool { // Must start with a letter/underscore (not a digit) and contain only ident chars !expr.is_empty() diff --git a/compiler/ds-parser/src/ast.rs b/compiler/ds-parser/src/ast.rs index 8c61fda..95e4c1f 100644 --- a/compiler/ds-parser/src/ast.rs +++ b/compiler/ds-parser/src/ast.rs @@ -34,6 +34,8 @@ pub enum Declaration { Import(ImportDecl), /// `export let count = 0`, `export component Card(...) = ...` Export(String, Box), + /// `type PositiveInt = Int where value > 0` + TypeAlias(TypeAliasDecl), } /// `import { Card, Button } from "./components"` @@ -48,6 +50,7 @@ pub struct ImportDecl { #[derive(Debug, Clone)] pub struct LetDecl { pub name: String, + pub type_annotation: Option, pub value: Expr, pub span: Span, } @@ -159,6 +162,19 @@ pub struct Param { pub enum TypeExpr { Named(String), Generic(String, Vec), + /// Refinement type: `Int where value > 0` + Refined { + base: Box, + predicate: Box, + }, +} + +/// `type PositiveInt = Int where value > 0` +#[derive(Debug, Clone)] +pub struct TypeAliasDecl { + pub name: String, + pub definition: TypeExpr, + pub span: Span, } /// Expressions — the core of the language. diff --git a/compiler/ds-parser/src/lexer.rs b/compiler/ds-parser/src/lexer.rs index 503cd26..a0942d2 100644 --- a/compiler/ds-parser/src/lexer.rs +++ b/compiler/ds-parser/src/lexer.rs @@ -57,6 +57,8 @@ pub enum TokenKind { Every, Import, Export, + Type, + Where, // Operators Plus, @@ -334,6 +336,8 @@ impl Lexer { "every" => TokenKind::Every, "import" => TokenKind::Import, "export" => TokenKind::Export, + "type" => TokenKind::Type, + "where" => TokenKind::Where, _ => TokenKind::Ident(ident.clone()), }; diff --git a/compiler/ds-parser/src/parser.rs b/compiler/ds-parser/src/parser.rs index 6a94f38..940f2a2 100644 --- a/compiler/ds-parser/src/parser.rs +++ b/compiler/ds-parser/src/parser.rs @@ -101,13 +101,14 @@ impl Parser { TokenKind::Every => self.parse_every_decl(), TokenKind::Import => self.parse_import_decl(), TokenKind::Export => self.parse_export_decl(), + TokenKind::Type => self.parse_type_alias_decl(), // Expression statement: `log("hello")`, `push(items, x)` TokenKind::Ident(_) => { let expr = self.parse_expr()?; Ok(Declaration::ExprStatement(expr)) } _ => Err(self.error(format!( - "expected declaration (let, view, effect, on, component, route, constrain, stream, every), got {:?}", + "expected declaration (let, view, effect, on, component, route, constrain, stream, every, type), got {:?}", self.peek() ))), } @@ -208,16 +209,80 @@ impl Parser { let line = self.current_token().line; self.advance(); // consume 'let' let name = self.expect_ident()?; + + // Optional type annotation: `let name: Type = value` + let type_annotation = if self.check(&TokenKind::Colon) { + self.advance(); // consume ':' + Some(self.parse_type_expr()?) + } else { + None + }; + self.expect(&TokenKind::Eq)?; let value = self.parse_expr()?; Ok(Declaration::Let(LetDecl { name, + type_annotation, value, span: Span { start: 0, end: 0, line }, })) } + /// Parse a type alias: `type PositiveInt = Int where value > 0` + fn parse_type_alias_decl(&mut self) -> Result { + let line = self.current_token().line; + self.advance(); // consume 'type' + let name = self.expect_ident()?; + self.expect(&TokenKind::Eq)?; + let definition = self.parse_type_expr()?; + + Ok(Declaration::TypeAlias(TypeAliasDecl { + name, + definition, + span: Span { start: 0, end: 0, line }, + })) + } + + /// Parse a type expression: `Int`, `Array`, `Int where value > 0` + fn parse_type_expr(&mut self) -> Result { + // Parse base type name + let name = match self.peek().clone() { + TokenKind::Ident(n) => { self.advance(); n } + _ => return Err(self.error(format!("expected type name, got {:?}", self.peek()))), + }; + + // Optional generic params: `` + let base = if self.check(&TokenKind::Lt) { + self.advance(); // < + let mut params = Vec::new(); + loop { + params.push(self.parse_type_expr()?); + if self.check(&TokenKind::Comma) { + self.advance(); + } else { + break; + } + } + self.expect(&TokenKind::Gt)?; + TypeExpr::Generic(name, params) + } else { + TypeExpr::Named(name) + }; + + // Optional `where` predicate + if self.check(&TokenKind::Where) { + self.advance(); // consume 'where' + let predicate = self.parse_expr()?; + Ok(TypeExpr::Refined { + base: Box::new(base), + predicate: Box::new(predicate), + }) + } else { + Ok(base) + } + } + fn parse_view_decl(&mut self) -> Result { let line = self.current_token().line; self.advance(); // consume 'view' @@ -468,23 +533,8 @@ impl Parser { Ok(params) } - fn parse_type_expr(&mut self) -> Result { - let name = self.expect_ident()?; - if self.check(&TokenKind::Lt) { - self.advance(); - let mut type_args = Vec::new(); - while !self.check(&TokenKind::Gt) && !self.is_at_end() { - type_args.push(self.parse_type_expr()?); - if self.check(&TokenKind::Comma) { - self.advance(); - } - } - self.expect(&TokenKind::Gt)?; - Ok(TypeExpr::Generic(name, type_args)) - } else { - Ok(TypeExpr::Named(name)) - } - } + + // ── Expressions ───────────────────────────────────── diff --git a/compiler/ds-types/src/checker.rs b/compiler/ds-types/src/checker.rs index 8ab8577..9749b95 100644 --- a/compiler/ds-types/src/checker.rs +++ b/compiler/ds-types/src/checker.rs @@ -8,8 +8,8 @@ use std::collections::HashMap; -use ds_parser::{Program, Declaration, LetDecl, ViewDecl, Expr, BinOp, UnaryOp}; -use crate::types::{Type, TypeVar, EffectType}; +use ds_parser::{Program, Declaration, LetDecl, ViewDecl, Expr, BinOp, UnaryOp, TypeExpr, TypeAliasDecl}; +use crate::types::{Type, TypeVar, EffectType, Predicate, PredicateExpr}; use crate::errors::{TypeError, TypeErrorKind}; /// The type checker. @@ -26,6 +26,8 @@ pub struct TypeChecker { substitutions: HashMap, /// Currently inside a view block? in_view: bool, + /// Type alias registry: name → resolved Type. + type_aliases: HashMap, } impl TypeChecker { @@ -37,6 +39,7 @@ impl TypeChecker { next_tv: 0, substitutions: HashMap::new(), in_view: false, + type_aliases: HashMap::new(), } } @@ -54,22 +57,66 @@ impl TypeChecker { /// Check an entire program. pub fn check_program(&mut self, program: &Program) { + // Pass 0: register type aliases + for decl in &program.declarations { + if let Declaration::TypeAlias(alias) = decl { + let resolved = self.resolve_type_expr(&alias.definition); + self.type_aliases.insert(alias.name.clone(), resolved); + } + } + // First pass: register all let declarations for decl in &program.declarations { if let Declaration::Let(let_decl) = decl { - let ty = self.infer_expr(&let_decl.value); + let inferred_ty = self.infer_expr(&let_decl.value); + + // If there's a type annotation, resolve and check it + if let Some(ref type_ann) = let_decl.type_annotation { + let declared_ty = self.resolve_type_expr(type_ann); + + // Check base type compatibility + let base_declared = match &declared_ty { + Type::Refined { base, .. } => base.as_ref(), + other => other, + }; + + let base_inferred = inferred_ty.unwrap_reactive(); + if *base_declared != *base_inferred + && !matches!(base_declared, Type::Var(_)) + && !matches!(base_inferred, Type::Var(_)) + && *base_declared != Type::Error + && *base_inferred != Type::Error + // Allow Int/Float coercion + && !(matches!(base_declared, Type::Float) && matches!(base_inferred, Type::Int)) + { + self.error(TypeErrorKind::Mismatch { + expected: base_declared.clone(), + found: base_inferred.clone(), + context: format!("in declaration of `{}`", let_decl.name), + }); + } + + // Check refinement predicate if present + if let Type::Refined { predicate, .. } = &declared_ty { + self.check_refinement( + predicate, + &let_decl.value, + &let_decl.name, + type_ann, + ); + } + } + // Heuristic: if name ends conventionally or is assigned // a literal, mark as Signal; otherwise just let-bound T. - // For now: any `let` with a literal is a source signal; - // derivations (expressions involving other identifiers) become Derived. let is_source = matches!( let_decl.value, Expr::IntLit(_) | Expr::FloatLit(_) | Expr::StringLit(_) | Expr::BoolLit(_) ); let final_ty = if is_source { - Type::Signal(Box::new(ty)) + Type::Signal(Box::new(inferred_ty)) } else { - Type::Derived(Box::new(ty)) + Type::Derived(Box::new(inferred_ty)) }; self.env.insert(let_decl.name.clone(), final_ty); } @@ -107,6 +154,165 @@ impl TypeChecker { } } + /// Resolve a TypeExpr from the AST into a semantic Type. + fn resolve_type_expr(&self, type_expr: &TypeExpr) -> Type { + match type_expr { + TypeExpr::Named(name) => { + // Check type aliases first + if let Some(resolved) = self.type_aliases.get(name) { + return resolved.clone(); + } + // Built-in type names + match name.as_str() { + "Int" => Type::Int, + "Float" => Type::Float, + "String" => Type::String, + "Bool" => Type::Bool, + "View" => Type::View, + _ => Type::Named(name.clone()), + } + } + TypeExpr::Generic(name, params) => { + let resolved_params: Vec = params.iter() + .map(|p| self.resolve_type_expr(p)) + .collect(); + match name.as_str() { + "Signal" if resolved_params.len() == 1 => { + Type::Signal(Box::new(resolved_params.into_iter().next().unwrap())) + } + "Array" if resolved_params.len() == 1 => { + Type::Array(Box::new(resolved_params.into_iter().next().unwrap())) + } + "Stream" if resolved_params.len() == 1 => { + Type::Stream(Box::new(resolved_params.into_iter().next().unwrap())) + } + _ => Type::Named(name.clone()), + } + } + TypeExpr::Refined { base, predicate } => { + let base_type = self.resolve_type_expr(base); + let pred = Self::ast_to_predicate(predicate); + Type::Refined { + base: Box::new(base_type), + predicate: pred, + } + } + } + } + + /// Convert an AST expression (from `where` clause) into a semantic Predicate. + fn ast_to_predicate(expr: &Expr) -> Predicate { + match expr { + Expr::BinOp(left, op, right) => { + match op { + BinOp::And => { + let l = Self::ast_to_predicate(left); + let r = Self::ast_to_predicate(right); + Predicate::And(Box::new(l), Box::new(r)) + } + BinOp::Or => { + let l = Self::ast_to_predicate(left); + let r = Self::ast_to_predicate(right); + Predicate::Or(Box::new(l), Box::new(r)) + } + BinOp::Gt => Predicate::Gt( + Box::new(Self::ast_to_pred_expr(left)), + Box::new(Self::ast_to_pred_expr(right)), + ), + BinOp::Gte => Predicate::Gte( + Box::new(Self::ast_to_pred_expr(left)), + Box::new(Self::ast_to_pred_expr(right)), + ), + BinOp::Lt => Predicate::Lt( + Box::new(Self::ast_to_pred_expr(left)), + Box::new(Self::ast_to_pred_expr(right)), + ), + BinOp::Lte => Predicate::Lte( + Box::new(Self::ast_to_pred_expr(left)), + Box::new(Self::ast_to_pred_expr(right)), + ), + BinOp::Eq => Predicate::Eq( + Box::new(Self::ast_to_pred_expr(left)), + Box::new(Self::ast_to_pred_expr(right)), + ), + BinOp::Neq => Predicate::Neq( + Box::new(Self::ast_to_pred_expr(left)), + Box::new(Self::ast_to_pred_expr(right)), + ), + _ => Predicate::Expr(format!("{:?}", expr)), + } + } + Expr::UnaryOp(UnaryOp::Not, inner) => { + Predicate::Not(Box::new(Self::ast_to_predicate(inner))) + } + _ => Predicate::Expr(format!("{:?}", expr)), + } + } + + /// Convert an AST expression into a predicate sub-expression. + fn ast_to_pred_expr(expr: &Expr) -> PredicateExpr { + match expr { + Expr::Ident(name) if name == "value" => PredicateExpr::Value, + Expr::IntLit(n) => PredicateExpr::IntLit(*n), + Expr::FloatLit(f) => PredicateExpr::FloatLit(*f), + Expr::StringLit(s) => { + let text: String = s.segments.iter().map(|seg| match seg { + ds_parser::StringSegment::Literal(l) => l.clone(), + _ => String::new(), + }).collect(); + PredicateExpr::StringLit(text) + } + Expr::BoolLit(b) => PredicateExpr::BoolLit(*b), + Expr::Call(name, args) => { + let pred_args: Vec = args.iter() + .map(|a| Self::ast_to_pred_expr(a)) + .collect(); + PredicateExpr::Call(name.clone(), pred_args) + } + _ => PredicateExpr::Value, // fallback + } + } + + /// Check a refinement predicate against a value expression. + /// For literals: evaluate statically and report error if violated. + /// For dynamic expressions: accept (runtime guard will be emitted by codegen). + fn check_refinement( + &mut self, + predicate: &Predicate, + value_expr: &Expr, + var_name: &str, + type_ann: &TypeExpr, + ) { + // Try to get a static value from the expression + let static_val = match value_expr { + Expr::IntLit(n) => Some(PredicateExpr::IntLit(*n)), + Expr::FloatLit(f) => Some(PredicateExpr::FloatLit(*f)), + _ => None, + }; + + if let Some(val) = static_val { + if let Some(result) = predicate.evaluate_static(&val) { + if !result { + // Static violation — compile-time error + let type_name = match type_ann { + TypeExpr::Named(n) => n.clone(), + TypeExpr::Refined { base, .. } => match base.as_ref() { + TypeExpr::Named(n) => format!("{} where ...", n), + _ => "".to_string(), + }, + _ => "".to_string(), + }; + self.error(TypeErrorKind::RefinementViolation { + type_name, + predicate: predicate.display(), + value: format!("{:?}", value_expr), + }); + } + } + } + // Dynamic values: accepted (codegen emits runtime guard) + } + /// Check a view declaration. fn check_view(&mut self, view: &ViewDecl) { self.in_view = true; @@ -498,6 +704,7 @@ mod tests { let program = make_program(vec![ Declaration::Let(LetDecl { name: "count".to_string(), + type_annotation: None, value: Expr::IntLit(0), span: span(), }), @@ -516,11 +723,13 @@ mod tests { let program = make_program(vec![ Declaration::Let(LetDecl { name: "count".to_string(), + type_annotation: None, value: Expr::IntLit(0), span: span(), }), Declaration::Let(LetDecl { name: "doubled".to_string(), + type_annotation: None, value: Expr::BinOp( Box::new(Expr::Ident("count".to_string())), BinOp::Mul, @@ -567,6 +776,7 @@ mod tests { let program = make_program(vec![ Declaration::Let(LetDecl { name: "name".to_string(), + type_annotation: None, value: Expr::StringLit(ds_parser::StringLit { segments: vec![ds_parser::StringSegment::Literal("hello".to_string())], }), @@ -596,4 +806,115 @@ mod tests { let msg = checker.display_errors(); assert!(msg.contains("VIEW OUTSIDE BLOCK")); } + + #[test] + fn test_type_annotation_basic() { + let mut checker = TypeChecker::new(); + let program = make_program(vec![ + Declaration::Let(LetDecl { + name: "count".to_string(), + type_annotation: Some(ds_parser::TypeExpr::Named("Int".to_string())), + value: Expr::IntLit(42), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(!checker.has_errors(), "Errors: {}", checker.display_errors()); + } + + #[test] + fn test_type_annotation_mismatch() { + let mut checker = TypeChecker::new(); + let program = make_program(vec![ + Declaration::Let(LetDecl { + name: "count".to_string(), + type_annotation: Some(ds_parser::TypeExpr::Named("Int".to_string())), + value: Expr::StringLit(ds_parser::StringLit { + segments: vec![ds_parser::StringSegment::Literal("oops".to_string())], + }), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(checker.has_errors()); + let msg = checker.display_errors(); + assert!(msg.contains("TYPE MISMATCH")); + } + + #[test] + fn test_refinement_passes_static() { + let mut checker = TypeChecker::new(); + // let count: Int where value > 0 = 5 + let program = make_program(vec![ + Declaration::Let(LetDecl { + name: "count".to_string(), + type_annotation: Some(ds_parser::TypeExpr::Refined { + base: Box::new(ds_parser::TypeExpr::Named("Int".to_string())), + predicate: Box::new(Expr::BinOp( + Box::new(Expr::Ident("value".to_string())), + BinOp::Gt, + Box::new(Expr::IntLit(0)), + )), + }), + value: Expr::IntLit(5), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(!checker.has_errors(), "Errors: {}", checker.display_errors()); + } + + #[test] + fn test_refinement_violation_static() { + let mut checker = TypeChecker::new(); + // let count: Int where value > 0 = -1 + let program = make_program(vec![ + Declaration::Let(LetDecl { + name: "count".to_string(), + type_annotation: Some(ds_parser::TypeExpr::Refined { + base: Box::new(ds_parser::TypeExpr::Named("Int".to_string())), + predicate: Box::new(Expr::BinOp( + Box::new(Expr::Ident("value".to_string())), + BinOp::Gt, + Box::new(Expr::IntLit(0)), + )), + }), + value: Expr::IntLit(-1), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(checker.has_errors()); + let msg = checker.display_errors(); + assert!(msg.contains("REFINEMENT VIOLATED"), "Expected REFINEMENT VIOLATED, got: {}", msg); + } + + #[test] + fn test_type_alias_with_refinement() { + let mut checker = TypeChecker::new(); + // type PositiveInt = Int where value > 0 + // let count: PositiveInt = 5 + let program = make_program(vec![ + Declaration::TypeAlias(ds_parser::TypeAliasDecl { + name: "PositiveInt".to_string(), + definition: ds_parser::TypeExpr::Refined { + base: Box::new(ds_parser::TypeExpr::Named("Int".to_string())), + predicate: Box::new(Expr::BinOp( + Box::new(Expr::Ident("value".to_string())), + BinOp::Gt, + Box::new(Expr::IntLit(0)), + )), + }, + span: span(), + }), + Declaration::Let(LetDecl { + name: "count".to_string(), + type_annotation: Some(ds_parser::TypeExpr::Named("PositiveInt".to_string())), + value: Expr::IntLit(5), + span: span(), + }), + ]); + checker.check_program(&program); + assert!(!checker.has_errors(), "Errors: {}", checker.display_errors()); + } } diff --git a/compiler/ds-types/src/errors.rs b/compiler/ds-types/src/errors.rs index 0bd0446..6491ebf 100644 --- a/compiler/ds-types/src/errors.rs +++ b/compiler/ds-types/src/errors.rs @@ -60,6 +60,18 @@ pub enum TypeErrorKind { field: String, record_type: Type, }, + + /// A refinement predicate was violated at compile time. + RefinementViolation { + type_name: String, + predicate: String, + value: String, + }, + + /// Circular type alias definition. + TypeAliasCycle { + name: String, + }, } impl TypeError { @@ -144,6 +156,21 @@ impl TypeError { field )) } + TypeErrorKind::RefinementViolation { type_name, predicate, value } => { + ("REFINEMENT VIOLATED".to_string(), format!( + "The value `{}` does not satisfy the refinement type `{}`.\n\n\ + The predicate `{}` is not satisfied.\n\n\ + Hint: Ensure the value meets the constraint before assignment.", + value, type_name, predicate + )) + } + TypeErrorKind::TypeAliasCycle { name } => { + ("TYPE ALIAS CYCLE".to_string(), format!( + "The type alias `{}` refers to itself, creating an infinite loop.\n\n\ + Break the cycle by using a concrete base type.", + name + )) + } }; // Format like Elm diff --git a/compiler/ds-types/src/lib.rs b/compiler/ds-types/src/lib.rs index dbb56a4..15c785f 100644 --- a/compiler/ds-types/src/lib.rs +++ b/compiler/ds-types/src/lib.rs @@ -4,5 +4,5 @@ pub mod types; pub mod errors; pub use checker::TypeChecker; -pub use types::{Type, TypeVar, SignalType, EffectType}; +pub use types::{Type, TypeVar, SignalType, EffectType, Predicate, PredicateExpr}; pub use errors::{TypeError, TypeErrorKind}; diff --git a/compiler/ds-types/src/types.rs b/compiler/ds-types/src/types.rs index dc7ce83..0ae9a58 100644 --- a/compiler/ds-types/src/types.rs +++ b/compiler/ds-types/src/types.rs @@ -57,6 +57,16 @@ pub enum Type { /// An unresolved type variable (for inference). Var(TypeVar), + /// Refinement type: base type + predicate constraint. + /// `Int where value > 0` becomes `Refined { base: Int, predicate: Gt(Value, IntLit(0)) }` + Refined { + base: Box, + predicate: Predicate, + }, + + /// Named type alias reference (resolved during checking). + Named(String), + /// Error sentinel (for error recovery). Error, } @@ -90,6 +100,139 @@ pub enum EffectType { Pure, } +/// A semantic predicate for refinement types. +#[derive(Debug, Clone, PartialEq)] +pub enum Predicate { + Gt(Box, Box), + Gte(Box, Box), + Lt(Box, Box), + Lte(Box, Box), + Eq(Box, Box), + Neq(Box, Box), + And(Box, Box), + Or(Box, Box), + Not(Box), + /// A call like `len(value) > 0` or `contains(value, "@")` + Call(String, Vec), + /// Raw expression that couldn't be further decomposed + Expr(String), +} + +/// Predicate sub-expression. +#[derive(Debug, Clone, PartialEq)] +pub enum PredicateExpr { + /// The refined value itself (the `value` keyword in `where value > 0`) + Value, + IntLit(i64), + FloatLit(f64), + StringLit(String), + BoolLit(bool), + Call(String, Vec), +} + +impl Predicate { + /// Pretty-print the predicate for error messages. + pub fn display(&self) -> String { + match self { + Predicate::Gt(l, r) => format!("{} > {}", l.display(), r.display()), + Predicate::Gte(l, r) => format!("{} >= {}", l.display(), r.display()), + Predicate::Lt(l, r) => format!("{} < {}", l.display(), r.display()), + Predicate::Lte(l, r) => format!("{} <= {}", l.display(), r.display()), + Predicate::Eq(l, r) => format!("{} == {}", l.display(), r.display()), + Predicate::Neq(l, r) => format!("{} != {}", l.display(), r.display()), + Predicate::And(l, r) => format!("{} and {}", l.display(), r.display()), + Predicate::Or(l, r) => format!("{} or {}", l.display(), r.display()), + Predicate::Not(p) => format!("not {}", p.display()), + Predicate::Call(name, args) => { + let args_str = args.iter().map(|a| a.display()).collect::>().join(", "); + format!("{}({})", name, args_str) + } + Predicate::Expr(s) => s.clone(), + } + } + + /// Try to evaluate the predicate statically with a concrete value. + /// Returns Some(true/false) if evaluable, None if dynamic. + pub fn evaluate_static(&self, value: &PredicateExpr) -> Option { + match self { + Predicate::Gt(l, r) => { + let lv = Self::resolve_expr(l, value)?; + let rv = Self::resolve_expr(r, value)?; + Some(lv > rv) + } + Predicate::Gte(l, r) => { + let lv = Self::resolve_expr(l, value)?; + let rv = Self::resolve_expr(r, value)?; + Some(lv >= rv) + } + Predicate::Lt(l, r) => { + let lv = Self::resolve_expr(l, value)?; + let rv = Self::resolve_expr(r, value)?; + Some(lv < rv) + } + Predicate::Lte(l, r) => { + let lv = Self::resolve_expr(l, value)?; + let rv = Self::resolve_expr(r, value)?; + Some(lv <= rv) + } + Predicate::Eq(l, r) => { + let lv = Self::resolve_expr(l, value)?; + let rv = Self::resolve_expr(r, value)?; + Some(lv == rv) + } + Predicate::Neq(l, r) => { + let lv = Self::resolve_expr(l, value)?; + let rv = Self::resolve_expr(r, value)?; + Some(lv != rv) + } + Predicate::And(l, r) => { + let lv = l.evaluate_static(value)?; + let rv = r.evaluate_static(value)?; + Some(lv && rv) + } + Predicate::Or(l, r) => { + let lv = l.evaluate_static(value)?; + let rv = r.evaluate_static(value)?; + Some(lv || rv) + } + Predicate::Not(p) => { + let v = p.evaluate_static(value)?; + Some(!v) + } + _ => None, // calls and raw expressions can't be statically evaluated + } + } + + fn resolve_expr(expr: &PredicateExpr, value: &PredicateExpr) -> Option { + match expr { + PredicateExpr::Value => match value { + PredicateExpr::IntLit(n) => Some(*n as f64), + PredicateExpr::FloatLit(f) => Some(*f), + _ => None, + }, + PredicateExpr::IntLit(n) => Some(*n as f64), + PredicateExpr::FloatLit(f) => Some(*f), + _ => None, + } + } +} + +impl PredicateExpr { + pub fn display(&self) -> String { + match self { + PredicateExpr::Value => "value".to_string(), + PredicateExpr::IntLit(n) => n.to_string(), + PredicateExpr::FloatLit(f) => f.to_string(), + PredicateExpr::StringLit(s) => format!("\"{}\"", s), + PredicateExpr::BoolLit(b) => b.to_string(), + PredicateExpr::Call(name, args) => { + let args_str = args.iter().map(|a| a.display()).collect::>().join(", "); + format!("{}({})", name, args_str) + } + } + } +} + impl Type { /// Unwrap the inner type of a Signal or Derived. pub fn unwrap_reactive(&self) -> &Type { @@ -147,6 +290,10 @@ impl Type { Type::Spring(inner) => format!("Spring<{}>", inner.display()), Type::View => "View".to_string(), Type::Var(tv) => format!("?{}", tv.0), + Type::Refined { base, predicate } => { + format!("{} where {}", base.display(), predicate.display()) + } + Type::Named(name) => name.clone(), Type::Error => "".to_string(), } } diff --git a/examples/refined-types.ds b/examples/refined-types.ds new file mode 100644 index 0000000..d471e40 --- /dev/null +++ b/examples/refined-types.ds @@ -0,0 +1,27 @@ +-- DreamStack Dependent Types Example +-- Demonstrates refinement types, type aliases, and type annotations + +-- Type aliases with refinement predicates +type PositiveInt = Int where value > 0 +type Percentage = Float where value >= 0.0 + +-- Basic type annotations +let count: Int = 0 +let name: String = "hello" + +-- Annotated with refinement type alias +let priority: PositiveInt = 1 +let progress: Percentage = 75.0 + +-- Derived values don't need annotation (inferred) +let doubled = count * 2 +let greeting = "Welcome, {name}!" + +view main = column [ + text "Count: {count}" + text "Priority: {priority}" + text "Progress: {progress}%" + text greeting + button "+" { click: count += 1 } + button "-" { click: count -= 1 } +]