From cfb3b1a2709107f0f06f89ea25cabc939cec15e5 Mon Sep 17 00:00:00 2001 From: Ian Wrzesinski <133046678+wrzian@users.noreply.github.com> Date: Wed, 26 Feb 2025 15:10:36 -0500 Subject: [PATCH] Improve clarity of `ast.rs` for newcomers to the codebase (#5784) Co-authored-by: PgBiel <9021226+PgBiel@users.noreply.github.com> Co-authored-by: T0mstone <39707032+T0mstone@users.noreply.github.com> --- crates/typst-eval/src/call.rs | 8 +- crates/typst-eval/src/code.rs | 36 +- crates/typst-eval/src/markup.rs | 4 +- crates/typst-eval/src/rules.rs | 2 +- crates/typst-ide/src/complete.rs | 8 +- crates/typst-ide/src/matchers.rs | 4 +- crates/typst-ide/src/tooltip.rs | 2 +- crates/typst-syntax/src/ast.rs | 641 ++++++++++++++++++------------- crates/typst-syntax/src/node.rs | 21 - 9 files changed, 415 insertions(+), 311 deletions(-) diff --git a/crates/typst-eval/src/call.rs b/crates/typst-eval/src/call.rs index c68bef963..1ca7b4b8f 100644 --- a/crates/typst-eval/src/call.rs +++ b/crates/typst-eval/src/call.rs @@ -466,7 +466,7 @@ impl<'a> CapturesVisitor<'a> { } // Code and content blocks create a scope. - Some(ast::Expr::Code(_) | ast::Expr::Content(_)) => { + Some(ast::Expr::CodeBlock(_) | ast::Expr::ContentBlock(_)) => { self.internal.enter(); for child in node.children() { self.visit(child); @@ -516,7 +516,7 @@ impl<'a> CapturesVisitor<'a> { // A let expression contains a binding, but that binding is only // active after the body is evaluated. - Some(ast::Expr::Let(expr)) => { + Some(ast::Expr::LetBinding(expr)) => { if let Some(init) = expr.init() { self.visit(init.to_untyped()); } @@ -529,7 +529,7 @@ impl<'a> CapturesVisitor<'a> { // A for loop contains one or two bindings in its pattern. These are // active after the iterable is evaluated but before the body is // evaluated. - Some(ast::Expr::For(expr)) => { + Some(ast::Expr::ForLoop(expr)) => { self.visit(expr.iterable().to_untyped()); self.internal.enter(); @@ -544,7 +544,7 @@ impl<'a> CapturesVisitor<'a> { // An import contains items, but these are active only after the // path is evaluated. - Some(ast::Expr::Import(expr)) => { + Some(ast::Expr::ModuleImport(expr)) => { self.visit(expr.source().to_untyped()); if let Some(ast::Imports::Items(items)) = expr.imports() { for item in items.iter() { diff --git a/crates/typst-eval/src/code.rs b/crates/typst-eval/src/code.rs index a7b6b6f90..9078418e4 100644 --- a/crates/typst-eval/src/code.rs +++ b/crates/typst-eval/src/code.rs @@ -30,7 +30,7 @@ fn eval_code<'a>( while let Some(expr) = exprs.next() { let span = expr.span(); let value = match expr { - ast::Expr::Set(set) => { + ast::Expr::SetRule(set) => { let styles = set.eval(vm)?; if vm.flow.is_some() { break; @@ -39,7 +39,7 @@ fn eval_code<'a>( let tail = eval_code(vm, exprs)?.display(); Value::Content(tail.styled_with_map(styles)) } - ast::Expr::Show(show) => { + ast::Expr::ShowRule(show) => { let recipe = show.eval(vm)?; if vm.flow.is_some() { break; @@ -94,9 +94,9 @@ impl Eval for ast::Expr<'_> { Self::Label(v) => v.eval(vm), Self::Ref(v) => v.eval(vm).map(Value::Content), Self::Heading(v) => v.eval(vm).map(Value::Content), - Self::List(v) => v.eval(vm).map(Value::Content), - Self::Enum(v) => v.eval(vm).map(Value::Content), - Self::Term(v) => v.eval(vm).map(Value::Content), + Self::ListItem(v) => v.eval(vm).map(Value::Content), + Self::EnumItem(v) => v.eval(vm).map(Value::Content), + Self::TermItem(v) => v.eval(vm).map(Value::Content), Self::Equation(v) => v.eval(vm).map(Value::Content), Self::Math(v) => v.eval(vm).map(Value::Content), Self::MathText(v) => v.eval(vm).map(Value::Content), @@ -116,8 +116,8 @@ impl Eval for ast::Expr<'_> { Self::Float(v) => v.eval(vm), Self::Numeric(v) => v.eval(vm), Self::Str(v) => v.eval(vm), - Self::Code(v) => v.eval(vm), - Self::Content(v) => v.eval(vm).map(Value::Content), + Self::CodeBlock(v) => v.eval(vm), + Self::ContentBlock(v) => v.eval(vm).map(Value::Content), Self::Array(v) => v.eval(vm).map(Value::Array), Self::Dict(v) => v.eval(vm).map(Value::Dict), Self::Parenthesized(v) => v.eval(vm), @@ -126,19 +126,19 @@ impl Eval for ast::Expr<'_> { Self::Closure(v) => v.eval(vm), Self::Unary(v) => v.eval(vm), Self::Binary(v) => v.eval(vm), - Self::Let(v) => v.eval(vm), - Self::DestructAssign(v) => v.eval(vm), - Self::Set(_) => bail!(forbidden("set")), - Self::Show(_) => bail!(forbidden("show")), + Self::LetBinding(v) => v.eval(vm), + Self::DestructAssignment(v) => v.eval(vm), + Self::SetRule(_) => bail!(forbidden("set")), + Self::ShowRule(_) => bail!(forbidden("show")), Self::Contextual(v) => v.eval(vm).map(Value::Content), Self::Conditional(v) => v.eval(vm), - Self::While(v) => v.eval(vm), - Self::For(v) => v.eval(vm), - Self::Import(v) => v.eval(vm), - Self::Include(v) => v.eval(vm).map(Value::Content), - Self::Break(v) => v.eval(vm), - Self::Continue(v) => v.eval(vm), - Self::Return(v) => v.eval(vm), + Self::WhileLoop(v) => v.eval(vm), + Self::ForLoop(v) => v.eval(vm), + Self::ModuleImport(v) => v.eval(vm), + Self::ModuleInclude(v) => v.eval(vm).map(Value::Content), + Self::LoopBreak(v) => v.eval(vm), + Self::LoopContinue(v) => v.eval(vm), + Self::FuncReturn(v) => v.eval(vm), }? .spanned(span); diff --git a/crates/typst-eval/src/markup.rs b/crates/typst-eval/src/markup.rs index 3a5ebe1fc..5beefa912 100644 --- a/crates/typst-eval/src/markup.rs +++ b/crates/typst-eval/src/markup.rs @@ -33,7 +33,7 @@ fn eval_markup<'a>( while let Some(expr) = exprs.next() { match expr { - ast::Expr::Set(set) => { + ast::Expr::SetRule(set) => { let styles = set.eval(vm)?; if vm.flow.is_some() { break; @@ -41,7 +41,7 @@ fn eval_markup<'a>( seq.push(eval_markup(vm, exprs)?.styled_with_map(styles)) } - ast::Expr::Show(show) => { + ast::Expr::ShowRule(show) => { let recipe = show.eval(vm)?; if vm.flow.is_some() { break; diff --git a/crates/typst-eval/src/rules.rs b/crates/typst-eval/src/rules.rs index 646354d4b..f4c1563f3 100644 --- a/crates/typst-eval/src/rules.rs +++ b/crates/typst-eval/src/rules.rs @@ -45,7 +45,7 @@ impl Eval for ast::ShowRule<'_> { let transform = self.transform(); let transform = match transform { - ast::Expr::Set(set) => Transformation::Style(set.eval(vm)?), + ast::Expr::SetRule(set) => Transformation::Style(set.eval(vm)?), expr => expr.eval(vm)?.cast::().at(transform.span())?, }; diff --git a/crates/typst-ide/src/complete.rs b/crates/typst-ide/src/complete.rs index e3d777115..91fa53f9a 100644 --- a/crates/typst-ide/src/complete.rs +++ b/crates/typst-ide/src/complete.rs @@ -517,7 +517,7 @@ fn complete_imports(ctx: &mut CompletionContext) -> bool { // "#import "path.typ": a, b, |". if_chain! { if let Some(prev) = ctx.leaf.prev_sibling(); - if let Some(ast::Expr::Import(import)) = prev.get().cast(); + if let Some(ast::Expr::ModuleImport(import)) = prev.get().cast(); if let Some(ast::Imports::Items(items)) = import.imports(); if let Some(source) = prev.children().find(|child| child.is::()); then { @@ -536,7 +536,7 @@ fn complete_imports(ctx: &mut CompletionContext) -> bool { if let Some(grand) = parent.parent(); if grand.kind() == SyntaxKind::ImportItems; if let Some(great) = grand.parent(); - if let Some(ast::Expr::Import(import)) = great.get().cast(); + if let Some(ast::Expr::ModuleImport(import)) = great.get().cast(); if let Some(ast::Imports::Items(items)) = import.imports(); if let Some(source) = great.children().find(|child| child.is::()); then { @@ -677,10 +677,10 @@ fn complete_params(ctx: &mut CompletionContext) -> bool { if let Some(args) = parent.get().cast::(); if let Some(grand) = parent.parent(); if let Some(expr) = grand.get().cast::(); - let set = matches!(expr, ast::Expr::Set(_)); + let set = matches!(expr, ast::Expr::SetRule(_)); if let Some(callee) = match expr { ast::Expr::FuncCall(call) => Some(call.callee()), - ast::Expr::Set(set) => Some(set.target()), + ast::Expr::SetRule(set) => Some(set.target()), _ => None, }; then { diff --git a/crates/typst-ide/src/matchers.rs b/crates/typst-ide/src/matchers.rs index 270d2f43c..93fdc5dd5 100644 --- a/crates/typst-ide/src/matchers.rs +++ b/crates/typst-ide/src/matchers.rs @@ -232,7 +232,9 @@ pub fn deref_target(node: LinkedNode) -> Option> { ast::Expr::FuncCall(call) => { DerefTarget::Callee(expr_node.find(call.callee().span())?) } - ast::Expr::Set(set) => DerefTarget::Callee(expr_node.find(set.target().span())?), + ast::Expr::SetRule(set) => { + DerefTarget::Callee(expr_node.find(set.target().span())?) + } ast::Expr::Ident(_) | ast::Expr::MathIdent(_) | ast::Expr::FieldAccess(_) => { DerefTarget::VarAccess(expr_node) } diff --git a/crates/typst-ide/src/tooltip.rs b/crates/typst-ide/src/tooltip.rs index cfb977733..cbfffe530 100644 --- a/crates/typst-ide/src/tooltip.rs +++ b/crates/typst-ide/src/tooltip.rs @@ -201,7 +201,7 @@ fn named_param_tooltip(world: &dyn IdeWorld, leaf: &LinkedNode) -> Option(); if let Some(ast::Expr::Ident(callee)) = match expr { ast::Expr::FuncCall(call) => Some(call.callee()), - ast::Expr::Set(set) => Some(set.target()), + ast::Expr::SetRule(set) => Some(set.target()), _ => None, }; diff --git a/crates/typst-syntax/src/ast.rs b/crates/typst-syntax/src/ast.rs index 640138e77..f79e65982 100644 --- a/crates/typst-syntax/src/ast.rs +++ b/crates/typst-syntax/src/ast.rs @@ -1,6 +1,81 @@ -//! A typed layer over the untyped syntax tree. -//! -//! The AST is rooted in the [`Markup`] node. +/*! +# Abstract Syntax Tree Interface + +Typst's Abstract Syntax Tree (AST) is a lazy, typed view over the untyped +Concrete Syntax Tree (CST) and is rooted in the [`Markup`] node. + +## The AST is a View + +Most AST nodes are wrapper structs around [`SyntaxNode`] pointers. This summary +will use a running example of the [`Raw`] node type, which is declared (after +macro expansion) as: `struct Raw<'a>(&'a SyntaxNode);`. + +[`SyntaxNode`]s are generated by the parser and constitute the Concrete Syntax +Tree (CST). The CST is _concrete_ because it has the property that an in-order +tree traversal will recreate the text of the source file exactly. + +[`SyntaxNode`]s in the CST contain their [`SyntaxKind`], but don't themselves +provide access to the semantic meaning of their contents. That semantic meaning +is available through the Abstract Syntax Tree by iterating over CST nodes and +inspecting their contents. The format is prepared ahead-of-time by the parser so +that this module can unpack the abstract meaning from the CST's structure. + +Raw nodes are parsed by recognizing paired backtick delimiters, which you will +find as CST nodes with the [`RawDelim`] kind. However, the AST doesn't include +these delimiters because it _abstracts_ over the backticks. Instead, the parent +raw node will only use its child [`RawDelim`] CST nodes to determine whether the +element is a block or inline. + +## The AST is Typed + +AST nodes all implement the [`AstNode`] trait, but nodes can also implement +their own unique methods. These unique methods are the "real" interface of the +AST, and provide access to the abstract, semantic, representation of each kind +of node. For example, the [`Raw`] node provides 3 methods that specify its +abstract representation: [`Raw::lines()`] returns the raw text as an iterator of +lines, [`Raw::lang()`] provides the optionally present [`RawLang`] language tag, +and [`Raw::block()`] gives a bool for whether the raw element is a block or +inline. + +This semantic information is unavailable in the CST. Only by converting a CST +node to an AST struct will Rust let you call a method of that struct. This is a +safe interface because the only way to create an AST node outside this file is +to call [`AstNode::from_untyped`]. The `node!` macro implements `from_untyped` +by checking the node's kind before constructing it, returning `Some()` only if +the kind matches. So we know that it will have the expected children underneath, +otherwise the parser wouldn't have produced this node. + +## The AST is rooted in the [`Markup`] node + +The AST is rooted in the [`Markup`] node, which provides only one method: +[`Markup::exprs`]. This returns an iterator of the main [`Expr`] enum. [`Expr`] +is important because it contains the majority of expressions that Typst will +evaluate. Not just markup, but also math and code expressions. Not all +expression types are available from the parser at every step, but this does +decrease the amount of wrapper enums needed in the AST (and this file is long +enough already). + +Expressions also branch off into the remaining tree. You can view enums in this +file as edges on a graph: areas where the tree has paths from one type to +another (accessed through methods), then structs are the nodes of the graph, +providing methods that return enums, etc. etc. + +## The AST is Lazy + +Being lazy means that the untyped CST nodes are converted to typed AST nodes +only as the tree is traversed. If we parse a file and a raw block is contained +in a branch of an if-statement that we don't take, then we won't pay the cost of +creating an iterator over the lines or checking whether it was a block or +inline (although it will still be parsed into nodes). + +This is also a factor of the current "tree-interpreter" evaluation model. A +bytecode interpreter might instead eagerly convert the AST into bytecode, but it +would still traverse using this lazy interface. While the tree-interpreter +evaluation is straightforward and easy to add new features onto, it has to +re-traverse the AST every time a function is evaluated. A bytecode interpreter +using the lazy interface would only need to traverse each node once, improving +throughput at the cost of initial latency and development flexibility. +*/ use std::num::NonZeroUsize; use std::ops::Deref; @@ -27,8 +102,55 @@ pub trait AstNode<'a>: Sized { } } +// A generic interface for converting untyped nodes into typed AST nodes. +impl SyntaxNode { + /// Whether the node can be cast to the given AST node. + pub fn is<'a, T: AstNode<'a>>(&'a self) -> bool { + self.cast::().is_some() + } + + /// Try to convert the node to a typed AST node. + pub fn cast<'a, T: AstNode<'a>>(&'a self) -> Option { + T::from_untyped(self) + } + + /// Find the first child that can cast to the AST type `T`. + fn try_cast_first<'a, T: AstNode<'a>>(&'a self) -> Option { + self.children().find_map(Self::cast) + } + + /// Find the last child that can cast to the AST type `T`. + fn try_cast_last<'a, T: AstNode<'a>>(&'a self) -> Option { + self.children().rev().find_map(Self::cast) + } + + /// Get the first child of AST type `T` or a placeholder if none. + fn cast_first<'a, T: AstNode<'a> + Default>(&'a self) -> T { + self.try_cast_first().unwrap_or_default() + } + + /// Get the last child of AST type `T` or a placeholder if none. + fn cast_last<'a, T: AstNode<'a> + Default>(&'a self) -> T { + self.try_cast_last().unwrap_or_default() + } +} + +/// Implements [`AstNode`] for a struct whose name matches a [`SyntaxKind`] +/// variant. +/// +/// The struct becomes a wrapper around a [`SyntaxNode`] pointer, and the +/// implementation of [`AstNode::from_untyped`] checks that the pointer's kind +/// matches when converting, returning `Some` or `None` respectively. +/// +/// The generated struct is the basis for typed accessor methods for properties +/// of this AST node. For example, the [`Raw`] struct has methods for accessing +/// its content by lines, its optional language tag, and whether the raw element +/// is inline or a block. These methods are accessible only _after_ a +/// `SyntaxNode` is coerced to the `Raw` struct type (via `from_untyped`), +/// guaranteeing their implementations will work with the expected structure. macro_rules! node { - ($(#[$attr:meta])* $name:ident) => { + ($(#[$attr:meta])* struct $name:ident) => { + // Create the struct as a wrapper around a `SyntaxNode` reference. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] #[repr(transparent)] $(#[$attr])* @@ -63,7 +185,7 @@ macro_rules! node { node! { /// The syntactical root capable of representing a full parsed document. - Markup + struct Markup } impl<'a> Markup<'a> { @@ -117,11 +239,11 @@ pub enum Expr<'a> { /// A section heading: `= Introduction`. Heading(Heading<'a>), /// An item in a bullet list: `- ...`. - List(ListItem<'a>), + ListItem(ListItem<'a>), /// An item in an enumeration (numbered list): `+ ...` or `1. ...`. - Enum(EnumItem<'a>), + EnumItem(EnumItem<'a>), /// An item in a term list: `/ Term: Details`. - Term(TermItem<'a>), + TermItem(TermItem<'a>), /// A mathematical equation: `$x$`, `$ x^2 $`. Equation(Equation<'a>), /// The contents of a mathematical equation: `x^2 + 1`. @@ -161,9 +283,9 @@ pub enum Expr<'a> { /// A quoted string: `"..."`. Str(Str<'a>), /// A code block: `{ let x = 1; x + 2 }`. - Code(CodeBlock<'a>), + CodeBlock(CodeBlock<'a>), /// A content block: `[*Hi* there!]`. - Content(ContentBlock<'a>), + ContentBlock(ContentBlock<'a>), /// A grouped expression: `(1 + 2)`. Parenthesized(Parenthesized<'a>), /// An array: `(1, "hi", 12cm)`. @@ -181,37 +303,37 @@ pub enum Expr<'a> { /// A closure: `(x, y) => z`. Closure(Closure<'a>), /// A let binding: `let x = 1`. - Let(LetBinding<'a>), + LetBinding(LetBinding<'a>), /// A destructuring assignment: `(x, y) = (1, 2)`. - DestructAssign(DestructAssignment<'a>), + DestructAssignment(DestructAssignment<'a>), /// A set rule: `set text(...)`. - Set(SetRule<'a>), + SetRule(SetRule<'a>), /// A show rule: `show heading: it => emph(it.body)`. - Show(ShowRule<'a>), + ShowRule(ShowRule<'a>), /// A contextual expression: `context text.lang`. Contextual(Contextual<'a>), /// An if-else conditional: `if x { y } else { z }`. Conditional(Conditional<'a>), /// A while loop: `while x { y }`. - While(WhileLoop<'a>), + WhileLoop(WhileLoop<'a>), /// A for loop: `for x in y { z }`. - For(ForLoop<'a>), + ForLoop(ForLoop<'a>), /// A module import: `import "utils.typ": a, b, c`. - Import(ModuleImport<'a>), + ModuleImport(ModuleImport<'a>), /// A module include: `include "chapter1.typ"`. - Include(ModuleInclude<'a>), + ModuleInclude(ModuleInclude<'a>), /// A break from a loop: `break`. - Break(LoopBreak<'a>), + LoopBreak(LoopBreak<'a>), /// A continue in a loop: `continue`. - Continue(LoopContinue<'a>), + LoopContinue(LoopContinue<'a>), /// A return from a function: `return`, `return x + 1`. - Return(FuncReturn<'a>), + FuncReturn(FuncReturn<'a>), } impl<'a> Expr<'a> { fn cast_with_space(node: &'a SyntaxNode) -> Option { match node.kind() { - SyntaxKind::Space => node.cast().map(Self::Space), + SyntaxKind::Space => Some(Self::Space(Space(node))), _ => Self::from_untyped(node), } } @@ -220,64 +342,69 @@ impl<'a> Expr<'a> { impl<'a> AstNode<'a> for Expr<'a> { fn from_untyped(node: &'a SyntaxNode) -> Option { match node.kind() { - SyntaxKind::Linebreak => node.cast().map(Self::Linebreak), - SyntaxKind::Parbreak => node.cast().map(Self::Parbreak), - SyntaxKind::Text => node.cast().map(Self::Text), - SyntaxKind::Escape => node.cast().map(Self::Escape), - SyntaxKind::Shorthand => node.cast().map(Self::Shorthand), - SyntaxKind::SmartQuote => node.cast().map(Self::SmartQuote), - SyntaxKind::Strong => node.cast().map(Self::Strong), - SyntaxKind::Emph => node.cast().map(Self::Emph), - SyntaxKind::Raw => node.cast().map(Self::Raw), - SyntaxKind::Link => node.cast().map(Self::Link), - SyntaxKind::Label => node.cast().map(Self::Label), - SyntaxKind::Ref => node.cast().map(Self::Ref), - SyntaxKind::Heading => node.cast().map(Self::Heading), - SyntaxKind::ListItem => node.cast().map(Self::List), - SyntaxKind::EnumItem => node.cast().map(Self::Enum), - SyntaxKind::TermItem => node.cast().map(Self::Term), - SyntaxKind::Equation => node.cast().map(Self::Equation), - SyntaxKind::Math => node.cast().map(Self::Math), - SyntaxKind::MathText => node.cast().map(Self::MathText), - SyntaxKind::MathIdent => node.cast().map(Self::MathIdent), - SyntaxKind::MathShorthand => node.cast().map(Self::MathShorthand), - SyntaxKind::MathAlignPoint => node.cast().map(Self::MathAlignPoint), - SyntaxKind::MathDelimited => node.cast().map(Self::MathDelimited), - SyntaxKind::MathAttach => node.cast().map(Self::MathAttach), - SyntaxKind::MathPrimes => node.cast().map(Self::MathPrimes), - SyntaxKind::MathFrac => node.cast().map(Self::MathFrac), - SyntaxKind::MathRoot => node.cast().map(Self::MathRoot), - SyntaxKind::Ident => node.cast().map(Self::Ident), - SyntaxKind::None => node.cast().map(Self::None), - SyntaxKind::Auto => node.cast().map(Self::Auto), - SyntaxKind::Bool => node.cast().map(Self::Bool), - SyntaxKind::Int => node.cast().map(Self::Int), - SyntaxKind::Float => node.cast().map(Self::Float), - SyntaxKind::Numeric => node.cast().map(Self::Numeric), - SyntaxKind::Str => node.cast().map(Self::Str), - SyntaxKind::CodeBlock => node.cast().map(Self::Code), - SyntaxKind::ContentBlock => node.cast().map(Self::Content), - SyntaxKind::Parenthesized => node.cast().map(Self::Parenthesized), - SyntaxKind::Array => node.cast().map(Self::Array), - SyntaxKind::Dict => node.cast().map(Self::Dict), - SyntaxKind::Unary => node.cast().map(Self::Unary), - SyntaxKind::Binary => node.cast().map(Self::Binary), - SyntaxKind::FieldAccess => node.cast().map(Self::FieldAccess), - SyntaxKind::FuncCall => node.cast().map(Self::FuncCall), - SyntaxKind::Closure => node.cast().map(Self::Closure), - SyntaxKind::LetBinding => node.cast().map(Self::Let), - SyntaxKind::DestructAssignment => node.cast().map(Self::DestructAssign), - SyntaxKind::SetRule => node.cast().map(Self::Set), - SyntaxKind::ShowRule => node.cast().map(Self::Show), - SyntaxKind::Contextual => node.cast().map(Self::Contextual), - SyntaxKind::Conditional => node.cast().map(Self::Conditional), - SyntaxKind::WhileLoop => node.cast().map(Self::While), - SyntaxKind::ForLoop => node.cast().map(Self::For), - SyntaxKind::ModuleImport => node.cast().map(Self::Import), - SyntaxKind::ModuleInclude => node.cast().map(Self::Include), - SyntaxKind::LoopBreak => node.cast().map(Self::Break), - SyntaxKind::LoopContinue => node.cast().map(Self::Continue), - SyntaxKind::FuncReturn => node.cast().map(Self::Return), + SyntaxKind::Space => Option::None, // Skipped unless using `cast_with_space`. + SyntaxKind::Linebreak => Some(Self::Linebreak(Linebreak(node))), + SyntaxKind::Parbreak => Some(Self::Parbreak(Parbreak(node))), + SyntaxKind::Text => Some(Self::Text(Text(node))), + SyntaxKind::Escape => Some(Self::Escape(Escape(node))), + SyntaxKind::Shorthand => Some(Self::Shorthand(Shorthand(node))), + SyntaxKind::SmartQuote => Some(Self::SmartQuote(SmartQuote(node))), + SyntaxKind::Strong => Some(Self::Strong(Strong(node))), + SyntaxKind::Emph => Some(Self::Emph(Emph(node))), + SyntaxKind::Raw => Some(Self::Raw(Raw(node))), + SyntaxKind::Link => Some(Self::Link(Link(node))), + SyntaxKind::Label => Some(Self::Label(Label(node))), + SyntaxKind::Ref => Some(Self::Ref(Ref(node))), + SyntaxKind::Heading => Some(Self::Heading(Heading(node))), + SyntaxKind::ListItem => Some(Self::ListItem(ListItem(node))), + SyntaxKind::EnumItem => Some(Self::EnumItem(EnumItem(node))), + SyntaxKind::TermItem => Some(Self::TermItem(TermItem(node))), + SyntaxKind::Equation => Some(Self::Equation(Equation(node))), + SyntaxKind::Math => Some(Self::Math(Math(node))), + SyntaxKind::MathText => Some(Self::MathText(MathText(node))), + SyntaxKind::MathIdent => Some(Self::MathIdent(MathIdent(node))), + SyntaxKind::MathShorthand => Some(Self::MathShorthand(MathShorthand(node))), + SyntaxKind::MathAlignPoint => { + Some(Self::MathAlignPoint(MathAlignPoint(node))) + } + SyntaxKind::MathDelimited => Some(Self::MathDelimited(MathDelimited(node))), + SyntaxKind::MathAttach => Some(Self::MathAttach(MathAttach(node))), + SyntaxKind::MathPrimes => Some(Self::MathPrimes(MathPrimes(node))), + SyntaxKind::MathFrac => Some(Self::MathFrac(MathFrac(node))), + SyntaxKind::MathRoot => Some(Self::MathRoot(MathRoot(node))), + SyntaxKind::Ident => Some(Self::Ident(Ident(node))), + SyntaxKind::None => Some(Self::None(None(node))), + SyntaxKind::Auto => Some(Self::Auto(Auto(node))), + SyntaxKind::Bool => Some(Self::Bool(Bool(node))), + SyntaxKind::Int => Some(Self::Int(Int(node))), + SyntaxKind::Float => Some(Self::Float(Float(node))), + SyntaxKind::Numeric => Some(Self::Numeric(Numeric(node))), + SyntaxKind::Str => Some(Self::Str(Str(node))), + SyntaxKind::CodeBlock => Some(Self::CodeBlock(CodeBlock(node))), + SyntaxKind::ContentBlock => Some(Self::ContentBlock(ContentBlock(node))), + SyntaxKind::Parenthesized => Some(Self::Parenthesized(Parenthesized(node))), + SyntaxKind::Array => Some(Self::Array(Array(node))), + SyntaxKind::Dict => Some(Self::Dict(Dict(node))), + SyntaxKind::Unary => Some(Self::Unary(Unary(node))), + SyntaxKind::Binary => Some(Self::Binary(Binary(node))), + SyntaxKind::FieldAccess => Some(Self::FieldAccess(FieldAccess(node))), + SyntaxKind::FuncCall => Some(Self::FuncCall(FuncCall(node))), + SyntaxKind::Closure => Some(Self::Closure(Closure(node))), + SyntaxKind::LetBinding => Some(Self::LetBinding(LetBinding(node))), + SyntaxKind::DestructAssignment => { + Some(Self::DestructAssignment(DestructAssignment(node))) + } + SyntaxKind::SetRule => Some(Self::SetRule(SetRule(node))), + SyntaxKind::ShowRule => Some(Self::ShowRule(ShowRule(node))), + SyntaxKind::Contextual => Some(Self::Contextual(Contextual(node))), + SyntaxKind::Conditional => Some(Self::Conditional(Conditional(node))), + SyntaxKind::WhileLoop => Some(Self::WhileLoop(WhileLoop(node))), + SyntaxKind::ForLoop => Some(Self::ForLoop(ForLoop(node))), + SyntaxKind::ModuleImport => Some(Self::ModuleImport(ModuleImport(node))), + SyntaxKind::ModuleInclude => Some(Self::ModuleInclude(ModuleInclude(node))), + SyntaxKind::LoopBreak => Some(Self::LoopBreak(LoopBreak(node))), + SyntaxKind::LoopContinue => Some(Self::LoopContinue(LoopContinue(node))), + SyntaxKind::FuncReturn => Some(Self::FuncReturn(FuncReturn(node))), _ => Option::None, } } @@ -298,9 +425,9 @@ impl<'a> AstNode<'a> for Expr<'a> { Self::Label(v) => v.to_untyped(), Self::Ref(v) => v.to_untyped(), Self::Heading(v) => v.to_untyped(), - Self::List(v) => v.to_untyped(), - Self::Enum(v) => v.to_untyped(), - Self::Term(v) => v.to_untyped(), + Self::ListItem(v) => v.to_untyped(), + Self::EnumItem(v) => v.to_untyped(), + Self::TermItem(v) => v.to_untyped(), Self::Equation(v) => v.to_untyped(), Self::Math(v) => v.to_untyped(), Self::MathText(v) => v.to_untyped(), @@ -320,8 +447,8 @@ impl<'a> AstNode<'a> for Expr<'a> { Self::Float(v) => v.to_untyped(), Self::Numeric(v) => v.to_untyped(), Self::Str(v) => v.to_untyped(), - Self::Code(v) => v.to_untyped(), - Self::Content(v) => v.to_untyped(), + Self::CodeBlock(v) => v.to_untyped(), + Self::ContentBlock(v) => v.to_untyped(), Self::Array(v) => v.to_untyped(), Self::Dict(v) => v.to_untyped(), Self::Parenthesized(v) => v.to_untyped(), @@ -330,19 +457,19 @@ impl<'a> AstNode<'a> for Expr<'a> { Self::FieldAccess(v) => v.to_untyped(), Self::FuncCall(v) => v.to_untyped(), Self::Closure(v) => v.to_untyped(), - Self::Let(v) => v.to_untyped(), - Self::DestructAssign(v) => v.to_untyped(), - Self::Set(v) => v.to_untyped(), - Self::Show(v) => v.to_untyped(), + Self::LetBinding(v) => v.to_untyped(), + Self::DestructAssignment(v) => v.to_untyped(), + Self::SetRule(v) => v.to_untyped(), + Self::ShowRule(v) => v.to_untyped(), Self::Contextual(v) => v.to_untyped(), Self::Conditional(v) => v.to_untyped(), - Self::While(v) => v.to_untyped(), - Self::For(v) => v.to_untyped(), - Self::Import(v) => v.to_untyped(), - Self::Include(v) => v.to_untyped(), - Self::Break(v) => v.to_untyped(), - Self::Continue(v) => v.to_untyped(), - Self::Return(v) => v.to_untyped(), + Self::WhileLoop(v) => v.to_untyped(), + Self::ForLoop(v) => v.to_untyped(), + Self::ModuleImport(v) => v.to_untyped(), + Self::ModuleInclude(v) => v.to_untyped(), + Self::LoopBreak(v) => v.to_untyped(), + Self::LoopContinue(v) => v.to_untyped(), + Self::FuncReturn(v) => v.to_untyped(), } } } @@ -360,25 +487,25 @@ impl Expr<'_> { | Self::Float(_) | Self::Numeric(_) | Self::Str(_) - | Self::Code(_) - | Self::Content(_) + | Self::CodeBlock(_) + | Self::ContentBlock(_) | Self::Array(_) | Self::Dict(_) | Self::Parenthesized(_) | Self::FieldAccess(_) | Self::FuncCall(_) - | Self::Let(_) - | Self::Set(_) - | Self::Show(_) + | Self::LetBinding(_) + | Self::SetRule(_) + | Self::ShowRule(_) | Self::Contextual(_) | Self::Conditional(_) - | Self::While(_) - | Self::For(_) - | Self::Import(_) - | Self::Include(_) - | Self::Break(_) - | Self::Continue(_) - | Self::Return(_) + | Self::WhileLoop(_) + | Self::ForLoop(_) + | Self::ModuleImport(_) + | Self::ModuleInclude(_) + | Self::LoopBreak(_) + | Self::LoopContinue(_) + | Self::FuncReturn(_) ) } @@ -405,7 +532,7 @@ impl Default for Expr<'_> { node! { /// Plain text without markup. - Text + struct Text } impl<'a> Text<'a> { @@ -418,22 +545,22 @@ impl<'a> Text<'a> { node! { /// Whitespace in markup or math. Has at most one newline in markup, as more /// indicate a paragraph break. - Space + struct Space } node! { /// A forced line break: `\`. - Linebreak + struct Linebreak } node! { /// A paragraph break, indicated by one or multiple blank lines. - Parbreak + struct Parbreak } node! { /// An escape sequence: `\#`, `\u{1F5FA}`. - Escape + struct Escape } impl Escape<'_> { @@ -456,7 +583,7 @@ impl Escape<'_> { node! { /// A shorthand for a unicode codepoint. For example, `~` for a non-breaking /// space or `-?` for a soft hyphen. - Shorthand + struct Shorthand } impl Shorthand<'_> { @@ -482,7 +609,7 @@ impl Shorthand<'_> { node! { /// A smart quote: `'` or `"`. - SmartQuote + struct SmartQuote } impl SmartQuote<'_> { @@ -494,31 +621,31 @@ impl SmartQuote<'_> { node! { /// Strong content: `*Strong*`. - Strong + struct Strong } impl<'a> Strong<'a> { /// The contents of the strong node. pub fn body(self) -> Markup<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } } node! { /// Emphasized content: `_Emphasized_`. - Emph + struct Emph } impl<'a> Emph<'a> { /// The contents of the emphasis node. pub fn body(self) -> Markup<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } } node! { /// Raw text with optional syntax highlighting: `` `...` ``. - Raw + struct Raw } impl<'a> Raw<'a> { @@ -530,18 +657,18 @@ impl<'a> Raw<'a> { /// An optional identifier specifying the language to syntax-highlight in. pub fn lang(self) -> Option> { // Only blocky literals are supposed to contain a language. - let delim: RawDelim = self.0.cast_first_match()?; + let delim: RawDelim = self.0.try_cast_first()?; if delim.0.len() < 3 { return Option::None; } - self.0.cast_first_match() + self.0.try_cast_first() } /// Whether the raw text should be displayed in a separate block. pub fn block(self) -> bool { self.0 - .cast_first_match() + .try_cast_first() .is_some_and(|delim: RawDelim| delim.0.len() >= 3) && self.0.children().any(|e| { e.kind() == SyntaxKind::RawTrimmed && e.text().chars().any(is_newline) @@ -551,7 +678,7 @@ impl<'a> Raw<'a> { node! { /// A language tag at the start of raw element: ``typ ``. - RawLang + struct RawLang } impl<'a> RawLang<'a> { @@ -563,12 +690,12 @@ impl<'a> RawLang<'a> { node! { /// A raw delimiter in single or 3+ backticks: `` ` ``. - RawDelim + struct RawDelim } node! { /// A hyperlink: `https://typst.org`. - Link + struct Link } impl<'a> Link<'a> { @@ -580,7 +707,7 @@ impl<'a> Link<'a> { node! { /// A label: ``. - Label + struct Label } impl<'a> Label<'a> { @@ -592,7 +719,7 @@ impl<'a> Label<'a> { node! { /// A reference: `@target`, `@target[..]`. - Ref + struct Ref } impl<'a> Ref<'a> { @@ -607,19 +734,19 @@ impl<'a> Ref<'a> { /// Get the supplement. pub fn supplement(self) -> Option> { - self.0.cast_last_match() + self.0.try_cast_last() } } node! { /// A section heading: `= Introduction`. - Heading + struct Heading } impl<'a> Heading<'a> { /// The contents of the heading. pub fn body(self) -> Markup<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The section depth (number of equals signs). @@ -634,19 +761,19 @@ impl<'a> Heading<'a> { node! { /// An item in a bullet list: `- ...`. - ListItem + struct ListItem } impl<'a> ListItem<'a> { /// The contents of the list item. pub fn body(self) -> Markup<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } } node! { /// An item in an enumeration (numbered list): `+ ...` or `1. ...`. - EnumItem + struct EnumItem } impl<'a> EnumItem<'a> { @@ -660,36 +787,36 @@ impl<'a> EnumItem<'a> { /// The contents of the list item. pub fn body(self) -> Markup<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } } node! { /// An item in a term list: `/ Term: Details`. - TermItem + struct TermItem } impl<'a> TermItem<'a> { /// The term described by the item. pub fn term(self) -> Markup<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The description of the term. pub fn description(self) -> Markup<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A mathematical equation: `$x$`, `$ x^2 $`. - Equation + struct Equation } impl<'a> Equation<'a> { /// The contained math. pub fn body(self) -> Math<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// Whether the equation should be displayed as a separate block. @@ -703,7 +830,7 @@ impl<'a> Equation<'a> { node! { /// The contents of a mathematical equation: `x^2 + 1`. - Math + struct Math } impl<'a> Math<'a> { @@ -715,7 +842,7 @@ impl<'a> Math<'a> { node! { /// A lone text fragment in math: `x`, `25`, `3.1415`, `=`, `[`. - MathText + struct MathText } /// The underlying text kind. @@ -743,7 +870,7 @@ impl<'a> MathText<'a> { node! { /// An identifier in math: `pi`. - MathIdent + struct MathIdent } impl<'a> MathIdent<'a> { @@ -770,7 +897,7 @@ impl Deref for MathIdent<'_> { node! { /// A shorthand for a unicode codepoint in math: `a <= b`. - MathShorthand + struct MathShorthand } impl MathShorthand<'_> { @@ -828,40 +955,40 @@ impl MathShorthand<'_> { node! { /// An alignment point in math: `&`. - MathAlignPoint + struct MathAlignPoint } node! { /// Matched delimiters in math: `[x + y]`. - MathDelimited + struct MathDelimited } impl<'a> MathDelimited<'a> { /// The opening delimiter. pub fn open(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The contents, including the delimiters. pub fn body(self) -> Math<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The closing delimiter. pub fn close(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A base with optional attachments in math: `a_1^2`. - MathAttach + struct MathAttach } impl<'a> MathAttach<'a> { /// The base, to which things are attached. pub fn base(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The bottom attachment. @@ -892,7 +1019,7 @@ impl<'a> MathAttach<'a> { node! { /// Grouped primes in math: `a'''`. - MathPrimes + struct MathPrimes } impl MathPrimes<'_> { @@ -907,24 +1034,24 @@ impl MathPrimes<'_> { node! { /// A fraction in math: `x/2` - MathFrac + struct MathFrac } impl<'a> MathFrac<'a> { /// The numerator. pub fn num(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The denominator. pub fn denom(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A root in math: `√x`, `∛x` or `∜x`. - MathRoot + struct MathRoot } impl<'a> MathRoot<'a> { @@ -940,13 +1067,13 @@ impl<'a> MathRoot<'a> { /// The radicand. pub fn radicand(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } } node! { /// An identifier: `it`. - Ident + struct Ident } impl<'a> Ident<'a> { @@ -973,17 +1100,17 @@ impl Deref for Ident<'_> { node! { /// The `none` literal. - None + struct None } node! { /// The `auto` literal. - Auto + struct Auto } node! { /// A boolean: `true`, `false`. - Bool + struct Bool } impl Bool<'_> { @@ -995,7 +1122,7 @@ impl Bool<'_> { node! { /// An integer: `120`. - Int + struct Int } impl Int<'_> { @@ -1017,7 +1144,7 @@ impl Int<'_> { node! { /// A floating-point number: `1.2`, `10e-4`. - Float + struct Float } impl Float<'_> { @@ -1029,7 +1156,7 @@ impl Float<'_> { node! { /// A numeric value with a unit: `12pt`, `3cm`, `2em`, `90deg`, `50%`. - Numeric + struct Numeric } impl Numeric<'_> { @@ -1086,7 +1213,7 @@ pub enum Unit { node! { /// A quoted string: `"..."`. - Str + struct Str } impl Str<'_> { @@ -1136,19 +1263,19 @@ impl Str<'_> { node! { /// A code block: `{ let x = 1; x + 2 }`. - CodeBlock + struct CodeBlock } impl<'a> CodeBlock<'a> { /// The contained code. pub fn body(self) -> Code<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } } node! { /// The body of a code block. - Code + struct Code } impl<'a> Code<'a> { @@ -1160,19 +1287,19 @@ impl<'a> Code<'a> { node! { /// A content block: `[*Hi* there!]`. - ContentBlock + struct ContentBlock } impl<'a> ContentBlock<'a> { /// The contained markup. pub fn body(self) -> Markup<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } } node! { /// A grouped expression: `(1 + 2)`. - Parenthesized + struct Parenthesized } impl<'a> Parenthesized<'a> { @@ -1180,20 +1307,20 @@ impl<'a> Parenthesized<'a> { /// /// Should only be accessed if this is contained in an `Expr`. pub fn expr(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The wrapped pattern. /// /// Should only be accessed if this is contained in a `Pattern`. pub fn pattern(self) -> Pattern<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } } node! { /// An array: `(1, "hi", 12cm)`. - Array + struct Array } impl<'a> Array<'a> { @@ -1215,7 +1342,7 @@ pub enum ArrayItem<'a> { impl<'a> AstNode<'a> for ArrayItem<'a> { fn from_untyped(node: &'a SyntaxNode) -> Option { match node.kind() { - SyntaxKind::Spread => node.cast().map(Self::Spread), + SyntaxKind::Spread => Some(Self::Spread(Spread(node))), _ => node.cast().map(Self::Pos), } } @@ -1230,7 +1357,7 @@ impl<'a> AstNode<'a> for ArrayItem<'a> { node! { /// A dictionary: `(thickness: 3pt, dash: "solid")`. - Dict + struct Dict } impl<'a> Dict<'a> { @@ -1254,9 +1381,9 @@ pub enum DictItem<'a> { impl<'a> AstNode<'a> for DictItem<'a> { fn from_untyped(node: &'a SyntaxNode) -> Option { match node.kind() { - SyntaxKind::Named => node.cast().map(Self::Named), - SyntaxKind::Keyed => node.cast().map(Self::Keyed), - SyntaxKind::Spread => node.cast().map(Self::Spread), + SyntaxKind::Named => Some(Self::Named(Named(node))), + SyntaxKind::Keyed => Some(Self::Keyed(Keyed(node))), + SyntaxKind::Spread => Some(Self::Spread(Spread(node))), _ => Option::None, } } @@ -1272,13 +1399,13 @@ impl<'a> AstNode<'a> for DictItem<'a> { node! { /// A named pair: `thickness: 3pt`. - Named + struct Named } impl<'a> Named<'a> { /// The name: `thickness`. pub fn name(self) -> Ident<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The right-hand side of the pair: `3pt`. @@ -1286,7 +1413,7 @@ impl<'a> Named<'a> { /// This should only be accessed if this `Named` is contained in a /// `DictItem`, `Arg`, or `Param`. pub fn expr(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } /// The right-hand side of the pair as a pattern. @@ -1294,19 +1421,19 @@ impl<'a> Named<'a> { /// This should only be accessed if this `Named` is contained in a /// `Destructuring`. pub fn pattern(self) -> Pattern<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A keyed pair: `"spacy key": true`. - Keyed + struct Keyed } impl<'a> Keyed<'a> { /// The key: `"spacy key"`. pub fn key(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The right-hand side of the pair: `true`. @@ -1314,13 +1441,13 @@ impl<'a> Keyed<'a> { /// This should only be accessed if this `Keyed` is contained in a /// `DictItem`. pub fn expr(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A spread: `..x` or `..x.at(0)`. - Spread + struct Spread } impl<'a> Spread<'a> { @@ -1329,7 +1456,7 @@ impl<'a> Spread<'a> { /// This should only be accessed if this `Spread` is contained in an /// `ArrayItem`, `DictItem`, or `Arg`. pub fn expr(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The sink identifier, if present. @@ -1337,7 +1464,7 @@ impl<'a> Spread<'a> { /// This should only be accessed if this `Spread` is contained in a /// `Param` or binding `DestructuringItem`. pub fn sink_ident(self) -> Option> { - self.0.cast_first_match() + self.0.try_cast_first() } /// The sink expressions, if present. @@ -1345,13 +1472,13 @@ impl<'a> Spread<'a> { /// This should only be accessed if this `Spread` is contained in a /// `DestructuringItem`. pub fn sink_expr(self) -> Option> { - self.0.cast_first_match() + self.0.try_cast_first() } } node! { /// A unary operation: `-x`. - Unary + struct Unary } impl<'a> Unary<'a> { @@ -1365,7 +1492,7 @@ impl<'a> Unary<'a> { /// The expression to operate on: `x`. pub fn expr(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } @@ -1411,7 +1538,7 @@ impl UnOp { node! { /// A binary operation: `a + b`. - Binary + struct Binary } impl<'a> Binary<'a> { @@ -1433,12 +1560,12 @@ impl<'a> Binary<'a> { /// The left-hand side of the operation: `a`. pub fn lhs(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The right-hand side of the operation: `b`. pub fn rhs(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } @@ -1598,41 +1725,41 @@ pub enum Assoc { node! { /// A field access: `properties.age`. - FieldAccess + struct FieldAccess } impl<'a> FieldAccess<'a> { /// The expression to access the field on. pub fn target(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The name of the field. pub fn field(self) -> Ident<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// An invocation of a function or method: `f(x, y)`. - FuncCall + struct FuncCall } impl<'a> FuncCall<'a> { /// The function to call. pub fn callee(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The arguments to the function. pub fn args(self) -> Args<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A function call's argument list: `(12pt, y)`. - Args + struct Args } impl<'a> Args<'a> { @@ -1666,8 +1793,8 @@ pub enum Arg<'a> { impl<'a> AstNode<'a> for Arg<'a> { fn from_untyped(node: &'a SyntaxNode) -> Option { match node.kind() { - SyntaxKind::Named => node.cast().map(Self::Named), - SyntaxKind::Spread => node.cast().map(Self::Spread), + SyntaxKind::Named => Some(Self::Named(Named(node))), + SyntaxKind::Spread => Some(Self::Spread(Spread(node))), _ => node.cast().map(Self::Pos), } } @@ -1683,7 +1810,7 @@ impl<'a> AstNode<'a> for Arg<'a> { node! { /// A closure: `(x, y) => z`. - Closure + struct Closure } impl<'a> Closure<'a> { @@ -1696,18 +1823,18 @@ impl<'a> Closure<'a> { /// The parameter bindings. pub fn params(self) -> Params<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The body of the closure. pub fn body(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A closure's parameters: `(x, y)`. - Params + struct Params } impl<'a> Params<'a> { @@ -1731,8 +1858,8 @@ pub enum Param<'a> { impl<'a> AstNode<'a> for Param<'a> { fn from_untyped(node: &'a SyntaxNode) -> Option { match node.kind() { - SyntaxKind::Named => node.cast().map(Self::Named), - SyntaxKind::Spread => node.cast().map(Self::Spread), + SyntaxKind::Named => Some(Self::Named(Named(node))), + SyntaxKind::Spread => Some(Self::Spread(Spread(node))), _ => node.cast().map(Self::Pos), } } @@ -1762,9 +1889,9 @@ pub enum Pattern<'a> { impl<'a> AstNode<'a> for Pattern<'a> { fn from_untyped(node: &'a SyntaxNode) -> Option { match node.kind() { - SyntaxKind::Underscore => node.cast().map(Self::Placeholder), - SyntaxKind::Parenthesized => node.cast().map(Self::Parenthesized), - SyntaxKind::Destructuring => node.cast().map(Self::Destructuring), + SyntaxKind::Underscore => Some(Self::Placeholder(Underscore(node))), + SyntaxKind::Parenthesized => Some(Self::Parenthesized(Parenthesized(node))), + SyntaxKind::Destructuring => Some(Self::Destructuring(Destructuring(node))), _ => node.cast().map(Self::Normal), } } @@ -1799,12 +1926,12 @@ impl Default for Pattern<'_> { node! { /// An underscore: `_` - Underscore + struct Underscore } node! { /// A destructuring pattern: `x` or `(x, _, ..y)`. - Destructuring + struct Destructuring } impl<'a> Destructuring<'a> { @@ -1841,8 +1968,8 @@ pub enum DestructuringItem<'a> { impl<'a> AstNode<'a> for DestructuringItem<'a> { fn from_untyped(node: &'a SyntaxNode) -> Option { match node.kind() { - SyntaxKind::Named => node.cast().map(Self::Named), - SyntaxKind::Spread => node.cast().map(Self::Spread), + SyntaxKind::Named => Some(Self::Named(Named(node))), + SyntaxKind::Spread => Some(Self::Spread(Spread(node))), _ => node.cast().map(Self::Pattern), } } @@ -1858,7 +1985,7 @@ impl<'a> AstNode<'a> for DestructuringItem<'a> { node! { /// A let binding: `let x = 1`. - LetBinding + struct LetBinding } /// The kind of a let binding, either a normal one or a closure. @@ -1883,11 +2010,11 @@ impl<'a> LetBindingKind<'a> { impl<'a> LetBinding<'a> { /// The kind of the let binding. pub fn kind(self) -> LetBindingKind<'a> { - match self.0.cast_first_match::() { - Some(Pattern::Normal(Expr::Closure(closure))) => { + match self.0.cast_first() { + Pattern::Normal(Expr::Closure(closure)) => { LetBindingKind::Closure(closure.name().unwrap_or_default()) } - pattern => LetBindingKind::Normal(pattern.unwrap_or_default()), + pattern => LetBindingKind::Normal(pattern), } } @@ -1897,43 +2024,43 @@ impl<'a> LetBinding<'a> { LetBindingKind::Normal(Pattern::Normal(_) | Pattern::Parenthesized(_)) => { self.0.children().filter_map(SyntaxNode::cast).nth(1) } - LetBindingKind::Normal(_) => self.0.cast_first_match(), - LetBindingKind::Closure(_) => self.0.cast_first_match(), + LetBindingKind::Normal(_) => self.0.try_cast_first(), + LetBindingKind::Closure(_) => self.0.try_cast_first(), } } } node! { /// An assignment expression `(x, y) = (1, 2)`. - DestructAssignment + struct DestructAssignment } impl<'a> DestructAssignment<'a> { /// The pattern of the assignment. pub fn pattern(self) -> Pattern<'a> { - self.0.cast_first_match::().unwrap_or_default() + self.0.cast_first() } /// The expression that is assigned. pub fn value(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A set rule: `set text(...)`. - SetRule + struct SetRule } impl<'a> SetRule<'a> { /// The function to set style properties for. pub fn target(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The style properties to set. pub fn args(self) -> Args<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } /// A condition under which the set rule applies. @@ -1947,7 +2074,7 @@ impl<'a> SetRule<'a> { node! { /// A show rule: `show heading: it => emph(it.body)`. - ShowRule + struct ShowRule } impl<'a> ShowRule<'a> { @@ -1962,31 +2089,31 @@ impl<'a> ShowRule<'a> { /// The transformation recipe. pub fn transform(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A contextual expression: `context text.lang`. - Contextual + struct Contextual } impl<'a> Contextual<'a> { /// The expression which depends on the context. pub fn body(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } } node! { /// An if-else conditional: `if x { y } else { z }`. - Conditional + struct Conditional } impl<'a> Conditional<'a> { /// The condition which selects the body to evaluate. pub fn condition(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The expression to evaluate if the condition is true. @@ -2006,30 +2133,30 @@ impl<'a> Conditional<'a> { node! { /// A while loop: `while x { y }`. - WhileLoop + struct WhileLoop } impl<'a> WhileLoop<'a> { /// The condition which selects whether to evaluate the body. pub fn condition(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The expression to evaluate while the condition is true. pub fn body(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A for loop: `for x in y { z }`. - ForLoop + struct ForLoop } impl<'a> ForLoop<'a> { /// The pattern to assign to. pub fn pattern(self) -> Pattern<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The expression to iterate over. @@ -2043,19 +2170,19 @@ impl<'a> ForLoop<'a> { /// The expression to evaluate for each iteration. pub fn body(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A module import: `import "utils.typ": a, b, c`. - ModuleImport + struct ModuleImport } impl<'a> ModuleImport<'a> { /// The module or path from which the items should be imported. pub fn source(self) -> Expr<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The items to be imported. @@ -2135,7 +2262,7 @@ pub enum Imports<'a> { node! { /// Items to import from a module: `a, b, c`. - ImportItems + struct ImportItems } impl<'a> ImportItems<'a> { @@ -2151,7 +2278,7 @@ impl<'a> ImportItems<'a> { node! { /// A path to a submodule's imported name: `a.b.c`. - ImportItemPath + struct ImportItemPath } impl<'a> ImportItemPath<'a> { @@ -2162,7 +2289,7 @@ impl<'a> ImportItemPath<'a> { /// The name of the imported item. This is the last segment in the path. pub fn name(self) -> Ident<'a> { - self.iter().last().unwrap_or_default() + self.0.cast_last() } } @@ -2207,13 +2334,13 @@ impl<'a> ImportItem<'a> { node! { /// A renamed import item: `a as d` - RenamedImportItem + struct RenamedImportItem } impl<'a> RenamedImportItem<'a> { /// The path to the imported item. pub fn path(self) -> ImportItemPath<'a> { - self.0.cast_first_match().unwrap_or_default() + self.0.cast_first() } /// The original name of the imported item (`a` in `a as d` or `c.b.a as d`). @@ -2223,45 +2350,41 @@ impl<'a> RenamedImportItem<'a> { /// The new name of the imported item (`d` in `a as d`). pub fn new_name(self) -> Ident<'a> { - self.0 - .children() - .filter_map(SyntaxNode::cast) - .last() - .unwrap_or_default() + self.0.cast_last() } } node! { /// A module include: `include "chapter1.typ"`. - ModuleInclude + struct ModuleInclude } impl<'a> ModuleInclude<'a> { /// The module or path from which the content should be included. pub fn source(self) -> Expr<'a> { - self.0.cast_last_match().unwrap_or_default() + self.0.cast_last() } } node! { /// A break from a loop: `break`. - LoopBreak + struct LoopBreak } node! { /// A continue in a loop: `continue`. - LoopContinue + struct LoopContinue } node! { /// A return from a function: `return`, `return x + 1`. - FuncReturn + struct FuncReturn } impl<'a> FuncReturn<'a> { /// The expression to return. pub fn body(self) -> Option> { - self.0.cast_last_match() + self.0.try_cast_last() } } diff --git a/crates/typst-syntax/src/node.rs b/crates/typst-syntax/src/node.rs index fde2eaca0..948657ca4 100644 --- a/crates/typst-syntax/src/node.rs +++ b/crates/typst-syntax/src/node.rs @@ -5,7 +5,6 @@ use std::sync::Arc; use ecow::{eco_format, eco_vec, EcoString, EcoVec}; -use crate::ast::AstNode; use crate::{FileId, Span, SyntaxKind}; /// A node in the untyped syntax tree. @@ -119,26 +118,6 @@ impl SyntaxNode { } } - /// Whether the node can be cast to the given AST node. - pub fn is<'a, T: AstNode<'a>>(&'a self) -> bool { - self.cast::().is_some() - } - - /// Try to convert the node to a typed AST node. - pub fn cast<'a, T: AstNode<'a>>(&'a self) -> Option { - T::from_untyped(self) - } - - /// Cast the first child that can cast to the AST type `T`. - pub fn cast_first_match<'a, T: AstNode<'a>>(&'a self) -> Option { - self.children().find_map(Self::cast) - } - - /// Cast the last child that can cast to the AST type `T`. - pub fn cast_last_match<'a, T: AstNode<'a>>(&'a self) -> Option { - self.children().rev().find_map(Self::cast) - } - /// Whether the node or its children contain an error. pub fn erroneous(&self) -> bool { match &self.0 {