From 428c55b6eed3536bb228924c6fb0ad6cea6d6d4b Mon Sep 17 00:00:00 2001 From: Marmare314 <49279081+Marmare314@users.noreply.github.com> Date: Sun, 16 Apr 2023 14:34:57 +0200 Subject: [PATCH] refactor SyntaxKind::Pattern (#831) --- src/eval/mod.rs | 24 ++++---- src/ide/highlight.rs | 3 +- src/syntax/ast.rs | 137 ++++++++++++++++++++++++------------------- src/syntax/kind.rs | 9 +-- src/syntax/parser.rs | 14 ++--- 5 files changed, 98 insertions(+), 89 deletions(-) diff --git a/src/eval/mod.rs b/src/eval/mod.rs index 68d8148f2..b8220112f 100644 --- a/src/eval/mod.rs +++ b/src/eval/mod.rs @@ -1179,16 +1179,16 @@ impl Eval for ast::Closure { impl ast::Pattern { // Destruct the given value into the pattern. pub fn define(&self, vm: &mut Vm, value: Value) -> SourceResult { - match self.kind() { - ast::PatternKind::Ident(ident) => { - vm.define(ident, value); + match self { + ast::Pattern::Ident(ident) => { + vm.define(ident.clone(), value); Ok(Value::None) } - ast::PatternKind::Destructure(pattern) => { + ast::Pattern::Destructuring(destruct) => { match value { Value::Array(value) => { let mut i = 0; - for p in &pattern { + for p in destruct.bindings() { match p { ast::DestructuringKind::Ident(ident) => { let Ok(v) = value.at(i) else { @@ -1198,7 +1198,7 @@ impl ast::Pattern { i += 1; } ast::DestructuringKind::Sink(ident) => { - (1 + value.len() as usize).checked_sub(pattern.len()).and_then(|sink_size| { + (1 + value.len() as usize).checked_sub(destruct.bindings().count()).and_then(|sink_size| { let Ok(sink) = value.slice(i, Some(i + sink_size as i64)) else { return None; }; @@ -1224,10 +1224,10 @@ impl ast::Pattern { Value::Dict(value) => { let mut sink = None; let mut used = HashSet::new(); - for p in &pattern { + for p in destruct.bindings() { match p { ast::DestructuringKind::Ident(ident) => { - let Ok(v) = value.at(ident) else { + let Ok(v) = value.at(&ident) else { bail!(ident.span(), "destructuring key not found in dictionary"); }; vm.define(ident.clone(), v.clone()); @@ -1237,7 +1237,7 @@ impl ast::Pattern { sink = ident.clone() } ast::DestructuringKind::Named(key, ident) => { - let Ok(v) = value.at(key) else { + let Ok(v) = value.at(&key) else { bail!(ident.span(), "destructuring key not found in dictionary"); }; vm.define(ident.clone(), v.clone()); @@ -1451,8 +1451,8 @@ impl Eval for ast::ForLoop { let iter = self.iter().eval(vm)?; let pattern = self.pattern(); - match (pattern.kind(), iter.clone()) { - (ast::PatternKind::Ident(_), Value::Str(string)) => { + match (&pattern, iter.clone()) { + (ast::Pattern::Ident(_), Value::Str(string)) => { // Iterate over graphemes of string. iter!(for pattern in string.as_str().graphemes(true)); } @@ -1464,7 +1464,7 @@ impl Eval for ast::ForLoop { // Iterate over values of array. iter!(for pattern in array); } - (ast::PatternKind::Ident(_), _) => { + (ast::Pattern::Ident(_), _) => { bail!(self.iter().span(), "cannot loop over {}", iter.type_name()); } (_, _) => { diff --git a/src/ide/highlight.rs b/src/ide/highlight.rs index 9180aaee7..259d34c36 100644 --- a/src/ide/highlight.rs +++ b/src/ide/highlight.rs @@ -239,14 +239,13 @@ pub fn highlight(node: &LinkedNode) -> Option { SyntaxKind::Conditional => None, SyntaxKind::WhileLoop => None, SyntaxKind::ForLoop => None, - SyntaxKind::ForPattern => None, SyntaxKind::ModuleImport => None, SyntaxKind::ImportItems => None, SyntaxKind::ModuleInclude => None, SyntaxKind::LoopBreak => None, SyntaxKind::LoopContinue => None, SyntaxKind::FuncReturn => None, - SyntaxKind::Pattern => None, + SyntaxKind::Destructuring => None, SyntaxKind::LineComment => Some(Tag::Comment), SyntaxKind::BlockComment => Some(Tag::Comment), diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index e492297ac..3c01bdffa 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -1533,10 +1533,7 @@ impl Closure { /// /// This only exists if you use the function syntax sugar: `let f(x) = y`. pub fn name(&self) -> Option { - match self.0.cast_first_match::()?.kind() { - PatternKind::Ident(ident) => Some(ident), - _ => Option::None, - } + self.0.children().next()?.cast() } /// The parameter bindings. @@ -1594,16 +1591,7 @@ impl AstNode for Param { node! { /// A destructuring pattern: `x` or `(x, _, ..y)`. - Pattern -} - -/// The kind of a pattern. -#[derive(Debug, Clone, Hash)] -pub enum PatternKind { - /// A single identifier: `x`. - Ident(Ident), - /// A destructuring pattern: `(x, _, ..y)`. - Destructure(Vec), + Destructuring } /// The kind of an element in a destructuring pattern. @@ -1617,57 +1605,74 @@ pub enum DestructuringKind { Named(Ident, Ident), } -impl Pattern { - /// The kind of the pattern. - pub fn kind(&self) -> PatternKind { - if self - .0 - .children() - .map(SyntaxNode::kind) - .skip_while(|&kind| kind == SyntaxKind::LeftParen) - .take_while(|&kind| kind != SyntaxKind::RightParen) - .eq([SyntaxKind::Ident]) - { - return PatternKind::Ident(self.0.cast_first_match().unwrap_or_default()); - } - - let mut bindings = Vec::new(); - for child in self.0.children() { - match child.kind() { - SyntaxKind::Ident => { - bindings - .push(DestructuringKind::Ident(child.cast().unwrap_or_default())); - } - SyntaxKind::Spread => { - bindings.push(DestructuringKind::Sink(child.cast_first_match())); - } - SyntaxKind::Named => { - let mut filtered = child.children().filter_map(SyntaxNode::cast); - let key = filtered.next().unwrap_or_default(); - let ident = filtered.next().unwrap_or_default(); - bindings.push(DestructuringKind::Named(key, ident)); - } - _ => (), +impl Destructuring { + /// The bindings of the destructuring. + pub fn bindings(&self) -> impl Iterator + '_ { + self.0.children().filter_map(|child| match child.kind() { + SyntaxKind::Ident => { + Some(DestructuringKind::Ident(child.cast().unwrap_or_default())) } - } - - PatternKind::Destructure(bindings) + SyntaxKind::Spread => Some(DestructuringKind::Sink(child.cast_first_match())), + SyntaxKind::Named => { + let mut filtered = child.children().filter_map(SyntaxNode::cast); + let key = filtered.next().unwrap_or_default(); + let ident = filtered.next().unwrap_or_default(); + Some(DestructuringKind::Named(key, ident)) + } + _ => Option::None, + }) } // Returns a list of all identifiers in the pattern. - pub fn idents(&self) -> Vec { - match self.kind() { - PatternKind::Ident(ident) => vec![ident], - PatternKind::Destructure(bindings) => bindings - .into_iter() - .filter_map(|binding| match binding { - DestructuringKind::Ident(ident) => Some(ident), - DestructuringKind::Sink(ident) => ident, - DestructuringKind::Named(_, ident) => Some(ident), - }) - .collect(), + pub fn idents(&self) -> impl Iterator + '_ { + self.bindings().into_iter().filter_map(|binding| match binding { + DestructuringKind::Ident(ident) => Some(ident), + DestructuringKind::Sink(ident) => ident, + DestructuringKind::Named(_, ident) => Some(ident), + }) + } +} + +/// The kind of a pattern. +#[derive(Debug, Clone, Hash)] +pub enum Pattern { + /// A single identifier: `x`. + Ident(Ident), + /// A destructuring pattern: `(x, _, ..y)`. + Destructuring(Destructuring), +} + +impl AstNode for Pattern { + fn from_untyped(node: &SyntaxNode) -> Option { + match node.kind() { + SyntaxKind::Ident => node.cast().map(Self::Ident), + SyntaxKind::Destructuring => node.cast().map(Self::Destructuring), + _ => Option::None, } } + + fn as_untyped(&self) -> &SyntaxNode { + match self { + Self::Ident(v) => v.as_untyped(), + Self::Destructuring(v) => v.as_untyped(), + } + } +} + +impl Pattern { + // Returns a list of all identifiers in the pattern. + pub fn idents(&self) -> Vec { + match self { + Pattern::Ident(ident) => vec![ident.clone()], + Pattern::Destructuring(destruct) => destruct.idents().collect(), + } + } +} + +impl Default for Pattern { + fn default() -> Self { + Self::Ident(Ident::default()) + } } node! { @@ -1675,6 +1680,7 @@ node! { LetBinding } +#[derive(Debug)] pub enum LetBindingKind { /// A normal binding: `let x = 1`. Normal(Pattern), @@ -1713,7 +1719,12 @@ impl LetBinding { /// The expression the binding is initialized with. pub fn init(&self) -> Option { match self.kind() { - LetBindingKind::Normal(_) => self.0.cast_last_match(), + LetBindingKind::Normal(Pattern::Ident(_)) => { + self.0.children().filter_map(SyntaxNode::cast).nth(1) + } + LetBindingKind::Normal(Pattern::Destructuring(_)) => { + self.0.cast_first_match() + } LetBindingKind::Closure(_) => self.0.cast_first_match(), } } @@ -1821,7 +1832,11 @@ impl ForLoop { /// The expression to iterate over. pub fn iter(&self) -> Expr { - self.0.cast_first_match().unwrap_or_default() + self.0 + .children() + .skip_while(|&c| c.kind() != SyntaxKind::In) + .find_map(SyntaxNode::cast) + .unwrap_or_default() } /// The expression to evaluate for each iteration. diff --git a/src/syntax/kind.rs b/src/syntax/kind.rs index fcde2bb47..d35901b09 100644 --- a/src/syntax/kind.rs +++ b/src/syntax/kind.rs @@ -230,8 +230,6 @@ pub enum SyntaxKind { WhileLoop, /// A for loop: `for x in y { z }`. ForLoop, - /// A for loop's destructuring pattern: `x` or `x, y`. - ForPattern, /// A module import: `import a, b, c from "utils.typ"`. ModuleImport, /// Items to import from a module: `a, b, c`. @@ -244,8 +242,8 @@ pub enum SyntaxKind { LoopContinue, /// A return from a function: `return`, `return x + 1`. FuncReturn, - /// A destructuring pattern: `x`, `(x, _, ..y)`. - Pattern, + /// A destructuring pattern: `(x, _, ..y)`. + Destructuring, /// A line comment: `// ...`. LineComment, @@ -425,14 +423,13 @@ impl SyntaxKind { Self::Conditional => "`if` expression", Self::WhileLoop => "while-loop expression", Self::ForLoop => "for-loop expression", - Self::ForPattern => "for-loop destructuring pattern", Self::ModuleImport => "`import` expression", Self::ImportItems => "import items", Self::ModuleInclude => "`include` expression", Self::LoopBreak => "`break` expression", Self::LoopContinue => "`continue` expression", Self::FuncReturn => "`return` expression", - Self::Pattern => "destructuring pattern", + Self::Destructuring => "destructuring pattern", Self::LineComment => "line comment", Self::BlockComment => "block comment", Self::Error => "syntax error", diff --git a/src/syntax/parser.rs b/src/syntax/parser.rs index 7c05eebcd..e68074045 100644 --- a/src/syntax/parser.rs +++ b/src/syntax/parser.rs @@ -839,7 +839,7 @@ fn args(p: &mut Parser) { } enum PatternKind { - Normal, + Ident, Destructuring, } @@ -849,18 +849,16 @@ fn pattern(p: &mut Parser) -> PatternKind { if p.at(SyntaxKind::LeftParen) { let kind = collection(p, false); validate_destruct_pattern(p, m); - p.wrap(m, SyntaxKind::Pattern); if kind == SyntaxKind::Parenthesized { - PatternKind::Normal + PatternKind::Ident } else { + p.wrap(m, SyntaxKind::Destructuring); PatternKind::Destructuring } } else { - if p.expect(SyntaxKind::Ident) { - p.wrap(m, SyntaxKind::Pattern); - } - PatternKind::Normal + p.expect(SyntaxKind::Ident); + PatternKind::Ident } } @@ -872,7 +870,7 @@ fn let_binding(p: &mut Parser) { let mut closure = false; let mut destructuring = false; match pattern(p) { - PatternKind::Normal => { + PatternKind::Ident => { closure = p.directly_at(SyntaxKind::LeftParen); if closure { let m3 = p.marker();