diff --git a/src/eval/ops.rs b/src/eval/ops.rs index da3432a26..69a0b02b2 100644 --- a/src/eval/ops.rs +++ b/src/eval/ops.rs @@ -1,4 +1,6 @@ -use super::{ArrayValue, DictValue, TemplateNode, Value}; +use std::cmp::Ordering::*; + +use super::{TemplateNode, Value}; use Value::*; /// Apply the plus operator to a value. @@ -164,61 +166,28 @@ pub fn or(lhs: Value, rhs: Value) -> Value { /// Compute whether two values are equal. pub fn eq(lhs: Value, rhs: Value) -> Value { - Bool(value_eq(&lhs, &rhs)) + Bool(lhs.eq(&rhs)) } /// Compute whether two values are equal. pub fn neq(lhs: Value, rhs: Value) -> Value { - Bool(!value_eq(&lhs, &rhs)) -} - -/// Recursively compute whether two values are equal. -fn value_eq(lhs: &Value, rhs: &Value) -> bool { - match (lhs, rhs) { - (&Int(a), &Float(b)) => a as f64 == b, - (&Float(a), &Int(b)) => a == b as f64, - (&Length(a), &Linear(b)) => a == b.abs && b.rel.is_zero(), - (&Relative(a), &Linear(b)) => a == b.rel && b.abs.is_zero(), - (&Linear(a), &Length(b)) => a.abs == b && a.rel.is_zero(), - (&Linear(a), &Relative(b)) => a.rel == b && a.abs.is_zero(), - (Array(a), Array(b)) => array_eq(a, b), - (Dict(a), Dict(b)) => dict_eq(a, b), - (a, b) => a == b, - } -} - -/// Compute whether two arrays are equal. -fn array_eq(a: &ArrayValue, b: &ArrayValue) -> bool { - a.len() == b.len() && a.iter().zip(b).all(|(x, y)| value_eq(x, y)) -} - -/// Compute whether two dictionaries are equal. -fn dict_eq(a: &DictValue, b: &DictValue) -> bool { - a.len() == b.len() - && a.iter().all(|(k, x)| b.get(k).map_or(false, |y| value_eq(x, y))) + Bool(!lhs.eq(&rhs)) } macro_rules! comparison { - ($name:ident, $op:tt) => { + ($name:ident, $($pat:tt)*) => { /// Compute how a value compares with another value. pub fn $name(lhs: Value, rhs: Value) -> Value { - match (lhs, rhs) { - (Int(a), Int(b)) => Bool(a $op b), - (Int(a), Float(b)) => Bool((a as f64) $op b), - (Float(a), Int(b)) => Bool(a $op b as f64), - (Float(a), Float(b)) => Bool(a $op b), - (Angle(a), Angle(b)) => Bool(a $op b), - (Length(a), Length(b)) => Bool(a $op b), - _ => Error, - } + lhs.cmp(&rhs) + .map_or(Value::Error, |x| Value::Bool(matches!(x, $($pat)*))) } }; } -comparison!(lt, <); -comparison!(leq, <=); -comparison!(gt, >); -comparison!(geq, >=); +comparison!(lt, Less); +comparison!(leq, Less | Equal); +comparison!(gt, Greater); +comparison!(geq, Greater | Equal); /// Concatenate two collections. fn concat(mut a: T, b: T) -> T diff --git a/src/eval/value.rs b/src/eval/value.rs index 288e5ed74..84701b3da 100644 --- a/src/eval/value.rs +++ b/src/eval/value.rs @@ -1,4 +1,5 @@ use std::any::Any; +use std::cmp::Ordering; use std::collections::BTreeMap; use std::fmt::{self, Debug, Display, Formatter}; use std::ops::Deref; @@ -57,6 +58,14 @@ impl Value { Self::Template(vec![TemplateNode::Func(TemplateFunc::new(name, f))]) } + /// Try to cast the value into a specific type. + pub fn cast(self) -> CastResult + where + T: Cast, + { + T::cast(self) + } + /// The name of the stored value's type. pub fn type_name(&self) -> &'static str { match self { @@ -79,12 +88,37 @@ impl Value { } } - /// Try to cast the value into a specific type. - pub fn cast(self) -> CastResult - where - T: Cast, - { - T::cast(self) + /// Recursively compute whether two values are equal. + pub fn eq(&self, rhs: &Self) -> bool { + match (self, rhs) { + (&Self::Int(a), &Self::Float(b)) => a as f64 == b, + (&Self::Float(a), &Self::Int(b)) => a == b as f64, + (&Self::Length(a), &Self::Linear(b)) => a == b.abs && b.rel.is_zero(), + (&Self::Relative(a), &Self::Linear(b)) => a == b.rel && b.abs.is_zero(), + (&Self::Linear(a), &Self::Length(b)) => a.abs == b && a.rel.is_zero(), + (&Self::Linear(a), &Self::Relative(b)) => a.rel == b && a.abs.is_zero(), + (Self::Array(a), Self::Array(b)) => { + a.len() == b.len() && a.iter().zip(b).all(|(x, y)| x.eq(y)) + } + (Self::Dict(a), Self::Dict(b)) => { + a.len() == b.len() + && a.iter().all(|(k, x)| b.get(k).map_or(false, |y| x.eq(y))) + } + (a, b) => a == b, + } + } + + /// Compare a value with another value. + pub fn cmp(&self, rhs: &Self) -> Option { + match (self, rhs) { + (Self::Int(a), Self::Int(b)) => a.partial_cmp(b), + (Self::Int(a), Self::Float(b)) => (*a as f64).partial_cmp(b), + (Self::Float(a), Self::Int(b)) => a.partial_cmp(&(*b as f64)), + (Self::Float(a), Self::Float(b)) => a.partial_cmp(b), + (Self::Angle(a), Self::Angle(b)) => a.partial_cmp(b), + (Self::Length(a), Self::Length(b)) => a.partial_cmp(b), + _ => None, + } } }