From ccb4753e24eefb5b8cf2acd6d25f0e2afce1c022 Mon Sep 17 00:00:00 2001 From: Laurenz Date: Mon, 30 May 2022 10:31:31 +0200 Subject: [PATCH] Find optionally takes function instead of value --- src/eval/array.rs | 59 +++++++++++++++++++++++++++++--- src/eval/methods.rs | 4 ++- src/eval/value.rs | 6 ++-- tests/typ/utility/collection.typ | 1 + 4 files changed, 61 insertions(+), 9 deletions(-) diff --git a/src/eval/array.rs b/src/eval/array.rs index 840c0aef6..fda6f3904 100644 --- a/src/eval/array.rs +++ b/src/eval/array.rs @@ -3,9 +3,9 @@ use std::fmt::{self, Debug, Formatter, Write}; use std::ops::{Add, AddAssign}; use std::sync::Arc; -use super::{ops, Args, Func, Machine, Value}; +use super::{ops, Args, Cast, Func, Machine, Value}; use crate::diag::{At, StrResult, TypResult}; -use crate::syntax::Spanned; +use crate::syntax::{Span, Spanned}; use crate::util::ArcExt; /// Create a new [`Array`] from values. @@ -120,10 +120,19 @@ impl Array { /// Transform each item in the array with a function. pub fn map(&self, vm: &mut Machine, f: Spanned) -> TypResult { + let enumerate = f.v.argc() == Some(2); Ok(self .iter() .cloned() - .map(|item| f.v.call(vm, Args::new(f.span, [item]))) + .enumerate() + .map(|(i, item)| { + let mut args = Args::new(f.span, []); + if enumerate { + args.push(f.span, Value::Int(i as i64)); + } + args.push(f.span, item); + f.v.call(vm, args) + }) .collect::>()?) } @@ -157,8 +166,14 @@ impl Array { } /// Return the index of the element if it is part of the array. - pub fn find(&self, value: Value) -> Option { - self.0.iter().position(|x| *x == value).map(|i| i as i64) + pub fn find(&self, vm: &mut Machine, target: Target) -> TypResult> { + for (i, item) in self.iter().enumerate() { + if target.matches(vm, item)? { + return Ok(Some(i as i64)); + } + } + + Ok(None) } /// Join all values in the array, optionally with separator and last @@ -304,3 +319,37 @@ impl<'a> IntoIterator for &'a Array { self.iter() } } + +/// Something that can be found. +pub enum Target { + /// A bare value. + Value(Value), + /// A function that returns a boolean. + Func(Func, Span), +} + +impl Target { + /// Whether the value is the search target. + pub fn matches(&self, vm: &mut Machine, other: &Value) -> TypResult { + match self { + Self::Value(value) => Ok(value == other), + Self::Func(f, span) => f + .call(vm, Args::new(*span, [other.clone()]))? + .cast::() + .at(*span), + } + } +} + +impl Cast> for Target { + fn is(_: &Spanned) -> bool { + true + } + + fn cast(value: Spanned) -> StrResult { + Ok(match value.v { + Value::Func(v) => Self::Func(v, value.span), + v => Self::Value(v), + }) + } +} diff --git a/src/eval/methods.rs b/src/eval/methods.rs index e8296d236..d425e007a 100644 --- a/src/eval/methods.rs +++ b/src/eval/methods.rs @@ -38,7 +38,9 @@ pub fn call( "map" => Value::Array(array.map(vm, args.expect("function")?)?), "filter" => Value::Array(array.filter(vm, args.expect("function")?)?), "flatten" => Value::Array(array.flatten()), - "find" => array.find(args.expect("value")?).map_or(Value::None, Value::Int), + "find" => array + .find(vm, args.expect("value or function")?)? + .map_or(Value::None, Value::Int), "join" => { let sep = args.eat()?; let last = args.named("last")?; diff --git a/src/eval/value.rs b/src/eval/value.rs index 9b36812ac..22f8d3cfb 100644 --- a/src/eval/value.rs +++ b/src/eval/value.rs @@ -713,9 +713,9 @@ castable! { castable! { Pattern, Expected: "function, string or regular expression", - Value::Func(func) => Pattern::Node(func.node()?), - Value::Str(text) => Pattern::text(&text), - @regex: Regex => Pattern::Regex(regex.clone()), + Value::Func(func) => Self::Node(func.node()?), + Value::Str(text) => Self::text(&text), + @regex: Regex => Self::Regex(regex.clone()), } #[cfg(test)] diff --git a/tests/typ/utility/collection.typ b/tests/typ/utility/collection.typ index 3414f0a9e..42c369061 100644 --- a/tests/typ/utility/collection.typ +++ b/tests/typ/utility/collection.typ @@ -32,6 +32,7 @@ // Test the `find` method. #test(("Hi", "❤️", "Love").find("❤️"), 1) #test(("Bye", "💘", "Apart").find("❤️"), none) +#test(("A", "B", "CDEF", "G").find(v => v.len() > 2), 2) --- // Test the `slice` method.