From 68c6160a14be4e77b98cd704d9e641a03aefa332 Mon Sep 17 00:00:00 2001 From: Laurenz Date: Fri, 30 Dec 2022 09:48:30 +0100 Subject: [PATCH] Recursion with max depth --- src/model/eval.rs | 17 ++++++++++--- src/model/func.rs | 20 ++++++++++++--- src/syntax/ast.rs | 2 +- tests/typ/compiler/let.typ | 4 +-- tests/typ/compiler/recursion.typ | 42 ++++++++++++++++++++++++++++++++ 5 files changed, 74 insertions(+), 11 deletions(-) create mode 100644 tests/typ/compiler/recursion.typ diff --git a/src/model/eval.rs b/src/model/eval.rs index ce1739eda..54007e76c 100644 --- a/src/model/eval.rs +++ b/src/model/eval.rs @@ -20,6 +20,9 @@ use crate::syntax::{ast, Source, SourceId, Span, Spanned, SyntaxKind, SyntaxNode use crate::util::{format_eco, EcoString, PathExt}; use crate::World; +const MAX_ITERATIONS: usize = 10_000; +const MAX_CALL_DEPTH: usize = 256; + /// Evaluate a source file and return the resulting module. #[comemo::memoize] pub fn eval( @@ -41,7 +44,7 @@ pub fn eval( // Evaluate the module. let route = unsafe { Route::insert(route, id) }; let scopes = Scopes::new(Some(&library.scope)); - let mut vm = Vm::new(world, route.track(), id, scopes); + let mut vm = Vm::new(world, route.track(), id, scopes, 0); let result = source.ast()?.eval(&mut vm); // Handle control flow. @@ -70,6 +73,8 @@ pub struct Vm<'a> { pub(super) flow: Option, /// The stack of scopes. pub(super) scopes: Scopes<'a>, + /// The current call depth. + pub(super) depth: usize, } impl<'a> Vm<'a> { @@ -79,6 +84,7 @@ impl<'a> Vm<'a> { route: Tracked<'a, Route>, location: SourceId, scopes: Scopes<'a>, + depth: usize, ) -> Self { Self { world, @@ -87,6 +93,7 @@ impl<'a> Vm<'a> { location, flow: None, scopes, + depth, } } @@ -787,6 +794,10 @@ impl Eval for ast::FuncCall { type Output = Value; fn eval(&self, vm: &mut Vm) -> SourceResult { + if vm.depth >= MAX_CALL_DEPTH { + bail!(self.span(), "maximum function call depth exceeded"); + } + let callee = self.callee().eval(vm)?; let args = self.args().eval(vm)?; @@ -994,8 +1005,6 @@ impl Eval for ast::WhileLoop { type Output = Value; fn eval(&self, vm: &mut Vm) -> SourceResult { - const MAX_ITERS: usize = 10_000; - let flow = vm.flow.take(); let mut output = Value::None; let mut i = 0; @@ -1009,7 +1018,7 @@ impl Eval for ast::WhileLoop { && !can_diverge(body.as_untyped()) { bail!(condition.span(), "condition is always true"); - } else if i >= MAX_ITERS { + } else if i >= MAX_ITERATIONS { bail!(self.span(), "loop seems to be infinite"); } diff --git a/src/model/func.rs b/src/model/func.rs index f04b864e3..3fb8f4d4d 100644 --- a/src/model/func.rs +++ b/src/model/func.rs @@ -96,7 +96,7 @@ impl Func { pub fn call(&self, vm: &Vm, mut args: Args) -> SourceResult { let value = match self.0.as_ref() { Repr::Native(native) => (native.func)(vm, &mut args)?, - Repr::Closure(closure) => closure.call(vm, &mut args)?, + Repr::Closure(closure) => closure.call(vm, self, &mut args)?, Repr::With(wrapped, applied) => { args.items.splice(..0, applied.items.iter().cloned()); return wrapped.call(vm, args); @@ -115,7 +115,7 @@ impl Func { let route = Route::default(); let id = SourceId::detached(); let scopes = Scopes::new(None); - let vm = Vm::new(world, route.track(), id, scopes); + let vm = Vm::new(world, route.track(), id, scopes, 0); self.call(&vm, args) } @@ -274,12 +274,17 @@ pub(super) struct Closure { impl Closure { /// Call the function in the context with the arguments. - fn call(&self, vm: &Vm, args: &mut Args) -> SourceResult { + fn call(&self, vm: &Vm, this: &Func, args: &mut Args) -> SourceResult { // Don't leak the scopes from the call site. Instead, we use the scope // of captured variables we collected earlier. let mut scopes = Scopes::new(None); scopes.top = self.captured.clone(); + // Provide the closure itself for recursive calls. + if let Some(name) = &self.name { + scopes.top.define(name.clone(), Value::Func(this.clone())); + } + // Parse the arguments according to the parameter list. for (param, default) in &self.params { scopes.top.define( @@ -304,7 +309,7 @@ impl Closure { let route = if detached { fresh.track() } else { vm.route }; // Evaluate the body. - let mut sub = Vm::new(vm.world, route, self.location, scopes); + let mut sub = Vm::new(vm.world, route, self.location, scopes, vm.depth + 1); let result = self.body.eval(&mut sub); // Handle control flow. @@ -378,6 +383,10 @@ impl<'a> CapturesVisitor<'a> { } } + if let Some(name) = expr.name() { + self.bind(name); + } + for param in expr.params() { match param { ast::Param::Pos(ident) => self.bind(ident), @@ -456,6 +465,7 @@ mod tests { #[track_caller] fn test(text: &str, result: &[&str]) { let mut scopes = Scopes::new(None); + scopes.top.define("f", 0); scopes.top.define("x", 0); scopes.top.define("y", 0); scopes.top.define("z", 0); @@ -477,6 +487,8 @@ mod tests { test("#let x = x", &["x"]); test("#let x; {x + y}", &["y"]); test("#let f(x, y) = x + y", &[]); + test("#let f(x, y) = f", &[]); + test("#let f = (x, y) => f", &["f"]); // Closure with different kinds of params. test("{(x, y) => x + z}", &["z"]); diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index d2b19ee3e..6483f7cc7 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -1277,7 +1277,7 @@ impl Closure { /// /// This only exists if you use the function syntax sugar: `let f(x) = y`. pub fn name(&self) -> Option { - self.0.cast_first_child() + self.0.children().next()?.cast() } /// The parameter bindings. diff --git a/tests/typ/compiler/let.typ b/tests/typ/compiler/let.typ index c3be64a5d..d4f9510ab 100644 --- a/tests/typ/compiler/let.typ +++ b/tests/typ/compiler/let.typ @@ -11,8 +11,8 @@ // Syntax sugar for function definitions. #let fill = conifer -#let rect(body) = rect(width: 2cm, fill: fill, inset: 5pt, body) -#rect[Hi!] +#let f(body) = rect(width: 2cm, fill: fill, inset: 5pt, body) +#f[Hi!] --- // Termination. diff --git a/tests/typ/compiler/recursion.typ b/tests/typ/compiler/recursion.typ new file mode 100644 index 000000000..ae214631a --- /dev/null +++ b/tests/typ/compiler/recursion.typ @@ -0,0 +1,42 @@ +// Test recursive function calls. +// Ref: false + +--- +// Test with named function. +#let fib(n) = { + if n <= 2 { + 1 + } else { + fib(n - 1) + fib(n - 2) + } +} + +#test(fib(10), 55) + +--- +// Test with unnamed function. +// Error: 17-18 unknown variable +#let f = (n) => f(n - 1) +#f(10) + +--- +// Test capturing with named function. +#let f = 10 +#let f() = f +#test(type(f()), "function") + +--- +// Test capturing with unnamed function. +#let f = 10 +#let f = () => f +#test(type(f()), "integer") + +--- +// Error: 15-21 maximum function call depth exceeded +#let rec(n) = rec(n) + 1 +#rec(1) + +--- +#let f(x) = "hello" +#let f(x) = if x != none { f(none) } else { "world" } +#test(f(1), "world")