From 0c74290519ee999002e9cc99ba3d272d68e1014f Mon Sep 17 00:00:00 2001 From: Laurenz Date: Thu, 8 Jul 2021 23:36:20 +0200 Subject: [PATCH] Compare functions and templates by identity --- src/eval/mod.rs | 5 ++- src/eval/ops.rs | 97 +++++++++++++++++++----------------------- src/eval/value.rs | 22 ++++------ src/exec/mod.rs | 2 +- tests/typ/code/ops.typ | 12 ++++++ 5 files changed, 69 insertions(+), 69 deletions(-) diff --git a/src/eval/mod.rs b/src/eval/mod.rs index ba2de8c7a..4992f70cc 100644 --- a/src/eval/mod.rs +++ b/src/eval/mod.rs @@ -229,7 +229,10 @@ impl Eval for Rc { let mut visitor = ExprVisitor { ctx, map: ExprMap::new() }; visitor.visit_tree(self); - vec![TemplateNode::Tree { tree: Rc::clone(self), map: visitor.map }] + Rc::new(vec![TemplateNode::Tree { + tree: Rc::clone(self), + map: visitor.map, + }]) } } diff --git a/src/eval/ops.rs b/src/eval/ops.rs index 01044842c..3b48140c9 100644 --- a/src/eval/ops.rs +++ b/src/eval/ops.rs @@ -1,35 +1,9 @@ use std::cmp::Ordering::*; +use std::rc::Rc; use super::{TemplateNode, Value}; use Value::*; -/// Join a value with another value. -pub fn join(lhs: Value, rhs: Value) -> Result { - Ok(match (lhs, rhs) { - (_, Error) => Error, - (Error, _) => Error, - - (a, None) => a, - (None, b) => b, - - (Str(a), Str(b)) => Str(a + &b), - (Array(a), Array(b)) => Array(concat(a, b)), - (Dict(a), Dict(b)) => Dict(concat(a, b)), - - (Template(a), Template(b)) => Template(concat(a, b)), - (Template(mut a), Str(b)) => Template({ - a.push(TemplateNode::Str(b)); - a - }), - (Str(a), Template(mut b)) => Template({ - b.insert(0, TemplateNode::Str(a)); - b - }), - - (a, _) => return Err(a), - }) -} - /// Apply the plus operator to a value. pub fn pos(value: Value) -> Value { match value { @@ -82,21 +56,7 @@ pub fn add(lhs: Value, rhs: Value) -> Value { (Fractional(a), Fractional(b)) => Fractional(a + b), - (Str(a), Str(b)) => Str(a + &b), - (Array(a), Array(b)) => Array(concat(a, b)), - (Dict(a), Dict(b)) => Dict(concat(a, b)), - - (Template(a), Template(b)) => Template(concat(a, b)), - (Template(mut a), Str(b)) => Template({ - a.push(TemplateNode::Str(b)); - a - }), - (Str(a), Template(mut b)) => Template({ - b.insert(0, TemplateNode::Str(a)); - b - }), - - _ => Error, + (a, b) => concat(a, b).unwrap_or(Value::Error), } } @@ -130,6 +90,11 @@ pub fn sub(lhs: Value, rhs: Value) -> Value { /// Compute the product of two values. pub fn mul(lhs: Value, rhs: Value) -> Value { + fn repeat(vec: Vec, n: usize) -> Vec { + let len = n * vec.len(); + vec.into_iter().cycle().take(len).collect() + } + match (lhs, rhs) { (Int(a), Int(b)) => Int(a * b), (Int(a), Float(b)) => Float(a as f64 * b), @@ -258,17 +223,43 @@ pub fn range(lhs: Value, rhs: Value) -> Value { } } -/// Concatenate two collections. -fn concat(mut a: T, b: T) -> T -where - T: Extend + IntoIterator, -{ - a.extend(b); - a +/// Join a value with another value. +pub fn join(lhs: Value, rhs: Value) -> Result { + Ok(match (lhs, rhs) { + (_, Error) => Error, + (Error, _) => Error, + + (a, None) => a, + (None, b) => b, + + (a, b) => return concat(a, b), + }) } -/// Repeat a vector `n` times. -fn repeat(vec: Vec, n: usize) -> Vec { - let len = n * vec.len(); - vec.into_iter().cycle().take(len).collect() +/// Concatentate two values. +fn concat(lhs: Value, rhs: Value) -> Result { + Ok(match (lhs, rhs) { + (Str(a), Str(b)) => Str(a + &b), + (Array(mut a), Array(b)) => Array({ + a.extend(b); + a + }), + (Dict(mut a), Dict(b)) => Dict({ + a.extend(b); + a + }), + (Template(mut a), Template(b)) => Template({ + Rc::make_mut(&mut a).extend(b.iter().cloned()); + a + }), + (Template(mut a), Str(b)) => Template({ + Rc::make_mut(&mut a).push(TemplateNode::Str(b)); + a + }), + (Str(a), Template(mut b)) => Template({ + Rc::make_mut(&mut b).insert(0, TemplateNode::Str(a)); + b + }), + (a, _) => return Err(a), + }) } diff --git a/src/eval/value.rs b/src/eval/value.rs index 472df0ea1..07552c845 100644 --- a/src/eval/value.rs +++ b/src/eval/value.rs @@ -59,7 +59,7 @@ impl Value { where F: Fn(&mut ExecContext) + 'static, { - Self::Template(vec![TemplateNode::Func(TemplateFunc::new(f))]) + Self::Template(Rc::new(vec![TemplateNode::Func(TemplateFunc::new(f))])) } /// The name of the stored value's type. @@ -102,6 +102,7 @@ impl Value { a.len() == b.len() && a.iter().all(|(k, x)| b.get(k).map_or(false, |y| x.eq(y))) } + (Self::Template(a), Self::Template(b)) => Rc::ptr_eq(a, b), (a, b) => a == b, } } @@ -153,7 +154,7 @@ pub type ArrayValue = Vec; pub type DictValue = BTreeMap; /// A template value: `[*Hi* there]`. -pub type TemplateValue = Vec; +pub type TemplateValue = Rc>; /// One chunk of a template. /// @@ -177,7 +178,6 @@ pub enum TemplateNode { impl PartialEq for TemplateNode { fn eq(&self, _: &Self) -> bool { - // TODO: Figure out what we want here. false } } @@ -205,13 +205,6 @@ impl TemplateFunc { } } -impl PartialEq for TemplateFunc { - fn eq(&self, _: &Self) -> bool { - // TODO: Figure out what we want here. - false - } -} - impl Deref for TemplateFunc { type Target = dyn Fn(&mut ExecContext); @@ -232,6 +225,7 @@ pub struct FuncValue { /// The string is boxed to make the whole struct fit into 24 bytes, so that /// a [`Value`] fits into 32 bytes. name: Option>, + /// The closure that defines the function. f: Rc Value>, } @@ -251,9 +245,9 @@ impl FuncValue { } impl PartialEq for FuncValue { - fn eq(&self, _: &Self) -> bool { - // TODO: Figure out what we want here. - false + fn eq(&self, other: &Self) -> bool { + // We cast to thin pointers because we don't want to compare vtables. + Rc::as_ptr(&self.f) as *const () == Rc::as_ptr(&other.f) as *const () } } @@ -620,7 +614,7 @@ primitive! { DictValue: "dictionary", Value::Dict } primitive! { TemplateValue: "template", Value::Template, - Value::Str(v) => vec![TemplateNode::Str(v)], + Value::Str(v) => Rc::new(vec![TemplateNode::Str(v)]), } primitive! { FuncValue: "function", Value::Func } diff --git a/src/exec/mod.rs b/src/exec/mod.rs index 8e369d127..fc829676f 100644 --- a/src/exec/mod.rs +++ b/src/exec/mod.rs @@ -152,7 +152,7 @@ impl Exec for Value { impl Exec for TemplateValue { fn exec(&self, ctx: &mut ExecContext) { - for node in self { + for node in self.iter() { node.exec(ctx); } } diff --git a/tests/typ/code/ops.typ b/tests/typ/code/ops.typ index 6d788df19..58fd957fc 100644 --- a/tests/typ/code/ops.typ +++ b/tests/typ/code/ops.typ @@ -112,6 +112,7 @@ --- // Test equality operators. +// Most things compare by value. #test(1 == "hi", false) #test(1 == 1.0, true) #test(30% == 30% + 0cm, true) @@ -124,6 +125,17 @@ #test((a: 2 - 1.0, b: 2) == (b: 2, a: 1), true) #test("a" != "a", false) +// Functions compare by identity. +#test(test == test, true) +#test((() => {}) == (() => {}), false) + +// Templates also compare by identity. +#let t = [a] +#test(t == t, true) +#test([] == [], false) +#test([] == [a], false) +#test([a] == [a], false) + --- // Test comparison operators.