Skip to content

Commit

Permalink
optimize more partition patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
kaikalii committed Nov 25, 2024
1 parent de23572 commit 98f81c1
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 56 deletions.
106 changes: 71 additions & 35 deletions src/algorithm/loops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,15 @@ pub fn do_(ops: Ops, env: &mut Uiua) -> UiuaResult {
Ok(())
}

pub fn split_by(scalar: bool, env: &mut Uiua) -> UiuaResult {
pub fn split_by(f: SigNode, scalar: bool, env: &mut Uiua) -> UiuaResult {
let delim = env.pop(1)?;
let haystack = env.pop(2)?;
if haystack.rank() > 1
if f.sig.args != 1
|| haystack.rank() > 1
|| delim.rank() > 1
|| scalar && !(delim.rank() == 0 || delim.rank() == 1 && delim.row_count() == 1)
|| matches!(delim, Value::Complex(_))
|| matches!(haystack, Value::Complex(_))
{
let mask = if scalar {
delim.is_ne(haystack.clone(), 0, 0, env)?
Expand All @@ -266,43 +269,79 @@ pub fn split_by(scalar: bool, env: &mut Uiua) -> UiuaResult {
};
env.push(haystack);
env.push(mask);
return partition(
SigNode {
node: Node::Prim(Primitive::Box, 0),
sig: Signature::new(1, 1),
return partition(f, env);
}
if let Some(Primitive::Box) = f.node.as_primitive() {
let val = haystack.generic_bin_ref(
&delim,
|a, b| a.split_by(b, |data| Boxed(data.into())),
|a, b| a.split_by(b, |data| Boxed(data.into())),
|_, _| unreachable!("split by complex"),
|a, b| a.split_by(b, |data| Boxed(data.into())),
|a, b| a.split_by(b, |data| Boxed(data.into())),
|a, b| {
env.error(format!(
"Cannot split {} by {}",
a.type_name_plural(),
b.type_name_plural()
))
},
)?;
env.push(val);
} else {
let parts = haystack.generic_bin_ref(
&delim,
|a, b| a.split_by(b, Value::from),
|a, b| a.split_by(b, Value::from),
|_, _| unreachable!("split by complex"),
|a, b| a.split_by(b, Value::from),
|a, b| a.split_by(b, Value::from),
|a, b| {
env.error(format!(
"Cannot split {} by {}",
a.type_name_plural(),
b.type_name_plural()
))
},
env,
);
)?;
if let Some(Primitive::Identity) = f.node.as_primitive() {
let val = Value::from_row_values(parts, env)?;
env.push(val);
} else {
let mut outputs = multi_output(f.sig.outputs, Vec::new());
env.without_fill(|env| -> UiuaResult {
for part in parts {
env.push(part);
env.exec(f.clone())?;
for i in 0..f.sig.outputs {
outputs[i].push(env.pop("split by output")?);
}
}
Ok(())
})?;
for outputs in outputs.into_iter().rev() {
let val = Value::from_row_values(outputs, env)?;
env.push(val);
}
}
}
let val = haystack.generic_bin_ref(
&delim,
|a, b| a.split_by(b, env),
|a, b| a.split_by(b, env),
|a, b| a.split_by(b, env),
|a, b| a.split_by(b, env),
|a, b| a.split_by(b, env),
|a, b| {
env.error(format!(
"Cannot split {} by {}",
a.type_name_plural(),
b.type_name_plural()
))
},
)?;
env.push(val);
Ok(())
}

impl<T: ArrayValue> Array<T>
where
Value: From<CowSlice<T>>,
{
fn split_by(&self, delim: &Self, _env: &Uiua) -> UiuaResult<Array<Boxed>> {
fn split_by<R: Clone>(
&self,
delim: &Self,
f: impl Fn(CowSlice<T>) -> R,
) -> UiuaResult<EcoVec<R>> {
let haystack = self.data.as_slice();
let delim_slice = delim.data.as_slice();
Ok(if delim.rank() == 0 || delim.row_count() == 1 {
let mut curr = 0;
let mut data = EcoVec::new();
let mut curr = 0;
let mut data = EcoVec::new();
if delim.rank() == 0 || delim.row_count() == 1 {
let delim = &delim_slice[0];
for slice in haystack.split(|elem| elem.array_eq(delim)) {
if slice.is_empty() {
Expand All @@ -311,13 +350,10 @@ where
}
let start = curr;
let end = start + slice.len();
data.push(Boxed(self.data.slice(start..end).into()));
data.push(f(self.data.slice(start..end)));
curr = end + 1;
}
data.into()
} else {
let mut curr = 0;
let mut data = EcoVec::new();
while curr < haystack.len() {
let prev_end = haystack[curr..]
.windows(delim_slice.len())
Expand All @@ -329,11 +365,11 @@ where
curr = next_start;
continue;
}
data.push(Boxed(self.data.slice(curr..prev_end).into()));
data.push(f(self.data.slice(curr..prev_end)));
curr = next_start;
}
data.into()
})
}
Ok(data)
}
}

Expand Down
47 changes: 30 additions & 17 deletions src/compile/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,43 +221,56 @@ opt!(
struct SplitByOpt;
impl Optimization for SplitByOpt {
fn match_and_replace(&self, nodes: &mut EcoVec<Node>) -> bool {
fn is_par_box(node: &Node) -> bool {
fn par_f(node: &Node) -> Option<SigNode> {
let Mod(Partition, args, _) = node else {
return false;
return None;
};
let [f] = args.as_slice() else {
return false;
return None;
};
matches!(f.node, Prim(Box, _))
Some(f.clone())
}
for i in 0..nodes.len() {
match &nodes[i..] {
[Mod(By, args, span), last, ..]
if is_par_box(last)
&& matches!(args.as_slice(), [f]
if matches!(args.as_slice(), [f]
if matches!(f.node, Prim(Ne, _))) =>
{
replace_nodes(nodes, i, 2, ImplPrim(SplitByScalar, *span));
let Some(f) = par_f(last) else {
continue;
};
replace_nodes(nodes, i, 2, ImplMod(SplitByScalar, eco_vec![f], *span));
break;
}
[Mod(By, args, span), Prim(Not, _), last, ..]
if is_par_box(last)
&& matches!(args.as_slice(), [f]
if matches!(args.as_slice(), [f]
if matches!(f.node, Prim(Mask, _))) =>
{
replace_nodes(nodes, i, 3, ImplPrim(SplitBy, *span));
let Some(f) = par_f(last) else {
continue;
};
replace_nodes(nodes, i, 3, ImplMod(SplitBy, eco_vec![f], *span));
break;
}
[Prim(Dup, span), Push(delim), Prim(Ne, _), last, ..] if is_par_box(last) => {
let new =
Node::from_iter([Push(delim.clone()), ImplPrim(SplitByScalar, *span)]);
[Prim(Dup, span), Push(delim), Prim(Ne, _), last, ..] => {
let Some(f) = par_f(last) else {
continue;
};
let new = Node::from_iter([
Push(delim.clone()),
ImplMod(SplitByScalar, eco_vec![f], *span),
]);
replace_nodes(nodes, i, 4, new);
break;
}
[Prim(Dup, span), Push(delim), Prim(Mask, _), Prim(Not, _), last, ..]
if is_par_box(last) =>
{
let new = Node::from_iter([Push(delim.clone()), ImplPrim(SplitBy, *span)]);
[Prim(Dup, span), Push(delim), Prim(Mask, _), Prim(Not, _), last, ..] => {
let Some(f) = par_f(last) else {
continue;
};
let new = Node::from_iter([
Push(delim.clone()),
ImplMod(SplitBy, eco_vec![f], *span),
]);
replace_nodes(nodes, i, 5, new);
break;
}
Expand Down
4 changes: 2 additions & 2 deletions src/primitive/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3690,8 +3690,8 @@ impl_primitive!(
(1, CountUnique),
(1(2)[3], AstarFirst),
(1[3], AstarPop),
(2, SplitByScalar),
(2, SplitBy),
(2[1], SplitByScalar),
(2[1], SplitBy),
// Implementation details
(1[2], RepeatWithInverse),
(2(1), ValidateType),
Expand Down
10 changes: 8 additions & 2 deletions src/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1514,8 +1514,6 @@ impl ImplPrimitive {
env.push(random());
}
ImplPrimitive::CountUnique => env.monadic_ref(Value::count_unique)?,
ImplPrimitive::SplitByScalar => loops::split_by(true, env)?,
ImplPrimitive::SplitBy => loops::split_by(false, env)?,
ImplPrimitive::MatchPattern => {
let expected = env.pop(1)?;
let got = env.pop(2)?;
Expand Down Expand Up @@ -1729,6 +1727,14 @@ impl ImplPrimitive {
env.exec(g.node)?;
env.push_all(f_outputs);
}
ImplPrimitive::SplitByScalar => {
let [f] = get_ops(ops, env)?;
loops::split_by(f, true, env)?;
}
ImplPrimitive::SplitBy => {
let [f] = get_ops(ops, env)?;
loops::split_by(f, false, env)?;
}
ImplPrimitive::EachSub(n)
| ImplPrimitive::RowsSub(n)
| ImplPrimitive::InventorySub(n) => {
Expand Down

0 comments on commit 98f81c1

Please sign in to comment.