From eb6708105ff89b8f80cd05e7c501f3fab688f19d Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:38:43 -0700 Subject: [PATCH 01/11] Update lib.rs --- src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 33d4675..e548546 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,5 +3,9 @@ pub mod parser; pub mod format; pub mod simplify; + pub mod token; pub mod tokenizer; + +pub mod solver; +mod tests; From 363418b9d4b6f34d7e39eaa3a5740e7669bf0f1a Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:39:29 -0700 Subject: [PATCH 02/11] Update ast.rs --- src/ast.rs | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 040f2f3..58b4b6e 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,6 +1,7 @@ // src/ast.rs - -#[derive(Debug, Clone)] +use std::fmt; +use std::ops::Div; +#[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Variable(String), @@ -8,4 +9,38 @@ pub enum Expr { Sub(Box, Box), Mul(Box, Box), Div(Box, Box), + Pow(Box, Box), + Undefined, +} + +impl Div for Expr { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (Expr::Mul(ll, lr), Expr::Mul(rl, rr)) if *ll == *rl => Expr::Div(lr, rr), + + (l, r) => Expr::Div(Box::new(l), Box::new(r)), + } + } +} +impl fmt::Display for Expr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Expr::Number(n) => { + if n.fract() == 0.0 { + write!(f, "{}", n.round()) + } else { + write!(f, "{}", n) + } + } + Expr::Variable(var) => write!(f, "{}", var), + Expr::Add(lhs, rhs) => write!(f, "({} + {})", lhs, rhs), + Expr::Sub(lhs, rhs) => write!(f, "({} - {})", lhs, rhs), + Expr::Mul(lhs, rhs) => write!(f, "({} * {})", lhs, rhs), + Expr::Div(lhs, rhs) => write!(f, "({} / {})", lhs, rhs), + Expr::Undefined => write!(f, "undefined"), + Expr::Pow(_, _) => todo!(), + } + } } From ddc156d02979de7f285d7bba026633d472850874 Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:39:48 -0700 Subject: [PATCH 03/11] Update format.rs --- src/format.rs | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/format.rs b/src/format.rs index a73d7e2..5168d36 100644 --- a/src/format.rs +++ b/src/format.rs @@ -1,5 +1,3 @@ -// src/format.rs - use crate::ast::Expr; pub fn format_expr(expr: &Expr) -> String { @@ -26,14 +24,36 @@ pub fn format_expr(expr: &Expr) -> String { Expr::Div(lhs, rhs) => { format!("({})/({})", format_expr(lhs), format_expr(rhs)) } + Expr::Pow(lhs, rhs) => { + format!("{}^{}", format_factor(lhs), format_factor(rhs)) + } + Expr::Undefined => "undefined".to_string(), } } fn format_factor(expr: &Expr) -> String { match expr { - Expr::Add(_, _) | Expr::Sub(_, _) => { + Expr::Number(n) => { + if n.fract() == 0.0 { + format!("{}", n.round()) + } else { + format!("{}", n) + } + } + Expr::Variable(var) => var.clone(), + Expr::Mul(lhs, rhs) => { + let left = format_factor(lhs); + let right = format_factor(rhs); + format!("{}{}", left, right) + } + Expr::Div(lhs, rhs) => { + format!("({})/({})", format_expr(lhs), format_expr(rhs)) + } + Expr::Pow(lhs, rhs) => { + format!("({})^({})", format_expr(lhs), format_expr(rhs)) + } + Expr::Add(_, _) | Expr::Sub(_, _) | Expr::Undefined => { format!("({})", format_expr(expr)) } - _ => format_expr(expr), } } From 4afac0d7ef8b0c07c04fd63fc6c4c543deab7111 Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:40:13 -0700 Subject: [PATCH 04/11] Update parser.rs --- src/parser.rs | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/parser.rs b/src/parser.rs index 8b00785..2c17fda 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,5 +1,3 @@ -// src/parser.rs - use crate::ast::Expr; use crate::token::Token; @@ -43,24 +41,24 @@ impl Parser { } fn parse_multiplication(&mut self) -> Result { - let mut expr = self.parse_factor()?; + let mut expr = self.parse_exponentiation()?; loop { if let Some(token) = self.current_token() { match token { Token::Star => { self.advance(); - let rhs = self.parse_factor()?; + let rhs = self.parse_exponentiation()?; expr = Expr::Mul(Box::new(expr), Box::new(rhs)); } Token::Slash => { self.advance(); - let rhs = self.parse_factor()?; + let rhs = self.parse_exponentiation()?; expr = Expr::Div(Box::new(expr), Box::new(rhs)); } - // Handle implicit multiplication + // Handle implicit multiplication (e.g., 2x, (a)(b)) Token::Number(_) | Token::Variable(_) | Token::LParen => { - let rhs = self.parse_factor()?; + let rhs = self.parse_exponentiation()?; expr = Expr::Mul(Box::new(expr), Box::new(rhs)); } _ => break, @@ -73,6 +71,21 @@ impl Parser { Ok(expr) } + fn parse_exponentiation(&mut self) -> Result { + let expr = self.parse_factor()?; + self.parse_exponentiation_rhs(expr) + } + + fn parse_exponentiation_rhs(&mut self, left: Expr) -> Result { + if let Some(Token::Pow) = self.current_token() { + self.advance(); + let right = self.parse_exponentiation()?; + Ok(Expr::Pow(Box::new(left), Box::new(right))) + } else { + Ok(left) + } + } + fn parse_factor(&mut self) -> Result { if let Some(token) = self.current_token() { match token { From b187ad0b342f0e8628002c801f9bd5343f581947 Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:42:52 -0700 Subject: [PATCH 05/11] Update simplify.rs --- src/simplify.rs | 583 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 526 insertions(+), 57 deletions(-) diff --git a/src/simplify.rs b/src/simplify.rs index b816796..be465c8 100644 --- a/src/simplify.rs +++ b/src/simplify.rs @@ -1,120 +1,589 @@ // src/simplify.rs use crate::ast::Expr; +use crate::solver::Solver; use std::collections::HashMap; +/// Type alias for variable exponents. +/// Uses `i32` to represent integer exponents. +pub type VarExponents = HashMap; +/// Builds an expression from a map of variables to their exponents. +/// Positive exponents are placed in the numerator, negative in the denominator. +fn build_expr_from_var_map(var_map: &VarExponents) -> Expr { + let mut numerator_terms = Vec::new(); + let mut denominator_terms = Vec::new(); + for (var, exp) in var_map { + if *exp > 0 { + for _ in 0..*exp { + numerator_terms.push(Expr::Variable(var.clone())); + } + } else if *exp < 0 { + for _ in 0..(-*exp) { + denominator_terms.push(Expr::Variable(var.clone())); + } + } + } + + // Build numerator expression + let numerator = if numerator_terms.is_empty() { + Expr::Number(1.0) + } else { + let mut expr = numerator_terms.remove(0); + for term in &numerator_terms { + expr = Expr::Mul(Box::new(expr), Box::new(term.clone())); + } + expr + }; + + // Build denominator expression + let denominator = if denominator_terms.is_empty() { + Expr::Number(1.0) + } else { + let mut expr = denominator_terms.remove(0); + for term in &denominator_terms { + expr = Expr::Mul(Box::new(expr), Box::new(term.clone())); + } + expr + }; + + // Construct the final expression + if denominator_terms.is_empty() && numerator_terms.is_empty() && numerator == Expr::Number(1.0) + { + Expr::Number(1.0) + } else if denominator_terms.is_empty() { + numerator + } else if numerator == Expr::Number(1.0) { + denominator + } else { + Expr::Div(Box::new(numerator), Box::new(denominator)) + } +} + +/// Simplifies a given mathematical expression by expanding, collecting like terms, +/// and reducing coefficients iteratively until no further simplifications can be made. pub fn simplify(expr: Expr) -> Expr { - let expanded_expr = expand_expr(&expr); - collect_like_terms(&expanded_expr) + let mut current = expr; + let max_iterations = 100; // Prevent infinite loops TODO change this + let mut iterations = 0; + + while iterations < max_iterations { + let expanded = expand_expr(¤t); + let collected = collect_like_terms(&expanded); + + if collected == current { + break; + } + + current = collected; + iterations += 1; + } + + current } +/// Expands the expression by applying distributive laws where applicable. fn expand_expr(expr: &Expr) -> Expr { match expr { - Expr::Add(lhs, rhs) => { - let left = expand_expr(lhs); - let right = expand_expr(rhs); - Expr::Add(Box::new(left), Box::new(right)) - } - Expr::Sub(lhs, rhs) => { - let left = expand_expr(lhs); - let right = expand_expr(rhs); - Expr::Sub(Box::new(left), Box::new(right)) + Expr::Add(lhs, rhs) => Expr::Add(Box::new(expand_expr(lhs)), Box::new(expand_expr(rhs))), + Expr::Pow(base, exponent) => { + let expanded_base = expand_expr(base); + let expanded_exponent = expand_expr(exponent); + Expr::Pow(Box::new(expanded_base), Box::new(expanded_exponent)) } + Expr::Sub(lhs, rhs) => Expr::Sub(Box::new(expand_expr(lhs)), Box::new(expand_expr(rhs))), Expr::Mul(lhs, rhs) => { let left = expand_expr(lhs); let right = expand_expr(rhs); - match (left, right) { - // Distribute multiplication over addition + match (&left, &right) { + // Distribute multiplication over addition: a * (b + c) = a*b + a*c (Expr::Add(a, b), c) => { - let left_mul = expand_expr(&Expr::Mul(a, Box::new(c.clone()))); - let right_mul = expand_expr(&Expr::Mul(b, Box::new(c))); - Expr::Add(Box::new(left_mul), Box::new(right_mul)) + let left_mul = Expr::Mul(a.clone(), Box::new(c.clone())); + let right_mul = Expr::Mul(b.clone(), Box::new(c.clone())); + Expr::Add( + Box::new(expand_expr(&left_mul)), + Box::new(expand_expr(&right_mul)), + ) + } + // Distribute multiplication over subtraction: a * (b - c) = a*b - a*c + (Expr::Sub(a, b), c) => { + let left_mul = Expr::Mul(a.clone(), Box::new(c.clone())); + let right_mul = Expr::Mul(b.clone(), Box::new(c.clone())); + Expr::Sub( + Box::new(expand_expr(&left_mul)), + Box::new(expand_expr(&right_mul)), + ) } (a, Expr::Add(b, c)) => { - let left_mul = expand_expr(&Expr::Mul(Box::new(a.clone()), b)); - let right_mul = expand_expr(&Expr::Mul(Box::new(a), c)); - Expr::Add(Box::new(left_mul), Box::new(right_mul)) + let left_mul = Expr::Mul(Box::new(a.clone()), b.clone()); + let right_mul = Expr::Mul(Box::new(a.clone()), c.clone()); + Expr::Add( + Box::new(expand_expr(&left_mul)), + Box::new(expand_expr(&right_mul)), + ) } - (a, b) => Expr::Mul(Box::new(a), Box::new(b)), + (a, Expr::Sub(b, c)) => { + let left_mul = Expr::Mul(Box::new(a.clone()), b.clone()); + let right_mul = Expr::Mul(Box::new(a.clone()), c.clone()); + Expr::Sub( + Box::new(expand_expr(&left_mul)), + Box::new(expand_expr(&right_mul)), + ) + } + + _ => Expr::Mul(Box::new(left), Box::new(right)), } } + // Expr::Div(lhs, rhs) => { + // match **lhs { + // // Distribute division over addition and subtraction + // Expr::Add(ref left, ref right) => Expr::Add( + // Box::new(Expr::Div( + // Box::new(expand_expr(left)), + // Box::new(expand_expr(rhs).clone()), + // )), + // Box::new(Expr::Div( + // Box::new(expand_expr(right)), + // Box::new(expand_expr(rhs).clone()), + // )), + // ), + // Expr::Sub(ref left, ref right) => Expr::Sub( + // Box::new(Expr::Div( + // Box::new(expand_expr(left)), + // Box::new(expand_expr(rhs).clone()), + // )), + // Box::new(Expr::Div( + // Box::new(expand_expr(right)), + // Box::new(expand_expr(rhs).clone()), + // )), + // ), + // // For multiplication and other operations, proceed as before + // _ => Expr::Div(Box::new(expand_expr(lhs)), Box::new(expand_expr(rhs))), + // } + // } + Expr::Div(lhs, rhs) => { + // Ensure the expansion covers the distribution over multiplication + let left_expanded = expand_expr(lhs); + let right_expanded = expand_expr(rhs); + match (left_expanded, right_expanded) { + // handle divisions where both sides involve multiplication + (Expr::Mul(l_lhs, l_rhs), Expr::Mul(r_lhs, r_rhs)) => { + // Compare and simplify directly if variables and coefficients allow + // TODO make this make + if l_lhs == r_lhs { + Expr::Div(l_rhs, r_rhs) + } else if l_rhs == r_rhs { + Expr::Div(l_lhs, r_lhs) + } else { + // Default back to division if not directly simplifiable + // better way TODO + Expr::Div( + Box::new(Expr::Mul(l_lhs, l_rhs)), + Box::new(Expr::Mul(r_lhs, r_rhs)), + ) + } + } + // Default division expansion + (l, r) => Expr::Div(Box::new(l), Box::new(r)), + } + } + // Expr::Div(lhs, rhs) => Expr::Div(Box::new(expand_expr(lhs)), Box::new(expand_expr(rhs))), _ => expr.clone(), } } + +/// Collects like terms in the expression and combines their coefficients. +/// This version now cancels common terms in division. fn collect_like_terms(expr: &Expr) -> Expr { - let mut terms = HashMap::new(); - collect_terms(expr, &mut terms, 1.0); + let mut terms_map: HashMap = HashMap::new(); + collect_terms(expr, &mut terms_map, 1.0, &mut HashMap::new()); - // Build the simplified expression from collected terms - let mut exprs = vec![]; + // Check for undefined expressions + if terms_map.contains_key("Undefined") { + return Expr::Undefined; + } - for (var, coeff) in terms { + // Separate positive and negative terms + let mut positive_terms = Vec::new(); + let mut negative_terms = Vec::new(); + + for (var_key, coeff) in terms_map { if coeff == 0.0 { continue; } - let term = if var.is_empty() { + if coeff > 0.0 { + positive_terms.push((var_key, coeff)); + } else { + negative_terms.push((var_key, -coeff)); + } + } + + // Build the expression + let mut exprs = Vec::new(); + + // Add positive terms + for (var_key, coeff) in positive_terms { + let term_expr = if var_key.is_empty() { Expr::Number(coeff) - } else if coeff == 1.0 { - Expr::Variable(var) } else { - Expr::Mul(Box::new(Expr::Number(coeff)), Box::new(Expr::Variable(var))) + // Parse the variable key back into variables and exponents + let factors = var_key.split('*').filter(|s| !s.is_empty()); + let mut expr_term: Option = None; + + for factor in factors { + let parts: Vec<&str> = factor.split('^').collect(); + let var = parts[0].to_string(); + let exp: i32 = parts.get(1).and_then(|e| e.parse().ok()).unwrap_or(1); + + let var_expr = Expr::Variable(var.clone()); + let powered_var = if exp == 1 { + var_expr + } else { + + let mut expr_power = Expr::Variable(var.clone()); + for _ in 1..exp { + expr_power = + Expr::Mul(Box::new(expr_power), Box::new(Expr::Variable(var.clone()))); + } + expr_power + }; + + expr_term = Some(match expr_term { + None => powered_var, + Some(existing) => Expr::Mul(Box::new(existing), Box::new(powered_var)), + }); + } + + + if coeff == 1.0 { + expr_term.unwrap() + } else { + Expr::Mul(Box::new(Expr::Number(coeff)), Box::new(expr_term.unwrap())) + } }; - exprs.push(term); + exprs.push(term_expr); } - // Combine terms into a single expression - if exprs.is_empty() { - Expr::Number(0.0) + // Start with the first positive term or zero + let mut result = if let Some(first) = exprs.pop() { + first } else { - let mut result = exprs.remove(0); - for e in exprs { - result = Expr::Add(Box::new(result), Box::new(e)); - } - result + Expr::Number(0.0) + }; + + // Add remaining positive terms + for term in exprs { + result = Expr::Add(Box::new(result), Box::new(term)); + } + + // Subtract negative terms + for (var_key, coeff) in negative_terms { + let term_expr = if var_key.is_empty() { + Expr::Number(coeff) + } else { + // Parse the variable key back into variables and exponents + let factors = var_key.split('*').filter(|s| !s.is_empty()); + let mut expr_term: Option = None; + + for factor in factors { + let parts: Vec<&str> = factor.split('^').collect(); + let var = parts[0].to_string(); + let exp: i32 = parts.get(1).and_then(|e| e.parse().ok()).unwrap_or(1); + + let var_expr = Expr::Variable(var.clone()); + let powered_var = if exp == 1 { + var_expr + } else { + // Handle exponents by repeated multiplication + let mut expr_power = Expr::Variable(var.clone()); + for _ in 1..exp { + expr_power = + Expr::Mul(Box::new(expr_power), Box::new(Expr::Variable(var.clone()))); + } + expr_power + }; + + expr_term = Some(match expr_term { + None => powered_var, + Some(existing) => Expr::Mul(Box::new(existing), Box::new(powered_var)), + }); + } + + + if coeff == 1.0 { + expr_term.unwrap() + } else { + Expr::Mul(Box::new(Expr::Number(coeff)), Box::new(expr_term.unwrap())) + } + }; + + result = Expr::Sub(Box::new(result), Box::new(term_expr)); } + + // Further simplify the combined expression + reduce_coefficients(result) +} + +/// Parses a serialized variable key into a `VarExponents` map. +/// Example: "x^2*y^-1" -> {"x": 2, "y": -1} +fn parse_var_key(var_key: &str) -> VarExponents { + let mut var_map = HashMap::new(); + if var_key.is_empty() { + return var_map; + } + for factor in var_key.split('*') { + let parts: Vec<&str> = factor.split('^').collect(); + let var = parts[0].to_string(); + let exp: i32 = parts.get(1).and_then(|e| e.parse().ok()).unwrap_or(1); + *var_map.entry(var).or_insert(0) += exp; + } + var_map } -fn collect_terms(expr: &Expr, terms: &mut HashMap, coeff: f64) { +/// Recursively collects terms from the expression and populates the `terms` map. +/// - `expr`: The current expression to process. +/// - `terms`: A map from variable keys to their coefficients. +/// - `coeff`: The current coefficient multiplier. +/// - `current_vars`: The current mapping of variables to their exponents. +fn collect_terms( + expr: &Expr, + terms: &mut HashMap, + coeff: f64, + _current_vars: &mut VarExponents, // todo: use this in future +) { match expr { + Expr::Pow(base, exponent) => { + if let Expr::Number(exp) = **exponent { + if exp.fract() != 0.0 || exp < 0.0 { + *terms.entry("Undefined".to_string()).or_insert(0.0) += + coeff * std::f64::INFINITY; + return; + } + + let exp_u32 = exp as u32; + let mut base_vars = VarExponents::new(); + // Use `traverse_expr_vars` instead of `traverse_expr` + if let Err(_) = Solver::traverse_expr_vars(base, 1.0, &mut base_vars) { + *terms.entry("Undefined".to_string()).or_insert(0.0) += + coeff * std::f64::INFINITY; + return; + } + + if base_vars.len() != 1 { + *terms.entry("Undefined".to_string()).or_insert(0.0) += + coeff * std::f64::INFINITY; + return; + } + + let (var, var_exp) = base_vars.iter().next().unwrap(); + let new_exp = var_exp * exp_u32 as i32; + let var_key = format!("{}^{}", var, new_exp); + + *terms.entry(var_key).or_insert(0.0) += coeff; + } else { + *terms.entry("Undefined".to_string()).or_insert(0.0) += coeff * std::f64::INFINITY; + } + } Expr::Add(lhs, rhs) => { - collect_terms(lhs, terms, coeff); - collect_terms(rhs, terms, coeff); + collect_terms(lhs, terms, coeff, _current_vars); + collect_terms(rhs, terms, coeff, _current_vars); } Expr::Sub(lhs, rhs) => { - collect_terms(lhs, terms, coeff); - collect_terms(rhs, terms, -coeff); + collect_terms(lhs, terms, coeff, _current_vars); + collect_terms(rhs, terms, -coeff, _current_vars); } Expr::Mul(lhs, rhs) => { - let mut new_coeff = coeff; - let mut vars = Vec::new(); - collect_factors(lhs, &mut new_coeff, &mut vars); - collect_factors(rhs, &mut new_coeff, &mut vars); - vars.sort(); // Ensure consistent ordering - let var_key = vars.join("*"); - *terms.entry(var_key).or_insert(0.0) += new_coeff; + let (lhs_coeff, lhs_vars) = extract_coeff_and_vars(lhs, 1.0); + let (rhs_coeff, rhs_vars) = extract_coeff_and_vars(rhs, 1.0); + let total_coeff = coeff * lhs_coeff * rhs_coeff; + + // Merge exponents + let mut merged_vars = lhs_vars.clone(); + for (var, exp) in rhs_vars.iter() { + *merged_vars.entry(var.clone()).or_insert(0) += exp; + } + + let var_key = serialize_vars(&merged_vars); + *terms.entry(var_key).or_insert(0.0) += total_coeff; + } + Expr::Div(lhs, rhs) => { + // Extract numerator coefficients and variables + let (numerator_coeff, numerator_vars) = extract_coeff_and_vars(lhs, 1.0); + // Extract denominator coefficients and variables + let (denominator_coeff, denominator_vars) = extract_coeff_and_vars(rhs, 1.0); + + // Prevent division by zero + if denominator_coeff == 0.0 { + // Represent division by zero as undefined + *terms.entry("Undefined".to_string()).or_insert(0.0) += coeff * std::f64::INFINITY; + return; + } + + // Compute the total coefficient + let total_coeff = coeff * numerator_coeff / denominator_coeff; + + // Adjust variables: numerator_vars - denominator_vars + let mut adjusted_vars = numerator_vars.clone(); + for (var, exp) in denominator_vars.iter() { + *adjusted_vars.entry(var.clone()).or_insert(0) -= exp; + } + + let var_key = serialize_vars(&adjusted_vars); + *terms.entry(var_key).or_insert(0.0) += total_coeff; } Expr::Number(n) => { *terms.entry(String::new()).or_insert(0.0) += coeff * n; } Expr::Variable(var) => { - *terms.entry(var.clone()).or_insert(0.0) += coeff; + let mut vars = HashMap::new(); + *vars.entry(var.clone()).or_insert(0) += 1; + let var_key = serialize_vars(&vars); + *terms.entry(var_key).or_insert(0.0) += coeff; + } + Expr::Undefined => { + // Represent undefined expressions + *terms.entry("Undefined".to_string()).or_insert(0.0) += coeff * std::f64::INFINITY; } - _ => {} } } -fn collect_factors(expr: &Expr, coeff: &mut f64, vars: &mut Vec) { +/// Extracts the coefficient and variables from an expression. +/// Returns a tuple containing the coefficient and a map of variables to their exponents. +fn extract_coeff_and_vars(expr: &Expr, current_coeff: f64) -> (f64, VarExponents) { + let mut coeff = current_coeff; + let mut vars = HashMap::new(); + match expr { Expr::Number(n) => { - *coeff *= *n; + coeff *= n; } Expr::Variable(var) => { - vars.push(var.clone()); + *vars.entry(var.clone()).or_insert(0) += 1; } Expr::Mul(lhs, rhs) => { - collect_factors(lhs, coeff, vars); - collect_factors(rhs, coeff, vars); + let (lhs_coeff, mut lhs_vars) = extract_coeff_and_vars(lhs, 1.0); + let (rhs_coeff, mut rhs_vars) = extract_coeff_and_vars(rhs, 1.0); + coeff *= lhs_coeff * rhs_coeff; + + // Merge exponents + for (var, exp) in rhs_vars.drain() { + *lhs_vars.entry(var.clone()).or_insert(0) += exp; + } + + vars = lhs_vars; + } + Expr::Div(lhs, rhs) => { + let (lhs_coeff, lhs_vars) = extract_coeff_and_vars(lhs, 1.0); + let (rhs_coeff, rhs_vars) = extract_coeff_and_vars(rhs, 1.0); + + if rhs_coeff == 0.0 { + // Represent division by zero as undefined + coeff *= std::f64::INFINITY; + } else { + coeff *= lhs_coeff / rhs_coeff; + + // Adjust exponents: lhs_vars - rhs_vars + for (var, exp) in rhs_vars.iter() { + *vars.entry(var.clone()).or_insert(0) -= exp; + } + } + } + Expr::Undefined => { + // Represent undefined expressions + coeff *= std::f64::INFINITY; } _ => {} } + + (coeff, vars) +} + +/// Serializes the variables and their exponents into a sorted, canonical string key. +/// Example: {"x": 2, "y": 1} -> "x^2*y" +fn serialize_vars(vars: &VarExponents) -> String { + let mut serialized = Vec::new(); + let mut sorted_vars: Vec<_> = vars.iter().collect(); + sorted_vars.sort_by(|a, b| a.0.cmp(b.0)); + + for (var, exp) in sorted_vars { + if *exp != 0 { + if *exp == 1 { + serialized.push(var.clone()); + } else { + serialized.push(format!("{}^{}", var, exp)); + } + } + } + + serialized.join("*") +} + +/// Reduces coefficients in the expression by combining numerical terms. +fn reduce_coefficients(expr: Expr) -> Expr { + match expr { + Expr::Add(lhs, rhs) => { + let left = reduce_coefficients(*lhs); + let right = reduce_coefficients(*rhs); + match (left.clone(), right.clone()) { + (Expr::Number(a), Expr::Number(b)) => Expr::Number(a + b), + (Expr::Number(a), _) if a == 0.0 => right, + (_, Expr::Number(b)) if b == 0.0 => left, + (Expr::Undefined, _) | (_, Expr::Undefined) => Expr::Undefined, + (a, b) => Expr::Add(Box::new(a), Box::new(b)), + } + } + Expr::Sub(lhs, rhs) => { + let left = reduce_coefficients(*lhs); + let right = reduce_coefficients(*rhs); + match (left, right.clone()) { + (Expr::Number(a), Expr::Number(b)) => Expr::Number(a - b), + (a, Expr::Number(b)) if b == 0.0 => a, + (Expr::Number(a), _) if a == 0.0 => { + Expr::Sub(Box::new(Expr::Number(0.0)), Box::new(right)) + } + (Expr::Undefined, _) | (_, Expr::Undefined) => Expr::Undefined, + (a, b) => Expr::Sub(Box::new(a), Box::new(b)), + } + } + Expr::Mul(lhs, rhs) => { + let left = reduce_coefficients(*lhs); + let right = reduce_coefficients(*rhs); + match (left, right) { + (Expr::Number(a), Expr::Number(b)) => Expr::Number(a * b), + (Expr::Number(0.0), _) | (_, Expr::Number(0.0)) => Expr::Number(0.0), + (Expr::Number(1.0), b) => b, + (a, Expr::Number(1.0)) => a, + (Expr::Undefined, _) | (_, Expr::Undefined) => Expr::Undefined, + (a, b) => Expr::Mul(Box::new(a), Box::new(b)), + } + } + Expr::Pow(base, exponent) => { + let reduced_base = reduce_coefficients(*base); + let reduced_exponent = reduce_coefficients(*exponent); + match (reduced_base, reduced_exponent) { + (Expr::Number(b), Expr::Number(e)) if e.fract() == 0.0 => Expr::Number(b.powf(e)), + (Expr::Undefined, _) | (_, Expr::Undefined) => Expr::Undefined, + (b, e) => Expr::Pow(Box::new(b), Box::new(e)), + } + } + Expr::Div(lhs, rhs) => { + let left = reduce_coefficients(*lhs); + let right = reduce_coefficients(*rhs); + match (left.clone(), right.clone()) { + (Expr::Number(a), Expr::Number(b)) if b != 0.0 => Expr::Number(a / b), + (Expr::Number(_), Expr::Number(b)) if b == 0.0 => Expr::Undefined, + (Expr::Mul(ll, lr), Expr::Mul(rl, rr)) if ll == rl => *lr / *rr, + (Expr::Mul(ll, lr), Expr::Mul(rl, rr)) if lr == rr => *ll / *rl, + (l, Expr::Number(r)) if r != 0.0 => { + Expr::Mul(Box::new(l), Box::new(Expr::Number(1.0 / r))) + } + + (a, Expr::Number(1.0)) => a, + (Expr::Undefined, _) | (_, Expr::Undefined) => Expr::Undefined, + (a, b) => Expr::Div(Box::new(a), Box::new(b)), + _ => Expr::Div(Box::new(left), Box::new(right)), + } + } + Expr::Undefined => Expr::Undefined, + _ => expr, + } } From 40ced708d51ba5f276b552187f561c4e9a0eca16 Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:43:16 -0700 Subject: [PATCH 06/11] Create solver.rs --- src/solver.rs | 655 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 655 insertions(+) create mode 100644 src/solver.rs diff --git a/src/solver.rs b/src/solver.rs new file mode 100644 index 0000000..f1dc872 --- /dev/null +++ b/src/solver.rs @@ -0,0 +1,655 @@ +// src/solver.rs + +use crate::ast::Expr; +use crate::format::format_expr; +use crate::parser::Parser; +use crate::simplify::simplify; +use crate::simplify::VarExponents; +use crate::tokenizer::tokenize; +use std::collections::HashMap; +use std::fmt; +/// Represents a mathematical equation with left-hand side (lhs) and right-hand side (rhs). +#[derive(Debug, Clone, PartialEq)] +pub struct Equation { + pub lhs: Expr, + pub rhs: Expr, +} + +impl Equation { + /// Creates a new Equation. + pub fn new(lhs: Expr, rhs: Expr) -> Self { + Equation { lhs, rhs } + } +} + +impl fmt::Display for Equation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} = {}", self.lhs, self.rhs) + } +} + +/// Represents the solution to an equation. +#[derive(Debug, Clone, PartialEq)] +pub enum Solution { + NoSolution, + InfiniteSolutions, + SingleSolution(f64), + MultipleSolutions(Vec), + Undefined, +} + +impl fmt::Display for Solution { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Solution::NoSolution => write!(f, "No solution"), + Solution::InfiniteSolutions => write!(f, "Infinite solutions"), + Solution::SingleSolution(val) => write!(f, "x = {}", val), + Solution::MultipleSolutions(vals) => { + write!(f, "Solutions: {:?}", vals) + } + Solution::Undefined => write!(f, "Undefined"), + } + } +} + +/// Represents a polynomial with coefficients mapped to their corresponding exponents. +#[derive(Debug, Clone, PartialEq)] +pub struct Polynomial { + /// Maps exponents to their coefficients. + /// Example: {0: 2.0, 1: -3.0, 2: 1.0} represents 1x² - 3x + 2 + pub terms: HashMap, +} + +impl Polynomial { + /// Creates a new, empty Polynomial. + pub fn new() -> Self { + Polynomial { + terms: HashMap::new(), + } + } + + /// Adds a term to the polynomial. + pub fn add_term(&mut self, exponent: u32, coefficient: f64) { + *self.terms.entry(exponent).or_insert(0.0) += coefficient; + // Remove the term if the coefficient becomes zero + if let Some(coeff) = self.terms.get(&exponent) { + if *coeff == 0.0 { + self.terms.remove(&exponent); + } + } + } + + /// Retrieves the coefficient for a given exponent. + pub fn get_coefficient(&self, exponent: u32) -> f64 { + *self.terms.get(&exponent).unwrap_or(&0.0) + } + + /// Determines the degree of the polynomial. + pub fn degree(&self) -> u32 { + self.terms.keys().cloned().max().unwrap_or(0) + } + + /// Normalizes the polynomial by removing zero coefficients. + pub fn normalize(&mut self) { + self.terms.retain(|_, &mut coeff| coeff != 0.0); + } +} + +/// Errors that can occur during polynomial parsing and solving. +#[derive(Debug, Clone)] +pub enum SolverError { + InvalidEquation, + NonPolynomial, + UnsupportedDegree(u32), + DivisionByZero, + Undefined, + ParsingError(String), +} + +impl fmt::Display for SolverError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SolverError::InvalidEquation => write!(f, "Invalid equation format."), + SolverError::NonPolynomial => write!(f, "The equation is not a polynomial."), + SolverError::UnsupportedDegree(degree) => { + write!(f, "Unsupported polynomial degree: {}", degree) + } + SolverError::DivisionByZero => write!(f, "Division by zero encountered."), + SolverError::Undefined => write!(f, "The equation is undefined."), + SolverError::ParsingError(msg) => write!(f, "Parsing error: {}", msg), + } + } +} + +impl std::error::Error for SolverError {} + +/// Represents the type of polynomial equation. +#[derive(Debug, Clone, PartialEq)] +pub enum PolynomialType { + Linear, // Degree 1 + Quadratic, // Degree 2 + Cubic, // Degree 3 +} + +/// Main Solver struct employing the Builder Pattern. +pub struct Solver { + equation: Equation, +} + +impl Solver { + /// Creates a new Solver with the given Equation. + pub fn builder() -> SolverBuilder { + SolverBuilder::new() + } + + /// Solves the equation and returns the solution. + pub fn solve(&self) -> Result { + // Step 1: Parse the equation into a Polynomial (lhs - rhs = 0) + let mut polynomial = Self::parse_equation(&self.equation)?; + + // Step 2: Normalize the polynomial + polynomial.normalize(); + + // Step 3: Determine the degree + let degree = polynomial.degree(); + + // Step 4: Solve based on degree + match degree { + 0 => { + // Constant equation: c = 0 + let c = polynomial.get_coefficient(0); + if c == 0.0 { + Ok(Solution::InfiniteSolutions) + } else { + Ok(Solution::NoSolution) + } + } + 1 => { + // Linear equation: ax + b = 0 + Self::solve_linear(&polynomial) + } + 2 => { + // Quadratic equation: ax² + bx + c = 0 + Self::solve_quadratic(&polynomial) + } + 3 => { + // Cubic equation: ax³ + bx² + cx + d = 0 + Self::solve_cubic(&polynomial) + } + _ => Err(SolverError::UnsupportedDegree(degree)), + } + } + + /// Parses the Equation into a Polynomial by subtracting rhs from lhs. + fn parse_equation(equation: &Equation) -> Result { + let mut polynomial = Polynomial::new(); + + // Parse lhs and add terms + let lhs_terms = Self::collect_terms(&equation.lhs)?; + for (exp, coeff) in lhs_terms { + polynomial.add_term(exp, coeff); + } + + // Parse rhs and subtract terms + let rhs_terms = Self::collect_terms(&equation.rhs)?; + for (exp, coeff) in rhs_terms { + polynomial.add_term(exp, -coeff); + } + + // Normalize the polynomial + polynomial.normalize(); + + Ok(polynomial) + } + + /// Collects terms from an Expr and returns a HashMap of exponents to coefficients. + fn collect_terms(expr: &Expr) -> Result, SolverError> { + let mut terms: HashMap = HashMap::new(); + Self::traverse_expr(expr, 1.0, &mut terms)?; + Ok(terms) // Return `terms` of type `HashMap` + } + + pub fn traverse_expr( + expr: &Expr, + multiplier: f64, + terms: &mut HashMap, + ) -> Result<(), SolverError> { + match expr { + Expr::Pow(base, exponent) => { + // Ensure the exponent is a number + if let Expr::Number(exp) = **exponent { + if exp.fract() != 0.0 || exp < 0.0 { + return Err(SolverError::NonPolynomial); + } + + let exp_u32 = exp as u32; + + // Traverse the base expression with the updated multiplier + let mut base_terms = HashMap::new(); + Self::traverse_expr(base, multiplier, &mut base_terms)?; + + // Multiply the exponents by exp_u32 + for (existing_exp, coeff) in base_terms { + let new_exp = existing_exp * exp_u32; + *terms.entry(new_exp).or_insert(0.0) += coeff; + } + } else { + return Err(SolverError::NonPolynomial); + } + } + Expr::Add(lhs, rhs) => { + Self::traverse_expr(lhs, multiplier, terms)?; + Self::traverse_expr(rhs, multiplier, terms)?; + } + Expr::Sub(lhs, rhs) => { + Self::traverse_expr(lhs, multiplier, terms)?; + Self::traverse_expr(rhs, -multiplier, terms)?; + } + Expr::Mul(lhs, rhs) => { + let mut lhs_terms = HashMap::new(); + let mut rhs_terms = HashMap::new(); + Self::traverse_expr(lhs, multiplier, &mut lhs_terms)?; + Self::traverse_expr(rhs, multiplier, &mut rhs_terms)?; + + let mut product_terms = HashMap::new(); + for (exp1, coeff1) in lhs_terms { + for (exp2, coeff2) in &rhs_terms { + let combined_exp = exp1 + *exp2; + *product_terms.entry(combined_exp).or_insert(0.0) += coeff1 * coeff2; + } + } + + for (exp, coeff) in product_terms { + *terms.entry(exp).or_insert(0.0) += coeff; + } + } + Expr::Div(lhs, rhs) => { + return Err(SolverError::ParsingError("Unsupported".to_string())); + } + Expr::Variable(_) => { + *terms.entry(1).or_insert(0.0) += multiplier; + } + Expr::Number(n) => { + // Constant term with exponent 0 + *terms.entry(0).or_insert(0.0) += multiplier * n; + } + _ => { + // Handle other expression types or return an error + return Err(SolverError::NonPolynomial); + } + } + Ok(()) + } + pub fn traverse_expr_vars( + expr: &Expr, + multiplier: f64, + vars: &mut VarExponents, + ) -> Result<(), SolverError> { + match expr { + Expr::Pow(base, exponent) => { + if let Expr::Number(exp) = **exponent { + if exp.fract() != 0.0 || exp < 0.0 { + return Err(SolverError::NonPolynomial); + } + + let exp_u32 = exp as u32; + + let mut base_vars = VarExponents::new(); + Self::traverse_expr_vars(base, multiplier, &mut base_vars)?; + + // Ensure only one variable is present + if base_vars.len() != 1 { + return Err(SolverError::NonPolynomial); + } + + let (var, var_exp) = base_vars.iter().next().unwrap(); + let new_exp = var_exp * exp_u32 as i32; + let var_key = format!("{}^{}", var, new_exp); + + *vars.entry(var_key).or_insert(0) += 1; + } else { + return Err(SolverError::NonPolynomial); + } + } + Expr::Add(lhs, rhs) => { + Self::traverse_expr_vars(lhs, multiplier, vars)?; + Self::traverse_expr_vars(rhs, multiplier, vars)?; + } + Expr::Sub(lhs, rhs) => { + Self::traverse_expr_vars(lhs, multiplier, vars)?; + Self::traverse_expr_vars(rhs, -multiplier, vars)?; + } + Expr::Mul(lhs, rhs) => { + // For simplicity, handle multiplication by constants only + // More complex handling (e.g., multiple variables) can be implemented as needed + if let Expr::Number(n) = **lhs { + Self::traverse_expr_vars(rhs, multiplier * n, vars)?; + } else if let Expr::Number(n) = **rhs { + Self::traverse_expr_vars(lhs, multiplier * n, vars)?; + } else { + return Err(SolverError::NonPolynomial); + } + } + Expr::Div(lhs, rhs) => { + // Handle division by constants only + if let Expr::Number(n) = **rhs { + Self::traverse_expr_vars(lhs, multiplier / n, vars)?; + } else { + return Err(SolverError::NonPolynomial); + } + } + Expr::Variable(var) => { + *vars.entry(var.clone()).or_insert(0) += multiplier as i32; + } + Expr::Number(n) => { + // Constants can be treated as variables with exponent 0 + *vars.entry("constant".to_string()).or_insert(0) += 1; + } + _ => { + // Handle other expression types or return an error + return Err(SolverError::NonPolynomial); + } + } + Ok(()) + } + + /// traverse expressions for solver (collecting polynomial terms) + pub fn traverse_expr_solver( + expr: &Expr, + multiplier: f64, + terms: &mut HashMap, + ) -> Result<(), SolverError> { + match expr { + Expr::Pow(base, exponent) => { + // Handle exponents correctly here + if let Expr::Number(exp) = **exponent { + if exp.fract() != 0.0 || exp < 0.0 { + return Err(SolverError::NonPolynomial); + } + + let exp_u32 = exp as u32; + + // Traverse the base expression + let mut base_terms = HashMap::new(); + Self::traverse_expr_solver(base, multiplier, &mut base_terms)?; + + // Multiply the exponents by exp_u32 + for (existing_exp, coeff) in base_terms { + let new_exp = existing_exp * exp_u32; + *terms.entry(new_exp).or_insert(0.0) += coeff; + } + } else { + return Err(SolverError::NonPolynomial); + } + } + Expr::Add(lhs, rhs) => { + Self::traverse_expr_solver(lhs, multiplier, terms)?; + Self::traverse_expr_solver(rhs, multiplier, terms)?; + } + Expr::Sub(lhs, rhs) => { + Self::traverse_expr_solver(lhs, multiplier, terms)?; + Self::traverse_expr_solver(rhs, -multiplier, terms)?; + } + Expr::Mul(lhs, rhs) => { + let mut lhs_terms = HashMap::new(); + let mut rhs_terms = HashMap::new(); + Self::traverse_expr_solver(lhs, multiplier, &mut lhs_terms)?; + Self::traverse_expr_solver(rhs, multiplier, &mut rhs_terms)?; + + let mut product_terms = HashMap::new(); + for (exp1, coeff1) in lhs_terms { + for (exp2, coeff2) in &rhs_terms { + let combined_exp = exp1 + exp2; + *product_terms.entry(combined_exp).or_insert(0.0) += coeff1 * coeff2; + } + } + + for (exp, coeff) in product_terms { + *terms.entry(exp).or_insert(0.0) += coeff; + } + } + Expr::Div(lhs, rhs) => { + // Handle division by constants only + if let Expr::Number(n) = **rhs { + Self::traverse_expr_solver(lhs, multiplier / n, terms)?; + } else { + return Err(SolverError::NonPolynomial); + } + } + Expr::Variable(_) => { + // Assume single variable 'x' with exponent 1 + *terms.entry(1).or_insert(0.0) += multiplier; + } + Expr::Number(n) => { + // Constant term with exponent 0 + *terms.entry(0).or_insert(0.0) += multiplier * n; + } + _ => { + // Handle other expression types or return an error + return Err(SolverError::NonPolynomial); + } + } + Ok(()) + } + + /// Helper function to collect terms from a sub-expression. + fn _collect_terms_expr(expr: &Expr) -> Result, SolverError> { + let mut terms = HashMap::new(); + Self::traverse_expr(expr, 1.0, &mut terms)?; + Ok(terms) + } + + /// Solves a linear equation ax + b = 0. + fn solve_linear(polynomial: &Polynomial) -> Result { + let a = polynomial.get_coefficient(1); + let b = polynomial.get_coefficient(0); + + if a == 0.0 { + if b == 0.0 { + Ok(Solution::InfiniteSolutions) + } else { + Ok(Solution::NoSolution) + } + } else { + let x = -b / a; + Ok(Solution::SingleSolution(x)) + } + } + + /// Solves a quadratic equation ax² + bx + c = 0 using the discriminant. + fn solve_quadratic(polynomial: &Polynomial) -> Result { + let a = polynomial.get_coefficient(2); + let b = polynomial.get_coefficient(1); + let c = polynomial.get_coefficient(0); + + if a == 0.0 { + // Degenerates to linear equation + let linear_poly = Polynomial { + terms: { + let mut terms = HashMap::new(); + terms.insert(1, b); + terms.insert(0, c); + terms + }, + }; + return Self::solve_linear(&linear_poly); + } + + let discriminant = b * b - 4.0 * a * c; + + if discriminant > 0.0 { + let sqrt_d = discriminant.sqrt(); + let x1 = (-b + sqrt_d) / (2.0 * a); + let x2 = (-b - sqrt_d) / (2.0 * a); + Ok(Solution::MultipleSolutions(vec![x1, x2])) + } else if discriminant == 0.0 { + let x = -b / (2.0 * a); + Ok(Solution::SingleSolution(x)) + } else { + // TODO: complex solution + Ok(Solution::Undefined) + } + } + + /// Solves a cubic equation ax³ + bx² + cx + d = 0 using Cardano's method. + fn solve_cubic(polynomial: &Polynomial) -> Result { + let a = polynomial.get_coefficient(3); + let b = polynomial.get_coefficient(2); + let c = polynomial.get_coefficient(1); + let d = polynomial.get_coefficient(0); + + if a == 0.0 { + // Degenerates to quadratic equation + let quadratic_poly = Polynomial { + terms: { + let mut terms = HashMap::new(); + terms.insert(2, b); + terms.insert(1, c); + terms.insert(0, d); + terms + }, + }; + return Self::solve_quadratic(&quadratic_poly); + } + + // Normalize coefficients + let b_norm = b / a; + let c_norm = c / a; + let d_norm = d / a; + + // Depressed cubic: t^3 + pt + q = 0 + let p = c_norm - b_norm * b_norm / 3.0; + let q = 2.0 * b_norm.powi(3) / 27.0 - b_norm * c_norm / 3.0 + d_norm; + + let discriminant = (q / 2.0).powi(2) + (p / 3.0).powi(3); + + if discriminant > 0.0 { + // One real root + let sqrt_d = discriminant.sqrt(); + let u = ((-q / 2.0) + sqrt_d).cbrt(); + let v = ((-q / 2.0) - sqrt_d).cbrt(); + let t = u + v; + let x = t - b_norm / 3.0; + Ok(Solution::SingleSolution(x)) + } else if discriminant == 0.0 { + // All roots real, at least two equal + let u = (-q / 2.0).cbrt(); + let t1 = 2.0 * u; + let t2 = -u; + let x1 = t1 - b_norm / 3.0; + let x2 = t2 - b_norm / 3.0; + if t1 == t2 { + Ok(Solution::SingleSolution(x1)) + } else { + Ok(Solution::MultipleSolutions(vec![x1, x2])) + } + } else { + // Three distinct real roots + let phi = ((-q / 2.0) / ((-p / 3.0).powf(1.5))).acos(); + let t1 = 2.0 * ((-p / 3.0).sqrt()) * phi.cos(); + let t2 = 2.0 * ((-p / 3.0).sqrt()) * (phi + 2.0 * std::f64::consts::PI / 3.0).cos(); + let t3 = 2.0 * ((-p / 3.0).sqrt()) * (phi + 4.0 * std::f64::consts::PI / 3.0).cos(); + let x1 = t1 - b_norm / 3.0; + let x2 = t2 - b_norm / 3.0; + let x3 = t3 - b_norm / 3.0; + Ok(Solution::MultipleSolutions(vec![x1, x2, x3])) + } + } +} + +/// Builder for the Solver struct. +pub struct SolverBuilder { + equation: Option, +} + +impl SolverBuilder { + /// Creates a new SolverBuilder instance. + pub fn new() -> Self { + SolverBuilder { equation: None } + } + + /// Sets the equation to be solved. + pub fn equation(mut self, lhs: Expr, rhs: Expr) -> Self { + self.equation = Some(Equation::new(lhs, rhs)); + self + } + + /// Builds the Solver instance. + pub fn build(self) -> Result { + match self.equation { + Some(eq) => Ok(Solver { equation: eq }), + None => Err(SolverError::InvalidEquation), + } + } +} +/// Solves a mathematical equation given as a string. +/// +/// # Arguments +/// +/// * `equation_str` - A string slice that holds the equation, e.g., "2x + 3 = 7" +/// +/// # Returns +/// +/// * `Ok(Solution)` if the equation is solved successfully. +/// * `Err(SolverError)` if an error occurs during solving. +pub fn solve_equation(equation_str: &str) -> Result { + // Step 1: Split the equation string on '=' + let parts: Vec<&str> = equation_str.split('=').collect(); + if parts.len() != 2 { + return Err(SolverError::InvalidEquation); + } + + let lhs_str = parts[0].trim(); + let rhs_str = parts[1].trim(); + + // Step 2: Tokenize and parse the left-hand side (lhs) + let lhs_expr = + parse_expression(lhs_str).map_err(|e| SolverError::ParsingError(format!("LHS: {}", e)))?; + + // Step 3: Tokenize and parse the right-hand side (rhs) + let rhs_expr = + parse_expression(rhs_str).map_err(|e| SolverError::ParsingError(format!("RHS: {}", e)))?; + + // TODO: remove this later + println!("Parsed LHS Expression: {:?}", lhs_expr); + println!("Simplified LHS Expression: {}", format_expr(&lhs_expr)); + println!("Parsed RHS Expression: {:?}", rhs_expr); + println!("Simplified RHS Expression: {}", format_expr(&rhs_expr)); + + // Step 4: Build the Solver + let solver = Solver::builder().equation(lhs_expr, rhs_expr).build()?; + + // Step 5: Solve the equation + let solution = solver.solve()?; + + Ok(solution) +} + +/// Parses and simplifies a mathematical expression from a string. +/// +/// # Arguments +/// +/// * `expr_str` - A string slice that holds the expression, e.g., "2x + 3" +/// +/// # Returns +/// +/// * `Ok(Expr)` if the expression is parsed and simplified successfully. +/// * `Err(String)` if an error occurs during tokenization or parsing. +fn parse_expression(expr_str: &str) -> Result { + // Tokenize the expression + let tokens = tokenize(expr_str).map_err(|e| format!("Tokenization error: {}", e))?; + println!("Tokens: {:?}", tokens); + + // Parse the tokens into an expression (AST) + let mut parser = Parser::new(tokens); + let expr = parser + .parse_expression() + .map_err(|e| format!("Parsing error: {}", e))?; + println!("Parsed Expression: {:?}", expr); + + // Simplify the expression + let simplified_expr = simplify(expr); + println!("Simplified Expression: {}", format_expr(&simplified_expr)); + + Ok(simplified_expr) +} From 4ad1d31b5637054ee807732451f729137df935cd Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:43:34 -0700 Subject: [PATCH 07/11] Create tests.rs --- src/tests.rs | 573 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 573 insertions(+) create mode 100644 src/tests.rs diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..0ec128c --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,573 @@ +#[cfg(test)] +mod tests { + + use crate::ast::Expr; + use crate::simplify::simplify; + + #[test] + fn test_simplify_constant() { + // 2 should remain 2 + let expr = Expr::Number(2.0); + assert_eq!(simplify(expr), Expr::Number(2.0)); + } + + #[test] + fn test_simplify_variable() { + // x should remain x + let expr = Expr::Variable("x".to_string()); + assert_eq!(simplify(expr), Expr::Variable("x".to_string())); + } + + #[test] + fn test_simplify_addition() { + // 2 + 3 should simplify to 5 + let expr = Expr::Add(Box::new(Expr::Number(2.0)), Box::new(Expr::Number(3.0))); + assert_eq!(simplify(expr), Expr::Number(5.0)); + } + + #[test] + fn test_simplify_subtraction() { + // 5 - 3 should simplify to 2 + let expr = Expr::Sub(Box::new(Expr::Number(5.0)), Box::new(Expr::Number(3.0))); + assert_eq!(simplify(expr), Expr::Number(2.0)); + } + + #[test] + fn test_simplify_multiplication() { + // 2 * 3 should simplify to 6 + let expr = Expr::Mul(Box::new(Expr::Number(2.0)), Box::new(Expr::Number(3.0))); + assert_eq!(simplify(expr), Expr::Number(6.0)); + } + + #[test] + fn test_simplify_division() { + // 6 / 3 should simplify to 2 + let expr = Expr::Div(Box::new(Expr::Number(6.0)), Box::new(Expr::Number(3.0))); + assert_eq!(simplify(expr), Expr::Number(2.0)); + } + + #[test] + fn test_simplify_multiplication_by_zero() { + // x * 0 should simplify to 0 + let expr = Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Number(0.0)), + ); + assert_eq!(simplify(expr), Expr::Number(0.0)); + } + + #[test] + fn test_simplify_multiplication_by_one() { + // x * 1 should simplify to x + let expr = Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Number(1.0)), + ); + assert_eq!(simplify(expr), Expr::Variable("x".to_string())); + } + + #[test] + fn test_simplify_exponents_via_multiplication() { + // x * x should simplify to x * x (representing x^2) + let expr = Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())), + ); + assert_eq!( + simplify(expr), + Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())) + ) + ); + + // x * x * x should simplify to x * x * x (representing x^3) + let expr = Expr::Mul( + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Variable("x".to_string())), + ); + assert_eq!( + simplify(expr), + Expr::Mul( + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())) + )), + Box::new(Expr::Variable("x".to_string())) + ) + ); + } + + #[test] + fn test_simplify_zero_variable_expression() { + // 0 + x should simplify to x + let expr = Expr::Add( + Box::new(Expr::Number(0.0)), + Box::new(Expr::Variable("x".to_string())), + ); + assert_eq!(simplify(expr), Expr::Variable("x".to_string())); + + // 0 * x should simplify to 0 + let expr = Expr::Mul( + Box::new(Expr::Number(0.0)), + Box::new(Expr::Variable("x".to_string())), + ); + assert_eq!(simplify(expr), Expr::Number(0.0)); + } + + #[test] + fn test_simplify_multiple_variables_with_exponents() { + // (x * y) * (x * y) should simplify to x^2 * y^2 + let expr = Expr::Mul( + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + ); + assert_eq!( + simplify(expr), + Expr::Mul( + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())) + )), + Box::new(Expr::Mul( + Box::new(Expr::Variable("y".to_string())), + Box::new(Expr::Variable("y".to_string())) + )) + ) + ); + } + + #[test] + fn test_simplify_expression_with_constants_and_multiple_operations() { + // (3 + 2) * (x - x) should simplify to 0 + let expr = Expr::Mul( + Box::new(Expr::Add( + Box::new(Expr::Number(3.0)), + Box::new(Expr::Number(2.0)), + )), + Box::new(Expr::Sub( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())), + )), + ); + assert_eq!(simplify(expr), Expr::Number(0.0)); + } + + #[test] + fn test_simplify_expression_with_nested_operations() { + // ((x + x) * (y + y)) / (2 * x * y) should simplify to 2 + let expr = Expr::Div( + Box::new(Expr::Mul( + Box::new(Expr::Add( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Add( + Box::new(Expr::Variable("y".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + )), + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + )), + ); + assert_eq!(simplify(expr), Expr::Number(2.0)); + } + + #[test] + fn test_simplify_expression_with_multiple_nested_divisions() { + // (((2 * x) / y) / (x / y)) should simplify to 2 + let expr = Expr::Div( + Box::new(Expr::Div( + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Variable("y".to_string())), + )), + Box::new(Expr::Div( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + ); + assert_eq!(simplify(expr), Expr::Number(2.0)); + } + + #[test] + fn test_simplify_expression_with_multiple_variables_and_constants() { + // (4 * x * y) / (2 * x) should simplify to 2 * y + let expr = Expr::Div( + Box::new(Expr::Mul( + Box::new(Expr::Mul( + Box::new(Expr::Number(4.0)), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Variable("y".to_string())), + )), + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())), + )), + ); + assert_eq!( + simplify(expr), + Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("y".to_string())) + ) + ); + } + + #[test] + fn test_simplify_expression_with_zero_result() { + // 5 - (2 + 3) should simplify to 0 + let expr = Expr::Sub( + Box::new(Expr::Number(5.0)), + Box::new(Expr::Add( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Number(3.0)), + )), + ); + assert_eq!(simplify(expr), Expr::Number(0.0)); + } + + #[test] + fn test_simplify_expression_with_distributive_property() { + // 2 * (x + 3) should simplify to 2*x + 6 + let expr = Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Add( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Number(3.0)), + )), + ); + assert_eq!( + simplify(expr), + Expr::Add( + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())) + )), + Box::new(Expr::Number(6.0)), + ) + ); + } + + #[test] + fn test_simplify_expression_with_multiple_distributive_applications() { + // (x + y) * (x - y) should simplify to x^2 - y^2 + let expr = Expr::Mul( + Box::new(Expr::Add( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + Box::new(Expr::Sub( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + ); + assert_eq!( + simplify(expr), + Expr::Sub( + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())) + )), + Box::new(Expr::Mul( + Box::new(Expr::Variable("y".to_string())), + Box::new(Expr::Variable("y".to_string())) + )) + ) + ); + } + + #[test] + fn test_simplify_expression_with_negative_exponents() { + // (x / y) * (y / x) should simplify to 1 + let expr = Expr::Mul( + Box::new(Expr::Div( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + Box::new(Expr::Div( + Box::new(Expr::Variable("y".to_string())), + Box::new(Expr::Variable("x".to_string())), + )), + ); + assert_eq!(simplify(expr), Expr::Number(1.0)); + } + + #[test] + fn test_simplify_expression_with_multiple_nested_operations() { + // ((x + 2) * (x - 2)) / (x^2 - 4) should simplify to 1 + // Since x^2 - 4 is (x + 2)(x - 2) + let expr = Expr::Div( + Box::new(Expr::Mul( + Box::new(Expr::Add( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Number(2.0)), + )), + Box::new(Expr::Sub( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Number(2.0)), + )), + )), + Box::new(Expr::Sub( + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Number(4.0)), + )), + ); + assert_eq!(simplify(expr), Expr::Number(1.0)); + } + + #[test] + fn test_simplify_expression_with_complex_nested_operations() { + // ((3 * x) + (2 * y)) - ((x) + (2 * y)) should simplify to 2 * x + let expr = Expr::Sub( + Box::new(Expr::Add( + Box::new(Expr::Mul( + Box::new(Expr::Number(3.0)), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("y".to_string())), + )), + )), + Box::new(Expr::Add( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("y".to_string())), + )), + )), + ); + assert_eq!( + simplify(expr), + Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())) + ) + ); + } + + #[test] + fn test_simplify_expression_with_nested_multiplications_and_divisions() { + // (2 * (x / y)) * (3 * (y / x)) should simplify to 6 + let expr = Expr::Mul( + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Div( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + )), + Box::new(Expr::Mul( + Box::new(Expr::Number(3.0)), + Box::new(Expr::Div( + Box::new(Expr::Variable("y".to_string())), + Box::new(Expr::Variable("x".to_string())), + )), + )), + ); + assert_eq!(simplify(expr), Expr::Number(6.0)); + } + + #[test] + fn test_simplify_expression_with_multiple_like_terms() { + // 2x + 2x + 2x should simplify to 6x + let expr = Expr::Add( + Box::new(Expr::Add( + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())), + )), + )), + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())), + )), + ); + assert_eq!( + simplify(expr), + Expr::Mul( + Box::new(Expr::Number(6.0)), + Box::new(Expr::Variable("x".to_string())) + ) + ); + } + + #[test] + fn test_simplify_variable_division() { + // (2 * x) / x should simplify to 2 + let expr = Expr::Div( + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Variable("x".to_string())), + ); + assert_eq!(simplify(expr), Expr::Number(2.0)); + } + + #[test] + fn test_simplify_addition_of_variables() { + // x + x should simplify to 2 * x + let expr = Expr::Add( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())), + ); + assert_eq!( + simplify(expr), + Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())) + ) + ); + } + + #[test] + fn test_simplify_subtraction_of_variables() { + // x - x should simplify to 0 + let expr = Expr::Sub( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())), + ); + assert_eq!(simplify(expr), Expr::Number(0.0)); + } + + #[test] + fn test_simplify_multiplication_of_variables() { + // x * x should simplify to x^2 + let expr = Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())), + ); + // Expected output: x^2 (which would need to be represented in Expr) + assert_eq!( + simplify(expr), + Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("x".to_string())) + ) + ); + } + + #[test] + fn test_simplify_nested_operations() { + // (x * 2) + (x * 3) should simplify to 5 * x + let expr = Expr::Add( + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Number(2.0)), + )), + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Number(3.0)), + )), + ); + assert_eq!( + simplify(expr), + Expr::Mul( + Box::new(Expr::Number(5.0)), + Box::new(Expr::Variable("x".to_string())) + ) + ); + } + + #[test] + fn test_simplify_complex_expression() { + // (2 * x) + (x * 3) - (5 * x) should simplify to 0 + let expr = Expr::Sub( + Box::new(Expr::Add( + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Number(3.0)), + )), + )), + Box::new(Expr::Mul( + Box::new(Expr::Number(5.0)), + Box::new(Expr::Variable("x".to_string())), + )), + ); + assert_eq!(simplify(expr), Expr::Number(0.0)); + } + + #[test] + fn test_simplify_multiple_variables() { + // (2 * x * y) / (x * y) should simplify to 2 + let expr = Expr::Div( + Box::new(Expr::Mul( + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Variable("y".to_string())), + )), + Box::new(Expr::Mul( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + ); + assert_eq!(simplify(expr), Expr::Number(2.0)); + } + + #[test] + fn test_simplify_with_constants_and_variables() { + // (3 * x) / (2 * x) should simplify to 3/2 + let expr = Expr::Div( + Box::new(Expr::Mul( + Box::new(Expr::Number(3.0)), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())), + )), + ); + assert_eq!(simplify(expr), Expr::Number(3.0 / 2.0)); + } + + #[test] + fn test_simplify_nested_division() { + // ((2 * x) / y) / (x / y) should simplify to 2 + let expr = Expr::Div( + Box::new(Expr::Div( + Box::new(Expr::Mul( + Box::new(Expr::Number(2.0)), + Box::new(Expr::Variable("x".to_string())), + )), + Box::new(Expr::Variable("y".to_string())), + )), + Box::new(Expr::Div( + Box::new(Expr::Variable("x".to_string())), + Box::new(Expr::Variable("y".to_string())), + )), + ); + assert_eq!(simplify(expr), Expr::Number(2.0)); + } +} From f5dfda33ae3ae9c9780b0894c8639ff92d753b4e Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:43:47 -0700 Subject: [PATCH 08/11] Update token.rs --- src/token.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/token.rs b/src/token.rs index 52b6ff1..c535281 100644 --- a/src/token.rs +++ b/src/token.rs @@ -6,6 +6,7 @@ pub enum Token { Minus, Star, Slash, + Pow, LParen, RParen, } From 5df0ebb938c629317af4b930c98b2372dacd673d Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:44:00 -0700 Subject: [PATCH 09/11] Update tokenizer.rs --- src/tokenizer.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 35a5803..63fd187 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -22,6 +22,10 @@ pub fn tokenize(input: &str) -> Result, String> { tokens.push(Token::Minus); idx += 1; } + '^' => { + tokens.push(Token::Pow); + idx += 1; + } '*' => { tokens.push(Token::Star); idx += 1; From e17b808dee4bafd24393381246c9b5e4c1acdec1 Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:44:14 -0700 Subject: [PATCH 10/11] Update Cargo.toml --- Cargo.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 11a7182..3d9a90f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,3 +10,5 @@ edition = "2021" [[example]] name = "foo" +[[example]] +name = "bar" From ac6f273debb869075e1ed72753a63e008b1e3c03 Mon Sep 17 00:00:00 2001 From: 0xb-s <145866191+0xb-s@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:44:35 -0700 Subject: [PATCH 11/11] Create bar.rs --- examples/bar.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 examples/bar.rs diff --git a/examples/bar.rs b/examples/bar.rs new file mode 100644 index 0000000..0a37cb4 --- /dev/null +++ b/examples/bar.rs @@ -0,0 +1,21 @@ +use math::solver::solve_equation; + +fn main() { + let equations = vec![ + "2x + 3 = 7", + "x/0 = 5", + "x^2 - 4 = 0", + "x^3 - 1 = 0", + "5 = 5", + "3 = 7", + "x^4 + 1 = 0", + ]; + + for equation_str in equations { + println!("\nSolving Equation: {}", equation_str); + match solve_equation(equation_str) { + Ok(solution) => println!("Solution: {}", solution), + Err(e) => println!("Error: {}", e), + } + } +}