diff --git a/crates/typst-library/src/foundations/dict.rs b/crates/typst-library/src/foundations/dict.rs index c93670c1d..a4afbb759 100644 --- a/crates/typst-library/src/foundations/dict.rs +++ b/crates/typst-library/src/foundations/dict.rs @@ -3,15 +3,17 @@ use std::hash::{Hash, Hasher}; use std::ops::{Add, AddAssign}; use std::sync::Arc; +use comemo::Tracked; use ecow::{eco_format, EcoString}; use indexmap::IndexMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use typst_syntax::is_ident; use typst_utils::ArcExt; -use crate::diag::{Hint, HintedStrResult, StrResult}; +use crate::diag::{Hint, HintedStrResult, SourceResult, StrResult}; +use crate::engine::Engine; use crate::foundations::{ - array, cast, func, repr, scope, ty, Array, Module, Repr, Str, Value, + array, cast, func, repr, scope, ty, Array, Context, Func, Module, Repr, Str, Value, }; /// Create a new [`Dict`] from key-value pairs. @@ -254,6 +256,101 @@ impl Dict { .map(|(k, v)| Value::Array(array![k.clone(), v.clone()])) .collect() } + + /// Produces a new dictionary or array by transforming each key-value pair with the given function. + /// + /// If the mapper function returns a pair (array of length 2), the result will be a new dictionary. + /// Otherwise, the result will be an array containing all mapped values. + /// + /// ```example + /// #let prices = (apples: 2, oranges: 3, bananas: 1.5) + /// #prices.map(pair => pair.at(0).len()) + /// #prices.map((key, value) => (key, value * 1.1)) + /// ``` + #[func] + pub fn map( + self, + engine: &mut Engine, + context: Tracked, + /// The function to apply to each key-value pair. + /// The function can either take a single parameter (receiving a pair as array of length 2), + /// or two parameters (receiving key and value separately). + /// Parameters exceeding two will be ignored. + mapper: Func, + ) -> SourceResult { + let mut dict_result = IndexMap::new(); + let mut array_result = Vec::new(); + let mut is_dict = true; + + // try to check the number of parameters, if not, use array form + let mut first_pair = true; + let mut use_single_arg = false; + + for (key, value) in self { + let mapped = if first_pair { + // try two calling ways for the first pair + first_pair = false; + + // try to call with two parameters + let result = mapper.call( + engine, + context, + [Value::Str(key.clone()), value.clone()], + ); + + // if failed, try to call with one parameter + if result.is_err() { + use_single_arg = true; + mapper.call( + engine, + context, + [Value::Array(array![Value::Str(key.clone()), value])], + )? + } else { + result? + } + } else if use_single_arg { + // try to call with one parameter + mapper.call( + engine, + context, + [Value::Array(array![Value::Str(key.clone()), value])], + )? + } else { + // try to call with two parameters + mapper.call(engine, context, [Value::Str(key.clone()), value.clone()])? + }; + + // check if the result is a dictionary key-value pair + if let Value::Array(arr) = &mapped { + if arr.len() == 2 { + if let Value::Str(k) = &arr.as_slice()[0] { + if is_dict { + dict_result.insert(k.clone(), arr.as_slice()[1].clone()); + continue; + } + } + } + } + + // if the result is not a key-value pair, switch the result type to array + if is_dict { + is_dict = false; + // convert the collected dictionary result to array items + for (k, v) in dict_result.drain(..) { + array_result.push(Value::Array(array![Value::Str(k), v])); + } + } + + array_result.push(mapped); + } + + if is_dict { + Ok(Value::Dict(Dict::from(dict_result))) + } else { + Ok(Value::Array(array_result.into_iter().collect())) + } + } } /// A value that can be cast to dictionary. diff --git a/tests/suite/foundations/dict.typ b/tests/suite/foundations/dict.typ index af9ad5e1a..8a2f5ffe5 100644 --- a/tests/suite/foundations/dict.typ +++ b/tests/suite/foundations/dict.typ @@ -23,6 +23,34 @@ test(world, "world") } +--- dict-map --- +// Test the map function +#let dict = (a: 1, b: 2, c: 3) + +// test map return new dict +#test( + dict.map((key, value) => (key, value * 2)), + (a: 2, b: 4, c: 6) +) + +// test map empty dict +#test( + (:).map((key, value) => (key, value * 2)), + (:) +) + +// test map return array +#test( + dict.map(pair => pair.at(0) + ": " + str(pair.at(1))), + ("a: 1", "b: 2", "c: 3") +) + +// test map return array(different return type) +#test( + dict.map((key, value) => if value > 1 { (key, value * 2) } else { "key smaller than 1: " + key }), + ("key smaller than 1: a", ("b", 4), ("c", 6)) +) + --- dict-missing-field --- // Error: 6-13 dictionary does not contain key "invalid" #(:).invalid