Add map function to Dict for transforming key-value pairs and update tests (#6006)

This commit is contained in:
Wesley Yang 2025-04-03 17:45:51 +08:00
parent ed2106e28d
commit cc9b69f13f
2 changed files with 103 additions and 2 deletions

View File

@ -3,15 +3,17 @@ use std::hash::{Hash, Hasher};
use std::ops::{Add, AddAssign};
use std::sync::Arc;
use comemo::Tracked;
use ecow::{eco_format, EcoString};
use indexmap::IndexMap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use typst_syntax::is_ident;
use typst_utils::ArcExt;
use crate::diag::{Hint, HintedStrResult, StrResult};
use crate::diag::{Hint, HintedStrResult, SourceResult, StrResult};
use crate::engine::Engine;
use crate::foundations::{
array, cast, func, repr, scope, ty, Array, Module, Repr, Str, Value,
array, cast, func, repr, scope, ty, Array, Context, Func, Module, Repr, Str, Value,
};
/// Create a new [`Dict`] from key-value pairs.
@ -254,6 +256,77 @@ impl Dict {
.map(|(k, v)| Value::Array(array![k.clone(), v.clone()]))
.collect()
}
/// Produces a new dictionary or array by transforming each key-value pair with the given function.
///
/// If the mapper function returns a pair (array of length 2), the result will be a new dictionary.
/// Otherwise, the result will be an array containing all mapped values.
///
/// ```example
/// #let prices = (apples: 2, oranges: 3, bananas: 1.5)
/// #prices.map(key => key.len()) \
/// #prices.map((key, price) => (key, price * 1.1))
/// ```
#[func]
pub fn map(
self,
engine: &mut Engine,
context: Tracked<Context>,
/// The function to apply to each key-value pair.
/// The function can either take a single parameter (receiving a pair as array of length 2),
/// or two parameters (receiving key and value separately).
mapper: Func,
) -> SourceResult<Value> {
let mut dict_result = IndexMap::new();
let mut array_result = Vec::new();
let mut is_dict = true;
// try to check the number of parameters, if not, use array form
let use_two_args = mapper.params().map_or(false, |params| params.len() >= 2);
for (key, value) in self {
// choose how to pass parameters based on the function signature
let mapped = if use_two_args {
mapper.call(engine, context, [
Value::Str(key.clone()),
value.clone(),
])?
} else {
mapper.call(engine, context, [
Value::Array(array![Value::Str(key.clone()), value]),
])?
};
// check if the result is a dictionary key-value pair
if let Value::Array(arr) = &mapped {
if arr.len() == 2 {
if let Value::Str(k) = &arr.as_slice()[0] {
if is_dict {
dict_result.insert(k.clone(), arr.as_slice()[1].clone());
continue;
}
}
}
}
// if the result is not a key-value pair, switch the result type to array
if is_dict {
is_dict = false;
// convert the collected dictionary result to array items
for (k, v) in dict_result.drain(..) {
array_result.push(Value::Array(array![Value::Str(k), v]));
}
}
array_result.push(mapped);
}
if is_dict {
Ok(Value::Dict(Dict::from(dict_result)))
} else {
Ok(Value::Array(array_result.into_iter().collect()))
}
}
}
/// A value that can be cast to dictionary.

View File

@ -23,6 +23,34 @@
test(world, "world")
}
--- dict-map ---
// Test the map function
#let dict = (a: 1, b: 2, c: 3)
// test map return new dict
#test(
dict.map(((key, value)) => (key, value * 2)),
(a: 2, b: 4, c: 6)
)
// test map empty dict
#test(
(:).map(((key, value)) => (key, value * 2)),
(:)
)
// test map return array
#test(
dict.map(pair => pair.at(0) + ": " + str(pair.at(1))),
("a: 1", "b: 2", "c: 3")
)
// test map return array(different return type)
#test(
dict.map(((key, value)) => if value > 1 { (key, value * 2) } else { "key smaller than 1: " + key }),
("key smaller than 1: a", ("b", 4), ("c", 6))
)
--- dict-missing-field ---
// Error: 6-13 dictionary does not contain key "invalid"
#(:).invalid