From 77cf9aa7535a1d9481f0bd3caeea26e2b85c5019 Mon Sep 17 00:00:00 2001 From: Nathan Reiner Date: Wed, 17 Jan 2024 22:30:02 +0100 Subject: migrate from f64 to own complex --- src/complex.rs | 191 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/context.rs | 24 +++++-- src/expression.rs | 39 +++++------ src/function.rs | 42 ++++++++++-- src/main.rs | 65 +++++++++++++------ src/operation.rs | 9 +-- 6 files changed, 316 insertions(+), 54 deletions(-) create mode 100644 src/complex.rs diff --git a/src/complex.rs b/src/complex.rs new file mode 100644 index 0000000..642122f --- /dev/null +++ b/src/complex.rs @@ -0,0 +1,191 @@ +use std::num::ParseFloatError; +use std::str::FromStr; +use std::f64::consts::E; + +#[derive(Clone, Copy, Default)] +pub struct Complex { + pub real: f64, + pub imag: f64, +} + +impl Complex { + pub fn new(real: f64, imag: f64) -> Self { + Self { real, imag } + } + + pub fn div(&mut self, other: &Self) { + let z = other.real * other.real + other.imag * other.imag; + self.real = self.real * other.real + self.imag * other.imag; + self.real /= z; + self.imag = self.imag * other.real - self.real * other.imag; + self.imag /= z; + } + + pub fn abs(&self) -> f64 { + (self.real.powi(2) + self.imag.powi(2)).sqrt() + } + + pub fn pow(&self, rhs: Self) -> Self { + let r = self.abs(); + let x = (self.imag / self.real).atan(); + let lnr = r.ln(); + let alpha = lnr * rhs.imag + x * rhs.real; + let e_exp = E.powf(lnr * rhs.real - rhs.imag * x); + Complex::new(e_exp * alpha.cos(), e_exp * alpha.sin()) + } + + pub fn sqrt(&self) -> Self { + self.pow(Complex::new(0.5, 0.0)) + } + + pub fn cos(&self) -> Self { + let e : Complex = E.into(); + let i = Complex::new(0.0, 1.0); + + Complex::new(e.pow(&i * self).real, 0.0) + } + + pub fn sin(&self) -> Self { + let e : Complex = E.into(); + let i = Complex::new(0.0, 1.0); + + Complex::new(e.pow(&i * self).imag, 0.0) + } + + pub fn tan(&self) -> Self { + self.sin() / self.cos() + } +} + +impl std::fmt::Display for Complex { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.imag == 0.0 { + write!(f, "{}", self.real) + } else if self.real == 0.0 { + if self.imag == 1.0 { + write!(f, "i") + } else { + write!(f, "{}i", self.imag) + } + } else { + if self.imag == 1.0 { + write!(f, "{} + i", self.real) + } else { + write!(f, "{} + {}i", self.real, self.imag) + } + } + } +} + +#[derive(Debug)] +pub struct ComplexParseError; + +impl From for ComplexParseError { + fn from(_: ParseFloatError) -> Self { + Self {} + } +} + +impl FromStr for Complex { + type Err = ComplexParseError; + + fn from_str(s: &str) -> Result { + let mut c = Complex::default(); + let s = s.replace(' ', ""); + if s.contains('+') { + let (a, b) = s.split_once('+').unwrap(); + if a.contains('i') { + c.imag = a[..a.len() - 1].parse()?; + c.real = b.parse()?; + } else { + c.real = a.parse()?; + c.imag = b[..b.len() - 1].parse()?; + } + } else { + if s.contains('i') { + c.imag = s[..s.len() - 1].parse()?; + } else { + c.real = s.parse()?; + } + } + + Ok(c) + } +} + +impl Into for f64 { + fn into(self) -> Complex { + Complex::new(self, 0.0) + } +} + +impl std::ops::Add for &Complex { + type Output = Complex; + + fn add(self, rhs: Self) -> Self::Output { + Complex::new(self.real + rhs.real, self.imag + rhs.imag) + } +} + +impl std::ops::Add for Complex { + type Output = Complex; + + fn add(self, rhs: Self) -> Self::Output { + &self + &rhs + } +} + +impl std::ops::Sub for &Complex { + type Output = Complex; + + fn sub(self, rhs: Self) -> Self::Output { + Complex::new(self.real - rhs.real, self.imag - rhs.imag) + } +} + +impl std::ops::Sub for Complex { + type Output = Complex; + + fn sub(self, rhs: Self) -> Self::Output { + &self - &rhs + } +} + +impl std::ops::Mul for &Complex { + type Output = Complex; + + fn mul(self, rhs: Self) -> Self::Output { + Complex::new( + self.real * rhs.real - self.imag * rhs.imag, + self.real * rhs.imag + self.imag * rhs.real, + ) + } +} + +impl std::ops::Mul for Complex { + type Output = Complex; + + fn mul(self, rhs: Self) -> Self::Output { + &self * &rhs + } +} + +impl std::ops::Div for &Complex { + type Output = Complex; + + fn div(self, rhs: Self) -> Self::Output { + let z = rhs.real * rhs.real + rhs.imag * rhs.imag; + Complex::new( + (self.real * rhs.real + self.imag * rhs.imag) / z, + (self.imag * rhs.real - self.real * rhs.imag) / z, + ) + } +} + +impl std::ops::Div for Complex { + type Output = Complex; + + fn div(self, rhs: Self) -> Self::Output { + &self / &rhs + } +} diff --git a/src/context.rs b/src/context.rs index 12bea6b..47be92b 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,11 +1,12 @@ use crate::function::Function; use crate::operation::Operation; +use crate::complex::Complex; use std::collections::HashMap; #[derive(Default)] pub struct Context { ops: Vec, - vars: HashMap, + vars: HashMap, funcs: HashMap>, } @@ -19,7 +20,7 @@ impl Context { self } - pub fn with_variables(mut self, vars: HashMap) -> Self { + pub fn with_variables(mut self, vars: HashMap) -> Self { self.vars = vars; self } @@ -33,7 +34,7 @@ impl Context { &self.ops } - pub fn variable_mut(&mut self, name: &str) -> Option<&mut f64> { + pub fn variable_mut(&mut self, name: &str) -> Option<&mut Complex> { self.vars.get_mut(name) } @@ -41,11 +42,26 @@ impl Context { self.funcs.get_mut(name) } - pub fn variable(&self, name: &str) -> Option<&f64> { + pub fn variable(&self, name: &str) -> Option<&Complex> { self.vars.get(name) } + #[allow(clippy::borrowed_box)] pub fn function(&self, name: &str) -> Option<&Box> { self.funcs.get(name) } } + +#[macro_export] +macro_rules! variables { + ({$($x:expr => $y:expr), *}) => { + { + let mut h : HashMap = HashMap::new(); + $( + h.insert($x.to_string(), $y); + )* + h + } + }; +} + diff --git a/src/expression.rs b/src/expression.rs index 5f2660b..4029f93 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -1,3 +1,4 @@ +use crate::complex::Complex; use crate::context::Context; use crate::function::FunctionArgument; use crate::string::{ContainsAndSkipBrackets, SplitMatchingBracket}; @@ -13,7 +14,7 @@ impl Expression { } } - pub fn evaluate(&self, context: &Context) -> Result { + pub fn evaluate(&self, context: &Context) -> Result { let (repr, oprepr) = { if self.repr.starts_with('(') { let (oprepr, r) = self.repr.split_on_matching_bracket(); @@ -44,30 +45,24 @@ impl Expression { let first_expr = Expression::from_string(first); let rest_expr = Expression::from_string(rest); Ok(op.evaluate(first_expr.evaluate(context)?, rest_expr.evaluate(context)?)) - } else { - if let Ok(r) = repr.parse::() { - Ok(r) - } else { - if let Some((func, args)) = repr.split_once('(') { - let mut argv = Vec::new(); + } else if let Ok(r) = repr.parse::() { + Ok(r) + } else if let Some((func, args)) = repr.split_once('(') { + let mut argv = Vec::new(); - for arg in args[..args.len() - 1].split(',') { - argv.push(Expression::from_string(arg).evaluate(context)?); - } + for arg in args[..args.len() - 1].split(',') { + argv.push(Expression::from_string(arg).evaluate(context)?); + } - if let Some(func) = context.function(func) { - Ok(func.eval(FunctionArgument::new(argv))) - } else { - Err(format!("function '{func}' not found")) - } - } else { - if let Some(res) = context.variable(&repr) { - Ok(*res) - } else { - Err(format!("variable '{repr}' not found")) - } - } + if let Some(func) = context.function(func) { + Ok(func.eval(FunctionArgument::new(argv))?) + } else { + Err(format!("function '{func}' not found")) } + } else if let Some(res) = context.variable(&repr) { + Ok(*res) + } else { + Err(format!("variable '{repr}' not found")) } } } diff --git a/src/function.rs b/src/function.rs index 547ae0e..442d987 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,17 +1,51 @@ +use crate::complex::Complex; + pub struct FunctionArgument { - args : Vec, + args : Vec, } impl FunctionArgument { - pub fn new(args: Vec) -> Self { + pub fn new(args: Vec) -> Self { Self { args } } - pub fn get(&self, i : usize) -> f64 { + pub fn get(&self, i : usize) -> Complex { self.args[i] } + + pub fn len(&self) -> usize { + self.args.len() + } } pub trait Function { - fn eval(&self, args: FunctionArgument) -> f64; + fn eval(&self, args: FunctionArgument) -> Result; +} + +#[macro_export] +macro_rules! functions { + ({$($x:expr => $y:expr), *}) => { + { + let mut h : HashMap> = HashMap::new(); + $( + h.insert($x.to_string(), Box::new($y)); + )* + h + } + }; +} + + +// Some default implementations +impl Function for T +where + T: Fn(Complex) -> Complex, +{ + fn eval(&self, args: FunctionArgument) -> Result { + if args.len() == 1 { + Ok(self(args.get(0))) + } else { + Err("too many arguments".to_string()) + } + } } diff --git a/src/main.rs b/src/main.rs index 313769f..d1e914d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +pub mod complex; pub mod context; pub mod expression; pub mod function; @@ -6,51 +7,75 @@ pub mod string; use std::collections::HashMap; +use complex::Complex; +use context::Context; use expression::Expression; use function::Function; use operation::Operation; -use crate::context::Context; - -fn add(a: f64, b: f64) -> f64 { +fn add(a: Complex, b: Complex) -> Complex { a + b } -fn mul(a: f64, b: f64) -> f64 { +fn sub(a: Complex, b: Complex) -> Complex { + a - b +} + +fn mul(a: Complex, b: Complex) -> Complex { a * b } -fn div(a: f64, b: f64) -> f64 { +fn div(a: Complex, b: Complex) -> Complex { a / b } -fn pow(a: f64, b: f64) -> f64 { - a.powf(b) +fn pow(a: Complex, b: Complex) -> Complex { + a.pow(b) } -fn sqrt(a: f64) -> f64 { +fn sqrt(a: Complex) -> Complex { a.sqrt() } -impl Function for T -where - T: Fn(f64) -> f64, -{ - fn eval(&self, args: function::FunctionArgument) -> f64 { - self(args.get(0)) - } +fn sin(a: Complex) -> Complex { + a.sin() +} + +fn cos(a: Complex) -> Complex { + a.cos() +} + +fn tan(a: Complex) -> Complex { + a.tan() } fn main() { - let expr = "(2 + (3 + 10) * (3 + 10)) * 2"; + let expr = "cos(3)"; - let mut funcs : HashMap> = HashMap::new(); - funcs.insert("sqrt".to_string(), Box::new(&sqrt)); let ctx: Context = Context::new() - .with_operations(opvec![('+', &add), ('*', &mul), ('/', &div), ('^', &pow)]) - .with_functions(funcs); + .with_operations(operations![{ + '+' => &add, + '-' => &sub, + '*' => &mul, + '/' => &div, + '^' => &pow + }]) + .with_functions(functions!({"sqrt" => &sqrt, "sin" => &sin, "cos" => &cos, "tan" => &tan})) + .with_variables(variables!({"x" => Complex::new(5.0, 0.0)})); + let value = Expression::from_string(expr); + match value.evaluate(&ctx) { + Ok(res) => println!("{} = {}", expr, res), + Err(err) => println!("Error: {}", err), + } + + let value = Expression::from_string("sin(3)"); + match value.evaluate(&ctx) { + Ok(res) => println!("{} = {}", expr, res), + Err(err) => println!("Error: {}", err), + } + let value = Expression::from_string("tan(3)"); match value.evaluate(&ctx) { Ok(res) => println!("{} = {}", expr, res), Err(err) => println!("Error: {}", err), diff --git a/src/operation.rs b/src/operation.rs index 861fb5a..1ee3d5c 100644 --- a/src/operation.rs +++ b/src/operation.rs @@ -1,5 +1,6 @@ +use crate::complex::Complex; -pub type Operator = dyn Fn(f64, f64) -> f64; +pub type Operator = dyn Fn(Complex, Complex) -> Complex; pub struct Operation { sign: char, @@ -15,14 +16,14 @@ impl Operation { self.sign } - pub fn evaluate(&self, a: f64, b: f64) -> f64 { + pub fn evaluate(&self, a: Complex, b: Complex) -> Complex { (self.func)(a, b) } } #[macro_export] -macro_rules! opvec { - ($(($x:expr, $y:expr)), *) => { +macro_rules! operations { + ({$($x:expr => $y:expr), *}) => { vec![$( Operation::new($x, Box::new($y)), )*] -- cgit v1.2.3-70-g09d2