diff --git a/library/src/math/matrix.rs b/library/src/math/matrix.rs index 72ef02531..9a0900bfb 100644 --- a/library/src/math/matrix.rs +++ b/library/src/math/matrix.rs @@ -1,13 +1,12 @@ use super::*; const ROW_GAP: Em = Em::new(0.5); +const COL_GAP: Em = Em::new(0.75); const VERTICAL_PADDING: Ratio = Ratio::new(0.1); /// # Vector /// A column vector. /// -/// _Note:_ Matrices are not yet supported. -/// /// ## Example /// ``` /// $ vec(a, b, c) dot vec(1, 2, 3) @@ -44,7 +43,72 @@ impl VecNode { impl LayoutMath for VecNode { fn layout_math(&self, ctx: &mut MathContext) -> SourceResult<()> { let delim = ctx.styles().get(Self::DELIM); - layout(ctx, &self.0, Align::Center, Some(delim.open()), Some(delim.close())) + let frame = layout_vec_body(ctx, &self.0, Align::Center)?; + layout_delimiters(ctx, frame, Some(delim.open()), Some(delim.close())) + } +} + +/// # Matrix +/// A matrix. +/// +/// ## Example +/// ``` +/// $ mat(1, 2; 3, 4) $ +/// ``` +/// +/// ## Parameters +/// - rows: Array (positional, variadic) +/// An array of arrays with the rows of the matrix. +/// +/// ## Category +/// math +#[func] +#[capable(LayoutMath)] +#[derive(Debug, Hash)] +pub struct MatNode(Vec>); + +#[node] +impl MatNode { + /// The delimiter to use. + /// + /// # Example + /// ``` + /// #set math.mat(delim: "[") + /// $ mat(1, 2; 3, 4) $ + /// ``` + pub const DELIM: Delimiter = Delimiter::Paren; + + fn construct(_: &Vm, args: &mut Args) -> SourceResult { + let mut rows = vec![]; + let mut width = 0; + + let values = args.all::>()?; + if values.iter().all(|spanned| matches!(spanned.v, Value::Content(_))) { + rows = vec![values.into_iter().map(|spanned| spanned.v.display()).collect()]; + } else { + for Spanned { v, span } in values { + let array = v.cast::().at(span)?; + let row: Vec<_> = array.into_iter().map(Value::display).collect(); + width = width.max(row.len()); + rows.push(row); + } + } + + for row in &mut rows { + if row.len() < width { + row.resize(width, Content::empty()); + } + } + + Ok(Self(rows).pack()) + } +} + +impl LayoutMath for MatNode { + fn layout_math(&self, ctx: &mut MathContext) -> SourceResult<()> { + let delim = ctx.styles().get(Self::DELIM); + let frame = layout_mat_body(ctx, &self.0)?; + layout_delimiters(ctx, frame, Some(delim.open()), Some(delim.close())) } } @@ -81,7 +145,8 @@ impl CasesNode { impl LayoutMath for CasesNode { fn layout_math(&self, ctx: &mut MathContext) -> SourceResult<()> { - layout(ctx, &self.0, Align::Left, Some('{'), None) + let frame = layout_vec_body(ctx, &self.0, Align::Left)?; + layout_delimiters(ctx, frame, Some('{'), None) } } @@ -133,29 +198,84 @@ castable! { "||" => Self::DoubleBar, } -/// Layout a matrix. -fn layout( +/// Layout the inner contents of a vector. +fn layout_vec_body( ctx: &mut MathContext, - elements: &[Content], + column: &[Content], align: Align, +) -> SourceResult { + let gap = ROW_GAP.scaled(ctx); + ctx.style(ctx.style.for_denominator()); + let mut flat = vec![]; + for element in column { + flat.push(ctx.layout_row(element)?); + } + ctx.unstyle(); + Ok(stack(ctx, flat, align, gap, 0)) +} + +/// Layout the inner contents of a matrix. +fn layout_mat_body(ctx: &mut MathContext, rows: &[Vec]) -> SourceResult { + let row_gap = ROW_GAP.scaled(ctx); + let col_gap = COL_GAP.scaled(ctx); + + let ncols = rows.first().map_or(0, |row| row.len()); + let nrows = rows.len(); + if ncols == 0 || nrows == 0 { + return Ok(Frame::new(Size::zero())); + } + + let mut rcols = vec![Abs::zero(); ncols]; + let mut rrows = vec![Abs::zero(); nrows]; + + ctx.style(ctx.style.for_denominator()); + let mut cols = vec![vec![]; ncols]; + for (row, rrow) in rows.iter().zip(&mut rrows) { + for ((cell, rcol), col) in row.iter().zip(&mut rcols).zip(&mut cols) { + let cell = ctx.layout_row(cell)?; + rcol.set_max(cell.width()); + rrow.set_max(cell.height()); + col.push(cell); + } + } + ctx.unstyle(); + + let width = rcols.iter().sum::() + col_gap * (ncols - 1) as f64; + let height = rrows.iter().sum::() + row_gap * (nrows - 1) as f64; + let size = Size::new(width, height); + + let mut frame = Frame::new(size); + let mut x = Abs::zero(); + for (col, &rcol) in cols.into_iter().zip(&rcols) { + let points = alignments(&col); + let mut y = Abs::zero(); + for (cell, &rrow) in col.into_iter().zip(&rrows) { + let cell = cell.to_aligned_frame(ctx, &points, Align::Center); + let pos = Point::new( + x + (rcol - cell.width()) / 2.0, + y + (rrow - cell.height()) / 2.0, + ); + frame.push_frame(pos, cell); + y += rrow + row_gap; + } + x += rcol + col_gap; + } + + Ok(frame) +} + +/// Layout the outer wrapper around a vector's or matrices' body. +fn layout_delimiters( + ctx: &mut MathContext, + mut frame: Frame, left: Option, right: Option, ) -> SourceResult<()> { let axis = scaled!(ctx, axis_height); - let gap = ROW_GAP.scaled(ctx); let short_fall = DELIM_SHORT_FALL.scaled(ctx); - - ctx.style(ctx.style.for_denominator()); - let mut rows = vec![]; - for element in elements { - rows.push(ctx.layout_row(element)?); - } - ctx.unstyle(); - - let mut frame = stack(ctx, rows, align, gap, 0); let height = frame.height(); let target = height + VERTICAL_PADDING.of(height); - frame.set_baseline(frame.height() / 2.0 + axis); + frame.set_baseline(height / 2.0 + axis); if let Some(left) = left { ctx.push(GlyphFragment::new(ctx, left).stretch_vertical(ctx, target, short_fall)); diff --git a/library/src/math/mod.rs b/library/src/math/mod.rs index ab67a0d3d..65baf7acd 100644 --- a/library/src/math/mod.rs +++ b/library/src/math/mod.rs @@ -74,6 +74,7 @@ pub fn module(sym: &Module) -> Module { math.def_func::("frac"); math.def_func::("binom"); math.def_func::("vec"); + math.def_func::("mat"); math.def_func::("cases"); // Roots. diff --git a/library/src/math/row.rs b/library/src/math/row.rs index f35d51c85..1bdfc1a3a 100644 --- a/library/src/math/row.rs +++ b/library/src/math/row.rs @@ -12,6 +12,11 @@ impl MathRow { self.0.iter().map(|fragment| fragment.width()).sum() } + pub fn height(&self) -> Abs { + let (ascent, descent) = self.extent(); + ascent + descent + } + pub fn push( &mut self, font_size: Abs, @@ -72,7 +77,16 @@ impl MathRow { self.0.push(fragment); } - pub fn to_frame(mut self, ctx: &MathContext) -> Frame { + pub fn to_frame(self, ctx: &MathContext) -> Frame { + self.to_aligned_frame(ctx, &[], Align::Center) + } + + pub fn to_aligned_frame( + mut self, + ctx: &MathContext, + points: &[Abs], + align: Align, + ) -> Frame { if self.0.iter().any(|frag| matches!(frag, MathFragment::Linebreak)) { let mut frame = Frame::new(Size::zero()); let fragments = std::mem::take(&mut self.0); @@ -86,7 +100,7 @@ impl MathRow { let points = alignments(&rows); for (i, row) in rows.into_iter().enumerate() { let size = frame.size_mut(); - let sub = row.to_line_frame(ctx, &points, Align::Center); + let sub = row.to_line_frame(ctx, &points, align); if i > 0 { size.y += leading; } @@ -97,14 +111,12 @@ impl MathRow { } frame } else { - self.to_line_frame(ctx, &[], Align::Center) + self.to_line_frame(ctx, points, align) } } - pub fn to_line_frame(self, ctx: &MathContext, points: &[Abs], align: Align) -> Frame { - let ascent = self.0.iter().map(MathFragment::ascent).max().unwrap_or_default(); - let descent = self.0.iter().map(MathFragment::descent).max().unwrap_or_default(); - + fn to_line_frame(self, ctx: &MathContext, points: &[Abs], align: Align) -> Frame { + let (ascent, descent) = self.extent(); let size = Size::new(Abs::zero(), ascent + descent); let mut frame = Frame::new(size); let mut x = Abs::zero(); @@ -140,6 +152,12 @@ impl MathRow { frame.size_mut().x = x; frame } + + fn extent(&self) -> (Abs, Abs) { + let ascent = self.0.iter().map(MathFragment::ascent).max().unwrap_or_default(); + let descent = self.0.iter().map(MathFragment::descent).max().unwrap_or_default(); + (ascent, descent) + } } impl From for MathRow diff --git a/library/src/math/stack.rs b/library/src/math/stack.rs index ec233cd91..3a47059c6 100644 --- a/library/src/math/stack.rs +++ b/library/src/math/stack.rs @@ -292,7 +292,7 @@ pub(super) fn stack( let points = alignments(&rows); let rows: Vec<_> = rows .into_iter() - .map(|row| row.to_line_frame(ctx, &points, align)) + .map(|row| row.to_aligned_frame(ctx, &points, align)) .collect(); for row in &rows { diff --git a/src/syntax/parser.rs b/src/syntax/parser.rs index 602d9f2cc..ad81cfa34 100644 --- a/src/syntax/parser.rs +++ b/src/syntax/parser.rs @@ -379,6 +379,8 @@ fn math_args(p: &mut Parser) { let mut namable = true; let mut named = None; + let mut has_arrays = false; + let mut array = p.marker(); let mut arg = p.marker(); while !p.eof() && !p.at(SyntaxKind::Dollar) { @@ -394,6 +396,17 @@ fn math_args(p: &mut Parser) { match p.current_text() { ")" => break, + ";" => { + maybe_wrap_in_math(p, arg, named); + p.wrap(array, SyntaxKind::Array); + p.convert(SyntaxKind::Semicolon); + array = p.marker(); + arg = p.marker(); + namable = true; + named = None; + has_arrays = true; + continue; + } "," => { maybe_wrap_in_math(p, arg, named); p.convert(SyntaxKind::Comma); @@ -418,6 +431,10 @@ fn math_args(p: &mut Parser) { maybe_wrap_in_math(p, arg, named); } + if has_arrays && array != p.marker() { + p.wrap(array, SyntaxKind::Array); + } + if p.at(SyntaxKind::Text) && p.current_text() == ")" { p.convert(SyntaxKind::RightParen); } else {