typst/src/syntax/node.rs
2023-02-17 10:01:40 +01:00

913 lines
27 KiB
Rust

use std::fmt::{self, Debug, Display, Formatter};
use std::ops::{Deref, Range};
use std::rc::Rc;
use std::sync::Arc;
use super::ast::AstNode;
use super::{SourceId, Span, SyntaxKind};
use crate::diag::SourceError;
use crate::util::EcoString;
/// A node in the untyped syntax tree.
#[derive(Clone, Eq, PartialEq, Hash)]
pub struct SyntaxNode(Repr);
/// The three internal representations.
#[derive(Clone, Eq, PartialEq, Hash)]
enum Repr {
/// A leaf node.
Leaf(LeafNode),
/// A reference-counted inner node.
Inner(Arc<InnerNode>),
/// An error node.
Error(Arc<ErrorNode>),
}
impl SyntaxNode {
/// Create a new leaf node.
pub fn leaf(kind: SyntaxKind, text: impl Into<EcoString>) -> Self {
Self(Repr::Leaf(LeafNode::new(kind, text)))
}
/// Create a new inner node with children.
pub fn inner(kind: SyntaxKind, children: Vec<SyntaxNode>) -> Self {
Self(Repr::Inner(Arc::new(InnerNode::new(kind, children))))
}
/// Create a new error node.
pub fn error(
message: impl Into<EcoString>,
text: impl Into<EcoString>,
pos: ErrorPos,
) -> Self {
Self(Repr::Error(Arc::new(ErrorNode::new(message, text, pos))))
}
/// The type of the node.
pub fn kind(&self) -> SyntaxKind {
match &self.0 {
Repr::Leaf(leaf) => leaf.kind,
Repr::Inner(inner) => inner.kind,
Repr::Error(_) => SyntaxKind::Error,
}
}
/// The byte length of the node in the source text.
pub fn len(&self) -> usize {
match &self.0 {
Repr::Leaf(leaf) => leaf.len(),
Repr::Inner(inner) => inner.len,
Repr::Error(error) => error.len(),
}
}
/// The span of the node.
pub fn span(&self) -> Span {
match &self.0 {
Repr::Leaf(leaf) => leaf.span,
Repr::Inner(inner) => inner.span,
Repr::Error(error) => error.span,
}
}
/// The text of the node if it is a leaf node.
///
/// Returns the empty string if this is an inner node.
pub fn text(&self) -> &EcoString {
static EMPTY: EcoString = EcoString::new();
match &self.0 {
Repr::Leaf(leaf) => &leaf.text,
Repr::Error(error) => &error.text,
Repr::Inner(_) => &EMPTY,
}
}
/// Extract the text from the node.
///
/// Builds the string if this is an inner node.
pub fn into_text(self) -> EcoString {
match self.0 {
Repr::Leaf(leaf) => leaf.text,
Repr::Error(error) => error.text.clone(),
Repr::Inner(node) => {
node.children.iter().cloned().map(Self::into_text).collect()
}
}
}
/// The node's children.
pub fn children(&self) -> std::slice::Iter<'_, SyntaxNode> {
match &self.0 {
Repr::Leaf(_) | Repr::Error(_) => [].iter(),
Repr::Inner(inner) => inner.children.iter(),
}
}
/// Whether the node can be cast to the given AST node.
pub fn is<T: AstNode>(&self) -> bool {
self.cast::<T>().is_some()
}
/// Try to convert the node to a typed AST node.
pub fn cast<T: AstNode>(&self) -> Option<T> {
T::from_untyped(self)
}
/// Cast the first child that can cast to the AST type `T`.
pub fn cast_first_match<T: AstNode>(&self) -> Option<T> {
self.children().find_map(Self::cast)
}
/// Cast the last child that can cast to the AST type `T`.
pub fn cast_last_match<T: AstNode>(&self) -> Option<T> {
self.children().rev().find_map(Self::cast)
}
/// Whether the node or its children contain an error.
pub fn erroneous(&self) -> bool {
match &self.0 {
Repr::Leaf(_) => false,
Repr::Inner(node) => node.erroneous,
Repr::Error(_) => true,
}
}
/// The error messages for this node and its descendants.
pub fn errors(&self) -> Vec<SourceError> {
if !self.erroneous() {
return vec![];
}
if let Repr::Error(error) = &self.0 {
vec![SourceError::new(error.span, error.message.clone()).with_pos(error.pos)]
} else {
self.children()
.filter(|node| node.erroneous())
.flat_map(|node| node.errors())
.collect()
}
}
/// Set a synthetic span for the node and all its descendants.
pub fn synthesize(&mut self, span: Span) {
match &mut self.0 {
Repr::Leaf(leaf) => leaf.span = span,
Repr::Inner(inner) => Arc::make_mut(inner).synthesize(span),
Repr::Error(error) => Arc::make_mut(error).span = span,
}
}
}
impl SyntaxNode {
/// Mark this node as erroneous.
pub(super) fn make_erroneous(&mut self) {
if let Repr::Inner(inner) = &mut self.0 {
Arc::make_mut(inner).erroneous = true;
}
}
/// Convert the child to another kind.
#[track_caller]
pub(super) fn convert_to_kind(&mut self, kind: SyntaxKind) {
debug_assert!(!kind.is_error());
match &mut self.0 {
Repr::Leaf(leaf) => leaf.kind = kind,
Repr::Inner(inner) => Arc::make_mut(inner).kind = kind,
Repr::Error(_) => panic!("cannot convert error"),
}
}
/// Convert the child to an error.
pub(super) fn convert_to_error(&mut self, message: impl Into<EcoString>) {
let text = std::mem::take(self).into_text();
*self = SyntaxNode::error(message, text, ErrorPos::Full);
}
/// Assign spans to each node.
pub(super) fn numberize(
&mut self,
id: SourceId,
within: Range<u64>,
) -> NumberingResult {
if within.start >= within.end {
return Err(Unnumberable);
}
let mid = Span::new(id, (within.start + within.end) / 2);
match &mut self.0 {
Repr::Leaf(leaf) => leaf.span = mid,
Repr::Inner(inner) => Arc::make_mut(inner).numberize(id, None, within)?,
Repr::Error(error) => Arc::make_mut(error).span = mid,
}
Ok(())
}
/// If the span points into this node, convert it to a byte range.
pub(super) fn range(&self, span: Span, offset: usize) -> Option<Range<usize>> {
match &self.0 {
Repr::Inner(inner) => inner.range(span, offset),
_ => (self.span() == span).then(|| offset..offset + self.len()),
}
}
/// Whether this is a leaf node.
pub(super) fn is_leaf(&self) -> bool {
matches!(self.0, Repr::Leaf(_))
}
/// The number of descendants, including the node itself.
pub(super) fn descendants(&self) -> usize {
match &self.0 {
Repr::Leaf(_) | Repr::Error(_) => 1,
Repr::Inner(inner) => inner.descendants,
}
}
/// The node's children, mutably.
pub(super) fn children_mut(&mut self) -> &mut [SyntaxNode] {
match &mut self.0 {
Repr::Leaf(_) | Repr::Error(_) => &mut [],
Repr::Inner(inner) => &mut Arc::make_mut(inner).children,
}
}
/// Replaces a range of children with a replacement.
///
/// May have mutated the children if it returns `Err(_)`.
pub(super) fn replace_children(
&mut self,
range: Range<usize>,
replacement: Vec<SyntaxNode>,
) -> NumberingResult {
if let Repr::Inner(inner) = &mut self.0 {
Arc::make_mut(inner).replace_children(range, replacement)?;
}
Ok(())
}
/// Update this node after changes were made to one of its children.
pub(super) fn update_parent(
&mut self,
prev_len: usize,
new_len: usize,
prev_descendants: usize,
new_descendants: usize,
) {
if let Repr::Inner(inner) = &mut self.0 {
Arc::make_mut(inner).update_parent(
prev_len,
new_len,
prev_descendants,
new_descendants,
);
}
}
/// The upper bound of assigned numbers in this subtree.
pub(super) fn upper(&self) -> u64 {
match &self.0 {
Repr::Inner(inner) => inner.upper,
Repr::Leaf(leaf) => leaf.span.number() + 1,
Repr::Error(error) => error.span.number() + 1,
}
}
}
impl Debug for SyntaxNode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match &self.0 {
Repr::Inner(node) => node.fmt(f),
Repr::Leaf(node) => node.fmt(f),
Repr::Error(node) => node.fmt(f),
}
}
}
impl Default for SyntaxNode {
fn default() -> Self {
Self::error("", "", ErrorPos::Full)
}
}
/// A leaf node in the untyped syntax tree.
#[derive(Clone, Eq, PartialEq, Hash)]
struct LeafNode {
/// What kind of node this is (each kind would have its own struct in a
/// strongly typed AST).
kind: SyntaxKind,
/// The source text of the node.
text: EcoString,
/// The node's span.
span: Span,
}
impl LeafNode {
/// Create a new leaf node.
#[track_caller]
fn new(kind: SyntaxKind, text: impl Into<EcoString>) -> Self {
debug_assert!(!kind.is_error());
Self { kind, text: text.into(), span: Span::detached() }
}
/// The byte length of the node in the source text.
fn len(&self) -> usize {
self.text.len()
}
}
impl Debug for LeafNode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{:?}: {:?}", self.kind, self.text)
}
}
/// An inner node in the untyped syntax tree.
#[derive(Clone, Eq, PartialEq, Hash)]
struct InnerNode {
/// What kind of node this is (each kind would have its own struct in a
/// strongly typed AST).
kind: SyntaxKind,
/// The byte length of the node in the source.
len: usize,
/// The node's span.
span: Span,
/// The number of nodes in the whole subtree, including this node.
descendants: usize,
/// Whether this node or any of its children are erroneous.
erroneous: bool,
/// The upper bound of this node's numbering range.
upper: u64,
/// This node's children, losslessly make up this node.
children: Vec<SyntaxNode>,
}
impl InnerNode {
/// Create a new inner node with the given kind and children.
#[track_caller]
fn new(kind: SyntaxKind, children: Vec<SyntaxNode>) -> Self {
debug_assert!(!kind.is_error());
let mut len = 0;
let mut descendants = 1;
let mut erroneous = false;
for child in &children {
len += child.len();
descendants += child.descendants();
erroneous |= child.erroneous();
}
Self {
kind,
len,
span: Span::detached(),
descendants,
erroneous,
upper: 0,
children,
}
}
/// Set a synthetic span for the node and all its descendants.
fn synthesize(&mut self, span: Span) {
self.span = span;
self.upper = span.number();
for child in &mut self.children {
child.synthesize(span);
}
}
/// Assign span numbers `within` an interval to this node's subtree or just
/// a `range` of its children.
fn numberize(
&mut self,
id: SourceId,
range: Option<Range<usize>>,
within: Range<u64>,
) -> NumberingResult {
// Determine how many nodes we will number.
let descendants = match &range {
Some(range) if range.is_empty() => return Ok(()),
Some(range) => self.children[range.clone()]
.iter()
.map(SyntaxNode::descendants)
.sum::<usize>(),
None => self.descendants,
};
// Determine the distance between two neighbouring assigned numbers. If
// possible, we try to fit all numbers into the left half of `within`
// so that there is space for future insertions.
let space = within.end - within.start;
let mut stride = space / (2 * descendants as u64);
if stride == 0 {
stride = space / self.descendants as u64;
if stride == 0 {
return Err(Unnumberable);
}
}
// Number the node itself.
let mut start = within.start;
if range.is_none() {
let end = start + stride;
self.span = Span::new(id, (start + end) / 2);
self.upper = within.end;
start = end;
}
// Number the children.
let len = self.children.len();
for child in &mut self.children[range.unwrap_or(0..len)] {
let end = start + child.descendants() as u64 * stride;
child.numberize(id, start..end)?;
start = end;
}
Ok(())
}
/// If the span points into this node, convert it to a byte range.
fn range(&self, span: Span, mut offset: usize) -> Option<Range<usize>> {
// Check whether we found it.
if span == self.span {
return Some(offset..offset + self.len);
}
// The parent of a subtree has a smaller span number than all of its
// descendants. Therefore, we can bail out early if the target span's
// number is smaller than our number.
if span.number() < self.span.number() {
return None;
}
let mut children = self.children.iter().peekable();
while let Some(child) = children.next() {
// Every node in this child's subtree has a smaller span number than
// the next sibling. Therefore we only need to recurse if the next
// sibling's span number is larger than the target span's number.
if children
.peek()
.map_or(true, |next| next.span().number() > span.number())
{
if let Some(range) = child.range(span, offset) {
return Some(range);
}
}
offset += child.len();
}
None
}
/// Replaces a range of children with a replacement.
///
/// May have mutated the children if it returns `Err(_)`.
fn replace_children(
&mut self,
mut range: Range<usize>,
replacement: Vec<SyntaxNode>,
) -> NumberingResult {
let superseded = &self.children[range.clone()];
// Compute the new byte length.
self.len = self.len + replacement.iter().map(SyntaxNode::len).sum::<usize>()
- superseded.iter().map(SyntaxNode::len).sum::<usize>();
// Compute the new number of descendants.
self.descendants = self.descendants
+ replacement.iter().map(SyntaxNode::descendants).sum::<usize>()
- superseded.iter().map(SyntaxNode::descendants).sum::<usize>();
// Determine whether we're still erroneous after the replacement. That's
// the case if
// - any of the new nodes is erroneous,
// - or if we were erroneous before due to a non-superseded node.
self.erroneous = replacement.iter().any(SyntaxNode::erroneous)
|| (self.erroneous
&& (self.children[..range.start].iter().any(SyntaxNode::erroneous))
|| self.children[range.end..].iter().any(SyntaxNode::erroneous));
// Perform the replacement.
let replacement_count = replacement.len();
self.children.splice(range.clone(), replacement);
range.end = range.start + replacement_count;
// Renumber the new children. Retries until it works, taking
// exponentially more children into account.
let mut left = 0;
let mut right = 0;
let max_left = range.start;
let max_right = self.children.len() - range.end;
loop {
let renumber = range.start - left..range.end + right;
// The minimum assignable number is either
// - the upper bound of the node right before the to-be-renumbered
// children,
// - or this inner node's span number plus one if renumbering starts
// at the first child.
let start_number = renumber
.start
.checked_sub(1)
.and_then(|i| self.children.get(i))
.map_or(self.span.number() + 1, |child| child.upper());
// The upper bound for renumbering is either
// - the span number of the first child after the to-be-renumbered
// children,
// - or this node's upper bound if renumbering ends behind the last
// child.
let end_number = self
.children
.get(renumber.end)
.map_or(self.upper, |next| next.span().number());
// Try to renumber.
let within = start_number..end_number;
let id = self.span.source();
if self.numberize(id, Some(renumber), within).is_ok() {
return Ok(());
}
// If it didn't even work with all children, we give up.
if left == max_left && right == max_right {
return Err(Unnumberable);
}
// Exponential expansion to both sides.
left = (left + 1).next_power_of_two().min(max_left);
right = (right + 1).next_power_of_two().min(max_right);
}
}
/// Update this node after changes were made to one of its children.
fn update_parent(
&mut self,
prev_len: usize,
new_len: usize,
prev_descendants: usize,
new_descendants: usize,
) {
self.len = self.len + new_len - prev_len;
self.descendants = self.descendants + new_descendants - prev_descendants;
self.erroneous = self.children.iter().any(SyntaxNode::erroneous);
}
}
impl Debug for InnerNode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{:?}: {}", self.kind, self.len)?;
if !self.children.is_empty() {
f.write_str(" ")?;
f.debug_list().entries(&self.children).finish()?;
}
Ok(())
}
}
/// An error node in the untyped syntax tree.
#[derive(Clone, Eq, PartialEq, Hash)]
struct ErrorNode {
/// The error message.
message: EcoString,
/// The source text of the node.
text: EcoString,
/// Where in the node an error should be annotated.
pos: ErrorPos,
/// The node's span.
span: Span,
}
impl ErrorNode {
/// Create new error node.
fn new(
message: impl Into<EcoString>,
text: impl Into<EcoString>,
pos: ErrorPos,
) -> Self {
Self {
message: message.into(),
text: text.into(),
pos,
span: Span::detached(),
}
}
/// The byte length of the node in the source text.
fn len(&self) -> usize {
self.text.len()
}
}
impl Debug for ErrorNode {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "Error: {:?} ({})", self.text, self.message)
}
}
/// Where in a node an error should be annotated,
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum ErrorPos {
/// Over the full width of the node.
Full,
/// At the start of the node.
Start,
/// At the end of the node.
End,
}
/// A syntax node in a context.
///
/// Knows its exact offset in the file and provides access to its
/// children, parent and siblings.
///
/// **Note that all sibling and leaf accessors skip over trivia!**
#[derive(Clone)]
pub struct LinkedNode<'a> {
node: &'a SyntaxNode,
parent: Option<Rc<Self>>,
index: usize,
offset: usize,
}
impl<'a> LinkedNode<'a> {
/// Start a new traversal at a root node.
pub fn new(root: &'a SyntaxNode) -> Self {
Self { node: root, parent: None, index: 0, offset: 0 }
}
/// Get the contained syntax node.
pub fn get(&self) -> &'a SyntaxNode {
self.node
}
/// The index of this node in its parent's children list.
pub fn index(&self) -> usize {
self.index
}
/// The absolute byte offset of this node in the source file.
pub fn offset(&self) -> usize {
self.offset
}
/// The byte range of this node in the source file.
pub fn range(&self) -> Range<usize> {
self.offset..self.offset + self.node.len()
}
/// An iterator over this node's children.
pub fn children(&self) -> LinkedChildren<'a> {
LinkedChildren {
parent: Rc::new(self.clone()),
iter: self.node.children().enumerate(),
front: self.offset,
back: self.offset + self.len(),
}
}
}
/// Access to parents and siblings.
impl<'a> LinkedNode<'a> {
/// Get this node's parent.
pub fn parent(&self) -> Option<&Self> {
self.parent.as_deref()
}
/// Get the first previous non-trivia sibling node.
pub fn prev_sibling(&self) -> Option<Self> {
let parent = self.parent()?;
let index = self.index.checked_sub(1)?;
let node = parent.node.children().nth(index)?;
let offset = self.offset - node.len();
let prev = Self { node, parent: self.parent.clone(), index, offset };
if prev.kind().is_trivia() {
prev.prev_sibling()
} else {
Some(prev)
}
}
/// Get the next non-trivia sibling node.
pub fn next_sibling(&self) -> Option<Self> {
let parent = self.parent()?;
let index = self.index.checked_add(1)?;
let node = parent.node.children().nth(index)?;
let offset = self.offset + self.node.len();
let next = Self { node, parent: self.parent.clone(), index, offset };
if next.kind().is_trivia() {
next.next_sibling()
} else {
Some(next)
}
}
/// Get the kind of this node's parent.
pub fn parent_kind(&self) -> Option<SyntaxKind> {
Some(self.parent()?.node.kind())
}
/// Get the kind of this node's first previous non-trivia sibling.
pub fn prev_sibling_kind(&self) -> Option<SyntaxKind> {
Some(self.prev_sibling()?.node.kind())
}
/// Get the kind of this node's next non-trivia sibling.
pub fn next_sibling_kind(&self) -> Option<SyntaxKind> {
Some(self.next_sibling()?.node.kind())
}
}
/// Access to leafs.
impl<'a> LinkedNode<'a> {
/// Get the rightmost non-trivia leaf before this node.
pub fn prev_leaf(&self) -> Option<Self> {
let mut node = self.clone();
while let Some(prev) = node.prev_sibling() {
if let Some(leaf) = prev.rightmost_leaf() {
return Some(leaf);
}
node = prev;
}
self.parent()?.prev_leaf()
}
/// Find the leftmost contained non-trivia leaf.
pub fn leftmost_leaf(&self) -> Option<Self> {
if self.is_leaf() && !self.kind().is_trivia() && !self.kind().is_error() {
return Some(self.clone());
}
for child in self.children() {
if let Some(leaf) = child.leftmost_leaf() {
return Some(leaf);
}
}
None
}
/// Get the leaf at the specified byte offset.
pub fn leaf_at(&self, cursor: usize) -> Option<Self> {
if self.node.children().len() == 0 && cursor <= self.offset + self.len() {
return Some(self.clone());
}
let mut offset = self.offset;
let count = self.node.children().len();
for (i, child) in self.children().enumerate() {
let len = child.len();
if (offset < cursor && cursor <= offset + len)
|| (offset == cursor && i + 1 == count)
{
return child.leaf_at(cursor);
}
offset += len;
}
None
}
/// Find the rightmost contained non-trivia leaf.
pub fn rightmost_leaf(&self) -> Option<Self> {
if self.is_leaf() && !self.kind().is_trivia() {
return Some(self.clone());
}
for child in self.children().rev() {
if let Some(leaf) = child.rightmost_leaf() {
return Some(leaf);
}
}
None
}
/// Get the leftmost non-trivia leaf after this node.
pub fn next_leaf(&self) -> Option<Self> {
let mut node = self.clone();
while let Some(next) = node.next_sibling() {
if let Some(leaf) = next.leftmost_leaf() {
return Some(leaf);
}
node = next;
}
self.parent()?.next_leaf()
}
}
impl Deref for LinkedNode<'_> {
type Target = SyntaxNode;
fn deref(&self) -> &Self::Target {
self.get()
}
}
impl Debug for LinkedNode<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.node.fmt(f)
}
}
/// An iterator over the children of a linked node.
pub struct LinkedChildren<'a> {
parent: Rc<LinkedNode<'a>>,
iter: std::iter::Enumerate<std::slice::Iter<'a, SyntaxNode>>,
front: usize,
back: usize,
}
impl<'a> Iterator for LinkedChildren<'a> {
type Item = LinkedNode<'a>;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(index, node)| {
let offset = self.front;
self.front += node.len();
LinkedNode {
node,
parent: Some(self.parent.clone()),
index,
offset,
}
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl DoubleEndedIterator for LinkedChildren<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
self.iter.next_back().map(|(index, node)| {
self.back -= node.len();
LinkedNode {
node,
parent: Some(self.parent.clone()),
index,
offset: self.back,
}
})
}
}
impl ExactSizeIterator for LinkedChildren<'_> {}
/// Result of numbering a node within an interval.
pub(super) type NumberingResult = Result<(), Unnumberable>;
/// Indicates that a node cannot be numbered within a given interval.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(super) struct Unnumberable;
impl Display for Unnumberable {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.pad("cannot number within this interval")
}
}
impl std::error::Error for Unnumberable {}
#[cfg(test)]
mod tests {
use super::*;
use crate::syntax::Source;
#[test]
fn test_linked_node() {
let source = Source::detached("#set text(12pt, red)");
// Find "text".
let node = LinkedNode::new(source.root()).leaf_at(7).unwrap();
assert_eq!(node.offset(), 5);
assert_eq!(node.text(), "text");
// Go back to "#set". Skips the space.
let prev = node.prev_sibling().unwrap();
assert_eq!(prev.offset(), 1);
assert_eq!(prev.text(), "set");
}
#[test]
fn test_linked_node_non_trivia_leaf() {
let source = Source::detached("#set fun(12pt, red)");
let leaf = LinkedNode::new(source.root()).leaf_at(6).unwrap();
let prev = leaf.prev_leaf().unwrap();
assert_eq!(leaf.text(), "fun");
assert_eq!(prev.text(), "set");
let source = Source::detached("#let x = 10");
let leaf = LinkedNode::new(source.root()).leaf_at(9).unwrap();
let prev = leaf.prev_leaf().unwrap();
let next = leaf.next_leaf().unwrap();
assert_eq!(prev.text(), "=");
assert_eq!(leaf.text(), " ");
assert_eq!(next.text(), "10");
}
}