diff --git a/library/src/compute/calc.rs b/library/src/compute/calc.rs index 76f58c919..c480cb640 100644 --- a/library/src/compute/calc.rs +++ b/library/src/compute/calc.rs @@ -1,5 +1,6 @@ //! Calculations and processing of numeric values. +use std::cmp; use std::cmp::Ordering; use std::ops::Rem; @@ -23,6 +24,9 @@ pub fn module() -> Module { scope.define("cosh", cosh); scope.define("tanh", tanh); scope.define("log", log); + scope.define("fact", fact); + scope.define("perm", perm); + scope.define("binom", binom); scope.define("floor", floor); scope.define("ceil", ceil); scope.define("round", round); @@ -404,6 +408,128 @@ pub fn log( Value::Float(result) } +/// Calculate the factorial of a number. +/// +/// ## Example +/// ```example +/// #calc.fact(5) +/// ``` +/// +/// Display: Factorial +/// Category: calculate +/// Returns: integer +#[func] +pub fn fact( + /// The number whose factorial to calculate. Must be positive. + number: Spanned, +) -> Value { + let result = factorial_range(1, number.v).and_then(|r| i64::try_from(r).ok()); + + match result { + None => bail!(number.span, "the factorial result is too large"), + Some(s) => Value::Int(s), + } +} + +/// Calculates the product of a range of numbers. Used to calculate permutations. +/// Returns None if the result is larger than `u64::MAX` +fn factorial_range(start: u64, end: u64) -> Option { + // By convention + if end + 1 < start { + return Some(0); + } + + let mut count: u64 = 1; + let real_start: u64 = cmp::max(1, start); + + for i in real_start..=end { + count = count.checked_mul(i)?; + } + Some(count) +} + +/// Calculate a permutation. +/// +/// ## Example +/// ```example +/// #calc.perm(10,5) +/// ``` +/// +/// Display: Permutation +/// Category: calculate +/// Returns: integer +#[func] +pub fn perm( + /// The base number. Must be positive. + base: Spanned, + /// The number of permutations. Must be positive. + numbers: Spanned, +) -> Value { + let base_parsed = base.v; + let numbers_parsed = numbers.v; + + let result = if base_parsed + 1 > numbers_parsed { + factorial_range(base_parsed - numbers_parsed + 1, base_parsed) + .and_then(|value| i64::try_from(value).ok()) + } else { + // By convention + Some(0) + }; + + match result { + None => bail!(base.span, "the permutation result is too large"), + Some(s) => Value::Int(s), + } +} + +/// Calculate a binomial coefficient. +/// +/// ## Example +/// ```example +/// #calc.binom(10,5) +/// ``` +/// +/// Display: Permutation +/// Category: calculate +/// Returns: integer +#[func] +pub fn binom( + /// The upper coefficient. Must be positive + n: Spanned, + /// The lower coefficient. Must be positive. + k: Spanned, +) -> Value { + let result = binomial(n.v, k.v).and_then(|raw| i64::try_from(raw).ok()); + + match result { + None => bail!(n.span, "the binomial result is too large"), + Some(r) => Value::Int(r), + } +} + +/// Calculates a binomial coefficient, with `n` the upper coefficient and `k` the lower coefficient. +/// Returns `None` if the result is larger than `u64::MAX` +fn binomial(n: u64, k: u64) -> Option { + if k > n { + return Some(0); + } + + // By symmetry + let real_k = cmp::min(n - k, k); + + if real_k == 0 { + return Some(1); + } + + let mut result: u64 = 1; + + for i in 0..real_k { + result = result.checked_mul(n - i).and_then(|r| r.checked_div(i + 1))?; + } + + Some(result) +} + /// Round a number down to the nearest integer. /// /// If the number is already an integer, it is returned unchanged. diff --git a/src/eval/cast.rs b/src/eval/cast.rs index 7b3066338..e2ae115fd 100644 --- a/src/eval/cast.rs +++ b/src/eval/cast.rs @@ -114,6 +114,21 @@ cast_to_value! { v: usize => Value::Int(v as i64) } +cast_from_value! { + u64, + int: i64 => int.try_into().map_err(|_| { + if int < 0 { + "number must be at least zero" + } else { + "number too large" + } + })?, +} + +cast_to_value! { + v: u64 => Value::Int(v as i64) +} + cast_from_value! { NonZeroUsize, int: i64 => int diff --git a/tests/typ/compute/calc.typ b/tests/typ/compute/calc.typ index 18c2e2c96..9d52355c8 100644 --- a/tests/typ/compute/calc.typ +++ b/tests/typ/compute/calc.typ @@ -114,6 +114,34 @@ // Error: 11-13 the result is not a real number #calc.log(10, base: -1) +--- +// Test the `fact` function. +#test(calc.fact(0), 1) +#test(calc.fact(5), 120) + +--- +// Error: 12-14 the factorial result is too large +#calc.fact(21) + +--- +// Test the `perm` function. +#test(calc.perm(0, 0), 1) +#test(calc.perm(5, 3), 60) +#test(calc.perm(5, 5), 120) +#test(calc.perm(5, 6), 0) + +--- +// Error: 12-14 the permutation result is too large +#calc.perm(21, 21) + +--- +// Test the `binom` function. +#test(calc.binom(0, 0), 1) +#test(calc.binom(5, 3), 10) +#test(calc.binom(5, 5), 1) +#test(calc.binom(5, 6), 0) +#test(calc.binom(6, 2), 15) + --- // Error: 10-12 expected at least one value #calc.min()