From 2036663ed25b5885a87eb3a80caec3fa2e258d77 Mon Sep 17 00:00:00 2001 From: Laurenz Date: Wed, 27 Jan 2021 15:05:18 +0100 Subject: [PATCH] =?UTF-8?q?Capture=20variables=20in=20templates=20?= =?UTF-8?q?=F0=9F=94=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/eval/context.rs | 2 +- src/eval/mod.rs | 38 ++++--- src/eval/scope.rs | 10 +- src/eval/template.rs | 56 ++++++++++ src/eval/value.rs | 13 ++- src/syntax/expr.rs | 7 ++ src/syntax/mod.rs | 1 + src/syntax/visit.rs | 185 +++++++++++++++++++++++++++++++++ tests/lang/typ/expressions.typ | 2 +- 9 files changed, 291 insertions(+), 23 deletions(-) create mode 100644 src/eval/template.rs create mode 100644 src/syntax/visit.rs diff --git a/src/eval/context.rs b/src/eval/context.rs index a998bbdcb..fd7e264fe 100644 --- a/src/eval/context.rs +++ b/src/eval/context.rs @@ -40,7 +40,7 @@ impl<'a> EvalContext<'a> { pub fn new(env: &'a mut Env, scope: &'a Scope, state: State) -> Self { Self { env, - scopes: Scopes::new(scope), + scopes: Scopes::new(Some(scope)), state, groups: vec![], inner: vec![], diff --git a/src/eval/mod.rs b/src/eval/mod.rs index 13d242f80..9e45a67ba 100644 --- a/src/eval/mod.rs +++ b/src/eval/mod.rs @@ -7,6 +7,7 @@ mod context; mod ops; mod scope; mod state; +mod template; pub use call::*; pub use context::*; @@ -174,7 +175,7 @@ impl Eval for Spanned<&Expr> { Expr::Str(v) => Value::Str(v.clone()), Expr::Array(v) => Value::Array(v.with_span(self.span).eval(ctx)), Expr::Dict(v) => Value::Dict(v.with_span(self.span).eval(ctx)), - Expr::Template(v) => Value::Template(v.clone()), + Expr::Template(v) => v.with_span(self.span).eval(ctx), Expr::Group(v) => v.eval(ctx), Expr::Block(v) => v.with_span(self.span).eval(ctx), Expr::Call(v) => v.with_span(self.span).eval(ctx), @@ -183,6 +184,7 @@ impl Eval for Spanned<&Expr> { Expr::Let(v) => v.with_span(self.span).eval(ctx), Expr::If(v) => v.with_span(self.span).eval(ctx), Expr::For(v) => v.with_span(self.span).eval(ctx), + Expr::CapturedValue(v) => v.clone(), } } } @@ -327,18 +329,28 @@ impl Spanned<&ExprBinary> { let rhs = self.v.rhs.eval(ctx); let span = self.v.lhs.span; - if let Expr::Ident(id) = &self.v.lhs.v { - if let Some(slot) = ctx.scopes.get_mut(id) { - let lhs = std::mem::replace(slot, Value::None); - *slot = op(lhs, rhs); - return Value::None; - } else if ctx.scopes.is_const(id) { - ctx.diag(error!(span, "cannot assign to constant")); - } else { - ctx.diag(error!(span, "unknown variable")); + match &self.v.lhs.v { + Expr::Ident(id) => { + if let Some(slot) = ctx.scopes.get_mut(id) { + *slot = op(std::mem::take(slot), rhs); + return Value::None; + } else if ctx.scopes.is_const(id) { + ctx.diag(error!(span, "cannot assign to a constant")); + } else { + ctx.diag(error!(span, "unknown variable")); + } + } + + Expr::CapturedValue(_) => { + ctx.diag(error!( + span, + "cannot assign to captured expression in a template", + )); + } + + _ => { + ctx.diag(error!(span, "cannot assign to this expression")); } - } else { - ctx.diag(error!(span, "cannot assign to this expression")); } Value::Error @@ -421,7 +433,7 @@ impl Eval for Spanned<&ExprFor> { (ForPattern::KeyValue(..), Value::Str(_)) | (ForPattern::KeyValue(..), Value::Array(_)) => { - ctx.diag(error!(self.v.pat.span, "mismatched pattern",)); + ctx.diag(error!(self.v.pat.span, "mismatched pattern")); } (_, Value::Error) => {} diff --git a/src/eval/scope.rs b/src/eval/scope.rs index 1ed34f866..9c966a242 100644 --- a/src/eval/scope.rs +++ b/src/eval/scope.rs @@ -5,19 +5,19 @@ use std::iter; use super::Value; /// A stack of scopes. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Default, Clone, PartialEq)] pub struct Scopes<'a> { /// The active scope. top: Scope, /// The stack of lower scopes. scopes: Vec, /// The base scope. - base: &'a Scope, + base: Option<&'a Scope>, } impl<'a> Scopes<'a> { /// Create a new hierarchy of scopes. - pub fn new(base: &'a Scope) -> Self { + pub fn new(base: Option<&'a Scope>) -> Self { Self { top: Scope::new(), scopes: vec![], base } } @@ -43,7 +43,7 @@ impl<'a> Scopes<'a> { pub fn get(&self, var: &str) -> Option<&Value> { iter::once(&self.top) .chain(self.scopes.iter().rev()) - .chain(iter::once(self.base)) + .chain(self.base.into_iter()) .find_map(|scope| scope.get(var)) } @@ -58,7 +58,7 @@ impl<'a> Scopes<'a> { /// /// Defaults to `false` if the variable does not exist. pub fn is_const(&self, var: &str) -> bool { - self.base.get(var).is_some() + self.base.map_or(false, |base| base.get(var).is_some()) } } diff --git a/src/eval/template.rs b/src/eval/template.rs new file mode 100644 index 000000000..040685f89 --- /dev/null +++ b/src/eval/template.rs @@ -0,0 +1,56 @@ +use super::*; +use crate::syntax::visit::*; + +impl Eval for Spanned<&ExprTemplate> { + type Output = Value; + + fn eval(self, ctx: &mut EvalContext) -> Self::Output { + let mut template = self.v.clone(); + let mut visitor = CapturesVisitor::new(ctx); + visitor.visit_template(&mut template); + Value::Template(template) + } +} + +/// A visitor that replaces all captured variables with their values. +struct CapturesVisitor<'a> { + external: &'a Scopes<'a>, + internal: Scopes<'a>, +} + +impl<'a> CapturesVisitor<'a> { + fn new(ctx: &'a EvalContext) -> Self { + Self { + external: &ctx.scopes, + internal: Scopes::default(), + } + } +} + +impl<'a> Visitor<'a> for CapturesVisitor<'a> { + fn visit_scope_pre(&mut self) { + self.internal.push(); + } + + fn visit_scope_post(&mut self) { + self.internal.pop(); + } + + fn visit_def(&mut self, id: &mut Ident) { + self.internal.define(id.as_str(), Value::None); + } + + fn visit_expr(&mut self, expr: &'a mut Expr) { + if let Expr::Ident(ident) = expr { + // Find out whether the identifier is not locally defined, but + // captured, and if so, replace it with it's value. + if self.internal.get(ident).is_none() { + if let Some(value) = self.external.get(ident) { + *expr = Expr::CapturedValue(value.clone()); + } + } + } else { + walk_expr(self, expr); + } + } +} diff --git a/src/eval/value.rs b/src/eval/value.rs index 6fa702060..6e838f6ca 100644 --- a/src/eval/value.rs +++ b/src/eval/value.rs @@ -87,7 +87,14 @@ impl Eval for &Value { ctx.push(ctx.make_text_node(match self { Value::None => return, Value::Str(s) => s.clone(), - Value::Template(tree) => return tree.eval(ctx), + Value::Template(tree) => { + // We do not want to allow the template access to the current + // scopes. + let prev = std::mem::take(&mut ctx.scopes); + tree.eval(ctx); + ctx.scopes = prev; + return; + } other => pretty(other), })); } @@ -195,7 +202,7 @@ impl Deref for ValueFunc { impl Pretty for ValueFunc { fn pretty(&self, p: &mut Printer) { - write!(p, "(function {})", self.name).unwrap(); + p.push_str(&self.name); } } @@ -515,7 +522,7 @@ mod tests { test_pretty(Color::Rgba(RgbaColor::new(1, 1, 1, 0xff)), "#010101"); test_pretty("hello", r#""hello""#); test_pretty(vec![Spanned::zero(Node::Strong)], "[*]"); - test_pretty(ValueFunc::new("nil", |_, _| Value::None), "(function nil)"); + test_pretty(ValueFunc::new("nil", |_, _| Value::None), "nil"); test_pretty(ValueAny::new(1), "1"); test_pretty(Value::Error, "(error)"); } diff --git a/src/syntax/expr.rs b/src/syntax/expr.rs index 15c53cc0e..ca55bdd0a 100644 --- a/src/syntax/expr.rs +++ b/src/syntax/expr.rs @@ -1,5 +1,6 @@ use super::*; use crate::color::RgbaColor; +use crate::eval::Value; use crate::geom::{AngularUnit, LengthUnit}; /// An expression. @@ -50,6 +51,11 @@ pub enum Expr { If(ExprIf), /// A for expression: `#for x #in y { z }`. For(ExprFor), + /// A captured value. + /// + /// This node is never created by parsing. It only results from an in-place + /// transformation of an identifier to a captured value. + CapturedValue(Value), } impl Pretty for Expr { @@ -86,6 +92,7 @@ impl Pretty for Expr { Self::Let(v) => v.pretty(p), Self::If(v) => v.pretty(p), Self::For(v) => v.pretty(p), + Self::CapturedValue(v) => v.pretty(p), } } } diff --git a/src/syntax/mod.rs b/src/syntax/mod.rs index 0b2ac06f9..16e691a9f 100644 --- a/src/syntax/mod.rs +++ b/src/syntax/mod.rs @@ -5,6 +5,7 @@ mod ident; mod node; mod span; mod token; +pub mod visit; pub use expr::*; pub use ident::*; diff --git a/src/syntax/visit.rs b/src/syntax/visit.rs new file mode 100644 index 000000000..e9e5dad7b --- /dev/null +++ b/src/syntax/visit.rs @@ -0,0 +1,185 @@ +//! Syntax tree traversal. + +use super::*; + +/// Visits syntax tree nodes in a depth-first manner. +pub trait Visitor<'a>: Sized { + /// Visit a variable definition. + fn visit_def(&mut self, _ident: &'a mut Ident) {} + + /// Visit the start of a scope. + fn visit_scope_pre(&mut self) {} + + /// Visit the end of a scope. + fn visit_scope_post(&mut self) {} + + fn visit_node(&mut self, node: &'a mut Node) { + walk_node(self, node) + } + fn visit_expr(&mut self, expr: &'a mut Expr) { + walk_expr(self, expr) + } + fn visit_array(&mut self, array: &'a mut ExprArray) { + walk_array(self, array) + } + fn visit_dict(&mut self, dict: &'a mut ExprDict) { + walk_dict(self, dict) + } + fn visit_template(&mut self, template: &'a mut ExprTemplate) { + walk_template(self, template) + } + fn visit_group(&mut self, group: &'a mut ExprGroup) { + walk_group(self, group) + } + fn visit_block(&mut self, block: &'a mut ExprBlock) { + walk_block(self, block) + } + fn visit_binary(&mut self, binary: &'a mut ExprBinary) { + walk_binary(self, binary) + } + fn visit_unary(&mut self, unary: &'a mut ExprUnary) { + walk_unary(self, unary) + } + fn visit_call(&mut self, call: &'a mut ExprCall) { + walk_call(self, call) + } + fn visit_arg(&mut self, arg: &'a mut Argument) { + walk_arg(self, arg) + } + fn visit_let(&mut self, expr_let: &'a mut ExprLet) { + walk_let(self, expr_let) + } + fn visit_if(&mut self, expr_if: &'a mut ExprIf) { + walk_if(self, expr_if) + } + fn visit_for(&mut self, expr_for: &'a mut ExprFor) { + walk_for(self, expr_for) + } +} + +pub fn walk_node<'a, V: Visitor<'a>>(v: &mut V, node: &'a mut Node) { + match node { + Node::Strong => {} + Node::Emph => {} + Node::Space => {} + Node::Linebreak => {} + Node::Parbreak => {} + Node::Text(_) => {} + Node::Heading(_) => {} + Node::Raw(_) => {} + Node::Expr(expr) => v.visit_expr(expr), + } +} + +pub fn walk_expr<'a, V: Visitor<'a>>(v: &mut V, expr: &'a mut Expr) { + match expr { + Expr::None => {} + Expr::Ident(_) => {} + Expr::Bool(_) => {} + Expr::Int(_) => {} + Expr::Float(_) => {} + Expr::Length(_, _) => {} + Expr::Angle(_, _) => {} + Expr::Percent(_) => {} + Expr::Color(_) => {} + Expr::Str(_) => {} + Expr::Array(e) => v.visit_array(e), + Expr::Dict(e) => v.visit_dict(e), + Expr::Template(e) => v.visit_template(e), + Expr::Group(e) => v.visit_group(e), + Expr::Block(e) => v.visit_block(e), + Expr::Unary(e) => v.visit_unary(e), + Expr::Binary(e) => v.visit_binary(e), + Expr::Call(e) => v.visit_call(e), + Expr::Let(e) => v.visit_let(e), + Expr::If(e) => v.visit_if(e), + Expr::For(e) => v.visit_for(e), + Expr::CapturedValue(_) => {} + } +} + +pub fn walk_array<'a, V: Visitor<'a>>(v: &mut V, array: &'a mut ExprArray) { + for expr in array { + v.visit_expr(&mut expr.v); + } +} + +pub fn walk_dict<'a, V: Visitor<'a>>(v: &mut V, dict: &'a mut ExprDict) { + for named in dict { + v.visit_expr(&mut named.expr.v); + } +} + +pub fn walk_template<'a, V: Visitor<'a>>(v: &mut V, template: &'a mut ExprTemplate) { + v.visit_scope_pre(); + for node in template { + v.visit_node(&mut node.v); + } + v.visit_scope_post(); +} + +pub fn walk_group<'a, V: Visitor<'a>>(v: &mut V, group: &'a mut ExprGroup) { + v.visit_expr(&mut group.v); +} + +pub fn walk_block<'a, V: Visitor<'a>>(v: &mut V, block: &'a mut ExprBlock) { + if block.scopes { + v.visit_scope_pre(); + } + for expr in &mut block.exprs { + v.visit_expr(&mut expr.v); + } + if block.scopes { + v.visit_scope_post(); + } +} + +pub fn walk_binary<'a, V: Visitor<'a>>(v: &mut V, binary: &'a mut ExprBinary) { + v.visit_expr(&mut binary.lhs.v); + v.visit_expr(&mut binary.rhs.v); +} + +pub fn walk_unary<'a, V: Visitor<'a>>(v: &mut V, unary: &'a mut ExprUnary) { + v.visit_expr(&mut unary.expr.v); +} + +pub fn walk_call<'a, V: Visitor<'a>>(v: &mut V, call: &'a mut ExprCall) { + v.visit_expr(&mut call.callee.v); + for arg in &mut call.args.v { + v.visit_arg(arg); + } +} + +pub fn walk_arg<'a, V: Visitor<'a>>(v: &mut V, arg: &'a mut Argument) { + match arg { + Argument::Pos(expr) => v.visit_expr(&mut expr.v), + Argument::Named(named) => v.visit_expr(&mut named.expr.v), + } +} + +pub fn walk_let<'a, V: Visitor<'a>>(v: &mut V, expr_let: &'a mut ExprLet) { + v.visit_def(&mut expr_let.pat.v); + if let Some(init) = &mut expr_let.init { + v.visit_expr(&mut init.v); + } +} + +pub fn walk_if<'a, V: Visitor<'a>>(v: &mut V, expr_if: &'a mut ExprIf) { + v.visit_expr(&mut expr_if.condition.v); + v.visit_expr(&mut expr_if.if_body.v); + if let Some(body) = &mut expr_if.else_body { + v.visit_expr(&mut body.v); + } +} + +pub fn walk_for<'a, V: Visitor<'a>>(v: &mut V, expr_for: &'a mut ExprFor) { + match &mut expr_for.pat.v { + ForPattern::Value(value) => v.visit_def(value), + ForPattern::KeyValue(key, value) => { + v.visit_def(key); + v.visit_def(value); + } + } + v.visit_expr(&mut expr_for.iter.v); + v.visit_expr(&mut expr_for.body.v); +} diff --git a/tests/lang/typ/expressions.typ b/tests/lang/typ/expressions.typ index 49b8910db..da6c6f4f8 100644 --- a/tests/lang/typ/expressions.typ +++ b/tests/lang/typ/expressions.typ @@ -86,7 +86,7 @@ // Error: 1:3-1:8 cannot assign to this expression { 1 + 2 = 3} -// Error: 1:3-1:6 cannot assign to constant +// Error: 1:3-1:6 cannot assign to a constant { box = "hi" } // Works if we define box before (since then it doesn't resolve to the standard