use std::f64::consts::E; use std::hash::Hash; use std::num::ParseFloatError; use std::str::FromStr; #[derive(Clone, Copy, Default, Debug)] pub struct Complex { pub real: f64, pub imag: f64, } impl Complex { pub fn new(real: f64, imag: f64) -> Self { Self { real, imag } } pub fn is_real(&self) -> bool { self.imag == 0.0 } pub fn abs(&self) -> f64 { (self.real.powi(2) + self.imag.powi(2)).sqrt() } pub fn pow(&self, rhs: Self) -> Self { if self.is_real() && rhs.is_real() { Complex::new(self.real.powf(rhs.real), 0.0) } else { let r = self.abs(); let x = (self.imag / self.real).atan(); let lnr = r.ln(); let alpha = lnr * rhs.imag + x * rhs.real; let e_exp = E.powf(lnr * rhs.real - rhs.imag * x); Complex::new(e_exp * alpha.cos(), e_exp * alpha.sin()) } } pub fn sqrt(&self) -> Self { self.pow(Complex::new(0.5, 0.0)) } pub fn cos(&self) -> Self { let e: Complex = E.into(); let i = Complex::new(0.0, 1.0); Complex::new(e.pow(&i * self).real, 0.0) } pub fn sin(&self) -> Self { let e: Complex = E.into(); let i = Complex::new(0.0, 1.0); Complex::new(e.pow(&i * self).imag, 0.0) } pub fn tan(&self) -> Self { self.sin() / self.cos() } } impl std::fmt::Display for Complex { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { if self.imag == 0.0 { write!(f, "{}", self.real) } else if self.real == 0.0 { if self.imag == 1.0 { write!(f, "i") } else { write!(f, "{}i", self.imag) } } else if self.imag == 1.0 { write!(f, "{} + i", self.real) } else { write!(f, "{} + {}i", self.real, self.imag) } } } impl PartialEq for Complex { fn eq(&self, other: &Self) -> bool { self.real == other.real && self.imag == other.imag } } impl Hash for Complex { fn hash(&self, state: &mut H) { format!("{}#{}", self.real, self.imag).hash(state); state.finish(); } } impl Eq for Complex {} #[derive(Debug)] pub struct ComplexParseError; impl From for ComplexParseError { fn from(_: ParseFloatError) -> Self { Self {} } } impl FromStr for Complex { type Err = ComplexParseError; fn from_str(s: &str) -> Result { let mut c = Complex::default(); let s = s.replace(' ', ""); if s.contains('+') { let (a, b) = s.split_once('+').unwrap(); if a.contains('i') { c.imag = a[..a.len() - 1].parse()?; c.real = b.parse()?; } else { c.real = a.parse()?; c.imag = b[..b.len() - 1].parse()?; } } else if s.contains('i') { if s.len() == 1 { c.imag = 1.0; } else { c.imag = s[..s.len() - 1].parse()?; } } else { c.real = s.parse()?; } Ok(c) } } impl From for Complex { fn from(val: f64) -> Self { Complex::new(val, 0.0) } } impl std::ops::Add for &Complex { type Output = Complex; fn add(self, rhs: Self) -> Self::Output { *self + *rhs } } impl std::ops::Add for Complex { type Output = Complex; fn add(self, rhs: Self) -> Self::Output { Complex::new(self.real + rhs.real, self.imag + rhs.imag) } } impl std::ops::Sub for &Complex { type Output = Complex; fn sub(self, rhs: Self) -> Self::Output { *self - *rhs } } impl std::ops::Sub for Complex { type Output = Complex; fn sub(self, rhs: Self) -> Self::Output { Complex::new(self.real - rhs.real, self.imag - rhs.imag) } } impl std::ops::Mul for &Complex { type Output = Complex; fn mul(self, rhs: Self) -> Self::Output { *self * *rhs } } impl std::ops::Mul for Complex { type Output = Complex; fn mul(self, rhs: Self) -> Self::Output { Complex::new( self.real * rhs.real - self.imag * rhs.imag, self.real * rhs.imag + self.imag * rhs.real, ) } } impl std::ops::Div for &Complex { type Output = Complex; fn div(self, rhs: Self) -> Self::Output { *self / *rhs } } impl std::ops::Div for Complex { type Output = Complex; fn div(self, rhs: Self) -> Self::Output { let z = rhs.real * rhs.real + rhs.imag * rhs.imag; Complex::new( (self.real * rhs.real + self.imag * rhs.imag) / z, (self.imag * rhs.real - self.real * rhs.imag) / z, ) } }